use super::*;
use crate::lm::{cache::KvCache, model::MockModel};
fn no_cache() -> CacheConfig {
CacheConfig {
num_hidden_layers: 0,
sliding_window: None,
}
}
fn expected_ppl(canned: &[f32], windows: &[&[i32]]) -> f64 {
let max = canned.iter().copied().fold(f32::NEG_INFINITY, f32::max) as f64;
let sumexp: f64 = canned.iter().map(|&x| (x as f64 - max).exp()).sum();
let lse = max + sumexp.ln();
let mut total = 0.0f64;
let mut n = 0usize;
for row in windows {
for &t in &row[1..] {
let score = canned[t as usize] as f64;
total += lse - score;
n += 1;
}
}
(total / n as f64).exp()
}
fn matrix(rows: &[&[i32]]) -> Array {
let n = rows.len();
let l = rows[0].len();
let mut data: Vec<i32> = Vec::with_capacity(n * l);
for r in rows {
assert_eq!(r.len(), l, "ragged test matrix");
data.extend_from_slice(r);
}
Array::from_slice::<i32>(&data, &(n, l)).unwrap()
}
#[test]
fn hand_traced_single_window_matches_reference() {
let model = MockModel::new(5);
let row: &[i32] = &[0, 1, 2, 3, 4];
let data = matrix(&[row]);
let res = perplexity(&model, &data, 8, &no_cache()).unwrap();
assert_eq!(res.num_tokens, 4);
let want = expected_ppl(&model.canned, &[row]);
assert!(
(res.perplexity as f64 - want).abs() < 1e-4,
"ppl {} != hand-traced {want}",
res.perplexity
);
assert!((res.mean_loss as f64 - want.ln()).abs() < 1e-4);
let mut losses = res.losses;
assert_eq!(losses.to_vec::<f32>().unwrap().len(), 4);
}
#[test]
fn uniform_logits_ppl_approximates_vocab_size() {
for vocab in [2usize, 7, 50] {
let model = MockModel {
canned: vec![0.0; vocab],
n_kv_heads: 1,
head_dim: 2,
};
let row: Vec<i32> = (0..vocab as i32).collect();
let data = matrix(&[&row]);
let res = perplexity(&model, &data, 4, &no_cache()).unwrap();
assert!(
(res.perplexity as f64 - vocab as f64).abs() < 1e-3,
"uniform vocab {vocab}: ppl {} != V",
res.perplexity
);
}
}
#[test]
fn multi_window_batched_matches_unbatched_aggregation() {
let model = MockModel::new(6);
let rows: Vec<&[i32]> = vec![
&[0, 1, 2, 3],
&[5, 4, 3, 2],
&[1, 1, 5, 0],
&[2, 3, 4, 5],
&[5, 5, 5, 5],
];
let data = matrix(&rows);
let res = perplexity(&model, &data, 2, &no_cache()).unwrap();
assert_eq!(res.num_tokens, 15);
let want = expected_ppl(&model.canned, &rows);
assert!(
(res.perplexity as f64 - want).abs() < 1e-4,
"batched ppl {} != hand-traced {want}",
res.perplexity
);
let res_one = perplexity(&model, &data, 64, &no_cache()).unwrap();
assert!((res.perplexity as f64 - res_one.perplexity as f64).abs() < 1e-5);
}
#[test]
fn batch_size_zero_is_treated_as_one() {
let model = MockModel::new(4);
let data = matrix(&[&[0, 1, 2, 3], &[3, 2, 1, 0]]);
let res = perplexity(&model, &data, 0, &no_cache()).unwrap();
assert_eq!(res.num_tokens, 6);
let want = expected_ppl(&model.canned, &[&[0, 1, 2, 3], &[3, 2, 1, 0]]);
assert!((res.perplexity as f64 - want).abs() < 1e-4);
}
#[test]
fn rejects_non_rank2_data() {
let model = MockModel::new(4);
let flat = Array::from_slice::<i32>(&[0, 1, 2, 3], &(4usize,)).unwrap();
assert!(perplexity(&model, &flat, 8, &no_cache()).is_err());
}
#[test]
fn rejects_too_short_window() {
let model = MockModel::new(4);
let data = Array::from_slice::<i32>(&[0, 1], &(2usize, 1)).unwrap();
assert!(perplexity(&model, &data, 8, &no_cache()).is_err());
}
#[test]
fn make_windows_drops_ragged_tail_and_reshapes() {
let toks: Vec<i32> = (0..7).collect();
let mut windows = make_windows(&toks, 3).unwrap();
assert_eq!(windows.shape(), vec![2, 3]);
assert_eq!(windows.to_vec::<i32>().unwrap(), vec![0, 1, 2, 3, 4, 5]);
}
#[test]
fn make_windows_rejects_short_input_and_tiny_window() {
assert!(make_windows(&[0, 1], 5).is_err());
assert!(make_windows(&[0, 1, 2, 3], 1).is_err());
}
#[test]
fn cross_entropy_none_matches_logsumexp_minus_score() {
let logits = Array::from_slice::<f32>(&[2.0, -1.0, 0.0, -1.0, 2.0, 0.5], &(2usize, 3)).unwrap();
let targets = Array::from_slice::<i32>(&[0, 1], &(2usize,)).unwrap();
let mut loss = cross_entropy_none(&logits, &targets).unwrap();
let got = loss.to_vec::<f32>().unwrap();
let lse = |r: &[f64]| {
let m = r.iter().copied().fold(f64::NEG_INFINITY, f64::max);
m + r.iter().map(|&x| (x - m).exp()).sum::<f64>().ln()
};
let want0 = lse(&[2.0, -1.0, 0.0]) - 2.0;
let want1 = lse(&[-1.0, 2.0, 0.5]) - 2.0;
assert!((got[0] as f64 - want0).abs() < 1e-5);
assert!((got[1] as f64 - want1).abs() < 1e-5);
}
#[test]
fn cross_entropy_none_rejects_target_rank_mismatch() {
let logits = Array::from_slice::<f32>(&[1.0, 2.0, 3.0, 4.0], &(2usize, 2)).unwrap();
let bad = Array::from_slice::<i32>(&[0, 1, 0, 1], &(2usize, 2)).unwrap();
assert!(cross_entropy_none(&logits, &bad).is_err());
}
#[test]
fn cross_entropy_none_rejects_broadcastable_target_shape() {
let b = 2usize;
let s = 3usize;
let v = 4usize;
let logits = Array::from_slice::<f32>(&vec![0.0f32; b * s * v], &(b, s, v)).unwrap();
let bad_bs = Array::from_slice::<i32>(&[0, 1], &(b, 1usize)).unwrap();
let err_bs = cross_entropy_none(&logits, &bad_bs)
.expect_err("targets [B, 1] should be rejected, not broadcast across S");
match err_bs {
Error::ShapePairMismatch(payload) => {
assert_eq!(payload.expected(), &[b, s][..]);
assert_eq!(payload.actual(), &[b, 1][..]);
}
other => panic!("expected ShapePairMismatch, got: {other:?}"),
}
let bad_1s = Array::from_slice::<i32>(&[0, 1, 2], &(1usize, s)).unwrap();
let err_1s = cross_entropy_none(&logits, &bad_1s)
.expect_err("targets [1, S] should be rejected, not broadcast across B");
match err_1s {
Error::ShapePairMismatch(payload) => {
assert_eq!(payload.expected(), &[b, s][..]);
assert_eq!(payload.actual(), &[1, s][..]);
}
other => panic!("expected ShapePairMismatch, got: {other:?}"),
}
let good = Array::from_slice::<i32>(&[0, 1, 2, 3, 0, 1], &(b, s)).unwrap();
let mut loss = cross_entropy_none(&logits, &good).unwrap();
let got = loss.to_vec::<f32>().unwrap();
assert_eq!(got.len(), b * s);
let want = (v as f64).ln();
for x in got {
assert!((x as f64 - want).abs() < 1e-5, "loss {x} != log V {want}");
}
}
#[test]
fn many_batches_match_single_batch_after_per_batch_eval() {
let model = MockModel::new(6);
let rows: Vec<&[i32]> = vec![
&[0, 1, 2, 3, 4],
&[5, 4, 3, 2, 1],
&[1, 2, 3, 4, 5],
&[0, 0, 5, 5, 0],
&[2, 4, 1, 3, 5],
&[5, 5, 0, 0, 5],
&[3, 3, 3, 3, 3],
&[1, 0, 2, 4, 1],
];
let data = matrix(&rows);
let res = perplexity(&model, &data, 1, &no_cache()).unwrap();
assert_eq!(res.num_tokens, rows.len() * (5 - 1));
let want = expected_ppl(&model.canned, &rows);
assert!(
(res.perplexity as f64 - want).abs() < 1e-4,
"many-batch ppl {} != hand-traced {want}",
res.perplexity
);
let res_one = perplexity(&model, &data, 64, &no_cache()).unwrap();
assert!((res.perplexity as f64 - res_one.perplexity as f64).abs() < 1e-5);
assert!((res.mean_loss as f64 - res_one.mean_loss as f64).abs() < 1e-5);
assert_eq!(res.num_tokens, res_one.num_tokens);
}
#[test]
fn logits_cast_to_f32_before_loss() {
struct F16Model {
canned: Vec<f32>,
}
impl Model for F16Model {
fn forward(&self, tokens: &Array, _cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let (batch, seq) = match tokens.shape().as_slice() {
[b, s] => (*b, *s),
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"F16Model::forward: tokens must be rank-2 [B, S]",
other.len() as u32,
other.to_vec(),
)));
}
};
let vocab = self.canned.len();
let mut data = Vec::with_capacity(batch * seq * vocab);
for _ in 0..batch * seq {
data.extend_from_slice(&self.canned);
}
let f32_logits = Array::from_slice::<f32>(&data, &(batch, seq, vocab))?;
f32_logits.astype(Dtype::F16)
}
}
let model = F16Model {
canned: vec![0.0, 1.0, 2.0, 3.0],
};
let row: &[i32] = &[0, 1, 2, 3];
let data = matrix(&[row]);
let res = perplexity(&model, &data, 8, &no_cache()).unwrap();
let want = expected_ppl(&model.canned, &[row]);
assert!(
(res.perplexity as f64 - want).abs() < 1e-2,
"f16-model ppl {} != {want}",
res.perplexity
);
}