use std::{
path::PathBuf,
sync::atomic::{AtomicU64, Ordering},
};
use crate::error::MissingFieldPayload;
use super::*;
use crate::{
array::Array,
lm::{cache::KvCache, generate::FinishReason},
};
fn mock_config_json(model_type: &str) -> String {
format!(
r#"{{
"model_type": "{model_type}",
"hidden_size": 8,
"num_hidden_layers": 2,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": 5,
"tie_word_embeddings": false,
"mock_extra": 7
}}"#
)
}
struct MockLoadedModel {
vocab: i32,
#[allow(dead_code)]
mock_extra: i64,
}
impl Model for MockLoadedModel {
fn forward(&self, tokens: &Array, _cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let (batch, seq) = match tokens.shape().as_slice() {
[b, s] => (*b, *s),
[s] => (1, *s),
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockLoadedModel::forward: tokens must be rank-1 [S] or rank-2 [B, S]",
other.len() as u32,
other.to_vec(),
)));
}
};
let vocab = self.vocab as usize;
Array::from_slice::<f32>(&vec![0.0_f32; batch * seq * vocab], &(batch, seq, vocab))
}
}
fn mock_constructor() -> ModelConstructor {
Box::new(|loaded: &LoadedModel| -> Result<Box<dyn Model>> {
assert!(
!loaded.weights.is_empty(),
"constructor should receive the loaded weights"
);
let raw: serde_json::Value = serde_json::from_str(&loaded.config_json).map_err(|e| {
Error::Parse(crate::error::ParsePayload::new(
"mock ctor: bad config json",
"config.json",
Box::new(e) as Box<dyn std::error::Error + Send + Sync>,
))
})?;
let mock_extra = raw
.get("mock_extra")
.and_then(serde_json::Value::as_i64)
.ok_or(Error::MissingField(MissingFieldPayload::new(
"mock ctor",
"mock_extra",
)))?;
Ok(Box::new(MockLoadedModel {
vocab: loaded.config.vocab_size,
mock_extra,
}))
})
}
fn fresh_dir(tag: &str) -> PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!("mlxrs-lm-factory-{tag}-{}-{n}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
dir
}
fn write_tokenizer(dir: &Path) {
use tokenizers::{
Tokenizer as HfTokenizer, models::wordlevel::WordLevel, pre_tokenizers::whitespace::Whitespace,
};
let vocab = [("a", 0u32), ("b", 1), ("c", 2)]
.iter()
.map(|(w, i)| ((*w).to_string(), *i))
.collect();
let wl = WordLevel::builder()
.vocab(vocab)
.unk_token("a".to_string())
.build()
.unwrap();
let mut hf = HfTokenizer::new(wl);
hf.with_pre_tokenizer(Some(Whitespace {}));
hf.save(dir.join("tokenizer.json"), false).unwrap();
}
fn write_model_dir_no_tokenizer(dir: &Path, model_type: &str) {
std::fs::write(dir.join("config.json"), mock_config_json(model_type)).unwrap();
let mut weights: Weights = HashMap::new();
weights.insert(
"mock.weight".to_owned(),
Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap(),
);
crate::io::save_safetensors(&dir.join("model.safetensors"), &weights).unwrap();
}
fn write_model_dir(dir: &Path, model_type: &str) {
write_model_dir_no_tokenizer(dir, model_type);
write_tokenizer(dir);
}
#[test]
fn load_dispatches_to_registered_mock_and_returns_model_and_tokenizer() {
let dir = fresh_dir("dispatch");
write_model_dir(&dir, "mockarch");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let config = ModelConfiguration::from_directory(&dir);
let ctx = load(&config, ®istry).expect("load should succeed");
assert_eq!(ctx.config.model_type(), "mockarch");
assert_eq!(ctx.config.vocab_size, 5);
let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
let tokens = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).unwrap();
let logits = ctx.model.forward(&tokens, &mut cache).unwrap();
assert_eq!(logits.shape(), vec![1, 3, 5]);
let ids = ctx.tokenizer.encode("a b c", false).unwrap();
assert_eq!(ids.len(), 3);
}
#[test]
fn from_id_resolves_as_local_path() {
let dir = fresh_dir("idpath");
write_model_dir(&dir, "mockarch");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let config = ModelConfiguration::from_id(dir.to_str().unwrap());
assert_eq!(config.model_directory(), dir.as_path());
let ctx = load(&config, ®istry).expect("id-as-local-path load should succeed");
assert_eq!(ctx.config.model_type(), "mockarch");
}
#[test]
fn constructor_reads_model_specific_raw_config_key() {
let dir = fresh_dir("rawkey");
write_model_dir(&dir, "mockarch");
let registry = ModelTypeRegistry::new().with("mockarch", {
Box::new(|loaded: &LoadedModel| -> Result<Box<dyn Model>> {
let raw: serde_json::Value = serde_json::from_str(&loaded.config_json).unwrap();
assert_eq!(raw.get("mock_extra").and_then(|v| v.as_i64()), Some(7));
Ok(Box::new(MockLoadedModel {
vocab: loaded.config.vocab_size,
mock_extra: 7,
}))
}) as ModelConstructor
});
let config = ModelConfiguration::from_directory(&dir);
let ctx = load(&config, ®istry).expect("load");
let _ = ctx.model;
}
#[test]
fn unknown_model_type_is_recoverable_error() {
let dir = fresh_dir("unknown");
write_model_dir(&dir, "nope");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let config = ModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("unknown model_type must error");
};
let msg = err.to_string();
assert!(msg.contains("unsupported model type"), "got: {msg}");
assert!(msg.contains("nope"), "error should name the type: {msg}");
}
#[test]
fn missing_config_json_is_recoverable_error() {
let dir = fresh_dir("noconfig");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let config = ModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("missing config.json must error");
};
assert!(
err.to_string().contains("config.json"),
"error should name config.json: {err}"
);
}
#[test]
fn registry_contains_and_remapping() {
let registry = ModelTypeRegistry::new().with("mistral", mock_constructor());
assert!(registry.contains("mistral"));
assert!(registry.contains("llama"));
assert!(!registry.contains("qwen3"));
assert_eq!(remap_model_type("mistral"), "llama");
assert_eq!(remap_model_type("qwen3"), "qwen3");
}
#[test]
fn register_replaces_and_returns_previous() {
let mut registry = ModelTypeRegistry::new();
assert!(registry.register("mockarch", mock_constructor()).is_none());
assert!(registry.register("mockarch", mock_constructor()).is_some());
}
#[test]
fn tokenizer_source_loads_from_separate_directory() {
let model_dir = fresh_dir("split-model");
write_model_dir_no_tokenizer(&model_dir, "mockarch");
assert!(!model_dir.join("tokenizer.json").exists());
let tok_dir = fresh_dir("split-tok");
write_tokenizer(&tok_dir);
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let config = ModelConfiguration::from_directory(&model_dir).with_tokenizer_source(&tok_dir);
assert_eq!(config.tokenizer_directory(), tok_dir.as_path());
let ctx = load(&config, ®istry).expect("split-tokenizer load should succeed");
let ids = ctx.tokenizer.encode("a b c", false).unwrap();
assert_eq!(ids.len(), 3);
}
#[test]
fn unsupported_model_type_does_not_touch_weights_or_tokenizer() {
let dir = fresh_dir("unsupported-cheap");
std::fs::write(dir.join("config.json"), mock_config_json("nope")).unwrap();
std::fs::write(
dir.join("model.safetensors"),
b"this is not a safetensors file",
)
.unwrap();
assert!(!dir.join("tokenizer.json").exists());
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let config = ModelConfiguration::from_directory(&dir);
let Err(err) = load(&config, ®istry) else {
panic!("unsupported model_type must error");
};
let msg = err.to_string();
assert!(
msg.contains("unsupported model type"),
"expected the unsupported-model error before any weight load, got: {msg}"
);
assert!(msg.contains("nope"), "error should name the type: {msg}");
assert!(
!msg.contains("safetensors") && !msg.contains("weights"),
"weights must not have been loaded, but the error mentions them: {msg}"
);
}
#[test]
fn raw_config_json_matches_parsed_config() {
let dir = fresh_dir("raw-consistency");
write_model_dir(&dir, "mockarch");
let on_disk = std::fs::read_to_string(dir.join("config.json")).unwrap();
let captured: std::sync::Arc<std::sync::Mutex<Option<String>>> =
std::sync::Arc::new(std::sync::Mutex::new(None));
let captured_in_ctor = std::sync::Arc::clone(&captured);
let registry = ModelTypeRegistry::new().with("mockarch", {
Box::new(move |loaded: &LoadedModel| -> Result<Box<dyn Model>> {
let raw: serde_json::Value = serde_json::from_str(&loaded.config_json).unwrap();
assert_eq!(
raw.get("model_type").and_then(|v| v.as_str()),
Some(loaded.config.model_type())
);
assert_eq!(
raw.get("vocab_size").and_then(|v| v.as_i64()),
Some(loaded.config.vocab_size as i64)
);
*captured_in_ctor.lock().unwrap() = Some(loaded.config_json.clone());
Ok(Box::new(MockLoadedModel {
vocab: loaded.config.vocab_size,
mock_extra: 7,
}))
}) as ModelConstructor
});
let config = ModelConfiguration::from_directory(&dir);
let _ctx = load(&config, ®istry).expect("load");
let seen = captured.lock().unwrap().clone().expect("ctor ran");
assert_eq!(seen, on_disk);
}
use crate::lm::{generate::GenConfig, model::MockModel};
fn fixture_tokenizer() -> Tokenizer {
let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures");
Tokenizer::from_path(&dir, None).expect("load fixture tokenizer")
}
fn mock_config(vocab: i32, num_layers: i32) -> Config {
Config::from_json(&format!(
r#"{{
"model_type": "mockarch",
"hidden_size": 8,
"num_hidden_layers": {num_layers},
"num_attention_heads": 4,
"num_key_value_heads": 2,
"head_dim": 2,
"rope_theta": 10000.0,
"vocab_size": {vocab},
"tie_word_embeddings": false
}}"#
))
.expect("mock config parses")
}
fn mock_context(vocab: i32, num_layers: i32) -> ModelContext {
ModelContext::new(
Box::new(MockModel::new(vocab as usize)),
fixture_tokenizer(),
mock_config(vocab, num_layers),
)
}
#[test]
fn context_owns_and_exposes_model_tokenizer_config() {
let ctx = mock_context(8, 2);
assert_eq!(ctx.config().model_type(), "mockarch");
assert_eq!(ctx.config().vocab_size, 8);
assert_eq!(ctx.config().num_hidden_layers, 2);
let ids = ctx.tokenizer().encode("the quick brown", false).unwrap();
assert_eq!(ids.len(), 3);
let mut cache: Vec<Box<dyn crate::lm::cache::KvCache>> = Vec::new();
let tokens = Array::from_slice::<i32>(&[1, 2, 3], &(1usize, 3)).unwrap();
let logits = ctx.model().forward(&tokens, &mut cache).unwrap();
assert_eq!(logits.shape(), vec![1, 3, 8]);
}
#[test]
fn encode_forwards_to_tokenizer() {
let ctx = mock_context(8, 1);
let text = "the quick brown world";
let via_context = ctx.encode(text, false).unwrap();
let via_tokenizer = ctx.tokenizer().encode(text, false).unwrap();
assert_eq!(via_context, via_tokenizer);
assert_eq!(via_context.len(), 4);
}
#[test]
fn decode_forwards_to_tokenizer_and_round_trips_encode() {
let ctx = mock_context(8, 1);
let ids = ctx.encode("hello world", false).unwrap();
let via_context = ctx.decode(&ids, true).unwrap();
let via_tokenizer = ctx.tokenizer().decode(&ids, true).unwrap();
assert_eq!(via_context, via_tokenizer);
assert_eq!(via_context, "hello world");
}
#[test]
fn apply_chat_template_forwards_to_tokenizer() {
let ctx = mock_context(8, 1);
let messages = serde_json::json!([
{"role": "user", "content": "hello"}
]);
let via_context = ctx
.apply_chat_template(&messages, None, true, false, None)
.unwrap();
let via_tokenizer = ctx
.tokenizer()
.apply_chat_template(&messages, None, true, false, None)
.unwrap();
assert_eq!(via_context, via_tokenizer);
assert_eq!(via_context, "<s><|user|>hello<|assistant|>");
}
#[test]
fn apply_chat_template_ids_forwards_and_equals_render_then_encode() {
let ctx = mock_context(8, 1);
let messages = serde_json::json!([
{"role": "user", "content": "the quick"}
]);
let via_context = ctx
.apply_chat_template_ids(&messages, None, true, false, None)
.unwrap();
let via_tokenizer = ctx
.tokenizer()
.apply_chat_template_ids(&messages, None, true, false, None)
.unwrap();
assert_eq!(via_context, via_tokenizer);
let rendered = ctx
.apply_chat_template(&messages, None, true, false, None)
.unwrap();
assert_eq!(via_context, ctx.encode(&rendered, false).unwrap());
}
#[test]
fn apply_chat_template_rejects_generation_prompt_with_continue() {
let ctx = mock_context(8, 1);
let messages = serde_json::json!([{"role": "user", "content": "hello"}]);
let err = ctx
.apply_chat_template(
&messages, None, true, true, None,
)
.expect_err("gen-prompt + continue must error");
assert!(
err.to_string().contains("continue_final_message"),
"got: {err}"
);
}
#[test]
fn generate_forwards_and_runs_to_length() {
let ctx = mock_context(8, 2);
let prompt = ctx.encode("hello world", false).unwrap();
let cfg = GenConfig {
max_tokens: 3,
..Default::default()
};
let (text, stats) = ctx.generate(&prompt, cfg).expect("generate");
assert_eq!(stats.generation_tokens, 3);
assert_eq!(stats.prompt_tokens, prompt.len());
assert_eq!(text, ctx.decode(&[7, 7, 7], true).unwrap());
}
#[test]
fn generate_stops_on_eos_token() {
let ctx = mock_context(3, 1);
let prompt = ctx.encode("hello", false).unwrap();
let cfg = GenConfig {
max_tokens: 16,
..Default::default()
};
let (text, stats) = ctx.generate(&prompt, cfg).expect("generate");
assert!(
text.is_empty(),
"eos token contributes no text, got {text:?}"
);
assert_eq!(stats.generation_tokens, 1);
}
#[test]
fn stream_generate_forwards_and_yields_per_token_responses() {
let ctx = mock_context(8, 2);
let prompt = ctx.encode("the quick", false).unwrap();
let cfg = GenConfig {
max_tokens: 4,
..Default::default()
};
let mut reasons = Vec::new();
let mut collected = String::new();
for resp in ctx.stream_generate(&prompt, cfg) {
let r = resp.expect("stream step");
collected.push_str(&r.text);
reasons.push(r.finish_reason);
}
assert_eq!(reasons.len(), 4);
assert_eq!(reasons[0], None);
assert_eq!(reasons[3], Some(FinishReason::Length));
let (gen_text, _) = ctx
.generate(
&prompt,
GenConfig {
max_tokens: 4,
..Default::default()
},
)
.unwrap();
assert_eq!(collected, gen_text);
}
#[test]
fn from_loaded_model_context_wraps_the_triple() {
let dir = fresh_dir("ctx-from-loaded");
write_model_dir(&dir, "mockarch");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let configuration = ModelConfiguration::from_directory(&dir);
let loaded = load(&configuration, ®istry).expect("load");
let ctx: ModelContext = loaded.into();
assert_eq!(ctx.config().model_type(), "mockarch");
let mut cache: Vec<Box<dyn crate::lm::cache::KvCache>> = Vec::new();
let tokens = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).unwrap();
let logits = ctx.model().forward(&tokens, &mut cache).unwrap();
assert_eq!(logits.shape(), vec![1, 3, 5]);
assert_eq!(ctx.encode("a b c", false).unwrap().len(), 3);
}
#[test]
fn context_load_convenience_equals_load_then_into() {
let dir = fresh_dir("ctx-load");
write_model_dir(&dir, "mockarch");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let configuration = ModelConfiguration::from_directory(&dir);
let ctx = ModelContext::load(&configuration, ®istry).expect("ModelContext::load");
assert_eq!(ctx.config().model_type(), "mockarch");
assert_eq!(ctx.config().vocab_size, 5);
assert_eq!(ctx.encode("a b c", false).unwrap().len(), 3);
}
#[test]
fn context_load_propagates_unknown_model_type_error() {
let dir = fresh_dir("ctx-load-unknown");
write_model_dir(&dir, "nope");
let registry = ModelTypeRegistry::new().with("mockarch", mock_constructor());
let configuration = ModelConfiguration::from_directory(&dir);
let Err(err) = ModelContext::load(&configuration, ®istry) else {
panic!("unknown model_type must error");
};
assert!(
err.to_string().contains("unsupported model type"),
"got: {err}"
);
}
#[test]
fn into_parts_round_trips_new() {
let ctx = mock_context(8, 3);
let (model, tokenizer, config) = ctx.into_parts();
assert_eq!(config.num_hidden_layers, 3);
let rebuilt = ModelContext::new(model, tokenizer, config);
assert_eq!(rebuilt.config().vocab_size, 8);
assert_eq!(rebuilt.encode("hello", false).unwrap().len(), 1);
}