use proptest::prelude::*;
const GPU_THRESHOLD: u64 = 100_000;
const SIMD_ONLY_THRESHOLD: u64 = 1_000;
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
enum Backend {
SimdOnly, SimdThreaded, Gpu, }
fn dispatch(element_count: u64) -> Backend {
if element_count >= GPU_THRESHOLD {
Backend::Gpu
} else if element_count >= SIMD_ONLY_THRESHOLD {
Backend::SimdThreaded
} else {
Backend::SimdOnly
}
}
fn is_garbage(text: &str) -> bool {
if text.is_empty() {
return true;
}
let mut counts = [0u32; 128]; let mut total = 0u32;
for b in text.bytes() {
if (b as usize) < 128 {
counts[b as usize] += 1;
total += 1;
}
}
if total == 0 {
return true;
}
let max_count = counts.iter().copied().max().unwrap_or(0);
let repetition_ratio = max_count as f64 / total as f64;
let unique_chars = counts.iter().filter(|&&c| c > 0).count();
repetition_ratio > 0.3 || unique_chars < 10
}
fn qk_score_bound(head_dim: usize) -> f64 {
(head_dim as f64).sqrt()
}
proptest! {
#[test]
fn prop_gpu_threshold_monotonic(
n1 in 0u64..200_000,
n2 in 0u64..200_000
) {
if n2 > n1 {
let d1 = dispatch(n1);
let d2 = dispatch(n2);
prop_assert!(
d2 >= d1,
"not monotonic: dispatch({})={:?}, dispatch({})={:?}",
n1, d1, n2, d2
);
}
}
#[test]
fn prop_garbage_detects_repetition(
c_idx in 0u8..26,
n in 10usize..200
) {
let c = (b'a' + c_idx) as char;
let text: String = std::iter::repeat_n(c, n).collect();
prop_assert!(
is_garbage(&text),
"repeated '{}' x {} not detected as garbage", c, n
);
}
#[test]
fn prop_garbage_passes_diverse(
_dummy in 0u8..1
) {
let text = "The quick brown fox jumps over the lazy dog. 0123456789!";
prop_assert!(
!is_garbage(text),
"diverse text incorrectly flagged as garbage"
);
}
#[test]
fn prop_qk_norm_bound(
q in proptest::collection::vec(-10.0f32..10.0, 4..64usize),
k in proptest::collection::vec(-10.0f32..10.0, 4..64usize)
) {
let d = q.len().min(k.len());
let q = &q[..d];
let k = &k[..d];
let q_norm: f32 = q.iter().map(|x| x * x).sum::<f32>().sqrt();
let k_norm: f32 = k.iter().map(|x| x * x).sum::<f32>().sqrt();
if q_norm > 1e-8 && k_norm > 1e-8 {
let q_unit: Vec<f32> = q.iter().map(|x| x / q_norm).collect();
let k_unit: Vec<f32> = k.iter().map(|x| x / k_norm).collect();
let score: f32 = q_unit.iter().zip(&k_unit).map(|(a, b)| a * b).sum();
let bound = qk_score_bound(d) as f32;
prop_assert!(
score.abs() <= bound + 1e-5,
"|score|={} > sqrt({})={}", score.abs(), d, bound
);
}
}
#[test]
#[ignore = "BPE roundtrip requires tokenizer API — realizar domain"]
fn prop_bpe_roundtrip(
_x in proptest::collection::vec(0u8..=127, 1..64usize)
) {
}
#[test]
#[ignore = "SIMD equivalence — trueno domain"]
fn prop_simd_equivalence(
_x in proptest::collection::vec(0u8..=255, 1..32usize)
) {
}
}