Skip to main content

provable_contracts/kernels/
attention.rs

1//! Scaled dot-product attention kernel.
2//!
3//! Matches `attention-kernel-v1.yaml`.
4//! Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
5//!
6//! Each function provides one of three backends:
7//! - `fn attention_scalar(...)` -- Pure Rust scalar reference (ground truth)
8//! - `unsafe fn attention_avx2(...)` -- AVX2 SIMD implementation
9//! - `fn attention_ptx() -> &'static str` -- PTX assembly source string
10
11use super::ops;
12
13// ────────────────────────────────────────────────────────────────────────────
14// Scalar implementation
15// ────────────────────────────────────────────────────────────────────────────
16
17/// Scaled dot-product attention (scalar reference).
18///
19/// Q is n x d_k, K is m x d_k, V is m x d_v, output is n x d_v.
20///
21/// Step 1: scores = Q * K^T / sqrt(d_k)  -- n x m matrix
22/// Step 2: softmax each row of scores
23/// Step 3: output = scores * V            -- n x d_v matrix
24///
25/// # Panics
26/// Panics if dimensions do not match expected sizes.
27pub fn attention_scalar(
28    q: &[f32],
29    k: &[f32],
30    v: &[f32],
31    n: usize,
32    m: usize,
33    d_k: usize,
34    d_v: usize,
35    output: &mut [f32],
36) {
37    assert_eq!(
38        q.len(),
39        n * d_k,
40        "Q dimension mismatch: expected {} got {}",
41        n * d_k,
42        q.len()
43    );
44    assert_eq!(
45        k.len(),
46        m * d_k,
47        "K dimension mismatch: expected {} got {}",
48        m * d_k,
49        k.len()
50    );
51    assert_eq!(
52        v.len(),
53        m * d_v,
54        "V dimension mismatch: expected {} got {}",
55        m * d_v,
56        v.len()
57    );
58    assert_eq!(
59        output.len(),
60        n * d_v,
61        "output dimension mismatch: expected {} got {}",
62        n * d_v,
63        output.len()
64    );
65
66    // Step 1: Compute scores = Q * K^T / sqrt(d_k), shape n x m
67    let mut scores = vec![0.0f32; n * m];
68    ops::score_matrix(q, k, n, m, d_k, &mut scores);
69
70    // Step 2: Softmax each row
71    ops::softmax_rows(&mut scores, n, m);
72
73    // Step 3: output = scores * V, shape n x d_v
74    ops::matmul_sv(&scores, v, n, m, d_v, output);
75}
76
77// ────────────────────────────────────────────────────────────────────────────
78// AVX2 implementation
79// ────────────────────────────────────────────────────────────────────────────
80
81/// AVX2 scaled dot-product attention -- delegates to scalar.
82///
83/// Attention is a composition of matmul and softmax; the scalar implementation
84/// is already efficient for the composed operation.
85///
86/// # Safety
87/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
88///
89/// # Panics
90/// Panics if dimensions do not match expected sizes.
91#[cfg(target_arch = "x86_64")]
92#[target_feature(enable = "avx2")]
93pub unsafe fn attention_avx2(
94    q: &[f32],
95    k: &[f32],
96    v: &[f32],
97    n: usize,
98    m: usize,
99    d_k: usize,
100    d_v: usize,
101    output: &mut [f32],
102) {
103    attention_scalar(q, k, v, n, m, d_k, d_v, output);
104}
105
106include!("attention_ptx.rs");
107
108// ────────────────────────────────────────────────────────────────────────────
109// Tests
110// ────────────────────────────────────────────────────────────────────────────
111
112#[cfg(test)]
113mod tests {
114    use super::super::ops::sequential_floats;
115    use super::super::ulp::assert_ulp_eq;
116    use super::*;
117    use proptest::prelude::*;
118
119    // ── Single query, single key ────────────────────────────────────────
120
121    #[test]
122    fn test_attention_single_query_single_key() {
123        // n=1 query, m=1 key: softmax of single score = 1.0, output = V
124        let d_k = 4;
125        let d_v = 3;
126        let q = vec![1.0, 0.0, 1.0, 0.0];
127        let k = vec![1.0, 0.0, 1.0, 0.0];
128        let v = vec![2.0, 3.0, 4.0];
129        let mut output = vec![0.0f32; d_v];
130
131        attention_scalar(&q, &k, &v, 1, 1, d_k, d_v, &mut output);
132
133        // softmax of a single element = 1.0, so output = 1.0 * V
134        assert_ulp_eq(&output, &v, 0);
135    }
136
137    // ── Uniform attention ───────────────────────────────────────────────
138
139    #[test]
140    fn test_attention_uniform_scores() {
141        // When all scores are equal, softmax gives uniform weights = 1/m.
142        // Output should be the mean of V rows.
143        let n = 1;
144        let m = 3;
145        let d_k = 2;
146        let d_v = 2;
147
148        // Q and K arranged so all dot products are equal
149        let q = vec![1.0, 0.0];
150        let k = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]; // all same K row
151        let v = vec![3.0, 6.0, 6.0, 9.0, 9.0, 12.0]; // V rows
152        let mut output = vec![0.0f32; d_v];
153
154        attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
155
156        // Mean of V rows: [(3+6+9)/3, (6+9+12)/3] = [6.0, 9.0]
157        let expected = [6.0, 9.0];
158        for (a, b) in output.iter().zip(expected.iter()) {
159            assert!((a - b).abs() < 1e-5, "expected ~{b}, got {a}");
160        }
161    }
162
163    // ── Known 2-query, 2-key attention ──────────────────────────────────
164
165    #[test]
166    fn test_attention_two_queries_two_keys() {
167        let n = 2;
168        let m = 2;
169        let d_k = 2;
170        let d_v = 2;
171
172        // Q = [[1,0],[0,1]], K = [[1,0],[0,1]]
173        // QK^T = [[1,0],[0,1]] (identity before scaling)
174        // scale = 1/sqrt(2)
175        // scores = [[1/sqrt(2), 0], [0, 1/sqrt(2)]]
176        // After softmax: dominant weight on diagonal
177        let q = vec![1.0, 0.0, 0.0, 1.0];
178        let k = vec![1.0, 0.0, 0.0, 1.0];
179        let v = vec![10.0, 20.0, 30.0, 40.0];
180        let mut output = vec![0.0f32; n * d_v];
181
182        attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
183
184        // First query attends more to first key, second to second key
185        // Exact values depend on softmax but first row should be closer to [10,20]
186        assert!(
187            output[0] < 20.0,
188            "first query, first dim should lean toward V[0]"
189        );
190        assert!(
191            output[2] > 20.0,
192            "second query, first dim should lean toward V[1]"
193        );
194    }
195
196    // ── Dimension assertions ────────────────────────────────────────────
197
198    #[test]
199    #[should_panic(expected = "Q dimension mismatch")]
200    fn test_attention_bad_q_dim() {
201        let mut output = vec![0.0f32; 2];
202        attention_scalar(&[1.0], &[1.0, 2.0], &[1.0, 2.0], 1, 1, 2, 2, &mut output);
203    }
204
205    #[test]
206    #[should_panic(expected = "K dimension mismatch")]
207    fn test_attention_bad_k_dim() {
208        let mut output = vec![0.0f32; 2];
209        attention_scalar(&[1.0, 2.0], &[1.0], &[1.0, 2.0], 1, 1, 2, 2, &mut output);
210    }
211
212    #[test]
213    #[should_panic(expected = "V dimension mismatch")]
214    fn test_attention_bad_v_dim() {
215        let mut output = vec![0.0f32; 2];
216        attention_scalar(&[1.0, 2.0], &[1.0, 2.0], &[1.0], 1, 1, 2, 2, &mut output);
217    }
218
219    // ── Property-based tests ────────────────────────────────────────────
220
221    proptest! {
222        #[test]
223        fn prop_attention_output_bounded(
224            n in 1usize..4,
225            m in 1usize..4,
226            d_k in 1usize..4,
227            d_v in 1usize..4,
228        ) {
229            let q = sequential_floats(n*d_k, 0.1);
230            let k = sequential_floats(m*d_k, 0.1);
231            let v = sequential_floats(m*d_v, 0.1);
232            let mut output = vec![0.0f32; n * d_v];
233
234            attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
235
236            // Output is convex combination of V rows, so each output element
237            // must be between min and max of corresponding V column
238            for j in 0..d_v {
239                let v_col_min = (0..m).map(|r| v[r * d_v + j]).fold(f32::INFINITY, f32::min);
240                let v_col_max = (0..m).map(|r| v[r * d_v + j]).fold(f32::NEG_INFINITY, f32::max);
241                for i in 0..n {
242                    let val = output[i * d_v + j];
243                    prop_assert!(
244                        val >= v_col_min - 1e-5 && val <= v_col_max + 1e-5,
245                        "output[{i},{j}] = {val} not in V column range [{v_col_min}, {v_col_max}]"
246                    );
247                }
248            }
249        }
250
251        #[test]
252        fn prop_attention_softmax_rows_sum_to_one(
253            n in 1usize..3,
254            m in 1usize..5,
255            d_k in 1usize..4,
256        ) {
257            let d_v = 1; // use d_v=1 so output = softmax weights * V column
258            let q = sequential_floats(n*d_k, 0.1);
259            let k = sequential_floats(m*d_k, 0.1);
260            // V = all ones => output[i] = sum of softmax weights = 1.0
261            let v = vec![1.0f32; m * d_v];
262            let mut output = vec![0.0f32; n * d_v];
263
264            attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
265
266            for i in 0..n {
267                prop_assert!(
268                    (output[i] - 1.0).abs() < 1e-5,
269                    "softmax row {i} should sum to 1.0, got {}",
270                    output[i]
271                );
272            }
273        }
274    }
275
276    // ── AVX2 parity test ────────────────────────────────────────────────
277
278    #[cfg(target_arch = "x86_64")]
279    #[test]
280    fn test_attention_avx2_parity() {
281        if !is_x86_feature_detected!("avx2") {
282            return;
283        }
284        let n = 3;
285        let m = 4;
286        let d_k = 5;
287        let d_v = 6;
288        let q = sequential_floats(n * d_k, 0.1);
289        let k = sequential_floats(m * d_k, 0.2);
290        let v = sequential_floats(m * d_v, 0.15);
291
292        let mut scalar_out = vec![0.0f32; n * d_v];
293        let mut avx2_out = vec![0.0f32; n * d_v];
294
295        attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut scalar_out);
296        unsafe { attention_avx2(&q, &k, &v, n, m, d_k, d_v, &mut avx2_out) };
297
298        // Composed operations allow up to 8 ULP
299        assert_ulp_eq(&scalar_out, &avx2_out, 8);
300    }
301
302    // ── PTX structural tests ────────────────────────────────────────────
303
304    #[test]
305    fn test_attention_ptx_structure() {
306        let ptx = attention_ptx();
307        assert!(ptx.contains(".version 8.5"), "missing PTX version");
308        assert!(ptx.contains(".target sm_90"), "missing PTX target");
309        assert!(
310            ptx.contains(".entry attention_kernel"),
311            "missing entry point"
312        );
313        assert!(ptx.contains("ret;"), "missing ret instruction");
314        assert!(ptx.contains(".shared"), "missing shared memory declaration");
315        assert!(ptx.contains("bar.sync"), "missing barrier synchronization");
316        assert!(ptx.contains("ex2.approx.f32"), "missing exp approximation");
317        assert!(ptx.contains("fma.rn.f32"), "missing FMA instruction");
318        let open = ptx.matches('{').count();
319        let close = ptx.matches('}').count();
320        assert_eq!(
321            open, close,
322            "unbalanced braces: {open} open vs {close} close"
323        );
324    }
325
326    #[test]
327    fn test_attention_ptx_nonempty() {
328        assert!(!attention_ptx().is_empty());
329    }
330
331    // ── Softmax helper test ─────────────────────────────────────────────
332
333    #[test]
334    fn test_softmax_row_uniform() {
335        let mut row = vec![1.0, 1.0, 1.0, 1.0];
336        ops::softmax_row(&mut row);
337        for &v in &row {
338            assert!(
339                (v - 0.25).abs() < 1e-6,
340                "uniform input should give 0.25, got {v}"
341            );
342        }
343    }
344
345    #[test]
346    fn test_softmax_row_single() {
347        let mut row = vec![42.0];
348        ops::softmax_row(&mut row);
349        assert!(
350            (row[0] - 1.0).abs() < 1e-6,
351            "single element softmax should be 1.0"
352        );
353    }
354
355    #[test]
356    fn test_softmax_row_sums_to_one() {
357        let mut row = vec![1.0, 2.0, 3.0, 4.0, 5.0];
358        ops::softmax_row(&mut row);
359        let sum: f32 = row.iter().sum();
360        assert!(
361            (sum - 1.0).abs() < 1e-6,
362            "softmax should sum to 1.0, got {sum}"
363        );
364    }
365
366    #[test]
367    fn test_softmax_row_monotonic() {
368        let mut row = vec![1.0, 2.0, 3.0];
369        ops::softmax_row(&mut row);
370        assert!(row[0] < row[1], "softmax should preserve order");
371        assert!(row[1] < row[2], "softmax should preserve order");
372    }
373}