use crate::generate::*;
use crate::tensor::Tensor;
#[test]
fn falsify_sa_001_greedy_is_argmax() {
let test_cases: Vec<(Vec<f32>, usize)> = vec![
(vec![1.0, 2.0, 3.0, 4.0, 5.0], 4),
(vec![5.0, 4.0, 3.0, 2.0, 1.0], 0),
(vec![1.0, 2.0, 100.0, 3.0, 4.0], 2),
(vec![7.0, 7.0, 7.0], 0),
(vec![42.0], 0),
(vec![-5.0, -1.0, -3.0, -0.5, -10.0], 3),
{
let mut logits = vec![0.0; 32000];
logits[12345] = 100.0;
(logits, 12345)
},
];
for (i, (logits_data, expected_idx)) in test_cases.iter().enumerate() {
let logits =
Tensor::from_vec(vec![logits_data.len()], logits_data.clone()).expect("test tensor");
let result = sample_greedy(&logits).expect("greedy should succeed");
assert_eq!(
result, *expected_idx,
"FALSIFIED SA-001 case {i}: greedy returned {result}, expected {expected_idx}"
);
}
}
#[test]
fn falsify_sa_002_top_k_cardinality() {
let logits_data = vec![5.0, 1.0, 0.1, 10.0, 3.0];
let logits = Tensor::from_vec(vec![5], logits_data).expect("test tensor");
let k = 3;
let top_k_set: std::collections::HashSet<usize> = [3, 0, 4].iter().copied().collect();
for rng_step in 0..100 {
let rng_value = rng_step as f32 / 100.0;
let result = sample_top_k(&logits, k, rng_value).expect("top_k should succeed");
assert!(
top_k_set.contains(&result),
"FALSIFIED SA-002: top_k(k={k}, rng={rng_value}) returned {result}, \
not in top-{k} set {:?}",
top_k_set
);
}
}
#[test]
fn falsify_sa_003_top_p_cumulative() {
let logits_data = vec![10.0, 1.0, 0.0, -1.0, -10.0];
let logits = Tensor::from_vec(vec![5], logits_data).expect("test tensor");
for rng_step in 0..100 {
let rng_value = rng_step as f32 / 100.0;
let result = sample_top_p(&logits, 0.01, rng_value).expect("top_p should succeed");
assert_eq!(
result, 0,
"FALSIFIED SA-003: top_p(p=0.01, rng={rng_value}) returned {result}, \
expected 0 (only token above nucleus threshold)"
);
}
for rng_step in 0..20 {
let rng_value = rng_step as f32 / 20.0;
let result = sample_top_p(&logits, 1.0, rng_value);
assert!(
result.is_ok(),
"FALSIFIED SA-003: top_p(p=1.0) should always succeed"
);
}
}
#[test]
fn falsify_sa_004_temperature_identity() {
let test_cases: Vec<Vec<f32>> = vec![
vec![1.0, 2.0, 3.0],
vec![-100.0, 0.0, 100.0],
vec![0.0; 10],
vec![42.0],
vec![1e-6, 1e6],
];
for (i, logits_data) in test_cases.iter().enumerate() {
let logits =
Tensor::from_vec(vec![logits_data.len()], logits_data.clone()).expect("test tensor");
let result = apply_temperature(&logits, 1.0).expect("temp=1.0 should succeed");
for (j, (&original, &scaled)) in logits.data().iter().zip(result.data().iter()).enumerate()
{
assert!(
(original - scaled).abs() < 1e-6,
"FALSIFIED SA-004 case {i}[{j}]: temp=1.0 changed {original} to {scaled}"
);
}
}
}
#[test]
fn falsify_sa_004b_temperature_scaling() {
let logits_data = vec![2.0, 4.0, 6.0, 8.0];
let logits = Tensor::from_vec(vec![4], logits_data).expect("test tensor");
for &temp in &[0.5_f32, 2.0, 0.1, 10.0] {
let result = apply_temperature(&logits, temp).expect("should succeed");
for (j, (&original, &scaled)) in logits.data().iter().zip(result.data().iter()).enumerate()
{
let expected = original / temp;
let diff = (scaled - expected).abs();
assert!(
diff < 1e-5,
"FALSIFIED SA-004b: temp={temp}, logits[{j}]={original} → {scaled}, expected {expected}"
);
}
}
}