1use crate::autograd::{BackwardOp, Tensor};
6use ndarray::Array1;
7use std::cell::RefCell;
8use std::rc::Rc;
9
10use super::matmul::{matmul_compute, transpose};
12
13pub fn attention(
27 q: &Tensor,
28 k: &Tensor,
29 v: &Tensor,
30 seq_len: usize,
31 d_k: usize,
32 _k_seq_len: usize, d_v: usize,
34) -> Tensor {
35 let scale = (d_k as f32).sqrt();
36
37 let q_slice = q.data().as_slice().unwrap_or(&[]);
41 let k_slice = k.data().as_slice().unwrap_or(&[]);
42 let k_t = transpose(k_slice, seq_len, d_k); let mut scores = matmul_compute(q_slice, &k_t, seq_len, d_k, seq_len);
44
45 for score in &mut scores {
47 *score /= scale;
48 }
49
50 let mut attention_weights = vec![0.0; seq_len * seq_len];
52 for i in 0..seq_len {
53 let row_start = i * seq_len;
54 let row_end = row_start + seq_len;
55 let row = &scores[row_start..row_end];
56
57 let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
59 let exp_vals: Vec<f32> = row.iter().map(|&x| (x - max_val).exp()).collect();
60 let sum_exp: f32 = exp_vals.iter().sum();
61
62 for (j, &exp_val) in exp_vals.iter().enumerate() {
63 attention_weights[row_start + j] = exp_val / sum_exp;
64 }
65 }
66
67 let v_slice = v.data().as_slice().unwrap_or(&[]);
71 let output_data = matmul_compute(&attention_weights, v_slice, seq_len, seq_len, d_v);
72
73 let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
74 let mut result = Tensor::new(Array1::from(output_data), requires_grad);
75
76 if requires_grad {
77 let q_clone = q.clone();
78 let k_clone = k.clone();
79 let v_clone = v.clone();
80 let backward_op = Rc::new(AttentionBackward {
81 q: q_clone,
82 k: k_clone,
83 v: v_clone,
84 attention_weights: Array1::from(attention_weights),
85 seq_len,
86 d_k,
87 d_v,
88 scale,
89 result_grad: result.grad_cell(),
90 });
91 result.set_backward_op(backward_op);
92 }
93
94 result
95}
96
97struct AttentionBackward {
98 q: Tensor,
99 k: Tensor,
100 v: Tensor,
101 attention_weights: Array1<f32>,
102 seq_len: usize,
103 d_k: usize,
104 d_v: usize,
105 scale: f32,
106 result_grad: Rc<RefCell<Option<Array1<f32>>>>,
107}
108
109impl BackwardOp for AttentionBackward {
110 fn backward(&self) {
111 if let Some(grad_output) = self.result_grad.borrow().as_ref() {
112 let seq_len = self.seq_len;
113 let d_k = self.d_k;
114 let d_v = self.d_v;
115 let grad_out_slice = grad_output.as_slice().unwrap_or(&[]);
116 let attn_slice = self.attention_weights.as_slice().unwrap_or(&[]);
117
118 if self.v.requires_grad() {
123 let attn_t = transpose(attn_slice, seq_len, seq_len);
124 let grad_v = matmul_compute(&attn_t, grad_out_slice, seq_len, seq_len, d_v);
125 self.v.accumulate_grad(Array1::from(grad_v));
126 }
127
128 let v_slice = self.v.data().as_slice().unwrap_or(&[]);
132 let v_t = transpose(v_slice, seq_len, d_v);
133 let grad_attention_weights =
134 matmul_compute(grad_out_slice, &v_t, seq_len, d_v, seq_len);
135
136 let mut grad_scores = vec![0.0; seq_len * seq_len];
138 for i in 0..seq_len {
139 let row_start = i * seq_len;
140 for j in 0..seq_len {
141 let idx = row_start + j;
142 let p_j = attn_slice[idx];
143
144 let mut sum_pk_gradk = 0.0;
146 for k in 0..seq_len {
147 let k_idx = row_start + k;
148 sum_pk_gradk += attn_slice[k_idx] * grad_attention_weights[k_idx];
149 }
150
151 grad_scores[idx] = p_j * (grad_attention_weights[idx] - sum_pk_gradk);
152 }
153 }
154
155 for g in &mut grad_scores {
157 *g /= self.scale;
158 }
159
160 if self.q.requires_grad() {
164 let k_slice = self.k.data().as_slice().unwrap_or(&[]);
165 let grad_q = matmul_compute(&grad_scores, k_slice, seq_len, seq_len, d_k);
166 self.q.accumulate_grad(Array1::from(grad_q));
167 }
168
169 if self.k.requires_grad() {
174 let grad_t = transpose(&grad_scores, seq_len, seq_len);
175 let q_slice = self.q.data().as_slice().unwrap_or(&[]);
176 let grad_k = matmul_compute(&grad_t, q_slice, seq_len, seq_len, d_k);
177 self.k.accumulate_grad(Array1::from(grad_k));
178 }
179
180 if let Some(op) = self.q.backward_op() {
182 op.backward();
183 }
184 if let Some(op) = self.k.backward_op() {
185 op.backward();
186 }
187 if let Some(op) = self.v.backward_op() {
188 op.backward();
189 }
190 }
191 }
192}
193
194#[cfg(test)]
210mod tests {
211 use super::*;
212 use ndarray::Array1;
213
214 #[test]
219 fn falsify_att_001_weight_normalization_via_uniform_v() {
220 let seq_len = 3;
221 let d_k = 4;
222 let d_v = 4;
223 let v_row = vec![2.0, -1.0, 3.0, 0.5];
224 let v_data: Vec<f32> = v_row.iter().copied().cycle().take(seq_len * d_v).collect();
225
226 let q = Tensor::new(
227 Array1::from(vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9]),
228 false,
229 );
230 let k = Tensor::new(
231 Array1::from(vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9]),
232 false,
233 );
234 let v = Tensor::new(Array1::from(v_data), false);
235
236 let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
237 let out_data = output.data();
238 let out_slice = out_data.as_slice().expect("contiguous");
239
240 for i in 0..seq_len {
241 for d in 0..d_v {
242 let diff = (out_slice[i * d_v + d] - v_row[d]).abs();
243 assert!(
244 diff < 1e-4,
245 "FALSIFIED ATT-001: output[{i}][{d}] = {}, expected {} (uniform V → weights sum to 1)",
246 out_slice[i * d_v + d],
247 v_row[d]
248 );
249 }
250 }
251 }
252
253 #[test]
257 fn falsify_att_002_output_convexity() {
258 let seq_len = 3;
259 let d_k = 4;
260 let d_v = 4;
261 let v_data = vec![2.0, -3.0, 5.0, 1.0, -1.0, 4.0, -2.0, 7.0, 3.0, 0.0, -4.0, 6.0];
262
263 let q = Tensor::new(
264 Array1::from(vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9]),
265 false,
266 );
267 let k = Tensor::new(
268 Array1::from(vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9]),
269 false,
270 );
271 let v = Tensor::new(Array1::from(v_data.clone()), false);
272
273 let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
274 let out_data = output.data();
275 let out_slice = out_data.as_slice().expect("contiguous");
276
277 for i in 0..seq_len {
278 for d in 0..d_v {
279 let out_val = out_slice[i * d_v + d];
280
281 let v_col_min =
282 (0..seq_len).map(|j| v_data[j * d_v + d]).fold(f32::INFINITY, f32::min);
283 let v_col_max =
284 (0..seq_len).map(|j| v_data[j * d_v + d]).fold(f32::NEG_INFINITY, f32::max);
285
286 assert!(
287 out_val >= v_col_min - 1e-4 && out_val <= v_col_max + 1e-4,
288 "FALSIFIED ATT-002: output[{i}][{d}] = {out_val} outside V column [{v_col_min}, {v_col_max}]"
289 );
290 }
291 }
292 }
293
294 #[test]
300 fn falsify_att_003_scaling_factor() {
301 let seq_len = 2;
302 let d_k = 4;
303 let d_v = 2;
304
305 let q_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
306 let k_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
307 let v_data = vec![10.0, 20.0, 30.0, 40.0];
308
309 let q = Tensor::new(Array1::from(q_data.clone()), false);
310 let k = Tensor::new(Array1::from(k_data.clone()), false);
311 let v = Tensor::new(Array1::from(v_data.clone()), false);
312
313 let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
314 let out_slice = output.data().as_slice().expect("contiguous").to_vec();
315
316 let scale = (d_k as f32).sqrt(); let s00 = 1.0 / scale;
321 let s01 = 0.0 / scale;
322 let max0 = s00.max(s01);
323 let e00 = (s00 - max0).exp();
324 let e01 = (s01 - max0).exp();
325 let sum0 = e00 + e01;
326 let w00 = e00 / sum0;
327 let w01 = e01 / sum0;
328 let ref_out_0_0 = w00 * v_data[0] + w01 * v_data[2];
329 let ref_out_0_1 = w00 * v_data[1] + w01 * v_data[3];
330
331 assert!(
332 (out_slice[0] - ref_out_0_0).abs() < 1e-4,
333 "FALSIFIED ATT-003: output[0][0] = {}, reference = {ref_out_0_0} (1/√d_k scaling)",
334 out_slice[0]
335 );
336 assert!(
337 (out_slice[1] - ref_out_0_1).abs() < 1e-4,
338 "FALSIFIED ATT-003: output[0][1] = {}, reference = {ref_out_0_1} (1/√d_k scaling)",
339 out_slice[1]
340 );
341 }
342
343 #[test]
345 fn falsify_att_005_single_position() {
346 let seq_len = 1;
347 let d_k = 4;
348 let d_v = 4;
349 let v_data = vec![7.0, -3.0, 2.5, 11.0];
350
351 let q = Tensor::new(Array1::from(vec![1.0, 0.0, 0.0, 0.0]), false);
352 let k = Tensor::new(Array1::from(vec![0.5, 0.5, 0.5, 0.5]), false);
353 let v = Tensor::new(Array1::from(v_data.clone()), false);
354
355 let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
356 let out_slice = output.data().as_slice().expect("contiguous").to_vec();
357
358 for (d, (&out_val, &v_val)) in out_slice.iter().zip(v_data.iter()).enumerate() {
359 let diff = (out_val - v_val).abs();
360 assert!(
361 diff < 1e-5,
362 "FALSIFIED ATT-005: single position output[{d}] = {out_val}, expected V[{d}] = {v_val}"
363 );
364 }
365 }
366
367 #[test]
373 fn enc_002_attention_is_bidirectional() {
374 let seq_len = 3;
375 let d_k = 4;
376 let d_v = 4;
377
378 let q_data = vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9];
379 let k_data_a = vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9];
380 let v_data = vec![10.0, 20.0, 30.0, 40.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
381
382 let q_a = Tensor::new(Array1::from(q_data.clone()), false);
384 let k_a = Tensor::new(Array1::from(k_data_a.clone()), false);
385 let v_a = Tensor::new(Array1::from(v_data.clone()), false);
386 let out_a = attention(&q_a, &k_a, &v_a, seq_len, d_k, seq_len, d_v);
387 let slice_a = out_a.data().as_slice().expect("contiguous").to_vec();
388
389 let mut k_data_b = k_data_a;
391 k_data_b[8] = 99.0; let q_b = Tensor::new(Array1::from(q_data), false);
393 let k_b = Tensor::new(Array1::from(k_data_b), false);
394 let v_b = Tensor::new(Array1::from(v_data), false);
395 let out_b = attention(&q_b, &k_b, &v_b, seq_len, d_k, seq_len, d_v);
396 let slice_b = out_b.data().as_slice().expect("contiguous").to_vec();
397
398 let diff_pos0: f32 = (0..d_v).map(|d| (slice_a[d] - slice_b[d]).abs()).sum();
400 assert!(
401 diff_pos0 > 1e-3,
402 "ENC-002 FAILED: position 0 output unchanged when K[2] modified \
403 (diff={diff_pos0}). Attention has causal mask — encoder requires bidirectional."
404 );
405 }
406
407 mod att_proptest_falsify {
408 use super::*;
409 use proptest::prelude::*;
410
411 proptest! {
413 #![proptest_config(ProptestConfig::with_cases(100))]
414
415 #[test]
416 fn falsify_att_002_prop_output_convexity(
417 seed in 0..1000u32,
418 ) {
419 let seq = 3;
420 let d = 4;
421
422 let q_data: Vec<f32> = (0..seq * d)
423 .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
424 .collect();
425 let k_data: Vec<f32> = (0..seq * d)
426 .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
427 .collect();
428 let v_data: Vec<f32> = (0..seq * d)
429 .map(|i| ((i as f32 + seed as f32) * 1.23).sin() * 5.0)
430 .collect();
431
432 let q = Tensor::new(Array1::from(q_data), false);
433 let k = Tensor::new(Array1::from(k_data), false);
434 let v = Tensor::new(Array1::from(v_data.clone()), false);
435
436 let output = attention(&q, &k, &v, seq, d, seq, d);
437 let out_slice = output.data().as_slice().expect("contiguous").to_vec();
438
439 for dim in 0..d {
440 let v_min = (0..seq).map(|j| v_data[j * d + dim]).fold(f32::INFINITY, f32::min);
441 let v_max = (0..seq).map(|j| v_data[j * d + dim]).fold(f32::NEG_INFINITY, f32::max);
442
443 for i in 0..seq {
444 let val = out_slice[i * d + dim];
445 prop_assert!(
446 val >= v_min - 1e-4 && val <= v_max + 1e-4,
447 "FALSIFIED ATT-002-prop: output[{}][{}] = {} outside V [{}, {}]",
448 i, dim, val, v_min, v_max
449 );
450 }
451 }
452 }
453 }
454
455 proptest! {
457 #![proptest_config(ProptestConfig::with_cases(100))]
458
459 #[test]
460 fn falsify_att_001_prop_uniform_v(
461 seq in 2..=5usize,
462 seed in 0..1000u32,
463 ) {
464 let d = 4;
465 let v_row: Vec<f32> = (0..d)
466 .map(|i| ((i as f32 + seed as f32) * 1.23).sin() * 5.0)
467 .collect();
468 let v_data: Vec<f32> = v_row.iter().copied().cycle().take(seq * d).collect();
469
470 let q_data: Vec<f32> = (0..seq * d)
471 .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
472 .collect();
473 let k_data: Vec<f32> = (0..seq * d)
474 .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
475 .collect();
476
477 let q = Tensor::new(Array1::from(q_data), false);
478 let k = Tensor::new(Array1::from(k_data), false);
479 let v = Tensor::new(Array1::from(v_data), false);
480
481 let output = attention(&q, &k, &v, seq, d, seq, d);
482 let out_slice = output.data().as_slice().expect("contiguous").to_vec();
483
484 for i in 0..seq {
485 for dim in 0..d {
486 let diff = (out_slice[i * d + dim] - v_row[dim]).abs();
487 prop_assert!(
488 diff < 1e-4,
489 "FALSIFIED ATT-001-prop: output[{}][{}] = {}, expected {} (uniform V)",
490 i, dim, out_slice[i * d + dim], v_row[dim]
491 );
492 }
493 }
494 }
495 }
496 }
497}