aprender-core 0.30.0

Next-generation machine learning library in pure Rust
// CONTRACT: backend-dispatch-v1.yaml
// HASH: sha256:a789012345678901
// Generated by: pv probar --binding
// DO NOT EDIT — regenerate with `pv probar --binding`

use proptest::prelude::*;

const GPU_THRESHOLD: u64 = 100_000;
const SIMD_ONLY_THRESHOLD: u64 = 1_000;

/// Dispatch decision based on element count.
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
enum Backend {
    SimdOnly,     // < 1000 elements
    SimdThreaded, // 1000..100_000
    Gpu,          // >= 100_000
}

fn dispatch(element_count: u64) -> Backend {
    if element_count >= GPU_THRESHOLD {
        Backend::Gpu
    } else if element_count >= SIMD_ONLY_THRESHOLD {
        Backend::SimdThreaded
    } else {
        Backend::SimdOnly
    }
}

/// Garbage oracle: detect degenerate output.
fn is_garbage(text: &str) -> bool {
    if text.is_empty() {
        return true;
    }

    // Repetition ratio: count of most common char / total
    let mut counts = [0u32; 128]; // ASCII only
    let mut total = 0u32;
    for b in text.bytes() {
        if (b as usize) < 128 {
            counts[b as usize] += 1;
            total += 1;
        }
    }
    if total == 0 {
        return true;
    }
    let max_count = counts.iter().copied().max().unwrap_or(0);
    let repetition_ratio = max_count as f64 / total as f64;

    let unique_chars = counts.iter().filter(|&&c| c > 0).count();

    repetition_ratio > 0.3 || unique_chars < 10
}

/// QK norm score bound: dot product of unit vectors bounded by sqrt(d).
fn qk_score_bound(head_dim: usize) -> f64 {
    (head_dim as f64).sqrt()
}

proptest! {
    /// Obligation: GPU threshold monotonic (monotonicity)
    /// Formal: n1 >= threshold AND n2 > n1 => n2 >= threshold
    #[test]
    fn prop_gpu_threshold_monotonic(
        n1 in 0u64..200_000,
        n2 in 0u64..200_000
    ) {
        if n2 > n1 {
            let d1 = dispatch(n1);
            let d2 = dispatch(n2);
            // More elements => same or higher dispatch level
            prop_assert!(
                d2 >= d1,
                "not monotonic: dispatch({})={:?}, dispatch({})={:?}",
                n1, d1, n2, d2
            );
        }
    }

    /// Obligation: Garbage oracle detects repetition (invariant)
    /// Formal: repetition_ratio > 0.3 => is_garbage
    #[test]
    fn prop_garbage_detects_repetition(
        c_idx in 0u8..26,
        n in 10usize..200
    ) {
        // String of repeated single character
        let c = (b'a' + c_idx) as char;
        let text: String = std::iter::repeat_n(c, n).collect();
        prop_assert!(
            is_garbage(&text),
            "repeated '{}' x {} not detected as garbage", c, n
        );
    }

    /// Obligation: Garbage oracle passes diverse text (invariant)
    #[test]
    fn prop_garbage_passes_diverse(
        _dummy in 0u8..1
    ) {
        let text = "The quick brown fox jumps over the lazy dog. 0123456789!";
        prop_assert!(
            !is_garbage(text),
            "diverse text incorrectly flagged as garbage"
        );
    }

    /// Obligation: QK norm score bound (bound)
    /// Formal: |score| <= sqrt(d_k) after L2 normalization
    #[test]
    fn prop_qk_norm_bound(
        q in proptest::collection::vec(-10.0f32..10.0, 4..64usize),
        k in proptest::collection::vec(-10.0f32..10.0, 4..64usize)
    ) {
        let d = q.len().min(k.len());
        let q = &q[..d];
        let k = &k[..d];

        // L2 normalize
        let q_norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
        let k_norm: f32 = k.iter().map(|x| x * x).sum::<f32>().sqrt();

        if q_norm > 1e-8 && k_norm > 1e-8 {
            let q_unit: Vec<f32> = q.iter().map(|x| x / q_norm).collect();
            let k_unit: Vec<f32> = k.iter().map(|x| x / k_norm).collect();

            // Dot product of unit vectors
            let score: f32 = q_unit.iter().zip(&k_unit).map(|(a, b)| a * b).sum();
            let bound = qk_score_bound(d) as f32;

            prop_assert!(
                score.abs() <= bound + 1e-5,
                "|score|={} > sqrt({})={}", score.abs(), d, bound
            );
        }
    }

    /// Obligation: BPE roundtrip (equivalence)
    #[test]
    #[ignore = "BPE roundtrip requires tokenizer API — realizar domain"]
    fn prop_bpe_roundtrip(
        _x in proptest::collection::vec(0u8..=127, 1..64usize)
    ) {
        // BPE roundtrip testing requires tokenizer API
    }

    /// Obligation: SIMD dispatch equivalence (equivalence)
    #[test]
    #[ignore = "SIMD equivalence — trueno domain"]
    fn prop_simd_equivalence(
        _x in proptest::collection::vec(0u8..=255, 1..32usize)
    ) {
        // Dispatch is pure math — no SIMD variant
    }
}