Skip to main content

entrenar/autograd/ops/
attention.rs

1//! Attention autograd operations: scaled dot-product attention
2//!
3//! Uses CUDA GEMM for Q@K^T and Attn@V operations when available.
4
5use crate::autograd::{BackwardOp, Tensor};
6use ndarray::Array1;
7use std::cell::RefCell;
8use std::rc::Rc;
9
10// Import matmul_compute from sibling module for GPU-accelerated matrix operations
11use super::matmul::{matmul_compute, transpose};
12
13/// Scaled Dot-Product Attention (GPU-accelerated)
14///
15/// Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
16///
17/// Parameters:
18/// - q: Query matrix (seq_len x d_k, stored flattened)
19/// - k: Key matrix (seq_len x d_k, stored flattened)
20/// - v: Value matrix (seq_len x d_v, stored flattened)
21/// - seq_len: Sequence length
22/// - d_k: Dimension of queries and keys
23/// - d_v: Dimension of values
24///
25/// Returns: Tensor of shape (seq_len x d_v, stored flattened)
26pub fn attention(
27    q: &Tensor,
28    k: &Tensor,
29    v: &Tensor,
30    seq_len: usize,
31    d_k: usize,
32    _k_seq_len: usize, // Kept for API compatibility, assumes same as seq_len
33    d_v: usize,
34) -> Tensor {
35    let scale = (d_k as f32).sqrt();
36
37    // Step 1: Compute Q @ K^T (seq_len x seq_len) using GPU GEMM
38    // Q is (seq_len, d_k), K is (seq_len, d_k), K^T is (d_k, seq_len)
39    // Result: (seq_len, d_k) @ (d_k, seq_len) = (seq_len, seq_len)
40    let q_slice = q.data().as_slice().unwrap_or(&[]);
41    let k_slice = k.data().as_slice().unwrap_or(&[]);
42    let k_t = transpose(k_slice, seq_len, d_k); // K^T: (d_k, seq_len)
43    let mut scores = matmul_compute(q_slice, &k_t, seq_len, d_k, seq_len);
44
45    // Apply scaling
46    for score in &mut scores {
47        *score /= scale;
48    }
49
50    // Step 2: Apply softmax row-wise (CPU for numerical stability)
51    let mut attention_weights = vec![0.0; seq_len * seq_len];
52    for i in 0..seq_len {
53        let row_start = i * seq_len;
54        let row_end = row_start + seq_len;
55        let row = &scores[row_start..row_end];
56
57        // Softmax for numerical stability
58        let max_val = row.iter().copied().fold(f32::NEG_INFINITY, f32::max);
59        let exp_vals: Vec<f32> = row.iter().map(|&x| (x - max_val).exp()).collect();
60        let sum_exp: f32 = exp_vals.iter().sum();
61
62        for (j, &exp_val) in exp_vals.iter().enumerate() {
63            attention_weights[row_start + j] = exp_val / sum_exp;
64        }
65    }
66
67    // Step 3: Compute attention_weights @ V (seq_len x d_v) using GPU GEMM
68    // attention_weights is (seq_len, seq_len), V is (seq_len, d_v)
69    // Result: (seq_len, seq_len) @ (seq_len, d_v) = (seq_len, d_v)
70    let v_slice = v.data().as_slice().unwrap_or(&[]);
71    let output_data = matmul_compute(&attention_weights, v_slice, seq_len, seq_len, d_v);
72
73    let requires_grad = q.requires_grad() || k.requires_grad() || v.requires_grad();
74    let mut result = Tensor::new(Array1::from(output_data), requires_grad);
75
76    if requires_grad {
77        let q_clone = q.clone();
78        let k_clone = k.clone();
79        let v_clone = v.clone();
80        let backward_op = Rc::new(AttentionBackward {
81            q: q_clone,
82            k: k_clone,
83            v: v_clone,
84            attention_weights: Array1::from(attention_weights),
85            seq_len,
86            d_k,
87            d_v,
88            scale,
89            result_grad: result.grad_cell(),
90        });
91        result.set_backward_op(backward_op);
92    }
93
94    result
95}
96
97struct AttentionBackward {
98    q: Tensor,
99    k: Tensor,
100    v: Tensor,
101    attention_weights: Array1<f32>,
102    seq_len: usize,
103    d_k: usize,
104    d_v: usize,
105    scale: f32,
106    result_grad: Rc<RefCell<Option<Array1<f32>>>>,
107}
108
109impl BackwardOp for AttentionBackward {
110    fn backward(&self) {
111        if let Some(grad_output) = self.result_grad.borrow().as_ref() {
112            let seq_len = self.seq_len;
113            let d_k = self.d_k;
114            let d_v = self.d_v;
115            let grad_out_slice = grad_output.as_slice().unwrap_or(&[]);
116            let attn_slice = self.attention_weights.as_slice().unwrap_or(&[]);
117
118            // Gradient w.r.t. V: attention_weights^T @ grad_output
119            // attention_weights is (seq_len, seq_len), grad_output is (seq_len, d_v)
120            // attention_weights^T is (seq_len, seq_len)
121            // Result: (seq_len, seq_len) @ (seq_len, d_v) = (seq_len, d_v)
122            if self.v.requires_grad() {
123                let attn_t = transpose(attn_slice, seq_len, seq_len);
124                let grad_v = matmul_compute(&attn_t, grad_out_slice, seq_len, seq_len, d_v);
125                self.v.accumulate_grad(Array1::from(grad_v));
126            }
127
128            // Gradient w.r.t. attention_weights: grad_output @ V^T
129            // grad_output is (seq_len, d_v), V is (seq_len, d_v), V^T is (d_v, seq_len)
130            // Result: (seq_len, d_v) @ (d_v, seq_len) = (seq_len, seq_len)
131            let v_slice = self.v.data().as_slice().unwrap_or(&[]);
132            let v_t = transpose(v_slice, seq_len, d_v);
133            let grad_attention_weights =
134                matmul_compute(grad_out_slice, &v_t, seq_len, d_v, seq_len);
135
136            // Gradient through softmax (row-wise) - must be CPU for numerical stability
137            let mut grad_scores = vec![0.0; seq_len * seq_len];
138            for i in 0..seq_len {
139                let row_start = i * seq_len;
140                for j in 0..seq_len {
141                    let idx = row_start + j;
142                    let p_j = attn_slice[idx];
143
144                    // Softmax gradient: p_j * (grad_j - sum_k(p_k * grad_k))
145                    let mut sum_pk_gradk = 0.0;
146                    for k in 0..seq_len {
147                        let k_idx = row_start + k;
148                        sum_pk_gradk += attn_slice[k_idx] * grad_attention_weights[k_idx];
149                    }
150
151                    grad_scores[idx] = p_j * (grad_attention_weights[idx] - sum_pk_gradk);
152                }
153            }
154
155            // Gradient through scaling
156            for g in &mut grad_scores {
157                *g /= self.scale;
158            }
159
160            // Gradient w.r.t. Q: grad_scaled @ K
161            // grad_scaled is (seq_len, seq_len), K is (seq_len, d_k)
162            // Result: (seq_len, seq_len) @ (seq_len, d_k) = (seq_len, d_k)
163            if self.q.requires_grad() {
164                let k_slice = self.k.data().as_slice().unwrap_or(&[]);
165                let grad_q = matmul_compute(&grad_scores, k_slice, seq_len, seq_len, d_k);
166                self.q.accumulate_grad(Array1::from(grad_q));
167            }
168
169            // Gradient w.r.t. K: grad_scaled^T @ Q
170            // grad_scaled is (seq_len, seq_len), grad_scaled^T is (seq_len, seq_len)
171            // Q is (seq_len, d_k)
172            // Result: (seq_len, seq_len) @ (seq_len, d_k) = (seq_len, d_k)
173            if self.k.requires_grad() {
174                let grad_t = transpose(&grad_scores, seq_len, seq_len);
175                let q_slice = self.q.data().as_slice().unwrap_or(&[]);
176                let grad_k = matmul_compute(&grad_t, q_slice, seq_len, seq_len, d_k);
177                self.k.accumulate_grad(Array1::from(grad_k));
178            }
179
180            // Continue backward through the graph
181            if let Some(op) = self.q.backward_op() {
182                op.backward();
183            }
184            if let Some(op) = self.k.backward_op() {
185                op.backward();
186            }
187            if let Some(op) = self.v.backward_op() {
188                op.backward();
189            }
190        }
191    }
192}
193
194// =========================================================================
195// FALSIFY-ATT: attention-kernel-v1.yaml contract (entrenar attention)
196//
197// Five-Whys (PMAT-354):
198//   Why 1: entrenar had zero attention tests
199//   Why 2: attention was added for GPU GEMM acceleration, tested via model-level e2e
200//   Why 3: no mapping from attention-kernel-v1.yaml to entrenar test names
201//   Why 4: entrenar predates the provable-contracts YAML convention
202//   Why 5: scaled dot-product attention was "obviously correct"
203//
204// References:
205//   - provable-contracts/contracts/attention-kernel-v1.yaml
206//   - Vaswani et al. (2017) "Attention Is All You Need"
207// =========================================================================
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use ndarray::Array1;
213
214    /// FALSIFY-ATT-001: Weight normalization (indirect) — uniform V → output equals V
215    ///
216    /// If all V rows are identical [c, c, ...], any convex combination gives [c, c, ...].
217    /// This implies the weights summed to 1.0.
218    #[test]
219    fn falsify_att_001_weight_normalization_via_uniform_v() {
220        let seq_len = 3;
221        let d_k = 4;
222        let d_v = 4;
223        let v_row = vec![2.0, -1.0, 3.0, 0.5];
224        let v_data: Vec<f32> = v_row.iter().copied().cycle().take(seq_len * d_v).collect();
225
226        let q = Tensor::new(
227            Array1::from(vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9]),
228            false,
229        );
230        let k = Tensor::new(
231            Array1::from(vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9]),
232            false,
233        );
234        let v = Tensor::new(Array1::from(v_data), false);
235
236        let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
237        let out_data = output.data();
238        let out_slice = out_data.as_slice().expect("contiguous");
239
240        for i in 0..seq_len {
241            for d in 0..d_v {
242                let diff = (out_slice[i * d_v + d] - v_row[d]).abs();
243                assert!(
244                    diff < 1e-4,
245                    "FALSIFIED ATT-001: output[{i}][{d}] = {}, expected {} (uniform V → weights sum to 1)",
246                    out_slice[i * d_v + d],
247                    v_row[d]
248                );
249            }
250        }
251    }
252
253    /// FALSIFY-ATT-002: Output convexity — output bounded by min/max of V columns
254    ///
255    /// Contract: min_j(V[j][d]) ≤ output[i][d] ≤ max_j(V[j][d])
256    #[test]
257    fn falsify_att_002_output_convexity() {
258        let seq_len = 3;
259        let d_k = 4;
260        let d_v = 4;
261        let v_data = vec![2.0, -3.0, 5.0, 1.0, -1.0, 4.0, -2.0, 7.0, 3.0, 0.0, -4.0, 6.0];
262
263        let q = Tensor::new(
264            Array1::from(vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9]),
265            false,
266        );
267        let k = Tensor::new(
268            Array1::from(vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9]),
269            false,
270        );
271        let v = Tensor::new(Array1::from(v_data.clone()), false);
272
273        let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
274        let out_data = output.data();
275        let out_slice = out_data.as_slice().expect("contiguous");
276
277        for i in 0..seq_len {
278            for d in 0..d_v {
279                let out_val = out_slice[i * d_v + d];
280
281                let v_col_min =
282                    (0..seq_len).map(|j| v_data[j * d_v + d]).fold(f32::INFINITY, f32::min);
283                let v_col_max =
284                    (0..seq_len).map(|j| v_data[j * d_v + d]).fold(f32::NEG_INFINITY, f32::max);
285
286                assert!(
287                    out_val >= v_col_min - 1e-4 && out_val <= v_col_max + 1e-4,
288                    "FALSIFIED ATT-002: output[{i}][{d}] = {out_val} outside V column [{v_col_min}, {v_col_max}]"
289                );
290            }
291        }
292    }
293
294    /// FALSIFY-ATT-003: Scaling factor — uses 1/√d_k not 1/d_k
295    ///
296    /// With d_k=1, both scalings are identical (1/√1 = 1/1 = 1).
297    /// With d_k=4, 1/√4 = 0.5 but 1/4 = 0.25 — outputs differ.
298    /// We verify by comparing attention output against a manual reference.
299    #[test]
300    fn falsify_att_003_scaling_factor() {
301        let seq_len = 2;
302        let d_k = 4;
303        let d_v = 2;
304
305        let q_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
306        let k_data = vec![1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0];
307        let v_data = vec![10.0, 20.0, 30.0, 40.0];
308
309        let q = Tensor::new(Array1::from(q_data.clone()), false);
310        let k = Tensor::new(Array1::from(k_data.clone()), false);
311        let v = Tensor::new(Array1::from(v_data.clone()), false);
312
313        let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
314        let out_slice = output.data().as_slice().expect("contiguous").to_vec();
315
316        // Manual reference with correct 1/√d_k scaling
317        let scale = (d_k as f32).sqrt(); // 2.0
318                                         // Q[0] = [1,0,0,0], K[0] = [1,0,0,0], K[1] = [0,0,1,0]
319                                         // scores[0] = [dot(Q0,K0)/scale, dot(Q0,K1)/scale] = [1.0/2.0, 0.0/2.0] = [0.5, 0.0]
320        let s00 = 1.0 / scale;
321        let s01 = 0.0 / scale;
322        let max0 = s00.max(s01);
323        let e00 = (s00 - max0).exp();
324        let e01 = (s01 - max0).exp();
325        let sum0 = e00 + e01;
326        let w00 = e00 / sum0;
327        let w01 = e01 / sum0;
328        let ref_out_0_0 = w00 * v_data[0] + w01 * v_data[2];
329        let ref_out_0_1 = w00 * v_data[1] + w01 * v_data[3];
330
331        assert!(
332            (out_slice[0] - ref_out_0_0).abs() < 1e-4,
333            "FALSIFIED ATT-003: output[0][0] = {}, reference = {ref_out_0_0} (1/√d_k scaling)",
334            out_slice[0]
335        );
336        assert!(
337            (out_slice[1] - ref_out_0_1).abs() < 1e-4,
338            "FALSIFIED ATT-003: output[0][1] = {}, reference = {ref_out_0_1} (1/√d_k scaling)",
339            out_slice[1]
340        );
341    }
342
343    /// FALSIFY-ATT-005: Single position — softmax of single score is 1.0, output = V
344    #[test]
345    fn falsify_att_005_single_position() {
346        let seq_len = 1;
347        let d_k = 4;
348        let d_v = 4;
349        let v_data = vec![7.0, -3.0, 2.5, 11.0];
350
351        let q = Tensor::new(Array1::from(vec![1.0, 0.0, 0.0, 0.0]), false);
352        let k = Tensor::new(Array1::from(vec![0.5, 0.5, 0.5, 0.5]), false);
353        let v = Tensor::new(Array1::from(v_data.clone()), false);
354
355        let output = attention(&q, &k, &v, seq_len, d_k, seq_len, d_v);
356        let out_slice = output.data().as_slice().expect("contiguous").to_vec();
357
358        for (d, (&out_val, &v_val)) in out_slice.iter().zip(v_data.iter()).enumerate() {
359            let diff = (out_val - v_val).abs();
360            assert!(
361                diff < 1e-5,
362                "FALSIFIED ATT-005: single position output[{d}] = {out_val}, expected V[{d}] = {v_val}"
363            );
364        }
365    }
366
367    /// ENC-002: Verify attention is bidirectional (no causal mask).
368    ///
369    /// In causal attention, position 0 cannot attend to position 1+.
370    /// In bidirectional attention, every position attends to every position.
371    /// We verify by checking that changing a later token affects earlier outputs.
372    #[test]
373    fn enc_002_attention_is_bidirectional() {
374        let seq_len = 3;
375        let d_k = 4;
376        let d_v = 4;
377
378        let q_data = vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5, 0.4, -0.6, 0.3, 0.9];
379        let k_data_a = vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9];
380        let v_data = vec![10.0, 20.0, 30.0, 40.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
381
382        // Run A: original K
383        let q_a = Tensor::new(Array1::from(q_data.clone()), false);
384        let k_a = Tensor::new(Array1::from(k_data_a.clone()), false);
385        let v_a = Tensor::new(Array1::from(v_data.clone()), false);
386        let out_a = attention(&q_a, &k_a, &v_a, seq_len, d_k, seq_len, d_v);
387        let slice_a = out_a.data().as_slice().expect("contiguous").to_vec();
388
389        // Run B: modify K at position 2 (last token)
390        let mut k_data_b = k_data_a;
391        k_data_b[8] = 99.0; // K[2][0] = 99.0 (was 0.6)
392        let q_b = Tensor::new(Array1::from(q_data), false);
393        let k_b = Tensor::new(Array1::from(k_data_b), false);
394        let v_b = Tensor::new(Array1::from(v_data), false);
395        let out_b = attention(&q_b, &k_b, &v_b, seq_len, d_k, seq_len, d_v);
396        let slice_b = out_b.data().as_slice().expect("contiguous").to_vec();
397
398        // Position 0's output MUST change — it attends bidirectionally to position 2
399        let diff_pos0: f32 = (0..d_v).map(|d| (slice_a[d] - slice_b[d]).abs()).sum();
400        assert!(
401            diff_pos0 > 1e-3,
402            "ENC-002 FAILED: position 0 output unchanged when K[2] modified \
403             (diff={diff_pos0}). Attention has causal mask — encoder requires bidirectional."
404        );
405    }
406
407    mod att_proptest_falsify {
408        use super::*;
409        use proptest::prelude::*;
410
411        // FALSIFY-ATT-002-prop: Output convexity for random V
412        proptest! {
413            #![proptest_config(ProptestConfig::with_cases(100))]
414
415            #[test]
416            fn falsify_att_002_prop_output_convexity(
417                seed in 0..1000u32,
418            ) {
419                let seq = 3;
420                let d = 4;
421
422                let q_data: Vec<f32> = (0..seq * d)
423                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
424                    .collect();
425                let k_data: Vec<f32> = (0..seq * d)
426                    .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
427                    .collect();
428                let v_data: Vec<f32> = (0..seq * d)
429                    .map(|i| ((i as f32 + seed as f32) * 1.23).sin() * 5.0)
430                    .collect();
431
432                let q = Tensor::new(Array1::from(q_data), false);
433                let k = Tensor::new(Array1::from(k_data), false);
434                let v = Tensor::new(Array1::from(v_data.clone()), false);
435
436                let output = attention(&q, &k, &v, seq, d, seq, d);
437                let out_slice = output.data().as_slice().expect("contiguous").to_vec();
438
439                for dim in 0..d {
440                    let v_min = (0..seq).map(|j| v_data[j * d + dim]).fold(f32::INFINITY, f32::min);
441                    let v_max = (0..seq).map(|j| v_data[j * d + dim]).fold(f32::NEG_INFINITY, f32::max);
442
443                    for i in 0..seq {
444                        let val = out_slice[i * d + dim];
445                        prop_assert!(
446                            val >= v_min - 1e-4 && val <= v_max + 1e-4,
447                            "FALSIFIED ATT-002-prop: output[{}][{}] = {} outside V [{}, {}]",
448                            i, dim, val, v_min, v_max
449                        );
450                    }
451                }
452            }
453        }
454
455        // FALSIFY-ATT-001-prop: Uniform V -> output equals V (weights sum to 1)
456        proptest! {
457            #![proptest_config(ProptestConfig::with_cases(100))]
458
459            #[test]
460            fn falsify_att_001_prop_uniform_v(
461                seq in 2..=5usize,
462                seed in 0..1000u32,
463            ) {
464                let d = 4;
465                let v_row: Vec<f32> = (0..d)
466                    .map(|i| ((i as f32 + seed as f32) * 1.23).sin() * 5.0)
467                    .collect();
468                let v_data: Vec<f32> = v_row.iter().copied().cycle().take(seq * d).collect();
469
470                let q_data: Vec<f32> = (0..seq * d)
471                    .map(|i| ((i as f32 + seed as f32) * 0.37).sin())
472                    .collect();
473                let k_data: Vec<f32> = (0..seq * d)
474                    .map(|i| ((i as f32 + seed as f32) * 0.73).cos())
475                    .collect();
476
477                let q = Tensor::new(Array1::from(q_data), false);
478                let k = Tensor::new(Array1::from(k_data), false);
479                let v = Tensor::new(Array1::from(v_data), false);
480
481                let output = attention(&q, &k, &v, seq, d, seq, d);
482                let out_slice = output.data().as_slice().expect("contiguous").to_vec();
483
484                for i in 0..seq {
485                    for dim in 0..d {
486                        let diff = (out_slice[i * d + dim] - v_row[dim]).abs();
487                        prop_assert!(
488                            diff < 1e-4,
489                            "FALSIFIED ATT-001-prop: output[{}][{}] = {}, expected {} (uniform V)",
490                            i, dim, out_slice[i * d + dim], v_row[dim]
491                        );
492                    }
493                }
494            }
495        }
496    }
497}