#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sa001Verdict { Pass, Fail }
#[must_use]
pub fn argmax(logits: &[f32]) -> Option<usize> {
if logits.is_empty() { return None; }
if !logits.iter().all(|v| v.is_finite()) { return None; }
let mut best_idx = 0;
let mut best_val = logits[0];
for (i, &v) in logits.iter().enumerate().skip(1) {
if v > best_val {
best_val = v;
best_idx = i;
}
}
Some(best_idx)
}
#[must_use]
pub fn verdict_from_greedy_argmax(logits: &[f32], observed: usize) -> Sa001Verdict {
match argmax(logits) {
Some(expected) if expected == observed => Sa001Verdict::Pass,
_ => Sa001Verdict::Fail,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sa002Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_top_k_cardinality(
probs: &[f32],
k: usize,
filtered_probs: &[f32],
) -> Sa002Verdict {
if probs.is_empty() || filtered_probs.is_empty() { return Sa002Verdict::Fail; }
if probs.len() != filtered_probs.len() { return Sa002Verdict::Fail; }
if k == 0 || k > probs.len() { return Sa002Verdict::Fail; }
let nonzero_count = filtered_probs.iter().filter(|&&p| p > 0.0).count();
if nonzero_count > k { return Sa002Verdict::Fail; }
let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap_or(std::cmp::Ordering::Equal));
let allowed_top_k: std::collections::HashSet<usize> =
sorted_indices.iter().take(k).copied().collect();
for (i, &p) in filtered_probs.iter().enumerate() {
if p > 0.0 && !allowed_top_k.contains(&i) {
return Sa002Verdict::Fail; }
}
Sa002Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sa003Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_top_p_cumulative(filtered_probs: &[f32], threshold: f32) -> Sa003Verdict {
if filtered_probs.is_empty() { return Sa003Verdict::Fail; }
if !threshold.is_finite() || threshold <= 0.0 || threshold > 1.0 { return Sa003Verdict::Fail; }
if !filtered_probs.iter().all(|v| v.is_finite() && *v >= 0.0) { return Sa003Verdict::Fail; }
let sum: f32 = filtered_probs.iter().sum();
if !sum.is_finite() { return Sa003Verdict::Fail; }
let slack = 1.0e-5_f32;
if sum + slack < threshold { return Sa003Verdict::Fail; }
Sa003Verdict::Pass
}
pub const AC_SA_004_TOLERANCE: f32 = 1.0e-6;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sa004Verdict { Pass, Fail }
#[must_use]
pub fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() { return vec![]; }
let m = logits.iter().fold(f32::NEG_INFINITY, |acc, &x| acc.max(x));
if !m.is_finite() { return vec![]; }
let exps: Vec<f32> = logits.iter().map(|&x| (x - m).exp()).collect();
let s: f32 = exps.iter().sum();
if s == 0.0 || !s.is_finite() { return vec![]; }
exps.iter().map(|&e| e / s).collect()
}
#[must_use]
pub fn verdict_from_temperature_identity(logits: &[f32]) -> Sa004Verdict {
if logits.is_empty() { return Sa004Verdict::Fail; }
if !logits.iter().all(|v| v.is_finite()) { return Sa004Verdict::Fail; }
let raw = softmax(logits);
let scaled: Vec<f32> = logits.iter().map(|&l| l / 1.0).collect();
let with_t1 = softmax(&scaled);
if raw.is_empty() || with_t1.is_empty() || raw.len() != with_t1.len() {
return Sa004Verdict::Fail;
}
for (a, b) in raw.iter().zip(with_t1.iter()) {
if (a - b).abs() > AC_SA_004_TOLERANCE { return Sa004Verdict::Fail; }
}
Sa004Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Sa005Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_simd_sampling_parity(
scalar_idx: usize,
simd_idx: usize,
scalar_dist: &[f32],
simd_dist: &[f32],
) -> Sa005Verdict {
if scalar_idx != simd_idx { return Sa005Verdict::Fail; }
if scalar_dist.is_empty() || simd_dist.is_empty() { return Sa005Verdict::Fail; }
if scalar_dist.len() != simd_dist.len() { return Sa005Verdict::Fail; }
for (&s, &v) in scalar_dist.iter().zip(simd_dist.iter()) {
if !s.is_finite() || !v.is_finite() { return Sa005Verdict::Fail; }
if s.to_bits() != v.to_bits() { return Sa005Verdict::Fail; }
}
Sa005Verdict::Pass
}
#[cfg(test)]
mod tests {
use super::*;
#[test] fn sa001_pass_canonical() {
let logits = vec![0.1_f32, 0.5, 0.9, 0.3];
assert_eq!(verdict_from_greedy_argmax(&logits, 2), Sa001Verdict::Pass);
}
#[test] fn sa001_pass_max_at_zero() {
let logits = vec![5.0_f32, 0.5, 0.9];
assert_eq!(verdict_from_greedy_argmax(&logits, 0), Sa001Verdict::Pass);
}
#[test] fn sa001_pass_max_at_end() {
let logits = vec![0.1_f32, 0.5, 0.9, 5.0];
assert_eq!(verdict_from_greedy_argmax(&logits, 3), Sa001Verdict::Pass);
}
#[test] fn sa001_fail_wrong_index() {
let logits = vec![0.1_f32, 0.5, 0.9, 0.3];
assert_eq!(verdict_from_greedy_argmax(&logits, 1), Sa001Verdict::Fail);
}
#[test] fn sa001_fail_empty() {
assert_eq!(verdict_from_greedy_argmax(&[], 0), Sa001Verdict::Fail);
}
#[test] fn sa001_fail_nan() {
let logits = vec![0.1_f32, f32::NAN];
assert_eq!(verdict_from_greedy_argmax(&logits, 0), Sa001Verdict::Fail);
}
#[test] fn sa002_pass_canonical() {
let probs = vec![0.1_f32, 0.05, 0.5, 0.35];
let filtered = vec![0.0_f32, 0.0, 0.5882, 0.4118]; assert_eq!(verdict_from_top_k_cardinality(&probs, 2, &filtered), Sa002Verdict::Pass);
}
#[test] fn sa002_pass_k_equals_n() {
let probs = vec![0.25_f32, 0.25, 0.25, 0.25];
let filtered = probs.clone();
assert_eq!(verdict_from_top_k_cardinality(&probs, 4, &filtered), Sa002Verdict::Pass);
}
#[test] fn sa002_fail_too_many_kept() {
let probs = vec![0.1_f32, 0.2, 0.3, 0.4];
let filtered = vec![0.0_f32, 0.2, 0.3, 0.4];
assert_eq!(verdict_from_top_k_cardinality(&probs, 2, &filtered), Sa002Verdict::Fail);
}
#[test] fn sa002_fail_kept_non_top_k() {
let probs = vec![0.1_f32, 0.2, 0.3, 0.4];
let filtered = vec![0.5_f32, 0.0, 0.0, 0.5];
assert_eq!(verdict_from_top_k_cardinality(&probs, 2, &filtered), Sa002Verdict::Fail);
}
#[test] fn sa002_fail_k_zero() {
let probs = vec![0.5_f32, 0.5];
let filtered = vec![0.0_f32, 0.0];
assert_eq!(verdict_from_top_k_cardinality(&probs, 0, &filtered), Sa002Verdict::Fail);
}
#[test] fn sa002_fail_k_above_v() {
let probs = vec![0.5_f32, 0.5];
let filtered = probs.clone();
assert_eq!(verdict_from_top_k_cardinality(&probs, 5, &filtered), Sa002Verdict::Fail);
}
#[test] fn sa003_pass_canonical() {
let filtered = vec![0.5_f32, 0.35, 0.0, 0.0];
assert_eq!(verdict_from_top_p_cumulative(&filtered, 0.8), Sa003Verdict::Pass);
}
#[test] fn sa003_pass_at_threshold() {
let filtered = vec![0.5_f32, 0.5];
assert_eq!(verdict_from_top_p_cumulative(&filtered, 1.0), Sa003Verdict::Pass);
}
#[test] fn sa003_fail_below_threshold() {
let filtered = vec![0.3_f32, 0.0];
assert_eq!(verdict_from_top_p_cumulative(&filtered, 0.8), Sa003Verdict::Fail);
}
#[test] fn sa003_fail_threshold_zero() {
let filtered = vec![1.0_f32];
assert_eq!(verdict_from_top_p_cumulative(&filtered, 0.0), Sa003Verdict::Fail);
}
#[test] fn sa003_fail_threshold_above_one() {
let filtered = vec![1.0_f32];
assert_eq!(verdict_from_top_p_cumulative(&filtered, 1.5), Sa003Verdict::Fail);
}
#[test] fn sa003_fail_negative_prob() {
let filtered = vec![-0.1_f32, 0.5];
assert_eq!(verdict_from_top_p_cumulative(&filtered, 0.4), Sa003Verdict::Fail);
}
#[test] fn sa004_pass_canonical() {
let logits = vec![0.5_f32, 1.0, 1.5, 2.0];
assert_eq!(verdict_from_temperature_identity(&logits), Sa004Verdict::Pass);
}
#[test] fn sa004_pass_uniform_logits() {
let logits = vec![1.0_f32, 1.0, 1.0];
assert_eq!(verdict_from_temperature_identity(&logits), Sa004Verdict::Pass);
}
#[test] fn sa004_pass_negative_logits() {
let logits = vec![-3.0_f32, 0.0, 5.0];
assert_eq!(verdict_from_temperature_identity(&logits), Sa004Verdict::Pass);
}
#[test] fn sa004_fail_empty() {
assert_eq!(verdict_from_temperature_identity(&[]), Sa004Verdict::Fail);
}
#[test] fn sa004_fail_nan() {
let logits = vec![0.5_f32, f32::NAN];
assert_eq!(verdict_from_temperature_identity(&logits), Sa004Verdict::Fail);
}
#[test] fn sa005_pass_identical() {
let dist = vec![0.25_f32, 0.5, 0.25];
assert_eq!(verdict_from_simd_sampling_parity(1, 1, &dist, &dist), Sa005Verdict::Pass);
}
#[test] fn sa005_fail_index_drift() {
let dist = vec![0.25_f32, 0.5, 0.25];
assert_eq!(verdict_from_simd_sampling_parity(1, 2, &dist, &dist), Sa005Verdict::Fail);
}
#[test] fn sa005_fail_dist_byte_drift() {
let scalar = vec![0.25_f32, 0.5, 0.25];
let simd = vec![0.25_f32, f32::from_bits(0.5_f32.to_bits() + 1), 0.25];
assert_eq!(verdict_from_simd_sampling_parity(1, 1, &scalar, &simd), Sa005Verdict::Fail);
}
#[test] fn sa005_fail_length() {
let scalar = vec![0.25_f32];
let simd = vec![0.25_f32, 0.5];
assert_eq!(verdict_from_simd_sampling_parity(0, 0, &scalar, &simd), Sa005Verdict::Fail);
}
#[test] fn argmax_canonical() {
assert_eq!(argmax(&[0.1, 0.5, 0.9, 0.3]), Some(2));
assert_eq!(argmax(&[5.0, 1.0, 0.5]), Some(0));
assert_eq!(argmax(&[]), None);
}
#[test] fn provenance_constants() {
assert!((AC_SA_004_TOLERANCE - 1e-6).abs() < 1e-12);
}
}