use proptest::prelude::*;
fn greedy(logits: &[f32]) -> usize {
logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
fn top_k_filter(probs: &[f32], k: usize) -> Vec<f32> {
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut filtered = vec![0.0f32; probs.len()];
let sum: f32 = indexed.iter().take(k).map(|(_, p)| p).sum();
if sum > 0.0 {
for &(i, p) in indexed.iter().take(k) {
filtered[i] = p / sum; }
}
filtered
}
fn top_p_filter(probs: &[f32], p: f32) -> Vec<f32> {
let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut cumulative = 0.0f32;
let mut filtered = vec![0.0f32; probs.len()];
let mut kept = Vec::new();
for &(i, prob) in &indexed {
kept.push((i, prob));
cumulative += prob;
if cumulative >= p {
break;
}
}
let sum: f32 = kept.iter().map(|(_, prob)| prob).sum();
if sum > 0.0 {
for (i, prob) in kept {
filtered[i] = prob / sum;
}
}
filtered
}
fn softmax_with_temp(logits: &[f32], temperature: f32) -> Vec<f32> {
let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
let max = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = scaled.iter().map(|&l| (l - max).exp()).collect();
let sum: f32 = exps.iter().sum();
exps.iter().map(|&e| e / sum).collect()
}
fn logit_strategy() -> impl Strategy<Value = Vec<f32>> {
proptest::collection::vec(-10.0f32..10.0, 4..32usize)
}
proptest! {
#[test]
fn prop_greedy_argmax(
logits in logit_strategy()
) {
let result = greedy(&logits);
let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let expected = logits.iter().position(|&v| v == max_val).unwrap_or(0);
prop_assert_eq!(
result, expected,
"greedy({:?}) = {}, argmax = {}", &logits[..4.min(logits.len())], result, expected
);
}
#[test]
fn prop_top_k_cardinality(
logits in logit_strategy(),
k in 1usize..16
) {
let probs = softmax_with_temp(&logits, 1.0);
let filtered = top_k_filter(&probs, k);
let nonzero_count = filtered.iter().filter(|&&p| p > 0.0).count();
prop_assert!(
nonzero_count <= k,
"top_k returned {} nonzero, expected <= {}", nonzero_count, k
);
}
#[test]
fn prop_top_p_cumulative(
logits in logit_strategy(),
p_pct in 50u32..100
) {
let p = p_pct as f32 / 100.0;
let probs = softmax_with_temp(&logits, 1.0);
let filtered = top_p_filter(&probs, p);
let retained_sum: f32 = filtered.iter().sum();
prop_assert!(
retained_sum > 0.99 || retained_sum < 0.01,
"top_p sum={}, expected ~1.0 or ~0.0", retained_sum
);
}
#[test]
fn prop_temperature_identity(
logits in logit_strategy()
) {
let t1_probs = softmax_with_temp(&logits, 1.0);
let raw_probs = softmax_with_temp(&logits, 1.0);
for (i, (&p1, &p2)) in t1_probs.iter().zip(raw_probs.iter()).enumerate() {
let diff = (p1 - p2).abs();
prop_assert!(
diff < 1e-6,
"T=1 identity violated at [{}]: {} != {}", i, p1, p2
);
}
}
#[test]
#[ignore = "SIMD equivalence — trueno domain"]
fn prop_simd_equivalence(
_x in proptest::collection::vec(-10.0f32..10.0, 1..32usize)
) {
}
}