use super::*;
use crate::lm::{cache::BatchKvCache, model::Model};
struct MockBatchModel {
canned: Vec<f32>, vocab: usize,
max_len: usize,
scripts: Vec<Vec<u32>>,
}
impl MockBatchModel {
fn new(vocab: usize, max_len: usize, scripts: Vec<Vec<u32>>) -> Self {
Self {
canned: vec![0.0; vocab],
vocab,
max_len,
scripts,
}
}
}
impl Model for MockBatchModel {
fn forward(
&self,
tokens: &Array,
cache: &mut [Box<dyn crate::lm::cache::KvCache>],
) -> Result<Array> {
let shape = tokens.shape();
let (batch, seq) = match shape.as_slice() {
[b, s] => (*b, *s),
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockBatchModel::forward: tokens must be rank-2 [B, S]",
other.len() as u32,
other.to_vec(),
)));
}
};
for layer in cache.iter_mut() {
let elems = batch * seq;
let k = Array::from_slice::<f32>(&vec![1.0_f32; elems], &(batch, 1usize, seq, 1usize))?;
let v = Array::from_slice::<f32>(&vec![2.0_f32; elems], &(batch, 1usize, seq, 1usize))?;
layer.update(&k, &v)?;
}
let cache_offset = cache.first().map(|c| c.offset()).unwrap_or(0);
let script_idx = cache_offset.checked_sub(self.max_len);
let mut data: Vec<f32> = Vec::with_capacity(batch * seq * self.vocab);
for row in 0..batch {
let pred = script_idx
.and_then(|i| self.scripts.get(row).and_then(|s| s.get(i).copied()))
.unwrap_or(0);
for _ in 0..seq {
let mut row_logits = self.canned.clone();
if (pred as usize) < self.vocab {
row_logits[pred as usize] = 10.0;
}
data.extend_from_slice(&row_logits);
}
}
Array::from_slice::<f32>(&data, &(batch, seq, self.vocab))
}
}
#[test]
fn batch_generate_left_pads_and_emits_per_row_sequences() {
let scripts = vec![
vec![11, 12, 13, 14, 15],
vec![21, 22, 5, 99, 99],
];
let prompts: Vec<&[u32]> = vec![&[1u32, 2, 3], &[7u32]];
let left_pad = batch_left_padding(&prompts);
assert_eq!(left_pad, vec![0, 2]);
let max_len = 3; let model = MockBatchModel::new(32, max_len, scripts);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 5,
eos: vec![5],
..Default::default()
};
let batch_gen = batch_generate_step(&model, &prompts, 0, cache, cfg);
let mut rows: Vec<Vec<u32>> = vec![Vec::new(); 2];
let mut last_step_per_row: Vec<Option<FinishReason>> = vec![None; 2];
for item in batch_gen {
let step = item.expect("step error");
match &step.finish_reason {
Some(r) if r.is_eos() => {}
_ => rows[step.row].push(step.token),
}
if let Some(r) = step.finish_reason {
last_step_per_row[step.row] = Some(r);
}
}
assert_eq!(rows[0], vec![11, 12, 13, 14, 15]);
assert_eq!(last_step_per_row[0], Some(FinishReason::Length));
assert_eq!(rows[1], vec![21, 22]);
assert_eq!(last_step_per_row[1], Some(FinishReason::Eos));
}
#[test]
fn batch_left_padding_three_ragged_rows() {
let prompts: Vec<&[u32]> = vec![&[1u32, 2, 3, 4], &[5u32, 6], &[7u32]];
let left_pad = batch_left_padding(&prompts);
assert_eq!(left_pad, vec![0, 2, 3]);
}
#[test]
fn batch_generate_per_row_eos_independent_finish() {
let scripts = vec![
vec![10, 11, 12, 99, 99],
vec![20, 5, 99, 99, 99],
vec![30, 31, 32, 99, 99],
];
let prompts: Vec<&[u32]> = vec![&[1u32, 2], &[3u32, 4], &[5u32]];
let left_pad = batch_left_padding(&prompts);
assert_eq!(left_pad, vec![0, 0, 1]); let max_len = 2; let model = MockBatchModel::new(64, max_len, scripts);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 3,
eos: vec![5],
..Default::default()
};
let mut rows: Vec<Vec<u32>> = vec![Vec::new(); 3];
let mut last_step_per_row: Vec<Option<FinishReason>> = vec![None; 3];
for item in batch_generate_step(&model, &prompts, 0, cache, cfg) {
let step = item.expect("step error");
match &step.finish_reason {
Some(r) if r.is_eos() => {}
_ => rows[step.row].push(step.token),
}
if let Some(r) = step.finish_reason {
last_step_per_row[step.row] = Some(r);
}
}
assert_eq!(rows[0], vec![10, 11, 12]);
assert_eq!(last_step_per_row[0], Some(FinishReason::Length));
assert_eq!(rows[1], vec![20]); assert_eq!(last_step_per_row[1], Some(FinishReason::Eos));
assert_eq!(rows[2], vec![30, 31, 32]);
assert_eq!(last_step_per_row[2], Some(FinishReason::Length));
}
#[test]
fn batch_generate_step_empty_prompts_is_err() {
let model = MockBatchModel::new(8, 0, vec![]);
let prompts: Vec<&[u32]> = vec![];
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = Vec::new();
let mut batch_gen = batch_generate_step(&model, &prompts, 0, cache, GenConfig::default());
assert!(batch_gen.next().unwrap().is_err());
assert!(batch_gen.next().is_none()); }
#[test]
fn batch_generate_step_empty_row_is_err() {
let model = MockBatchModel::new(8, 0, vec![vec![]]);
let prompts: Vec<&[u32]> = vec![&[]];
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = Vec::new();
let mut batch_gen = batch_generate_step(&model, &prompts, 0, cache, GenConfig::default());
assert!(batch_gen.next().unwrap().is_err());
assert!(batch_gen.next().is_none());
}
#[test]
fn batch_generate_step_zero_max_tokens_emits_nothing_and_skips_prefill() {
let prompts: Vec<&[u32]> = vec![&[1u32, 2, 3], &[7u32]];
let left_pad = batch_left_padding(&prompts);
let max_len = 3;
let model = MockBatchModel::new(16, max_len, vec![vec![], vec![]]);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
assert_eq!(cache[0].offset(), 0);
let cfg = GenConfig {
max_tokens: 0,
eos: vec![5],
..Default::default()
};
let batch_gen = batch_generate_step(&model, &prompts, 0, cache, cfg);
assert_eq!(batch_gen.count(), 0);
}
#[test]
fn batch_generate_zero_max_tokens_returns_empty_vec_per_row() {
let prompts: Vec<&[u32]> = vec![&[1u32, 2, 3], &[7u32], &[9u32, 10]];
let left_pad = batch_left_padding(&prompts);
let max_len = 3;
let model = MockBatchModel::new(16, max_len, vec![vec![], vec![], vec![]]);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 0,
..Default::default()
};
let b = prompts.len();
let mut results: Vec<Vec<u32>> = vec![Vec::new(); b];
for step in batch_generate_step(&model, &prompts, 0, cache, cfg) {
let step = step.expect("zero-budget guard must not yield Err");
match &step.finish_reason {
Some(r) if r.is_eos() => {}
_ => results[step.row].push(step.token),
}
}
assert_eq!(results, vec![Vec::<u32>::new(); b]);
}
#[test]
fn batch_stream_generate_finished_row_not_re_emitted() {
let prompts: Vec<&[u32]> = vec![&[1u32, 2], &[3u32, 4]];
let left_pad = batch_left_padding(&prompts);
assert_eq!(left_pad, vec![0, 0]);
let max_len = 2;
let max_tokens = 5;
let scripts = vec![
vec![5u32, 99, 99, 99, 99], vec![20u32, 21, 22, 23, 24],
];
let model = MockBatchModel::new(64, max_len, scripts);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens,
eos: vec![5],
..Default::default()
};
let mut emits_per_row: Vec<usize> = vec![0; 2];
let mut finish_per_row: Vec<Option<FinishReason>> = vec![None; 2];
for item in batch_generate_step(&model, &prompts, 0, cache, cfg) {
let step = item.expect("step error");
emits_per_row[step.row] += 1;
if let Some(r) = step.finish_reason {
assert!(
finish_per_row[step.row].is_none(),
"row {} got a second finish_reason emit: prior={:?}, new={:?}",
step.row,
finish_per_row[step.row],
r,
);
finish_per_row[step.row] = Some(r);
}
}
assert_eq!(
emits_per_row[0], 1,
"row 0 finished on step 1 but was re-emitted on later steps (got {} emits, expected 1)",
emits_per_row[0]
);
assert_eq!(finish_per_row[0], Some(FinishReason::Eos));
assert_eq!(
emits_per_row[1], max_tokens,
"row 1 expected {max_tokens} emits, got {}",
emits_per_row[1]
);
assert_eq!(finish_per_row[1], Some(FinishReason::Length));
}
struct BatchFailModel {
calls: std::cell::RefCell<usize>,
}
impl Model for BatchFailModel {
fn forward(
&self,
_tokens: &Array,
_cache: &mut [Box<dyn crate::lm::cache::KvCache>],
) -> Result<Array> {
*self.calls.borrow_mut() += 1;
Err(Error::InvariantViolation(
crate::error::InvariantViolationPayload::new(
"BatchFailModel::forward",
"mock batch forward failure (test fixture)",
),
))
}
}
#[test]
fn batch_generate_step_propagates_validate_err_before_forward() {
let model = BatchFailModel {
calls: std::cell::RefCell::new(0),
};
let prompts: Vec<&[u32]> = vec![&[1u32, 2, 3], &[4u32]];
let left_pad = batch_left_padding(&prompts);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
temp: -1.0, max_tokens: 4,
..GenConfig::default()
};
let mut it = batch_generate_step(&model, &prompts, 0, cache, cfg);
let first = it.next().expect("iterator yields at least one item");
let err = first.expect_err("validation Err must propagate");
let msg = format!("{err:?}");
assert!(
msg.contains("temp"),
"yielded validation error, not the forward error (validate ran BEFORE forward): {msg}"
);
assert!(
!msg.contains("mock batch forward failure"),
"model.forward must NOT have been called (validate fail-fast): {msg}"
);
assert_eq!(
*model.calls.borrow(),
0,
"model.forward was called {} time(s) — validate gate did not fail-fast",
*model.calls.borrow()
);
assert!(it.next().is_none(), "iterator fuses after the yielded Err");
}
#[test]
fn left_pad_rows_rejects_empty_inputs() {
let empty: Vec<&[u32]> = vec![];
assert!(matches!(
left_pad_rows(&empty, 0),
Err(Error::EmptyInput(_))
));
let all_empty: Vec<&[u32]> = vec![&[], &[]];
assert!(matches!(
left_pad_rows(&all_empty, 0),
Err(Error::EmptyInput(_))
));
}
#[test]
fn left_pad_rows_rejects_ragged_empty_row() {
let ragged: Vec<&[u32]> = vec![&[1u32, 2], &[]];
let err = left_pad_rows(&ragged, 0).unwrap_err();
assert!(
matches!(err, Error::EmptyInput(ref p) if p.context().contains("every prompt")),
"a ragged empty row ⇒ EmptyInput(every prompt), got {err:?}"
);
}
#[test]
fn left_pad_rows_pads_and_preserves_tail() {
let prompts: Vec<&[u32]> = vec![&[1u32, 2, 3], &[7u32]];
let (padded, max_len) = left_pad_rows(&prompts, 99).unwrap();
assert_eq!(max_len, 3);
assert_eq!(padded, vec![vec![1, 2, 3], vec![99, 99, 7]]);
}
fn fixture_tokenizer() -> crate::tokenizer::Tokenizer {
let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures");
crate::tokenizer::Tokenizer::from_path(&dir, None).expect("load fixture tokenizer")
}
#[test]
fn batch_generate_drops_eos_keeps_length_tokens() {
let tok = fixture_tokenizer();
let prompts: Vec<&[u32]> = vec![&[1u32, 1], &[1u32, 1]];
let left_pad = batch_left_padding(&prompts);
assert_eq!(left_pad, vec![0, 0]);
let max_len = 2;
let scripts = vec![
vec![7u32, 2, 99, 99], vec![8u32, 9, 10, 99], ];
let model = MockBatchModel::new(32, max_len, scripts);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 3,
..Default::default()
};
let out = batch_generate(&model, &tok, &prompts, 0, cache, cfg).expect("batch_generate ok");
assert_eq!(out.len(), 2, "one output row per prompt");
assert_eq!(out[0], vec![7], "row 0: token 7 kept, eos(2) dropped");
assert_eq!(out[1], vec![8, 9, 10], "row 1: length-finish token kept");
}
#[test]
fn batch_generate_zero_max_tokens_empty_rows() {
let tok = fixture_tokenizer();
let prompts: Vec<&[u32]> = vec![&[1u32, 1], &[1u32, 1], &[1u32, 1]];
let left_pad = batch_left_padding(&prompts);
let model = MockBatchModel::new(16, 2, vec![vec![], vec![], vec![]]);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 0,
..Default::default()
};
let out = batch_generate(&model, &tok, &prompts, 0, cache, cfg).unwrap();
assert_eq!(out, vec![Vec::<u32>::new(); 3]);
}
#[test]
fn batch_stream_generate_uses_tokenizer_eos() {
let tok = fixture_tokenizer();
let prompts: Vec<&[u32]> = vec![&[1u32, 1]];
let left_pad = batch_left_padding(&prompts);
let model = MockBatchModel::new(16, 2, vec![vec![5u32, 2, 99]]); let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 5,
..Default::default()
};
let mut last_reason: Option<FinishReason> = None;
let mut tokens = Vec::new();
for item in batch_stream_generate(&model, &tok, &prompts, 0, cache, cfg) {
let step = item.expect("step ok");
match &step.finish_reason {
Some(r) if r.is_eos() => last_reason = step.finish_reason.clone(),
_ => tokens.push(step.token),
}
}
assert_eq!(tokens, vec![5], "token 5 emitted before eos");
assert_eq!(
last_reason,
Some(FinishReason::Eos),
"tokenizer eos {{2}} drove the stop even with empty cfg.eos"
);
}
#[test]
fn batch_generate_with_repetition_penalty_runs_per_row_processor() {
let tok = fixture_tokenizer();
let prompts: Vec<&[u32]> = vec![&[1u32, 1], &[1u32, 1]];
let left_pad = batch_left_padding(&prompts);
let max_len = 2;
let scripts = vec![vec![10u32, 11, 12, 99], vec![20u32, 21, 22, 99]];
let model = MockBatchModel::new(32, max_len, scripts);
let cache: Vec<Box<dyn crate::lm::cache::KvCache>> = vec![Box::new(BatchKvCache::new(&left_pad))];
let cfg = GenConfig {
max_tokens: 3,
repetition_penalty: Some(2.0), ..Default::default()
};
let out = batch_generate(&model, &tok, &prompts, 0, cache, cfg).expect("ok");
assert_eq!(
out[0],
vec![10, 11, 12],
"row 0 unaffected by no-repeat penalty"
);
assert_eq!(
out[1],
vec![20, 21, 22],
"row 1 unaffected by no-repeat penalty"
);
}