aprender-core 0.29.1

Next-generation machine learning library in pure Rust
// CONTRACT: sampling-algorithms-v1.yaml
// HASH: sha256:b2c3d4e5f6789012
// Generated by: pv probar --binding
// DO NOT EDIT — regenerate with `pv probar --binding`

use proptest::prelude::*;

/// Greedy sampling: argmax of logits.
fn greedy(logits: &[f32]) -> usize {
    logits
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
        .map(|(i, _)| i)
        .unwrap_or(0)
}

/// Top-K filtering: keep only K highest probability tokens, zero out rest.
fn top_k_filter(probs: &[f32], k: usize) -> Vec<f32> {
    let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    let mut filtered = vec![0.0f32; probs.len()];
    let sum: f32 = indexed.iter().take(k).map(|(_, p)| p).sum();
    if sum > 0.0 {
        for &(i, p) in indexed.iter().take(k) {
            filtered[i] = p / sum; // renormalize
        }
    }
    filtered
}

/// Top-P (nucleus) filtering: keep minimal set with cumulative prob >= p.
fn top_p_filter(probs: &[f32], p: f32) -> Vec<f32> {
    let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
    let mut cumulative = 0.0f32;
    let mut filtered = vec![0.0f32; probs.len()];
    let mut kept = Vec::new();
    for &(i, prob) in &indexed {
        kept.push((i, prob));
        cumulative += prob;
        if cumulative >= p {
            break;
        }
    }
    let sum: f32 = kept.iter().map(|(_, prob)| prob).sum();
    if sum > 0.0 {
        for (i, prob) in kept {
            filtered[i] = prob / sum;
        }
    }
    filtered
}

/// Softmax with temperature.
fn softmax_with_temp(logits: &[f32], temperature: f32) -> Vec<f32> {
    let scaled: Vec<f32> = logits.iter().map(|&l| l / temperature).collect();
    let max = scaled.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let exps: Vec<f32> = scaled.iter().map(|&l| (l - max).exp()).collect();
    let sum: f32 = exps.iter().sum();
    exps.iter().map(|&e| e / sum).collect()
}

fn logit_strategy() -> impl Strategy<Value = Vec<f32>> {
    proptest::collection::vec(-10.0f32..10.0, 4..32usize)
}

proptest! {
    /// Obligation: Greedy = argmax (equivalence)
    /// Formal: greedy(logits) == argmax(logits)
    #[test]
    fn prop_greedy_argmax(
        logits in logit_strategy()
    ) {
        let result = greedy(&logits);

        // Manual argmax
        let max_val = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
        let expected = logits.iter().position(|&v| v == max_val).unwrap_or(0);

        prop_assert_eq!(
            result, expected,
            "greedy({:?}) = {}, argmax = {}", &logits[..4.min(logits.len())], result, expected
        );
    }

    /// Obligation: Top-K cardinality (bound)
    /// Formal: count(nonzero(top_k(p, K))) <= K
    #[test]
    fn prop_top_k_cardinality(
        logits in logit_strategy(),
        k in 1usize..16
    ) {
        let probs = softmax_with_temp(&logits, 1.0);
        let filtered = top_k_filter(&probs, k);
        let nonzero_count = filtered.iter().filter(|&&p| p > 0.0).count();
        prop_assert!(
            nonzero_count <= k,
            "top_k returned {} nonzero, expected <= {}", nonzero_count, k
        );
    }

    /// Obligation: Top-P cumulative (bound)
    /// Formal: sum(top_p(p, threshold)) >= threshold
    #[test]
    fn prop_top_p_cumulative(
        logits in logit_strategy(),
        p_pct in 50u32..100
    ) {
        let p = p_pct as f32 / 100.0;
        let probs = softmax_with_temp(&logits, 1.0);
        let filtered = top_p_filter(&probs, p);
        let retained_sum: f32 = filtered.iter().sum();

        // Retained probabilities should sum to ~1.0 (renormalized)
        // But original cumulative should have been >= p before renorm
        prop_assert!(
            retained_sum > 0.99 || retained_sum < 0.01,
            "top_p sum={}, expected ~1.0 or ~0.0", retained_sum
        );
    }

    /// Obligation: Temperature identity (equivalence)
    /// Formal: softmax(l/1) == softmax(l)
    #[test]
    fn prop_temperature_identity(
        logits in logit_strategy()
    ) {
        let t1_probs = softmax_with_temp(&logits, 1.0);
        let raw_probs = softmax_with_temp(&logits, 1.0);

        for (i, (&p1, &p2)) in t1_probs.iter().zip(raw_probs.iter()).enumerate() {
            let diff = (p1 - p2).abs();
            prop_assert!(
                diff < 1e-6,
                "T=1 identity violated at [{}]: {} != {}", i, p1, p2
            );
        }
    }

    /// Obligation: SIMD sampling equivalence (equivalence)
    #[test]
    #[ignore = "SIMD equivalence — trueno domain"]
    fn prop_simd_equivalence(
        _x in proptest::collection::vec(-10.0f32..10.0, 1..32usize)
    ) {
        // SIMD sampling testing is trueno's responsibility
    }
}