use aprender::autograd::Tensor;
use aprender::nn::functional::softmax;
use proptest::prelude::*;
fn scaled_scores(q: &[f32], k: &[f32], n: usize, d: usize) -> Vec<f32> {
let scale = (d as f32).sqrt();
let mut scores = vec![0.0f32; n * n];
for i in 0..n {
for j in 0..n {
let mut dot = 0.0f32;
for dd in 0..d {
dot += q[i * d + dd] * k[j * d + dd];
}
scores[i * n + j] = dot / scale;
}
}
scores
}
fn softmax_rows(scores: &mut [f32], n: usize) {
for i in 0..n {
let row = &mut scores[i * n..(i + 1) * n];
let max = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for val in row.iter_mut() {
*val = (*val - max).exp();
sum += *val;
}
for val in row.iter_mut() {
*val /= sum;
}
}
}
fn weighted_sum(weights: &[f32], v: &[f32], n: usize, d: usize) -> Vec<f32> {
let mut output = vec![0.0f32; n * d];
for i in 0..n {
for j in 0..d {
let mut sum = 0.0f32;
for kk in 0..n {
sum += weights[i * n + kk] * v[kk * d + j];
}
output[i * d + j] = sum;
}
}
output
}
fn reference_attention(q: &[f32], k: &[f32], v: &[f32], n: usize, d: usize) -> Vec<f32> {
let mut scores = scaled_scores(q, k, n, d);
softmax_rows(&mut scores, n);
weighted_sum(&scores, v, n, d)
}
fn attention_weight_rows(data: &[f32], d: usize) -> Vec<(usize, Vec<f32>)> {
let n = data.len() / d;
let total = n * d;
let q: Vec<f32> = data.iter().take(total).copied().collect();
let scale = (d as f32).sqrt();
let scores = Tensor::new(&q, &[n, d]);
let scores_t = scores.transpose();
let raw = scores.matmul(&scores_t);
let scaled: Vec<f32> = raw.data().iter().map(|v| v / scale).collect();
let scaled_t = Tensor::new(&scaled, &[n, n]);
(0..n)
.map(|i| {
let row: Vec<f32> = scaled_t.data()[i * n..(i + 1) * n].to_vec();
let row_t = Tensor::new(&row, &[1, n]);
let sm = softmax(&row_t, -1);
(i, sm.data().to_vec())
})
.collect()
}
proptest! {
#[test]
fn prop_attention_weights_normalize(
data in proptest::collection::vec(-5.0f32..5.0, 8..33usize)
) {
let d = 4;
let n = data.len() / d;
if n < 1 { return Ok(()); }
for (i, row) in attention_weight_rows(&data, d) {
let sum: f32 = row.iter().sum();
prop_assert!(
(sum - 1.0).abs() < 1e-5,
"row {} weights sum={}", i, sum
);
}
}
#[test]
fn prop_attention_weights_bounded(
data in proptest::collection::vec(-5.0f32..5.0, 8..33usize)
) {
let d = 4;
let n = data.len() / d;
if n < 1 { return Ok(()); }
for (i, row) in attention_weight_rows(&data, d) {
for (j, &val) in row.iter().enumerate() {
prop_assert!(
val > 0.0 && val <= 1.0,
"attn[{},{}]={}, expected in (0,1]", i, j, val
);
}
}
}
#[test]
fn prop_output_bounded_by_v(
data in proptest::collection::vec(-5.0f32..5.0, 16..33usize)
) {
let d = 4;
let n = data.len() / (3 * d);
if n < 1 { return Ok(()); }
let total = n * d;
if data.len() < 3 * total { return Ok(()); }
let q = &data[..total];
let k = &data[total..2 * total];
let v = &data[2 * total..3 * total];
let output = reference_attention(q, k, v, n, d);
let v_min = v.iter().copied().fold(f32::INFINITY, f32::min);
let v_max = v.iter().copied().fold(f32::NEG_INFINITY, f32::max);
for (i, &val) in output.iter().enumerate() {
prop_assert!(
val >= v_min - 1e-5 && val <= v_max + 1e-5,
"output[{i}]={val}, V range=[{v_min},{v_max}]"
);
}
}
#[test]
#[ignore = "SIMD equivalence — trueno domain"]
fn prop_simd_matches_scalar(
_x in proptest::collection::vec(-5.0f32..5.0, 1..32usize)
) {
}
#[test]
fn prop_scaling_factor(
d in 1usize..17
) {
let scale = 1.0 / (d as f32).sqrt();
let wrong_scale = 1.0 / (d as f32);
if d > 1 {
prop_assert!(
(scale - wrong_scale).abs() > 1e-6,
"1/sqrt({d})={scale} should differ from 1/{d}={wrong_scale}"
);
}
}
}