use super::*;
use crate::{lm::model::MockModel, tokenizer::StreamingDetokenizer};
fn fixture_tokenizer() -> Tokenizer {
let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures");
Tokenizer::from_path(&dir, None).expect("load fixture tokenizer")
}
fn cache_config() -> CacheConfig {
CacheConfig {
num_hidden_layers: 2,
sliding_window: None,
}
}
fn session(max_tokens: usize) -> ChatSession {
let cfg = GenConfig {
max_tokens,
..Default::default()
};
ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.generate_params(cfg)
.build()
.expect("build")
}
#[test]
fn fresh_session_has_no_cache_until_first_turn() {
let mut s = session(4);
assert!(!s.has_cache(), "fresh session: cache unrealised");
assert!(s.current_cache().is_none());
assert!(s.history().is_empty());
let reply = s.respond("hello").expect("respond");
assert!(!reply.is_empty(), "MockModel produces a non-empty reply");
assert!(s.has_cache(), "cache realised after the first turn");
}
#[test]
fn multi_turn_reuses_cache_and_accumulates_history() {
let mut s = session(3);
let _ = s.respond("hello world").expect("turn 1");
assert_eq!(s.history().len(), 2);
assert_eq!(s.history()[0].role, Role::User);
assert_eq!(s.history()[1].role, Role::Assistant);
let offset_after_turn_1 = s
.current_cache()
.expect("cache realised")
.first()
.expect(">=1 layer")
.offset();
assert!(offset_after_turn_1 > 0, "turn 1 advanced the cache");
let _ = s.respond("the quick fox").expect("turn 2");
assert_eq!(s.history().len(), 4);
assert_eq!(s.history()[2].role, Role::User);
assert_eq!(s.history()[2].content(), "the quick fox");
assert_eq!(s.history()[3].role, Role::Assistant);
let offset_after_turn_2 = s
.current_cache()
.expect("cache realised")
.first()
.expect("layer")
.offset();
assert!(
offset_after_turn_2 > offset_after_turn_1,
"turn 2 reused + extended the cache (offset {offset_after_turn_1} -> {offset_after_turn_2})"
);
}
#[test]
fn turn_two_prefills_only_the_new_suffix_not_the_whole_history() {
let max_tokens = 3;
let mut s = session(max_tokens);
let _ = s.respond("hello world").expect("turn 1");
let off1 = s.current_cache().expect("cache realised")[0].offset();
assert!(off1 > 0, "turn 1 advanced the cache");
let (prompt2, _) = s
.build_turn_prompt("the quick fox", Role::User)
.expect("render turn 2");
let full_render2 = prompt2.len();
assert!(
full_render2 > off1,
"turn-2 render ({full_render2}) extends past the cached prefix ({off1})"
);
let suffix_len = full_render2 - off1;
let _ = s.respond("the quick fox").expect("turn 2");
let off2 = s.current_cache().expect("cache realised")[0].offset();
let expected_growth = suffix_len - 1 + max_tokens;
assert_eq!(
off2 - off1,
expected_growth,
"turn 2 grew the cache by the new suffix ({suffix_len}) + generated \
({max_tokens}) only, not the whole render"
);
assert!(
off2 - off1 < full_render2,
"turn 2 did NOT re-prefill the whole conversation (grew {}, full \
render is {full_render2})",
off2 - off1
);
}
#[test]
fn instructions_change_forces_a_cache_rebuild_not_wrong_output() {
let max_tokens = 3;
let mut s = session(max_tokens);
let _ = s.respond("hello").expect("turn 1");
let off1 = s.current_cache().expect("realised")[0].offset();
s.set_instructions(Some("hello world the quick".to_string()));
let (prompt2, _) = s
.build_turn_prompt("world", Role::User)
.expect("render turn 2");
let full_render2 = prompt2.len();
let _ = s.respond("world").expect("turn 2");
let off2 = s.current_cache().expect("realised")[0].offset();
assert_eq!(
off2,
full_render2 - 1 + max_tokens,
"instructions change rebuilt the cache from scratch (offset reset, \
full render re-prefilled)"
);
assert!(
off2 != off1 + (full_render2 - 1 + max_tokens),
"the stale cache was discarded, not extended"
);
}
#[test]
fn every_cache_layer_advances_in_lockstep() {
let mut s = session(3);
let _ = s.respond("hello").expect("turn");
let cache = s.current_cache().expect("realised");
assert_eq!(cache.len(), 2, "one cache per decoder layer");
let off0 = cache[0].offset();
assert!(off0 > 0);
assert!(
cache.iter().all(|c| c.offset() == off0),
"all layers advance in lockstep"
);
}
#[test]
fn streaming_and_non_streaming_respond_are_consistent() {
let mut a = session(5);
let non_streaming = a.respond("hello world").expect("non-streaming");
let mut b = session(5);
let mut streamed = String::new();
{
let stream = b.stream_respond("hello world").expect("stream");
for resp in stream {
streamed.push_str(&resp.expect("stream step").text);
}
}
assert_eq!(
non_streaming, streamed,
"streaming and non-streaming respond produce the same text"
);
assert_eq!(a.history().len(), b.history().len());
assert_eq!(a.history()[1].content(), b.history()[1].content);
}
#[test]
fn streaming_reply_matches_recorded_history() {
let mut s = session(4);
let mut streamed = String::new();
{
let stream = s.stream_respond("hello").expect("stream");
for resp in stream {
streamed.push_str(&resp.expect("step").text);
}
}
assert_eq!(s.history().len(), 2);
assert_eq!(
s.history()[1].content(),
streamed,
"the recorded assistant turn equals the streamed text"
);
}
#[test]
fn finish_reason_is_length_when_max_tokens_reached() {
let mut s = session(3);
let mut reasons = Vec::new();
{
let stream = s.stream_respond("hello").expect("stream");
for resp in stream {
reasons.push(resp.expect("step").finish_reason);
}
}
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Length));
assert_eq!(reasons.iter().filter(|r| r.is_some()).count(), 1);
assert_eq!(reasons.len(), 3, "max_tokens responses produced");
}
#[test]
fn clear_drops_cache_and_history_keeps_instructions() {
let mut s = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.instructions("be terse")
.generate_params(GenConfig {
max_tokens: 3,
..Default::default()
})
.build()
.expect("build");
let _ = s.respond("hello").expect("turn");
assert!(s.has_cache());
assert!(!s.history().is_empty());
s.clear();
assert!(!s.has_cache(), "clear() drops the cache");
assert!(s.history().is_empty(), "clear() drops the history");
assert_eq!(
s.instructions(),
Some("be terse"),
"clear() preserves instructions"
);
let _ = s.respond("world").expect("post-clear turn");
assert!(s.has_cache());
assert_eq!(s.history().len(), 2, "history restarts after clear");
}
#[test]
fn early_drop_of_stream_still_records_partial_turn() {
let mut s = session(10);
let mut streamed = String::new();
{
let mut stream = s.stream_respond("hello").expect("stream");
let first = stream.next().expect("first token").expect("ok");
streamed.push_str(&first.text);
assert!(first.finish_reason.is_none() || first.finish_reason.is_some());
}
assert!(s.has_cache(), "interrupted turn still realised the cache");
assert_eq!(s.history().len(), 2, "interrupted turn still recorded");
assert_eq!(s.history()[1].role, Role::Assistant);
assert!(
s.history()[1].content().starts_with(&streamed),
"recorded reply ({:?}) includes the streamed text ({streamed:?})",
s.history()[1].content()
);
let off_before = s.current_cache().unwrap()[0].offset();
let _ = s.respond("world").expect("follow-up turn");
let off_after = s.current_cache().unwrap()[0].offset();
assert!(off_after > off_before, "follow-up reused the cache");
}
#[test]
fn early_drop_then_followup_does_incremental_prefill() {
let max_tokens = 8;
let mut s = session(max_tokens);
{
let mut stream = s.stream_respond("hello world").expect("stream");
let _ = stream.next().expect("token 1").expect("ok");
let _ = stream.next().expect("token 2").expect("ok");
}
let off1 = s.current_cache().expect("realised")[0].offset();
assert!(off1 > 0, "interrupted turn advanced the cache");
let (prompt2, _) = s
.build_turn_prompt("the quick fox", Role::User)
.expect("render turn 2");
let full_render2 = prompt2.len();
assert!(
full_render2 > off1,
"turn-2 render extends past the interrupted cache"
);
let suffix_len = full_render2 - off1;
let _ = s.respond("the quick fox").expect("turn 2");
let off2 = s.current_cache().expect("realised")[0].offset();
assert_eq!(
off2 - off1,
suffix_len - 1 + max_tokens,
"follow-up after an early drop still prefilled only the new suffix"
);
assert!(
off2 - off1 < full_render2,
"follow-up did not re-prefill the whole conversation"
);
}
#[test]
fn history_seeded_session_replays_then_realises_cache() {
let seeded = vec![ChatMessage::user("hello"), ChatMessage::assistant("world")];
let mut s = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.history(seeded)
.generate_params(GenConfig {
max_tokens: 3,
..Default::default()
})
.build()
.expect("build");
assert!(!s.has_cache(), "history-seeded: cache unrealised pre-turn");
assert!(s.history().is_empty(), "live history empty pre-turn");
let _ = s.respond("the fox").expect("first turn");
assert!(s.has_cache(), "cache realised after the first turn");
assert_eq!(s.history().len(), 4, "replayed history folded in");
assert_eq!(s.history()[0].content(), "hello");
assert_eq!(s.history()[1].content(), "world");
assert_eq!(s.history()[2].content(), "the fox");
assert_eq!(s.history()[3].role, Role::Assistant);
}
#[test]
fn save_cache_errors_before_any_generation() {
let s = session(3);
let path = std::env::temp_dir().join("mlxrs-l11-chat-session-nocache.safetensors");
let err = s.save_cache(&path).expect_err("no cache yet");
assert!(
format!("{err}").contains("no KV cache"),
"noCacheAvailable surfaced: {err}"
);
}
#[test]
fn instructions_are_rendered_into_the_prompt() {
let with = {
let mut s = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.instructions("hello world the quick brown fox")
.generate_params(GenConfig {
max_tokens: 1,
..Default::default()
})
.build()
.expect("build");
let _ = s.respond("hello").expect("turn");
s.current_cache().unwrap()[0].offset()
};
let without = {
let mut s = session(1);
let _ = s.respond("hello").expect("turn");
s.current_cache().unwrap()[0].offset()
};
assert!(
with > without,
"the system instructions lengthened the prompt ({without} -> {with})"
);
}
#[test]
fn set_instructions_and_generate_params_accessors() {
let mut s = session(3);
assert!(s.instructions().is_none());
s.set_instructions(Some("be brief".to_string()));
assert_eq!(s.instructions(), Some("be brief"));
s.set_instructions(None);
assert!(s.instructions().is_none());
assert_eq!(s.generate_params().max_tokens, 3);
s.generate_params_mut().max_tokens = 7;
assert_eq!(s.generate_params().max_tokens, 7);
}
#[test]
fn speculative_session_runs_multi_turn_and_accumulates_history() {
let mut s = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.speculative(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
))
.generate_params(GenConfig {
max_tokens: 4,
..Default::default()
})
.build()
.expect("build");
assert!(
!s.has_cache(),
"speculative session: cache unrealised pre-turn"
);
let reply1 = s.respond("hello").expect("speculative turn 1");
assert!(!reply1.is_empty(), "speculative decoding produced a reply");
assert_eq!(s.history().len(), 2);
let reply2 = s.respond("world").expect("speculative turn 2");
assert!(!reply2.is_empty());
assert_eq!(s.history().len(), 4);
assert_eq!(s.history()[2].content(), "world");
assert_eq!(s.history()[3].role, Role::Assistant);
}
#[test]
fn speculative_session_does_not_expose_a_saveable_cache() {
let mut s = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.speculative(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
))
.generate_params(GenConfig {
max_tokens: 3,
..Default::default()
})
.build()
.expect("build");
let _ = s.respond("hello").expect("speculative turn");
assert_eq!(s.history().len(), 2, "the turn was still recorded");
assert!(
!s.has_cache(),
"a speculative session never exposes a realised cache"
);
assert!(
s.current_cache().is_none(),
"no current cache for a speculative session"
);
let path = std::env::temp_dir().join("mlxrs-l11-chat-session-spec-nocache.safetensors");
let err = s
.save_cache(&path)
.expect_err("speculative cache not saveable");
let msg = format!("{err}");
assert!(
msg.contains("speculative"),
"speculative-specific save error surfaced: {msg}"
);
assert!(
!msg.contains("call respond"),
"not the pre-generation noCacheAvailable error: {msg}"
);
assert!(
!path.exists(),
"no cache file was written for the speculative session"
);
}
fn prefilled_opaque_cache(n_tokens: usize) -> Vec<Box<dyn KvCache>> {
assert!(n_tokens > 0, "an opaque cache must have a non-empty prefix");
let model = MockModel::new(11);
let mut cache = make_prompt_cache(&cache_config());
let window: Vec<i32> = (0..n_tokens as i32).map(|i| i % 11).collect();
let arr = crate::array::Array::from_slice::<i32>(&window, &(1usize, n_tokens))
.expect("opaque-prefill token window");
let _ = model.forward(&arr, &mut cache).expect("opaque prefill");
assert_eq!(cache[0].offset(), n_tokens, "opaque cache pre-advanced");
cache
}
fn cache_restored_session(opaque_len: usize, max_tokens: usize) -> ChatSession {
ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.generate_params(GenConfig {
max_tokens,
..Default::default()
})
.cache(prefilled_opaque_cache(opaque_len))
.build()
.expect("build")
}
#[test]
fn cache_restore_short_prompt_reuses_opaque_prefix_not_rebuilds() {
let max_tokens = 3;
let opaque_len = 10;
let p_short = session(max_tokens)
.build_turn_prompt("hello", Role::User)
.expect("render short prompt")
.0
.len();
assert!(
p_short < opaque_len,
"the short render ({p_short}) must be shorter than the opaque prefix \
({opaque_len}) to exercise the rebuilt-from-empty bug"
);
let mut s = cache_restored_session(opaque_len, max_tokens);
assert!(s.has_cache(), "cache-restored session: cache realised");
let _ = s.respond("hello").expect("first turn over restored cache");
let off = s.current_cache().expect("cache realised")[0].offset();
assert_eq!(
off,
opaque_len + p_short - 1 + max_tokens,
"the full new render ({p_short}) was fed onto the opaque prefix \
({opaque_len}); offset = opaque + P - 1 + generated"
);
assert!(
off > opaque_len + max_tokens,
"the restored opaque prefix was REUSED, not rebuilt-from-empty \
(offset {off} retains the opaque {opaque_len} tokens)"
);
let (prompt2, _) = s
.build_turn_prompt("world", Role::User)
.expect("render turn 2");
let full_render2 = prompt2.len();
let _ = s.respond("world").expect("second turn");
let off2 = s.current_cache().expect("realised")[0].offset();
assert!(
off2 > off && off2 - off < full_render2,
"turn 2 reused the cache (grew {} < full render {full_render2})",
off2 - off
);
}
#[test]
fn cache_restore_long_prompt_feeds_full_new_prompt_no_dropped_tokens() {
let max_tokens = 4;
let opaque_len = 10;
let p_long = session(max_tokens)
.build_turn_prompt("hello world the quick brown fox", Role::User)
.expect("render long prompt")
.0
.len();
assert!(
p_long > opaque_len,
"the long render ({p_long}) must exceed the opaque prefix ({opaque_len}) \
to exercise the dropped-first-tokens bug"
);
let mut s = cache_restored_session(opaque_len, max_tokens);
let _ = s
.respond("hello world the quick brown fox")
.expect("first turn over restored cache");
let off = s.current_cache().expect("cache realised")[0].offset();
assert_eq!(
off,
opaque_len + p_long - 1 + max_tokens,
"the FULL new render ({p_long} tokens) was fed; no first-{opaque_len} \
tokens dropped"
);
assert_eq!(
off - (p_long - 1 + max_tokens),
opaque_len,
"offset retains the full opaque prefix; the dropped-tokens bug would \
lose exactly {opaque_len} tokens"
);
}
fn bpe_withholding_tokenizer() -> (Tokenizer, u32, usize) {
let a_id = 11u32;
let tokenizer_json = json!({
"version": "1.0",
"truncation": Value::Null,
"padding": Value::Null,
"added_tokens": [
{ "id": 0, "content": "<unk>", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true },
{ "id": 1, "content": "<s>", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true },
{ "id": 2, "content": "</s>", "single_word": false, "lstrip": false,
"rstrip": false, "normalized": false, "special": true }
],
"normalizer": Value::Null,
"pre_tokenizer": { "type": "Whitespace" },
"post_processor": Value::Null,
"decoder": { "type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true,
"use_regex": true },
"model": {
"type": "WordLevel",
"vocab": {
"<unk>": 0, "<s>": 1, "</s>": 2, "hello": 3, "world": 4, "the": 5,
"quick": 6, "brown": 7, "fox": 8, "<think>": 9, "</think>": 10,
"â": a_id
},
"unk_token": "<unk>"
}
});
let config_json = json!({
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"clean_up_tokenization_spaces": false,
"chat_template":
"{{ bos_token }}{% for m in messages %}{{ '<|' + m['role'] + '|>' }}\
{{ m['content'] }}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}"
});
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs-l11-bpe-withhold-{}-{}",
std::process::id(),
seq
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).expect("temp tokenizer dir");
std::fs::write(
dir.join("tokenizer.json"),
serde_json::to_string(&tokenizer_json).expect("serialize tokenizer.json"),
)
.expect("write tokenizer.json");
std::fs::write(
dir.join("tokenizer_config.json"),
serde_json::to_string(&config_json).expect("serialize tokenizer_config.json"),
)
.expect("write tokenizer_config.json");
let tok = Tokenizer::from_path(&dir, None).expect("load BPE-decoder tokenizer");
(tok, a_id, 12)
}
#[test]
fn speculative_interrupted_stream_flushes_detokenizer_tail() {
let (tok, a_id, vocab) = bpe_withholding_tokenizer();
let mut canned = vec![0.0_f32; vocab];
canned[a_id as usize] = 10.0;
let target = MockModel {
canned: canned.clone(),
n_kv_heads: 1,
head_dim: 2,
};
let draft = MockModel {
canned,
n_kv_heads: 1,
head_dim: 2,
};
let mut s = ChatSession::builder(Box::new(target), tok, cache_config())
.speculative(SpeculativeDecodingConfig::new(
Rc::new(draft),
cache_config(),
))
.generate_params(GenConfig {
max_tokens: 12,
..Default::default()
})
.build()
.expect("build");
let mut produced_tokens: Vec<u32> = Vec::new();
let mut streamed = String::new();
{
let mut stream = s.stream_respond("hello").expect("speculative stream");
for _ in 0..3 {
let r = stream.next().expect("token").expect("ok");
produced_tokens.push(r.token);
streamed.push_str(&r.text);
}
}
assert_eq!(produced_tokens.len(), 3, "drained 3 tokens before drop");
assert!(
produced_tokens.iter().all(|&t| t == a_id),
"the mock samples the withheld `â` token every step"
);
assert_eq!(s.history().len(), 2, "interrupted turn still recorded");
let recorded = s.history()[1].content();
let reference = {
let mut d = crate::tokenizer::BpeStreamingDetokenizer::new(
vec![("â".to_string(), a_id)],
false,
);
for &t in &produced_tokens {
d.add_token(t);
}
d.finalize();
d.last_segment()
};
assert_eq!(
*recorded, reference,
"the interrupted speculative turn recorded token-complete text \
(detokenizer tail flushed): recorded {recorded:?} == finalized {reference:?}"
);
assert!(
!reference.is_empty() && recorded.len() > streamed.len(),
"the BPE detok genuinely withheld a tail that `commit()` flushed \
(streamed {streamed:?}, recorded {recorded:?})"
);
}
struct ErringAfterModel {
inner: MockModel,
ok_calls: std::cell::Cell<usize>,
}
impl ErringAfterModel {
fn new(inner: MockModel, ok_calls: usize) -> Self {
Self {
inner,
ok_calls: std::cell::Cell::new(ok_calls),
}
}
}
impl Model for ErringAfterModel {
fn forward(
&self,
tokens: &crate::array::Array,
cache: &mut [Box<dyn KvCache>],
) -> Result<crate::array::Array> {
let remaining = self.ok_calls.get();
if remaining == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"ErringAfterModel::forward",
"budget exhausted (test fixture)",
)));
}
self.ok_calls.set(remaining - 1);
self.inner.forward(tokens, cache)
}
fn forward_embeddings(
&self,
embeddings: &crate::array::Array,
cache: &mut [Box<dyn KvCache>],
) -> Result<crate::array::Array> {
self.inner.forward_embeddings(embeddings, cache)
}
}
#[test]
fn build_rejects_cache_plus_speculative_combination() {
let restored = prefilled_opaque_cache(8);
let res = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.cache(restored)
.speculative(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
))
.build();
let err = match res {
Ok(_) => panic!("cache + speculative must be rejected at build()"),
Err(e) => e,
};
let msg = format!("{err}");
assert!(
msg.contains("speculative") && msg.contains("cache"),
"error names the rejected combination: {msg}"
);
assert!(
msg.contains(".cache") && msg.contains(".speculative"),
"error points at both workarounds: {msg}"
);
let restored2 = prefilled_opaque_cache(8);
let res2 = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.speculative(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
))
.cache(restored2)
.build();
assert!(
res2.is_err(),
"rejection is independent of builder-method order"
);
let ok = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.speculative(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
))
.build();
assert!(
ok.is_ok(),
".speculative(..) alone still builds (only the cache+speculative combo is unsupported)"
);
}
#[test]
fn standard_error_terminated_stream_flushes_detokenizer_tail() {
let (tok, a_id, vocab) = bpe_withholding_tokenizer();
let mut canned = vec![0.0_f32; vocab];
canned[a_id as usize] = 10.0;
let inner = MockModel {
canned,
n_kv_heads: 1,
head_dim: 2,
};
let ok_steps: usize = 4;
let model = ErringAfterModel::new(inner, ok_steps);
let mut s = ChatSession::builder(Box::new(model), tok, cache_config())
.generate_params(GenConfig {
max_tokens: 32,
..Default::default()
})
.build()
.expect("build");
let mut produced_tokens: Vec<u32> = Vec::new();
let mut streamed = String::new();
let mut saw_err = false;
{
let stream = s.stream_respond("hello").expect("stream");
for resp in stream {
match resp {
Ok(r) => {
produced_tokens.push(r.token);
streamed.push_str(&r.text);
}
Err(_) => {
saw_err = true;
break;
}
}
}
}
assert!(saw_err, "the stream MUST yield an Err mid-generation");
assert!(
!produced_tokens.is_empty(),
"at least one token must have streamed before the Err"
);
assert!(
produced_tokens.iter().all(|&t| t == a_id),
"every sampled token is the withheld `â`"
);
assert_eq!(s.history().len(), 2, "error-terminated turn still recorded");
let recorded = s.history()[1].content();
let reference = {
let mut d =
crate::tokenizer::BpeStreamingDetokenizer::new(vec![("â".to_string(), a_id)], false);
for &t in &produced_tokens {
d.add_token(t);
}
d.finalize();
d.last_segment()
};
assert_eq!(
*recorded, reference,
"the error-terminated standard turn recorded token-complete text \
(detok tail flushed in commit): recorded {recorded:?} == finalized {reference:?}"
);
assert!(
!reference.is_empty() && recorded.len() > streamed.len(),
"the BPE detok genuinely withheld a tail commit() flushed \
(streamed {streamed:?}, recorded {recorded:?})"
);
}
#[test]
fn speculative_error_terminated_stream_flushes_detokenizer_tail() {
let (tok, a_id, vocab) = bpe_withholding_tokenizer();
let mut canned = vec![0.0_f32; vocab];
canned[a_id as usize] = 10.0;
let target_inner = MockModel {
canned: canned.clone(),
n_kv_heads: 1,
head_dim: 2,
};
let draft_inner = MockModel {
canned,
n_kv_heads: 1,
head_dim: 2,
};
let target = ErringAfterModel::new(target_inner, 64);
let draft = ErringAfterModel::new(draft_inner, 7);
let mut s = ChatSession::builder(Box::new(target), tok, cache_config())
.speculative(SpeculativeDecodingConfig::new(
Rc::new(draft),
cache_config(),
))
.generate_params(GenConfig {
max_tokens: 32,
..Default::default()
})
.build()
.expect("build");
let mut produced_tokens: Vec<u32> = Vec::new();
let mut streamed = String::new();
let mut saw_err = false;
{
let stream = s.stream_respond("hello").expect("speculative stream");
for resp in stream {
match resp {
Ok(r) => {
produced_tokens.push(r.token);
streamed.push_str(&r.text);
}
Err(_) => {
saw_err = true;
break;
}
}
}
}
assert!(
saw_err,
"the speculative stream MUST yield an Err mid-generation"
);
assert!(
!produced_tokens.is_empty(),
"at least one token must have streamed before the speculative Err"
);
assert!(
produced_tokens.iter().all(|&t| t == a_id),
"every sampled token is the withheld `â`"
);
assert_eq!(s.history().len(), 2, "error-terminated turn still recorded");
let recorded = s.history()[1].content();
let reference = {
let mut d =
crate::tokenizer::BpeStreamingDetokenizer::new(vec![("â".to_string(), a_id)], false);
for &t in &produced_tokens {
d.add_token(t);
}
d.finalize();
d.last_segment()
};
assert_eq!(
*recorded, reference,
"the error-terminated speculative turn recorded token-complete text: \
recorded {recorded:?} == finalized {reference:?}"
);
assert!(
!reference.is_empty() && recorded.len() > streamed.len(),
"speculative BPE detok genuinely withheld a tail commit() flushed \
(streamed {streamed:?}, recorded {recorded:?})"
);
}
#[test]
fn role_as_str_round_trips_every_variant() {
assert_eq!(Role::System.as_str(), "system");
assert_eq!(Role::User.as_str(), "user");
assert_eq!(Role::Assistant.as_str(), "assistant");
assert_eq!(Role::Tool.as_str(), "tool");
assert_eq!(format!("{}", Role::Tool), "tool");
assert!(Role::Tool.is_tool());
assert!(!Role::User.is_tool());
}
#[test]
fn chat_message_constructors_set_role_and_content() {
let t = ChatMessage::tool("tool result payload");
assert_eq!(t.role, Role::Tool);
assert_eq!(t.content(), "tool result payload");
let s = ChatMessage::system("sys");
assert_eq!(s.role, Role::System);
assert_eq!(s.content(), "sys");
let u = ChatMessage::user("u");
assert_eq!(u.role, Role::User);
let a = ChatMessage::assistant("a");
assert_eq!(a.role, Role::Assistant);
let n = ChatMessage::new(Role::Tool, String::from("owned"));
assert_eq!(n, ChatMessage::tool("owned"));
}
#[test]
fn chat_session_error_display_messages_are_distinct_and_actionable() {
let no_cache = format!("{}", ChatSessionError::NoCacheAvailable);
assert!(
no_cache.contains("no KV cache") && no_cache.contains("save_cache"),
"NoCacheAvailable Display points at the missing-generation cause: {no_cache}"
);
let spec_save = format!("{}", ChatSessionError::SpeculativeCacheUnsupported);
assert!(
spec_save.contains("speculative") && spec_save.contains("consumes its KV caches"),
"SpeculativeCacheUnsupported Display explains the consumed-cache reason: {spec_save}"
);
let spec_restore = format!("{}", ChatSessionError::SpeculativeCacheRestoreUnsupported);
assert!(
spec_restore.contains(".cache(") && spec_restore.contains(".speculative("),
"SpeculativeCacheRestoreUnsupported Display points at both workarounds: {spec_restore}"
);
assert_ne!(no_cache, spec_save);
assert_ne!(no_cache, spec_restore);
assert_ne!(spec_save, spec_restore);
let e: &dyn std::error::Error = &ChatSessionError::NoCacheAvailable;
assert!(e.source().is_none());
}
#[test]
fn save_cache_writes_a_safetensors_file_after_a_real_turn() {
let mut s = session(3);
let _ = s.respond("hello").expect("turn");
assert!(s.has_cache(), "the turn realised the cache");
let path = std::env::temp_dir().join(format!(
"mlxrs-l11-chat-session-savecache-{}.safetensors",
std::process::id()
));
let _ = std::fs::remove_file(&path);
s.save_cache(&path).expect("save the realised cache");
assert!(path.exists(), "save_cache wrote the safetensors file");
let (restored, _meta) =
crate::lm::cache::load_prompt_cache(&path).expect("reload the saved cache");
assert_eq!(
restored.len(),
cache_config().num_hidden_layers,
"the saved cache has one entry per decoder layer"
);
let _ = std::fs::remove_file(&path);
}
fn templateless_tokenizer() -> Tokenizer {
let tokenizer_json = json!({
"version": "1.0",
"truncation": Value::Null,
"padding": Value::Null,
"added_tokens": [],
"normalizer": Value::Null,
"pre_tokenizer": { "type": "Whitespace" },
"post_processor": Value::Null,
"decoder": Value::Null,
"model": {
"type": "WordLevel",
"vocab": { "<unk>": 0, "hello": 1, "world": 2 },
"unk_token": "<unk>"
}
});
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"mlxrs-l11-no-template-{}-{}",
std::process::id(),
seq
));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).expect("temp tokenizer dir");
std::fs::write(
dir.join("tokenizer.json"),
serde_json::to_string(&tokenizer_json).expect("serialize tokenizer.json"),
)
.expect("write tokenizer.json");
Tokenizer::from_path(&dir, Some(&[2u32])).expect("load templateless tokenizer")
}
#[test]
fn build_turn_prompt_maps_a_template_failure_to_a_parse_error() {
let s = ChatSession::builder(
Box::new(MockModel::new(3)),
templateless_tokenizer(),
cache_config(),
)
.generate_params(GenConfig {
max_tokens: 1,
..Default::default()
})
.build()
.expect("build");
let err = s
.build_turn_prompt("hello", Role::User)
.expect_err("templateless render must fail");
assert!(
matches!(err, Error::Parse(_)),
"the template failure is surfaced as Error::Parse, got {err:?}"
);
let msg = format!("{err}");
assert!(
msg.contains("chat template"),
"the parse error names the chat-template stage: {msg}"
);
let mut s2 = ChatSession::builder(
Box::new(MockModel::new(3)),
templateless_tokenizer(),
cache_config(),
)
.build()
.expect("build");
let resp_err = s2.respond("hello").expect_err("respond must propagate it");
assert!(matches!(resp_err, Error::Parse(_)));
}
#[test]
fn respond_propagates_a_generation_error_after_recording_the_turn() {
let inner = MockModel::new(11);
let model = ErringAfterModel::new(inner, 0);
let mut s = ChatSession::builder(Box::new(model), fixture_tokenizer(), cache_config())
.generate_params(GenConfig {
max_tokens: 8,
..Default::default()
})
.build()
.expect("build");
let err = s
.respond("hello")
.expect_err("generation error must propagate");
assert!(
matches!(err, Error::InvariantViolation(_)),
"the model's error is propagated verbatim, got {err:?}"
);
assert_eq!(s.history().len(), 2, "the failed turn is still recorded");
assert_eq!(s.history()[0].role, Role::User);
assert_eq!(s.history()[0].content(), "hello");
assert_eq!(s.history()[1].role, Role::Assistant);
}
#[test]
fn standard_zero_max_tokens_yields_nothing_and_realises_the_cache() {
let mut s = session(0);
let mut count = 0usize;
{
let stream = s.stream_respond("hello").expect("stream");
for resp in stream {
let _ = resp.expect("no error on the empty path");
count += 1;
}
}
assert_eq!(count, 0, "max_tokens = 0 yields no tokens");
assert!(s.has_cache(), "the empty turn still realised the cache");
assert_eq!(s.history().len(), 2);
assert_eq!(s.history()[1].role, Role::Assistant);
assert_eq!(s.history()[1].content(), "", "no reply was produced");
assert_eq!(
s.current_cache().expect("realised")[0].offset(),
0,
"prefill never ran on the zero-token path"
);
}
#[test]
fn speculative_zero_max_tokens_yields_nothing_and_marks_spent() {
let mut s = ChatSession::builder(
Box::new(MockModel::new(11)),
fixture_tokenizer(),
cache_config(),
)
.speculative(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
))
.generate_params(GenConfig {
max_tokens: 0,
..Default::default()
})
.build()
.expect("build");
let mut count = 0usize;
{
let stream = s.stream_respond("hello").expect("speculative stream");
for resp in stream {
let _ = resp.expect("no error on the empty speculative path");
count += 1;
}
}
assert_eq!(count, 0, "max_tokens = 0 yields no speculative tokens");
assert_eq!(s.history().len(), 2, "the empty turn is recorded");
assert_eq!(s.history()[1].content(), "");
assert!(!s.has_cache());
assert!(s.current_cache().is_none());
}
#[test]
fn standard_turn_finishes_on_eos_with_a_stop_reason() {
let mut canned = vec![0.0_f32; 11];
canned[2] = 10.0; let model = MockModel {
canned,
n_kv_heads: 1,
head_dim: 2,
};
let mut s = ChatSession::builder(Box::new(model), fixture_tokenizer(), cache_config())
.generate_params(GenConfig {
max_tokens: 32,
..Default::default()
})
.build()
.expect("build");
let mut reasons = Vec::new();
let mut tokens = Vec::new();
{
let stream = s.stream_respond("hello").expect("stream");
for resp in stream {
let r = resp.expect("step");
reasons.push(r.finish_reason);
tokens.push(r.token);
}
}
assert_eq!(reasons, vec![Some(FinishReason::Eos)]);
assert_eq!(
tokens,
vec![2],
"the eos token (id 2) was the yielded token"
);
assert!(s.has_cache());
assert_eq!(s.history().len(), 2);
}
#[test]
fn take_cache_allocates_the_draft_cache_for_a_realised_speculative_slot() {
let mut s = ChatSession {
model: Box::new(MockModel::new(11)),
tokenizer: fixture_tokenizer(),
cache_config: cache_config(),
instructions: None,
generate_params: GenConfig::default(),
speculative: Some(SpeculativeDecodingConfig::new(
Rc::new(MockModel::new(11)),
cache_config(),
)),
cache: CacheSlot::Realised {
cache: make_prompt_cache(&cache_config()),
draft_cache: None,
cached: CachedTokens::empty(),
},
history: Vec::new(),
};
let (main, draft, cached) = s.take_cache();
assert_eq!(
main.len(),
cache_config().num_hidden_layers,
"the main cache is returned unchanged"
);
let draft = draft.expect("a speculative session's draft cache is allocated");
assert_eq!(
draft.len(),
cache_config().num_hidden_layers,
"the freshly-built draft cache has one entry per draft-model layer"
);
assert_eq!(cached.opaque_len, 0, "an empty cached-token record carried");
assert!(cached.known.is_empty());
}
#[test]
fn rc_model_forwards_and_forward_embeddings_delegate_to_the_inner_model() {
let rc: Rc<dyn Model> = Rc::new(MockModel::new(5));
let m = RcModel(Rc::clone(&rc));
let mut cache = make_prompt_cache(&cache_config());
let tokens = crate::array::Array::from_slice::<i32>(&[0, 1, 2], &(1usize, 3)).expect("tokens");
let logits = m.forward(&tokens, &mut cache).expect("forward delegates");
assert_eq!(logits.shape(), vec![1, 3, 5]);
assert!(cache.iter().all(|c| c.offset() == 3));
let mut cache2: Vec<Box<dyn KvCache>> = Vec::new();
let emb = crate::array::Array::from_slice::<f32>(&[0.0, 1.0], &(1usize, 1, 2)).expect("emb");
assert!(
m.forward_embeddings(&emb, &mut cache2).is_err(),
"forward_embeddings forwards to the inner model's erroring default seam"
);
}
#[test]
fn commit_falls_back_to_opaque_when_the_cache_outruns_the_named_tokens() {
let model = MockModel::new(11);
let m: &dyn Model = &model;
let advanced = prefilled_opaque_cache(20);
assert_eq!(advanced[0].offset(), 20);
let generator = build_generator(m, &[3u32], advanced, GenConfig::default());
let driver: Driver<'_> = Driver::Standard(Box::new(StandardTurn {
generator,
draft_cache: None,
}));
let mut slot = CacheSlot::Empty;
let mut history: Vec<ChatMessage> = Vec::new();
{
let mut stream = ChatResponseStream {
cache_slot: &mut slot,
history: &mut history,
driver: Some(driver),
detok: fixture_tokenizer().detokenizer(),
eos: vec![2],
max_tokens: 4,
prompt_tokens: 1,
produced: 0,
reply: String::new(),
prompt_ids: vec![3],
opaque_len: 0,
generated: Vec::new(),
finished: false,
detok_finalized: false,
committed: false,
};
stream.commit();
}
match slot {
CacheSlot::Realised { cache, cached, .. } => {
assert_eq!(cache[0].offset(), 20, "the advanced cache was stored");
assert_eq!(
cached.opaque_len, 20,
"the cache that outran the named tokens is treated as fully opaque"
);
assert!(
cached.known.is_empty(),
"no known region recorded for the opaque fallback"
);
}
_ => panic!("commit() must realise the cache"),
}
assert_eq!(history.len(), 1);
assert_eq!(history[0].role, Role::Assistant);
}