use crate::autograd::{BackwardOp, Tensor};
use ndarray::Array1;
use std::cell::RefCell;
use std::rc::Rc;
use super::matmul::{matmul_compute, transpose};
pub fn attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
seq_len: usize,
d_k: usize,
_k_seq_len: usize, d_v: usize,
) -> Tensor {
let scale = (d_k as f32).sqrt();
let q_slice = q.data().as_slice().unwrap_or(&[]);
let k_slice = k.data().as_slice().unwrap_or(&[]);
let k_t = transpose(k_slice, seq_len, d_k); let mut scores = matmul_compute(q_slice, &k_t, seq_len, d_k, seq_len);
for score in &mut scores {
*score /= scale;
}
let mut attention_weights = vec![0.0; seq_len * seq_len];
for i in 0..seq_len {
let row_start = i * seq_len;
let row_end = row_start + seq_len;
let row = &scores[row_start..row_end];
let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = row.iter().map(|&x| (x - max_val).exp()).collect();
let sum_exp: f32 = exp_vals.iter().sum();
for (j, &exp_val) in exp_vals.iter().enumerate() {
attention_weights[row_start + j] = exp_val / sum_exp;
}
}
let v_slice = v.data().as_slice().unwrap_or(&[]);
let output_data = matmul_compute(&attention_weights, v_slice, seq_len, seq_len, d_v);
let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
let mut result = Tensor::new(Array1::from(output_data), requires_grad);
if requires_grad {
let q_clone = q.clone();
let k_clone = k.clone();
let v_clone = v.clone();
let backward_op = Rc::new(AttentionBackward {
q: q_clone,
k: k_clone,
v: v_clone,
attention_weights: Array1::from(attention_weights),
seq_len,
d_k,
d_v,
scale,
result_grad: result.grad_cell(),
});
result.set_backward_op(backward_op);
}
result
}
struct AttentionBackward {
q: Tensor,
k: Tensor,
v: Tensor,
attention_weights: Array1<f32>,
seq_len: usize,
d_k: usize,
d_v: usize,
scale: f32,
result_grad: Rc<RefCell<Option<Array1<f32>>>>,
}
impl BackwardOp for AttentionBackward {
fn backward(&self) {
if let Some(grad_output) = self.result_grad.borrow().as_ref() {
let seq_len = self.seq_len;
let d_k = self.d_k;
let d_v = self.d_v;
let grad_out_slice = grad_output.as_slice().unwrap_or(&[]);
let attn_slice = self.attention_weights.as_slice().unwrap_or(&[]);
if self.v.requires_grad() {
let attn_t = transpose(attn_slice, seq_len, seq_len);
let grad_v = matmul_compute(&attn_t, grad_out_slice, seq_len, seq_len, d_v);
self.v.accumulate_grad(Array1::from(grad_v));
}
let v_slice = self.v.data().as_slice().unwrap_or(&[]);
let v_t = transpose(v_slice, seq_len, d_v);
let grad_attention_weights =
matmul_compute(grad_out_slice, &v_t, seq_len, d_v, seq_len);
let mut grad_scores = vec![0.0; seq_len * seq_len];
for i in 0..seq_len {
let row_start = i * seq_len;
for j in 0..seq_len {
let idx = row_start + j;
let p_j = attn_slice[idx];
let mut sum_pk_gradk = 0.0;
for k in 0..seq_len {
let k_idx = row_start + k;
sum_pk_gradk += attn_slice[k_idx] * grad_attention_weights[k_idx];
}
grad_scores[idx] = p_j * (grad_attention_weights[idx] - sum_pk_gradk);
}
}
for g in &mut grad_scores {
*g /= self.scale;
}
if self.q.requires_grad() {
let k_slice = self.k.data().as_slice().unwrap_or(&[]);
let grad_q = matmul_compute(&grad_scores, k_slice, seq_len, seq_len, d_k);
self.q.accumulate_grad(Array1::from(grad_q));
}
if self.k.requires_grad() {
let grad_t = transpose(&grad_scores, seq_len, seq_len);
let q_slice = self.q.data().as_slice().unwrap_or(&[]);
let grad_k = matmul_compute(&grad_t, q_slice, seq_len, seq_len, d_k);
self.k.accumulate_grad(Array1::from(grad_k));
}
if let Some(op) = self.q.backward_op() {
op.backward();
}
if let Some(op) = self.k.backward_op() {
op.backward();
}
if let Some(op) = self.v.backward_op() {
op.backward();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
#[test]
fn falsify_att_001_weight_normalization_via_uniform_v() {
let seq_len = 3;
let d_k = 4;
let d_v = 4;
let v_row = vec![2.0, -1.0, 3.0, 0.5];
let v_data: Vec<f32> = v_row.iter().copied().cycle().take(seq_len * d_v).collect();
let q = Tensor::new(
Array1::from(vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9]),
false,
);
let k = Tensor::new(
Array1::from(vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9]),
false,
);
let v = Tensor::new(Array1::from(v_data), false);
let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
let out_data = output.data();
let out_slice = out_data.as_slice().expect("contiguous");
for i in 0..seq_len {
for d in 0..d_v {
let diff = (out_slice[i * d_v + d] - v_row[d]).abs();
assert!(
diff < 1e-4,
"FALSIFIED ATT-001: output[{i}][{d}] = {}, expected {} (uniform V → weights sum to 1)",
out_slice[i * d_v + d],
v_row[d]
);
}
}
}
#[test]
fn falsify_att_002_output_convexity() {
let seq_len = 3;
let d_k = 4;
let d_v = 4;
let v_data = vec![2.0, -3.0, 5.0, 1.0, -1.0, 4.0, -2.0, 7.0, 3.0, 0.0, -4.0, 6.0];
let q = Tensor::new(
Array1::from(vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9]),
false,
);
let k = Tensor::new(
Array1::from(vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9]),
false,
);
let v = Tensor::new(Array1::from(v_data.clone()), false);
let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
let out_data = output.data();
let out_slice = out_data.as_slice().expect("contiguous");
for i in 0..seq_len {
for d in 0..d_v {
let out_val = out_slice[i * d_v + d];
let v_col_min =
(0..seq_len).map(|j| v_data[j * d_v + d]).fold(f32::INFINITY, f32::min);
let v_col_max =
(0..seq_len).map(|j| v_data[j * d_v + d]).fold(f32::NEG_INFINITY, f32::max);
assert!(
out_val >= v_col_min - 1e-4 && out_val <= v_col_max + 1e-4,
"FALSIFIED ATT-002: output[{i}][{d}] = {out_val} outside V column [{v_col_min}, {v_col_max}]"
);
}
}
}
#[test]
fn falsify_att_003_scaling_factor() {
let seq_len = 2;
let d_k = 4;
let d_v = 2;
let q_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let k_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
let v_data = vec![10.0, 20.0, 30.0, 40.0];
let q = Tensor::new(Array1::from(q_data.clone()), false);
let k = Tensor::new(Array1::from(k_data.clone()), false);
let v = Tensor::new(Array1::from(v_data.clone()), false);
let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
let out_slice = output.data().as_slice().expect("contiguous").to_vec();
let scale = (d_k as f32).sqrt(); let s00 = 1.0 / scale;
let s01 = 0.0 / scale;
let max0 = s00.max(s01);
let e00 = (s00 - max0).exp();
let e01 = (s01 - max0).exp();
let sum0 = e00 + e01;
let w00 = e00 / sum0;
let w01 = e01 / sum0;
let ref_out_0_0 = w00 * v_data[0] + w01 * v_data[2];
let ref_out_0_1 = w00 * v_data[1] + w01 * v_data[3];
assert!(
(out_slice[0] - ref_out_0_0).abs() < 1e-4,
"FALSIFIED ATT-003: output[0][0] = {}, reference = {ref_out_0_0} (1/√d_k scaling)",
out_slice[0]
);
assert!(
(out_slice[1] - ref_out_0_1).abs() < 1e-4,
"FALSIFIED ATT-003: output[0][1] = {}, reference = {ref_out_0_1} (1/√d_k scaling)",
out_slice[1]
);
}
#[test]
fn falsify_att_005_single_position() {
let seq_len = 1;
let d_k = 4;
let d_v = 4;
let v_data = vec![7.0, -3.0, 2.5, 11.0];
let q = Tensor::new(Array1::from(vec![1.0, 0.0, 0.0, 0.0]), false);
let k = Tensor::new(Array1::from(vec![0.5, 0.5, 0.5, 0.5]), false);
let v = Tensor::new(Array1::from(v_data.clone()), false);
let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
let out_slice = output.data().as_slice().expect("contiguous").to_vec();
for (d, (&out_val, &v_val)) in out_slice.iter().zip(v_data.iter()).enumerate() {
let diff = (out_val - v_val).abs();
assert!(
diff < 1e-5,
"FALSIFIED ATT-005: single position output[{d}] = {out_val}, expected V[{d}] = {v_val}"
);
}
}
#[test]
fn enc_002_attention_is_bidirectional() {
let seq_len = 3;
let d_k = 4;
let d_v = 4;
let q_data = vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9];
let k_data_a = vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9];
let v_data = vec![10.0, 20.0, 30.0, 40.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let q_a = Tensor::new(Array1::from(q_data.clone()), false);
let k_a = Tensor::new(Array1::from(k_data_a.clone()), false);
let v_a = Tensor::new(Array1::from(v_data.clone()), false);
let out_a = attention(&q_a, &k_a, &v_a, seq_len, d_k, seq_len, d_v);
let slice_a = out_a.data().as_slice().expect("contiguous").to_vec();
let mut k_data_b = k_data_a;
k_data_b[8] = 99.0; let q_b = Tensor::new(Array1::from(q_data), false);
let k_b = Tensor::new(Array1::from(k_data_b), false);
let v_b = Tensor::new(Array1::from(v_data), false);
let out_b = attention(&q_b, &k_b, &v_b, seq_len, d_k, seq_len, d_v);
let slice_b = out_b.data().as_slice().expect("contiguous").to_vec();
let diff_pos0: f32 = (0..d_v).map(|d| (slice_a[d] - slice_b[d]).abs()).sum();
assert!(
diff_pos0 > 1e-3,
"ENC-002 FAILED: position 0 output unchanged when K[2] modified \
(diff={diff_pos0}). Attention has causal mask — encoder requires bidirectional."
);
}
mod att_proptest_falsify {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_att_002_prop_output_convexity(
seed in 0..1000u32,
) {
let seq = 3;
let d = 4;
let q_data: Vec<f32> = (0..seq * d)
.map(|i| ((i as f32 + seed as f32) * 0.37).sin())
.collect();
let k_data: Vec<f32> = (0..seq * d)
.map(|i| ((i as f32 + seed as f32) * 0.73).cos())
.collect();
let v_data: Vec<f32> = (0..seq * d)
.map(|i| ((i as f32 + seed as f32) * 1.23).sin() * 5.0)
.collect();
let q = Tensor::new(Array1::from(q_data), false);
let k = Tensor::new(Array1::from(k_data), false);
let v = Tensor::new(Array1::from(v_data.clone()), false);
let output = attention(&q, &k, &v, seq, d, seq, d);
let out_slice = output.data().as_slice().expect("contiguous").to_vec();
for dim in 0..d {
let v_min = (0..seq).map(|j| v_data[j * d + dim]).fold(f32::INFINITY, f32::min);
let v_max = (0..seq).map(|j| v_data[j * d + dim]).fold(f32::NEG_INFINITY, f32::max);
for i in 0..seq {
let val = out_slice[i * d + dim];
prop_assert!(
val >= v_min - 1e-4 && val <= v_max + 1e-4,
"FALSIFIED ATT-002-prop: output[{}][{}] = {} outside V [{}, {}]",
i, dim, val, v_min, v_max
);
}
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn falsify_att_001_prop_uniform_v(
seq in 2..=5usize,
seed in 0..1000u32,
) {
let d = 4;
let v_row: Vec<f32> = (0..d)
.map(|i| ((i as f32 + seed as f32) * 1.23).sin() * 5.0)
.collect();
let v_data: Vec<f32> = v_row.iter().copied().cycle().take(seq * d).collect();
let q_data: Vec<f32> = (0..seq * d)
.map(|i| ((i as f32 + seed as f32) * 0.37).sin())
.collect();
let k_data: Vec<f32> = (0..seq * d)
.map(|i| ((i as f32 + seed as f32) * 0.73).cos())
.collect();
let q = Tensor::new(Array1::from(q_data), false);
let k = Tensor::new(Array1::from(k_data), false);
let v = Tensor::new(Array1::from(v_data), false);
let output = attention(&q, &k, &v, seq, d, seq, d);
let out_slice = output.data().as_slice().expect("contiguous").to_vec();
for i in 0..seq {
for dim in 0..d {
let diff = (out_slice[i * d + dim] - v_row[dim]).abs();
prop_assert!(
diff < 1e-4,
"FALSIFIED ATT-001-prop: output[{}][{}] = {}, expected {} (uniform V)",
i, dim, out_slice[i * d + dim], v_row[dim]
);
}
}
}
}
}
}