use std::{
path::PathBuf,
sync::atomic::{AtomicU64, Ordering},
};
use super::*;
use crate::{
array::Array,
error::{FileOp, MissingFieldPayload, RankMismatchPayload},
lm::{cache::KvCache, generate::GenConfig, model::Model as LmModel},
vlm::{
generate::{VlmGenConfig, vlm_generate},
image::{ColorOrder, ImageProcessorConfig, ResizeFilter},
prompt::MarkerPolicy,
},
};
fn mock_config_json(model_type: &str) -> String {
format!(
r#"{{
"model_type": "{model_type}",
"vocab_size": 5,
"mock_extra": 11
}}"#
)
}
fn mock_nested_config_json(model_type: &str) -> String {
format!(
r#"{{
"model_type": "{model_type}",
"text_config": {{
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false
}},
"vision_config": {{
"hidden_size": 16,
"num_hidden_layers": 1,
"image_size": 224
}},
"mock_extra": 11
}}"#
)
}
fn mock_preprocessor_config_json(processor_class: &str, image_size: u32) -> String {
format!(
r#"{{
"processor_class": "{processor_class}",
"mock_image_size": {image_size}
}}"#
)
}
struct MockVlmModel {
vocab: i32,
#[allow(dead_code)]
mock_extra: i64,
}
impl LmModel for MockVlmModel {
fn forward(&self, tokens: &Array, _cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let (batch, seq) = match tokens.shape().as_slice() {
[b, s] => (*b, *s),
[s] => (1, *s),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockVlmModel::forward expects [B, S] (rank 1 or 2)",
tokens.shape().len() as u32,
tokens.shape(),
)));
}
};
let vocab = self.vocab as usize;
Array::from_slice::<f32>(&vec![0.0_f32; batch * seq * vocab], &(batch, seq, vocab))
}
}
impl crate::vlm::model::Model for MockVlmModel {
fn embed_tokens(&self, tokens: &Array) -> Result<Array> {
let shape = tokens.shape();
let (b, t) = match shape.as_slice() {
[b, t] => (*b, *t),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockVlmModel::embed_tokens expects [B, T] (rank 2)",
shape.len() as u32,
shape,
)));
}
};
Array::from_slice::<f32>(&vec![0.0_f32; b * t * 8], &(b, t, 8usize))
}
fn encode_image(&self, _image: &Array) -> Result<Array> {
Array::from_slice::<f32>(&[0.0_f32; 8], &(1usize, 8usize))
}
}
struct MockVlmProcessor {
#[allow(dead_code)]
processor_class: String,
image_size: u32,
}
impl Processor for MockVlmProcessor {
fn image_processor_config(&self) -> ImageProcessorConfig {
ImageProcessorConfig::new()
.with_size((self.image_size, self.image_size))
.with_mean([0.5, 0.5, 0.5])
.with_std([0.5, 0.5, 0.5])
.with_rescale_factor(1.0 / 255.0)
.with_do_resize(true)
.with_do_rescale(true)
.with_do_normalize(true)
.with_resample(ResizeFilter::Bilinear)
.with_color_order(ColorOrder::Rgb)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
fn mock_vlm_constructor() -> VlmModelConstructor {
Box::new(|loaded: &LoadedVlmModel| -> Result<Box<dyn VlmModel>> {
assert!(
!loaded.weights_ref().is_empty(),
"constructor should receive the loaded weights"
);
let raw: serde_json::Value = serde_json::from_str(loaded.config_json_ref())
.map_err(|e| Error::Parse(ParsePayload::new("mock vlm ctor: config.json", "JSON", e)))?;
let vocab = raw
.get("vocab_size")
.or_else(|| raw.get("text_config").and_then(|t| t.get("vocab_size")))
.and_then(serde_json::Value::as_i64)
.and_then(|x| i32::try_from(x).ok())
.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"mock vlm ctor",
"vocab_size (top-level or text_config.vocab_size)",
))
})?;
let mock_extra = raw
.get("mock_extra")
.and_then(serde_json::Value::as_i64)
.ok_or(Error::MissingField(MissingFieldPayload::new(
"mock vlm ctor",
"mock_extra",
)))?;
Ok(Box::new(MockVlmModel { vocab, mock_extra }))
})
}
fn mock_processor_constructor() -> ProcessorConstructor {
Box::new(
|loaded: &LoadedProcessor<'_>| -> Result<Box<dyn Processor>> {
let body = loaded
.preprocessor_config_json
.or(loaded.processor_config_json)
.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"mock vlm processor ctor",
"preprocessor_config_json or processor_config_json",
))
})?;
let raw: serde_json::Value = serde_json::from_str(body).map_err(|e| {
Error::Parse(ParsePayload::new(
"mock vlm processor ctor: processor config",
"JSON",
e,
))
})?;
let image_size = raw
.get("mock_image_size")
.and_then(serde_json::Value::as_u64)
.and_then(|x| u32::try_from(x).ok())
.ok_or(Error::MissingField(MissingFieldPayload::new(
"mock vlm processor ctor",
"mock_image_size",
)))?;
let _ = loaded
.tokenizer
.encode("a", false)
.expect("processor constructor must receive a working tokenizer");
Ok(Box::new(MockVlmProcessor {
processor_class: loaded.processor_class.to_owned(),
image_size,
}))
},
)
}
fn fresh_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs-vlm-factory-{tag}-{}-{n}",
std::process::id()
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn write_tokenizer(dir: &Path) {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let vocab = [("a", 0u32), ("b", 1), ("c", 2)]
.iter()
.map(|(w, i)| ((*w).to_string(), *i))
.collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("a".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
hf.save(dir.join("tokenizer.json"), false).unwrap();
}
fn write_vlm_dir_no_tokenizer(
dir: &Path,
model_type: &str,
processor_filename: &str,
processor_class: &str,
image_size: u32,
) {
std::fs::write(dir.join("config.json"), mock_config_json(model_type)).unwrap();
std::fs::write(
dir.join(processor_filename),
mock_preprocessor_config_json(processor_class, image_size),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
}
fn write_vlm_dir(
dir: &Path,
model_type: &str,
processor_filename: &str,
processor_class: &str,
image_size: u32,
) {
write_vlm_dir_no_tokenizer(
dir,
model_type,
processor_filename,
processor_class,
image_size,
);
write_tokenizer(dir);
}
#[test]
fn load_dispatches_to_registered_mocks_and_returns_full_bundle() {
let dir = fresh_dir("dispatch");
write_vlm_dir(&dir, "mockvlm", "preprocessor_config.json", "MockProc", 64);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry).expect("load should succeed");
assert_eq!(ctx.config_ref().model_type(), "mockvlm");
assert_eq!(ctx.config_ref().eos_token_id().cloned(), None);
assert_eq!(ctx.config_ref().quantization(), None);
let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
let tokens = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).unwrap();
let logits = LmModel::forward(ctx.model(), &tokens, &mut cache).unwrap();
assert_eq!(logits.shape(), vec![1, 3, 5]);
let proc_cfg = ctx.processor().image_processor_config();
assert_eq!(proc_cfg.size(), (64, 64));
let ids = ctx.tokenizer().encode("a b c", false).unwrap();
assert_eq!(ids.len(), 3);
}
#[test]
fn loaded_model_drives_vlm_generate_end_to_end() {
let dir = fresh_dir("e2e-generate");
write_vlm_dir(&dir, "mockvlm", "preprocessor_config.json", "MockProc", 64);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry).expect("load should succeed");
let cfg = VlmGenConfig::new(
GenConfig::default().with_max_tokens(4),
99,
3,
MarkerPolicy::Required,
);
let prompt = [0_u32, 1, 2];
let img_cfg = ctx.processor().image_processor_config();
let steps = vlm_generate(ctx.model(), &img_cfg, &prompt, &[], Vec::new(), cfg)
.expect("vlm_generate constructs against the loaded trait-object model");
let tokens: Vec<u32> = steps
.map(|s| s.expect("each generation step succeeds").token)
.collect();
assert_eq!(tokens, vec![0_u32, 0, 0, 0]);
}
#[test]
fn loaded_processor_config_drives_image_preprocessing_not_model_default() {
use std::sync::{Arc, Mutex};
struct RecordingVlmModel {
seen_image_shape: Arc<Mutex<Option<Vec<usize>>>>,
}
impl LmModel for RecordingVlmModel {
fn forward(&self, tokens: &Array, _c: &mut [Box<dyn KvCache>]) -> Result<Array> {
let (b, s) = match tokens.shape().as_slice() {
[b, s] => (*b, *s),
[s] => (1, *s),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"RecordingVlmModel::forward expects [B, S] (rank 1 or 2)",
tokens.shape().len() as u32,
tokens.shape(),
)));
}
};
Array::from_slice::<f32>(&vec![0.0_f32; b * s * 5], &(b, s, 5usize))
}
fn forward_embeddings(&self, embeddings: &Array, _c: &mut [Box<dyn KvCache>]) -> Result<Array> {
let (b, t) = match embeddings.shape().as_slice() {
[b, t, _d] => (*b, *t),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"RecordingVlmModel::forward_embeddings expects [B, T, D] (rank 3)",
embeddings.shape().len() as u32,
embeddings.shape(),
)));
}
};
Array::from_slice::<f32>(&vec![0.0_f32; b * t * 5], &(b, t, 5usize))
}
}
impl crate::vlm::model::Model for RecordingVlmModel {
fn embed_tokens(&self, tokens: &Array) -> Result<Array> {
let (b, t) = match tokens.shape().as_slice() {
[b, t] => (*b, *t),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"RecordingVlmModel::embed_tokens expects [B, T] (rank 2)",
tokens.shape().len() as u32,
tokens.shape(),
)));
}
};
Array::from_slice::<f32>(&vec![0.0_f32; b * t * 8], &(b, t, 8usize))
}
fn encode_image(&self, image: &Array) -> Result<Array> {
*self.seen_image_shape.lock().unwrap() = Some(image.shape());
Array::from_slice::<f32>(&[0.0_f32; 8], &(1usize, 8usize))
}
}
let recorded: Arc<Mutex<Option<Vec<usize>>>> = Arc::new(Mutex::new(None));
let model_registry = {
let recorded = Arc::clone(&recorded);
VlmTypeRegistry::new().with(
"recordingvlm",
Box::new(
move |loaded: &LoadedVlmModel| -> Result<Box<dyn VlmModel>> {
assert!(
!loaded.weights_ref().is_empty(),
"constructor should receive the loaded weights"
);
Ok(Box::new(RecordingVlmModel {
seen_image_shape: Arc::clone(&recorded),
}))
},
),
)
};
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let dir = fresh_dir("loaded-proc-drives-preprocess");
write_vlm_dir(
&dir,
"recordingvlm",
"preprocessor_config.json",
"MockProc",
48,
);
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry).expect("load should succeed");
let loaded_img_cfg = ctx.processor().image_processor_config();
assert_eq!(loaded_img_cfg.size(), (48, 48));
assert_eq!(
crate::vlm::model::Model::image_processor_config(ctx.model()).size(),
(224, 224),
"the recording model uses the trait-default 224×224 — the value that must NOT drive preprocessing"
);
let img_path = dir.join("prompt.png");
let mut buf = ::image::RgbImage::new(10, 7);
for y in 0..7 {
for x in 0..10 {
buf.put_pixel(x, y, ::image::Rgb([(x * 20) as u8, (y * 30) as u8, 64]));
}
}
::image::DynamicImage::ImageRgb8(buf)
.save_with_format(&img_path, ::image::ImageFormat::Png)
.unwrap();
let cfg = VlmGenConfig::new(
GenConfig::default().with_max_tokens(2),
99,
1,
MarkerPolicy::Required,
);
let prompt = [0_u32, 99, 1]; let steps = vlm_generate(
ctx.model(),
&loaded_img_cfg,
&prompt,
std::slice::from_ref(&img_path),
Vec::new(),
cfg,
)
.expect("vlm_generate constructs against the loaded model + loaded processor config");
let tokens: Vec<u32> = steps
.map(|s| s.expect("each generation step succeeds").token)
.collect();
assert_eq!(tokens, vec![0_u32, 0]);
let seen = recorded
.lock()
.unwrap()
.clone()
.expect("encode_image must have run on the single image prompt");
assert_eq!(
seen,
vec![48, 48, 3],
"image preprocessing must use the loaded processor config's size (48×48), \
not the model's default 224×224"
);
}
#[test]
fn preprocessor_config_is_preferred_over_processor_config() {
let dir = fresh_dir("prefer-preprocessor");
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("Preferred", 32),
)
.unwrap();
std::fs::write(
dir.join("processor_config.json"),
mock_preprocessor_config_json("Fallback", 999),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("Preferred", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry)
.expect("load should succeed using the preferred preprocessor_config.json");
assert_eq!(ctx.processor().image_processor_config().size(), (32, 32));
}
#[test]
fn processor_config_is_used_when_only_fallback_present() {
let dir = fresh_dir("fallback-processor-config");
write_vlm_dir(&dir, "mockvlm", "processor_config.json", "MockProc", 48);
assert!(!dir.join("preprocessor_config.json").exists());
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx =
load(&config, &model_registry, &processor_registry).expect("fallback processor_config load");
assert_eq!(ctx.processor().image_processor_config().size(), (48, 48));
}
#[test]
fn from_id_resolves_as_local_path() {
let dir = fresh_dir("idpath");
write_vlm_dir(&dir, "mockvlm", "preprocessor_config.json", "MockProc", 24);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_id(dir.to_str().unwrap());
assert_eq!(config.model_directory(), dir.as_path());
let ctx = load(&config, &model_registry, &processor_registry)
.expect("id-as-local-path load should succeed");
assert_eq!(ctx.config_ref().model_type(), "mockvlm");
}
#[test]
fn tokenizer_source_loads_from_separate_directory() {
let model_dir = fresh_dir("split-model");
write_vlm_dir_no_tokenizer(
&model_dir,
"mockvlm",
"preprocessor_config.json",
"MockProc",
16,
);
assert!(!model_dir.join("tokenizer.json").exists());
let tok_dir = fresh_dir("split-tok");
write_tokenizer(&tok_dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&model_dir).with_tokenizer_source(&tok_dir);
assert_eq!(config.tokenizer_directory(), tok_dir.as_path());
let ctx = load(&config, &model_registry, &processor_registry).expect("split-tokenizer load");
let ids = ctx.tokenizer().encode("a b c", false).unwrap();
assert_eq!(ids.len(), 3);
}
#[test]
fn unknown_model_type_is_recoverable_error_with_no_io_beyond_config() {
let dir = fresh_dir("unknown-model-cheap");
std::fs::write(dir.join("config.json"), mock_config_json("nope")).unwrap();
std::fs::write(
dir.join("model.safetensors"),
b"this is not a safetensors file",
)
.unwrap();
assert!(!dir.join("tokenizer.json").exists());
assert!(!dir.join("preprocessor_config.json").exists());
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, &model_registry, &processor_registry) else {
panic!("unknown VLM model_type must error");
};
match &err {
Error::MissingKey(p) => {
assert_eq!(p.key(), "nope", "error should carry the model_type as key");
assert!(
p.context().contains("model_type"),
"context should name model_type, got: {}",
p.context()
);
}
_ => panic!("expected MissingKey, got: {err:?}"),
}
let msg = err.to_string();
assert!(
!msg.contains("safetensors") && !msg.contains("tokenizer.json"),
"weights/tokenizer must not have been loaded, got: {msg}"
);
}
#[test]
fn unknown_processor_class_is_recoverable_error_with_no_weight_io() {
let dir = fresh_dir("unknown-processor-cheap");
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("WrongProc", 16),
)
.unwrap();
std::fs::write(
dir.join("model.safetensors"),
b"this is not a safetensors file",
)
.unwrap();
assert!(!dir.join("tokenizer.json").exists());
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, &model_registry, &processor_registry) else {
panic!("unknown processor class must error");
};
match &err {
Error::MissingKey(p) => {
assert_eq!(
p.key(),
"WrongProc",
"error should carry the processor_class as key"
);
assert!(
p.context().contains("processor_class"),
"context should name processor_class, got: {}",
p.context()
);
}
_ => panic!("expected MissingKey, got: {err:?}"),
}
let msg = err.to_string();
assert!(
!msg.contains("safetensors") && !msg.contains("tokenizer.json"),
"weights/tokenizer must not have been loaded, got: {msg}"
);
}
#[test]
fn missing_processor_config_is_recoverable_error() {
let dir = fresh_dir("no-proc-config");
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, &model_registry, &processor_registry) else {
panic!("missing processor config must error");
};
match &err {
Error::FileIo(p) => {
assert_eq!(p.op(), FileOp::Open);
let s = p.path().to_string_lossy();
assert!(
s.contains("processor_config.json"),
"FileIo path should be the fallback processor_config.json, got: {s}"
);
}
_ => panic!("expected FileIo for missing processor config, got: {err:?}"),
}
}
#[test]
fn processor_class_override_applies_for_mistral3() {
let dir = fresh_dir("mistral3-override");
write_vlm_dir(
&dir,
"mistral3",
"preprocessor_config.json",
"PixtralProcessor",
40,
);
let model_registry = VlmTypeRegistry::new().with("mistral3", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("Mistral3Processor", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry)
.expect("mistral3 override should dispatch to Mistral3Processor");
assert_eq!(ctx.processor().image_processor_config().size(), (40, 40));
}
#[test]
fn vlm_remap_applies_on_registration_and_lookup() {
let registry = VlmTypeRegistry::new().with("lfm2-vl", mock_vlm_constructor());
assert!(registry.contains("lfm2-vl"));
assert!(registry.contains("lfm2_vl"));
assert!(!registry.contains("qwen3_vl"));
assert_eq!(remap_vlm_model_type("lfm2-vl"), "lfm2_vl");
assert_eq!(remap_vlm_model_type("qwen3_vl"), "qwen3_vl");
}
#[test]
fn register_replaces_and_returns_previous() {
let mut registry = VlmTypeRegistry::new();
assert!(
registry
.register("mockvlm", mock_vlm_constructor())
.is_none()
);
assert!(
registry
.register("mockvlm", mock_vlm_constructor())
.is_some()
);
let mut proc_registry = VlmProcessorTypeRegistry::new();
assert!(
proc_registry
.register("MockProc", mock_processor_constructor())
.is_none()
);
assert!(
proc_registry
.register("MockProc", mock_processor_constructor())
.is_some()
);
}
#[test]
fn raw_config_and_processor_json_reach_constructors() {
let dir = fresh_dir("raw-dispatch");
write_vlm_dir(&dir, "mockvlm", "preprocessor_config.json", "MockProc", 24);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry).expect("load");
assert_eq!(ctx.processor().image_processor_config().size(), (24, 24));
}
#[test]
fn load_succeeds_for_nested_vlm_config_with_no_top_level_lm_fields() {
let dir = fresh_dir("nested-vlm-config");
std::fs::write(dir.join("config.json"), mock_nested_config_json("mockvlm")).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 32),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let config = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&config, &model_registry, &processor_registry)
.expect("nested-config VLM should load (no top-level LM fields)");
assert_eq!(ctx.config_ref().model_type(), "mockvlm");
let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
let tokens = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).unwrap();
let logits = LmModel::forward(ctx.model(), &tokens, &mut cache).unwrap();
assert_eq!(logits.shape(), vec![1, 3, 5]);
}
#[test]
fn eos_token_id_on_vlm_config_flows_to_tokenizer() {
let dir = fresh_dir("eos-from-config");
let cfg = r#"{
"model_type": "mockvlm",
"eos_token_id": [1, 2],
"text_config": {
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry).expect("eos config load");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Many(vec![1, 2])),
"base config should carry the top-level eos_token_id list"
);
let eos_vec: Vec<u32> = ctx.tokenizer().eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![1u32, 2],
"tokenizer eos set should be exactly the resolved {{1, 2}}"
);
}
#[test]
fn generation_config_eos_overrides_vlm_base_config_eos() {
let dir = fresh_dir("eos-generation-override");
let cfg = r#"{
"model_type": "mockvlm",
"eos_token_id": 1,
"text_config": {
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(dir.join("generation_config.json"), r#"{"eos_token_id": 2}"#).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("eos generation override load");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Single(2)),
"generation_config.json eos_token_id should override config.json"
);
let eos_vec: Vec<u32> = ctx.tokenizer().eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![2u32],
"tokenizer eos set should be the overridden {{2}}"
);
}
#[test]
fn vlm_base_config_parses_without_top_level_lm_fields() {
let cfg = r#"{ "model_type": "qwen2_vl" }"#;
let base = VlmBaseConfig::from_json(cfg).expect("VLM base config should parse");
assert_eq!(base.model_type(), "qwen2_vl");
assert_eq!(base.eos_token_id().cloned(), None);
assert_eq!(base.quantization(), None);
let lm_err = crate::lm::load::Config::from_json(cfg)
.expect_err("LM Config should reject a model_type-only body");
let msg = lm_err.to_string();
assert!(
msg.contains("hidden_size") || msg.contains("missing field"),
"LM Config parse error should name the missing LM field, got: {msg}"
);
}
fn write_tokenizer_config_with_eos_c(dir: &Path) {
std::fs::write(dir.join("tokenizer_config.json"), r#"{ "eos_token": "c" }"#).unwrap();
}
#[test]
fn nested_text_config_eos_promotes_to_tokenizer() {
let dir = fresh_dir("nested-text-config-eos");
let cfg = r#"{
"model_type": "mockvlm",
"text_config": {
"hidden_size": 8,
"vocab_size": 5,
"eos_token_id": [42, 50]
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
write_tokenizer_config_with_eos_c(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("nested text_config.eos_token_id should promote");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Many(vec![42, 50])),
"VlmBaseConfig should carry the promoted text_config.eos_token_id list"
);
let mut eos_vec: Vec<u32> = ctx.tokenizer().eos_token_ids_iter().collect();
eos_vec.sort_unstable();
assert_eq!(
eos_vec,
vec![42u32, 50],
"tokenizer eos set should be exactly the promoted {{42, 50}}, not the tokenizer-config fallback"
);
}
#[test]
fn top_level_eos_wins_over_nested_text_config_eos() {
let dir = fresh_dir("top-eos-wins-over-nested");
let cfg = r#"{
"model_type": "mockvlm",
"eos_token_id": 7,
"text_config": {
"hidden_size": 8,
"vocab_size": 5,
"eos_token_id": [42, 50]
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("top-level eos with nested present");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Single(7)),
"top-level eos_token_id must win over nested text_config.eos_token_id"
);
let eos_vec: Vec<u32> = ctx.tokenizer().eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![7u32],
"tokenizer eos set must be the top-level {{7}}, not nested {{42, 50}}"
);
}
#[test]
fn generation_config_eos_overrides_promoted_nested_eos() {
let dir = fresh_dir("gen-cfg-overrides-promoted-nested");
let cfg = r#"{
"model_type": "mockvlm",
"text_config": {
"hidden_size": 8,
"vocab_size": 5,
"eos_token_id": [42, 50]
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(dir.join("generation_config.json"), r#"{"eos_token_id": 9}"#).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("generation_config override over promoted nested eos");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Single(9)),
"generation_config.json eos_token_id must override the promoted nested value"
);
let eos_vec: Vec<u32> = ctx.tokenizer().eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![9u32],
"tokenizer eos set must be the post-override {{9}}, not the promoted nested set"
);
}
#[test]
fn nested_llm_config_eos_promotes_when_text_config_absent() {
let dir = fresh_dir("nested-llm-config-eos");
let cfg = r#"{
"model_type": "mockvlm",
"vocab_size": 5,
"llm_config": {
"hidden_size": 8,
"eos_token_id": [11, 13]
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("nested llm_config.eos_token_id alias should promote");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Many(vec![11, 13])),
"VlmBaseConfig should carry the promoted llm_config.eos_token_id list"
);
}
#[test]
fn text_config_eos_wins_over_llm_config_alias_when_both_present() {
let dir = fresh_dir("text-config-wins-over-llm-config");
let cfg = r#"{
"model_type": "mockvlm",
"text_config": {
"hidden_size": 8,
"vocab_size": 5,
"eos_token_id": [42, 50]
},
"llm_config": {
"eos_token_id": [11, 13]
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("text_config should win over llm_config alias when both present");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
Some(EosTokenId::Many(vec![42, 50])),
"text_config.eos_token_id must take precedence over llm_config.eos_token_id"
);
}
#[test]
fn falsy_nested_eos_does_not_promote() {
let dir = fresh_dir("falsy-nested-eos");
let cfg = r#"{
"model_type": "mockvlm",
"text_config": {
"hidden_size": 8,
"vocab_size": 5,
"eos_token_id": 0
},
"mock_extra": 11
}"#;
std::fs::write(dir.join("config.json"), cfg).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
write_tokenizer_config_with_eos_c(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry).expect("falsy eos");
assert_eq!(
ctx.config_ref().eos_token_id().cloned(),
None,
"scalar 0 nested eos must not promote (falsy)"
);
let eos_vec: Vec<u32> = ctx.tokenizer().eos_token_ids_iter().collect();
assert_eq!(
eos_vec,
vec![2u32],
"tokenizer should fall back to its tokenizer_config `eos_token` when nested is falsy"
);
}
fn mock_image_only_preprocessor_config_json(image_size: u32) -> String {
format!(
r#"{{
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
"mock_image_size": {image_size}
}}"#
)
}
fn mock_processor_class_only_config_json(processor_class: &str) -> String {
format!(r#"{{ "processor_class": "{processor_class}" }}"#)
}
fn mock_processor_config_with_seq_len(processor_class: &str, image_seq_len: u32) -> String {
format!(r#"{{ "processor_class": "{processor_class}", "image_seq_len": {image_seq_len} }}"#)
}
#[test]
fn processor_class_falls_back_to_processor_config_when_preprocessor_has_none() {
let dir = fresh_dir("split-dispatch-preprocessor-no-class");
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_image_only_preprocessor_config_json(24),
)
.unwrap();
std::fs::write(
dir.join("processor_config.json"),
mock_processor_class_only_config_json("MockProc"),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", mock_processor_constructor());
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("dispatch class from processor_config.json + body from preprocessor_config.json");
assert_eq!(
ctx.processor().image_processor_config().size(),
(24, 24),
"constructor must see preprocessor_config.json body (image-preprocessor metadata)"
);
}
#[test]
fn split_layout_carries_both_preprocessor_and_processor_config_bodies() {
let dir = fresh_dir("split-carries-both-bodies");
std::fs::write(
dir.join("preprocessor_config.json"),
mock_image_only_preprocessor_config_json(24),
)
.unwrap();
std::fs::write(
dir.join("processor_config.json"),
mock_processor_config_with_seq_len("MockProc", 256),
)
.unwrap();
let (proc_config, preprocessor_body, processor_body, filename) =
load_processor_config(&dir).expect("split-layout processor config must resolve");
assert_eq!(
proc_config.processor_class(),
"MockProc",
"dispatch class must come from processor_config.json"
);
assert_eq!(
filename, "preprocessor_config.json",
"primary-body filename is the preprocessor file (image-preprocessor metadata source)"
);
let preprocessor_body = preprocessor_body.expect("preprocessor_config.json body must be carried");
assert!(
preprocessor_body.contains("mock_image_size"),
"preprocessor body must carry the image-preprocessor metadata, got: {preprocessor_body}"
);
let processor_body =
processor_body.expect("processor_config.json body must be carried, not discarded");
assert!(
processor_body.contains("image_seq_len"),
"processor_config.json body must survive with its non-class field, got: {processor_body}"
);
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let asserting_processor_ctor: ProcessorConstructor = Box::new(
|loaded: &LoadedProcessor<'_>| -> Result<Box<dyn Processor>> {
let preprocessor = loaded.preprocessor_config_json.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"LoadedProcessor",
"preprocessor_config_json",
))
})?;
let processor = loaded.processor_config_json.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"LoadedProcessor (the carried body)",
"processor_config_json",
))
})?;
let pre: serde_json::Value = serde_json::from_str(preprocessor).map_err(|e| {
Error::Parse(ParsePayload::new(
"asserting_processor_ctor: preprocessor body",
"JSON",
e,
))
})?;
let image_size = pre
.get("mock_image_size")
.and_then(serde_json::Value::as_u64)
.and_then(|x| u32::try_from(x).ok())
.ok_or(Error::MissingField(MissingFieldPayload::new(
"preprocessor body",
"mock_image_size",
)))?;
let proc: serde_json::Value = serde_json::from_str(processor).map_err(|e| {
Error::Parse(ParsePayload::new(
"asserting_processor_ctor: processor_config.json body",
"JSON",
e,
))
})?;
let seq_len = proc
.get("image_seq_len")
.and_then(serde_json::Value::as_u64)
.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"processor_config.json body (the discarded field)",
"image_seq_len",
))
})?;
if seq_len != 256 {
return Err(Error::LengthMismatch(
crate::error::LengthMismatchPayload::new(
"asserting_processor_ctor: image_seq_len round-trip",
256,
seq_len as usize,
),
));
}
Ok(Box::new(MockVlmProcessor {
processor_class: loaded.processor_class.to_owned(),
image_size,
}))
},
);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", asserting_processor_ctor);
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("split-layout load must surface BOTH config bodies to the processor constructor");
assert_eq!(ctx.processor().image_processor_config().size(), (24, 24));
}
#[test]
fn preferred_class_layout_still_carries_processor_config_body() {
let dir = fresh_dir("preferred-class-carries-processor-body");
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 24),
)
.unwrap();
std::fs::write(
dir.join("processor_config.json"),
mock_processor_config_with_seq_len("OtherProc", 256),
)
.unwrap();
let (proc_config, preprocessor_body, processor_body, filename) =
load_processor_config(&dir).expect("preferred-class processor config must resolve");
assert_eq!(
proc_config.processor_class(),
"MockProc",
"dispatch class must come from preprocessor_config.json (precedence unchanged)"
);
assert_eq!(
filename, "preprocessor_config.json",
"primary-body filename is the preprocessor file"
);
let preprocessor_body = preprocessor_body.expect("preprocessor_config.json body must be carried");
assert!(
preprocessor_body.contains("mock_image_size"),
"preprocessor body must carry the image-preprocessor metadata, got: {preprocessor_body}"
);
let processor_body = processor_body
.expect("processor_config.json body must be carried in the preferred-class path too");
assert!(
processor_body.contains("image_seq_len"),
"processor_config.json body must survive with its non-class field, got: {processor_body}"
);
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let asserting_processor_ctor: ProcessorConstructor = Box::new(
|loaded: &LoadedProcessor<'_>| -> Result<Box<dyn Processor>> {
let preprocessor = loaded.preprocessor_config_json.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"LoadedProcessor",
"preprocessor_config_json",
))
})?;
let processor = loaded.processor_config_json.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"LoadedProcessor (the carried body)",
"processor_config_json",
))
})?;
let pre: serde_json::Value = serde_json::from_str(preprocessor).map_err(|e| {
Error::Parse(ParsePayload::new(
"asserting_processor_ctor: preprocessor body",
"JSON",
e,
))
})?;
let image_size = pre
.get("mock_image_size")
.and_then(serde_json::Value::as_u64)
.and_then(|x| u32::try_from(x).ok())
.ok_or(Error::MissingField(MissingFieldPayload::new(
"preprocessor body",
"mock_image_size",
)))?;
let proc: serde_json::Value = serde_json::from_str(processor).map_err(|e| {
Error::Parse(ParsePayload::new(
"asserting_processor_ctor: processor_config.json body",
"JSON",
e,
))
})?;
let seq_len = proc
.get("image_seq_len")
.and_then(serde_json::Value::as_u64)
.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"processor_config.json body (the discarded field)",
"image_seq_len",
))
})?;
if seq_len != 256 {
return Err(Error::LengthMismatch(
crate::error::LengthMismatchPayload::new(
"asserting_processor_ctor: image_seq_len round-trip",
256,
seq_len as usize,
),
));
}
Ok(Box::new(MockVlmProcessor {
processor_class: loaded.processor_class.to_owned(),
image_size,
}))
},
);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry =
VlmProcessorTypeRegistry::new().with("MockProc", asserting_processor_ctor);
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry).expect(
"preferred-class load must dispatch on preprocessor_config.json's class and surface \
BOTH config bodies",
);
assert_eq!(ctx.processor().image_processor_config().size(), (24, 24));
}
#[test]
fn neither_processor_config_file_has_processor_class_is_recoverable_error() {
let dir = fresh_dir("neither-has-processor-class");
std::fs::write(dir.join("config.json"), mock_config_json("mockvlm")).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_image_only_preprocessor_config_json(16),
)
.unwrap();
std::fs::write(
dir.join("processor_config.json"),
r#"{ "some_other_key": 1 }"#,
)
.unwrap();
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry = VlmProcessorTypeRegistry::new();
let configuration = VlmModelConfiguration::from_directory(&dir);
let Err(err) = load(&configuration, &model_registry, &processor_registry) else {
panic!("missing-processor-class across both files must error");
};
match &err {
Error::MissingField(p) => {
assert_eq!(
p.field(),
"processor_class",
"MissingField should name processor_class as the field"
);
assert!(
p.type_name().contains("ProcessorConfig"),
"type_name should name ProcessorConfig, got: {}",
p.type_name()
);
}
_ => panic!("expected MissingField for missing processor_class, got: {err:?}"),
}
}
struct MockConcreteProcessor {
special: u32,
}
impl MockConcreteProcessor {
fn mock_special(&self) -> u32 {
self.special
}
}
impl Processor for MockConcreteProcessor {
fn image_processor_config(&self) -> ImageProcessorConfig {
ImageProcessorConfig::new()
.with_size((1, 1))
.with_mean([0.5, 0.5, 0.5])
.with_std([0.5, 0.5, 0.5])
.with_rescale_factor(1.0 / 255.0)
.with_do_resize(true)
.with_do_rescale(true)
.with_do_normalize(true)
.with_resample(ResizeFilter::Bilinear)
.with_color_order(ColorOrder::Rgb)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
#[test]
fn loaded_processor_downcasts_to_concrete_per_model_type() {
let dir = fresh_dir("processor-downcast");
write_vlm_dir(
&dir,
"mockvlm",
"preprocessor_config.json",
"MockConcreteProc",
64,
);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry = VlmProcessorTypeRegistry::new().with(
"MockConcreteProc",
Box::new(
|_loaded: &LoadedProcessor<'_>| -> Result<Box<dyn Processor>> {
Ok(Box::new(MockConcreteProcessor { special: 4242 }))
},
),
);
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("load should construct the concrete processor");
let concrete = ctx
.processor()
.as_any()
.downcast_ref::<MockConcreteProcessor>()
.expect("loaded processor must downcast to its concrete per-model type");
assert_eq!(concrete.mock_special(), 4242);
}
#[test]
fn loaded_processor_reads_model_config_json_only_arch_field() {
let dir = fresh_dir("processor-reads-model-config-json");
std::fs::write(dir.join("config.json"), mock_nested_config_json("mockvlm")).unwrap();
std::fs::write(
dir.join("preprocessor_config.json"),
mock_preprocessor_config_json("MockProc", 999),
)
.unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0], &(1usize,)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
write_tokenizer(&dir);
let config_reading_ctor: ProcessorConstructor = Box::new(
|loaded: &LoadedProcessor<'_>| -> Result<Box<dyn Processor>> {
let cfg: serde_json::Value = serde_json::from_str(loaded.config_json).map_err(|e| {
Error::Parse(ParsePayload::new(
"processor ctor: model config_json",
"JSON",
e,
))
})?;
let hidden_size = cfg
.get("text_config")
.and_then(|t| t.get("hidden_size"))
.and_then(serde_json::Value::as_u64)
.and_then(|x| u32::try_from(x).ok())
.ok_or_else(|| {
Error::MissingField(MissingFieldPayload::new(
"processor ctor: LoadedProcessor.config_json (config.json-only arch field)",
"text_config.hidden_size",
))
})?;
Ok(Box::new(MockVlmProcessor {
processor_class: loaded.processor_class.to_owned(),
image_size: hidden_size,
}))
},
);
let model_registry = VlmTypeRegistry::new().with("mockvlm", mock_vlm_constructor());
let processor_registry = VlmProcessorTypeRegistry::new().with("MockProc", config_reading_ctor);
let configuration = VlmModelConfiguration::from_directory(&dir);
let ctx = load(&configuration, &model_registry, &processor_registry)
.expect("load must surface the model config.json to the processor constructor");
assert_eq!(
ctx.processor().image_processor_config().size(),
(8, 8),
"processor must have read hidden_size=8 off LoadedProcessor.config_json, \
not mock_image_size=999 off the processor config"
);
}