use super::*;
use crate::lm::cache::{CacheConfig, KvCache, make_prompt_cache};
fn fixture_tokenizer() -> crate::tokenizer::Tokenizer {
let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures");
crate::tokenizer::Tokenizer::from_path(&dir, None).expect("load fixture tokenizer")
}
struct ScriptModel {
vocab: usize,
prompt_len: usize,
script: Vec<u32>,
}
impl Model for ScriptModel {
fn forward(&self, tokens: &Array, cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let shape = tokens.shape();
let (batch, seq) = match shape.as_slice() {
[b, s] => (*b, *s),
[s] => (1usize, *s),
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"ScriptModel::forward: tokens must be rank-1 [S] or rank-2 [B, S]",
other.len() as u32,
other.to_vec(),
)));
}
};
for layer in cache.iter_mut() {
let elems = batch * seq;
let k = Array::from_slice::<f32>(&vec![1.0_f32; elems], &(batch, 1usize, seq, 1usize))?;
let v = Array::from_slice::<f32>(&vec![2.0_f32; elems], &(batch, 1usize, seq, 1usize))?;
layer.update(&k, &v)?;
}
let cache_offset = cache.first().map(|c| c.offset()).unwrap_or(0);
let script_idx = cache_offset.checked_sub(self.prompt_len);
let pred = script_idx
.and_then(|i| self.script.get(i).copied())
.unwrap_or(0);
let mut data = vec![0.0_f32; batch * seq * self.vocab];
if (pred as usize) < self.vocab {
for pos in 0..batch * seq {
data[pos * self.vocab + pred as usize] = 10.0;
}
}
Array::from_slice::<f32>(&data, &(batch, seq, self.vocab))
}
}
fn run(
prompt: &[u32],
script: Vec<u32>,
max_tokens: usize,
stop_strings: Vec<String>,
) -> (String, Vec<Option<FinishReason>>) {
let tok = fixture_tokenizer();
let vocab = 16usize;
let model = ScriptModel {
vocab,
prompt_len: prompt.len(),
script,
};
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let cfg = GenConfig {
max_tokens,
stop_strings,
..Default::default()
};
let mut text = String::new();
let mut reasons = Vec::new();
for resp in stream_generate(&model, &tok, prompt, cache, cfg) {
let r = resp.expect("stream step");
text.push_str(&r.text);
reasons.push(r.finish_reason);
}
(text, reasons)
}
fn decode(ids: &[u32]) -> String {
fixture_tokenizer().decode(ids, false).expect("decode")
}
#[test]
fn empty_stop_strings_is_eos_only_unchanged() {
let prompt = [1u32, 3]; let script = vec![4u32, 5, 2, 6, 7]; let (text, reasons) = run(&prompt, script, 32, Vec::new());
assert_eq!(text, decode(&[4, 5]));
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Eos));
assert_eq!(reasons.iter().filter(|r| r.is_some()).count(), 1);
}
#[test]
fn single_token_stop_string_stops_and_trims() {
let prompt = [1u32, 3];
let script = vec![3u32, 4, 5, 6, 7]; let stop = decode(&[4]); let (text, reasons) = run(&prompt, script, 32, vec![stop.clone()]);
let full = decode(&[3, 4, 5]); let cut = full.find(&stop).expect("stop substring present in decode");
assert_eq!(text, full[..cut].to_string());
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Stop(stop)));
}
#[test]
fn multi_token_stop_spanning_boundary_stops_and_trims() {
let prompt = [1u32, 3];
let script = vec![3u32, 5, 6, 7, 8]; let stop = decode(&[5, 6]); let (text, reasons) = run(&prompt, script, 32, vec![stop.clone()]);
let full = decode(&[3, 5, 6, 7]); let cut = full.find(&stop).expect("multi-token stop present");
assert_eq!(text, full[..cut].to_string());
assert!(!text.is_empty());
assert!(!text.contains(&stop));
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Stop(stop)));
}
#[test]
fn partial_match_then_diverge_does_not_stop() {
let prompt = [1u32, 3];
let script = vec![3u32, 5, 8, 4, 7]; let stop = decode(&[5, 6]);
let (text, reasons) = run(&prompt, script, 5, vec![stop.clone()]);
assert!(!text.contains(&stop), "stop string must not appear");
assert_eq!(text, decode(&[3, 5, 8, 4, 7]));
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Length));
assert!(
reasons
.iter()
.all(|r| r.as_ref() != Some(&FinishReason::Eos)),
"no premature stop on the partial match"
);
}
#[test]
fn stop_completes_mid_token_trims_at_char_boundary() {
let prompt = [1u32, 3];
let script = vec![3u32, 6, 7, 8, 4]; let quick = decode(&[6]); let trimmed = quick.trim_start();
assert!(trimmed.len() >= 3, "need a multi-char token to cut");
let stop = trimmed[..trimmed.len() - 1].to_string(); let (text, reasons) = run(&prompt, script, 32, vec![stop.clone()]);
let full = decode(&[3, 6, 7]); let cut = full.find(&stop).expect("mid-token stop prefix present");
assert_eq!(text, full[..cut].to_string());
assert!(!text.contains(&stop));
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Stop(stop)));
}
#[test]
fn multiple_stop_strings_first_completion_wins() {
let prompt = [1u32, 3];
let script = vec![3u32, 4, 5, 6, 7]; let early = decode(&[4]); let late = decode(&[6]); let (text, reasons) = run(&prompt, script, 32, vec![late.clone(), early.clone()]);
let full = decode(&[3, 4]); let cut = full.find(&early).expect("early stop present");
assert_eq!(text, full[..cut].to_string());
assert!(!text.contains(&early));
assert!(!text.contains(&late));
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Stop(early)));
}
#[test]
fn finish_reason_is_stop_on_stop_string_match() {
let prompt = [1u32, 3];
let script = vec![3u32, 4, 5, 6, 7];
let stop = decode(&[5]); let (_text, reasons) = run(&prompt, script, 32, vec![stop.clone()]);
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Stop(stop)));
assert_eq!(reasons.iter().filter(|r| r.is_some()).count(), 1);
}
#[derive(Default)]
struct WithholdDetokenizer {
text: String,
pending: String,
tokens: Vec<u32>,
offset: usize,
}
impl WithholdDetokenizer {
fn push(&mut self, s: &str, withhold: bool) {
self.text.push_str(&self.pending);
self.pending.clear();
if withhold {
self.pending.push_str(s);
} else {
self.text.push_str(s);
}
self.tokens.push(self.tokens.len() as u32);
}
}
impl crate::tokenizer::StreamingDetokenizer for WithholdDetokenizer {
fn reset(&mut self) {
self.text.clear();
self.pending.clear();
self.tokens.clear();
self.offset = 0;
}
fn add_token(&mut self, _token: u32) {}
fn finalize(&mut self) {
self.text.push_str(&self.pending);
self.pending.clear();
}
fn text(&self) -> std::borrow::Cow<'_, str> {
std::borrow::Cow::Borrowed(&self.text)
}
fn tokens(&self) -> &[u32] {
&self.tokens
}
fn offset(&self) -> usize {
self.offset
}
fn set_offset(&mut self, offset: usize) {
self.offset = offset;
}
}
#[test]
fn mock_withholds_tail_until_finalize() {
use crate::tokenizer::StreamingDetokenizer;
let mut d = WithholdDetokenizer::default();
d.push("hello", false);
d.push(" ", true); assert_eq!(d.text().as_ref(), "hello"); d.finalize();
assert_eq!(d.text().as_ref(), "hello "); }
#[test]
fn finalized_tail_completes_stop_on_eos_trims_and_reports_stop() {
let stop = crate::lm::stop::StopMatcher::new(vec![" ".to_string()]);
let mut d = WithholdDetokenizer::default();
d.push("hello", false);
d.push(" ", true); let mut emitted_len = "hello".len();
crate::tokenizer::StreamingDetokenizer::finalize(&mut d);
let (text, reason) = finalize_active_tail(&d, &stop, &mut emitted_len, FinishReason::Eos);
assert_eq!(text, "");
assert_eq!(reason, FinishReason::Stop(" ".to_string()));
assert!(!text.contains(' '), "the bare space must not be emitted");
}
#[test]
fn finalized_tail_completes_stop_on_max_tokens_wins_over_length() {
let stop = crate::lm::stop::StopMatcher::new(vec![" ".to_string()]);
let mut d = WithholdDetokenizer::default();
d.push("hi", false);
d.push(" ", true); let mut emitted_len = "hi".len();
crate::tokenizer::StreamingDetokenizer::finalize(&mut d);
let (text, reason) = finalize_active_tail(&d, &stop, &mut emitted_len, FinishReason::Length);
assert_eq!(
reason,
FinishReason::Stop(" ".to_string()),
"finalized-tail stop must win over length and carry the matched payload"
);
assert_eq!(text, ""); assert!(!text.contains(' '));
}
#[test]
fn finalized_tail_no_stop_emits_tail_with_default_reason() {
let stop = crate::lm::stop::StopMatcher::new(vec!["ZZZ".to_string()]); let mut d = WithholdDetokenizer::default();
d.push("hi", false);
d.push(" ", true); let mut emitted_len = "hi".len();
crate::tokenizer::StreamingDetokenizer::finalize(&mut d);
let (text, reason) = finalize_active_tail(&d, &stop, &mut emitted_len, FinishReason::Length);
assert_eq!(text, " ");
assert_eq!(reason, FinishReason::Length);
let mut d2 = WithholdDetokenizer::default();
d2.push("hi", false);
d2.push(" ", true);
let mut emitted_len2 = "hi".len();
crate::tokenizer::StreamingDetokenizer::finalize(&mut d2);
let (text2, reason2) = finalize_active_tail(&d2, &stop, &mut emitted_len2, FinishReason::Eos);
assert_eq!(text2, " ");
assert_eq!(reason2, FinishReason::Eos);
}
#[test]
fn finalized_tail_completes_multichar_stop_spanning_into_tail() {
let stop = crate::lm::stop::StopMatcher::new(vec!["abc".to_string()]);
let mut d = WithholdDetokenizer::default();
d.push("ab", false);
d.push("c", true);
let mut emitted_len = "ab".len();
crate::tokenizer::StreamingDetokenizer::finalize(&mut d);
let (text, reason) = finalize_active_tail(&d, &stop, &mut emitted_len, FinishReason::Length);
assert_eq!(reason, FinishReason::Stop("abc".to_string()));
assert_eq!(text, "");
assert_eq!(emitted_len, 2);
}
#[test]
fn eos_with_active_non_matching_matcher_reports_eos() {
let prompt = [1u32, 3];
let script = vec![4u32, 5, 2, 6, 7]; let (text, reasons) = run(&prompt, script, 32, vec!["ZZZ".to_string()]);
assert_eq!(text, decode(&[4, 5]));
assert_eq!(reasons.last().unwrap(), &Some(FinishReason::Eos));
assert_eq!(reasons.iter().filter(|r| r.is_some()).count(), 1);
}
#[test]
fn generate_collects_text_and_reports_stats_on_eos() {
let tok = fixture_tokenizer();
let model = ScriptModel {
vocab: 16,
prompt_len: 2,
script: vec![4u32, 5, 2, 6, 7], };
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let cfg = GenConfig {
max_tokens: 32,
..Default::default()
};
let (text, stats) = generate(&model, &tok, &[1u32, 3], cache, cfg).expect("generate ok");
assert_eq!(text, decode(&[4, 5]), "eos token contributes no text");
assert_eq!(stats.prompt_tokens, 2, "prompt was 2 tokens");
assert_eq!(
stats.generation_tokens, 3,
"two emitted tokens + the eos-bearing final response (n + 1)"
);
assert!(stats.prompt_tps >= 0.0 && stats.generation_tps >= 0.0);
}
#[test]
fn generate_zero_max_tokens_empty_text_zero_stats() {
let tok = fixture_tokenizer();
let model = ScriptModel {
vocab: 16,
prompt_len: 3,
script: vec![4u32, 5, 6],
};
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let cfg = GenConfig {
max_tokens: 0,
..Default::default()
};
let (text, stats) = generate(&model, &tok, &[1u32, 3, 4], cache, cfg).expect("generate ok");
assert_eq!(text, "", "no tokens produced ⇒ empty output");
assert_eq!(stats.generation_tokens, 0);
assert_eq!(
stats.prompt_tokens, 3,
"prompt_tokens preserved on the empty run"
);
assert_eq!(stats.prompt_tps, 0.0);
assert_eq!(stats.generation_tps, 0.0);
}
#[test]
fn generate_propagates_step_error() {
struct FailModel;
impl Model for FailModel {
fn forward(&self, _t: &Array, _c: &mut [Box<dyn KvCache>]) -> Result<Array> {
Err(Error::InvariantViolation(
crate::error::InvariantViolationPayload::new("FailModel::forward", "mock forward failure"),
))
}
}
let tok = fixture_tokenizer();
let cache = make_prompt_cache(&CacheConfig {
num_hidden_layers: 1,
sliding_window: None,
});
let cfg = GenConfig {
max_tokens: 4,
..Default::default()
};
let res = generate(&FailModel, &tok, &[1u32, 3], cache, cfg);
assert!(
res.is_err(),
"a forward failure surfaces as Err from generate"
);
}