1use super::ops;
12
13pub fn attention_scalar(
28 q: &[f32],
29 k: &[f32],
30 v: &[f32],
31 n: usize,
32 m: usize,
33 d_k: usize,
34 d_v: usize,
35 output: &mut [f32],
36) {
37 assert_eq!(
38 q.len(),
39 n * d_k,
40 "Q dimension mismatch: expected {} got {}",
41 n * d_k,
42 q.len()
43 );
44 assert_eq!(
45 k.len(),
46 m * d_k,
47 "K dimension mismatch: expected {} got {}",
48 m * d_k,
49 k.len()
50 );
51 assert_eq!(
52 v.len(),
53 m * d_v,
54 "V dimension mismatch: expected {} got {}",
55 m * d_v,
56 v.len()
57 );
58 assert_eq!(
59 output.len(),
60 n * d_v,
61 "output dimension mismatch: expected {} got {}",
62 n * d_v,
63 output.len()
64 );
65
66 let mut scores = vec![0.0f32; n * m];
68 ops::score_matrix(q, k, n, m, d_k, &mut scores);
69
70 ops::softmax_rows(&mut scores, n, m);
72
73 ops::matmul_sv(&scores, v, n, m, d_v, output);
75}
76
77#[cfg(target_arch = "x86_64")]
92#[target_feature(enable = "avx2")]
93pub unsafe fn attention_avx2(
94 q: &[f32],
95 k: &[f32],
96 v: &[f32],
97 n: usize,
98 m: usize,
99 d_k: usize,
100 d_v: usize,
101 output: &mut [f32],
102) {
103 attention_scalar(q, k, v, n, m, d_k, d_v, output);
104}
105
106include!("attention_ptx.rs");
107
108#[cfg(test)]
113mod tests {
114 use super::super::ops::sequential_floats;
115 use super::super::ulp::assert_ulp_eq;
116 use super::*;
117 use proptest::prelude::*;
118
119 #[test]
122 fn test_attention_single_query_single_key() {
123 let d_k = 4;
125 let d_v = 3;
126 let q = vec![1.0, 0.0, 1.0, 0.0];
127 let k = vec![1.0, 0.0, 1.0, 0.0];
128 let v = vec![2.0, 3.0, 4.0];
129 let mut output = vec![0.0f32; d_v];
130
131 attention_scalar(&q, &k, &v, 1, 1, d_k, d_v, &mut output);
132
133 assert_ulp_eq(&output, &v, 0);
135 }
136
137 #[test]
140 fn test_attention_uniform_scores() {
141 let n = 1;
144 let m = 3;
145 let d_k = 2;
146 let d_v = 2;
147
148 let q = vec![1.0, 0.0];
150 let k = vec![1.0, 0.0, 1.0, 0.0, 1.0, 0.0]; let v = vec![3.0, 6.0, 6.0, 9.0, 9.0, 12.0]; let mut output = vec![0.0f32; d_v];
153
154 attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
155
156 let expected = [6.0, 9.0];
158 for (a, b) in output.iter().zip(expected.iter()) {
159 assert!((a - b).abs() < 1e-5, "expected ~{b}, got {a}");
160 }
161 }
162
163 #[test]
166 fn test_attention_two_queries_two_keys() {
167 let n = 2;
168 let m = 2;
169 let d_k = 2;
170 let d_v = 2;
171
172 let q = vec![1.0, 0.0, 0.0, 1.0];
178 let k = vec![1.0, 0.0, 0.0, 1.0];
179 let v = vec![10.0, 20.0, 30.0, 40.0];
180 let mut output = vec![0.0f32; n * d_v];
181
182 attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
183
184 assert!(
187 output[0] < 20.0,
188 "first query, first dim should lean toward V[0]"
189 );
190 assert!(
191 output[2] > 20.0,
192 "second query, first dim should lean toward V[1]"
193 );
194 }
195
196 #[test]
199 #[should_panic(expected = "Q dimension mismatch")]
200 fn test_attention_bad_q_dim() {
201 let mut output = vec![0.0f32; 2];
202 attention_scalar(&[1.0], &[1.0, 2.0], &[1.0, 2.0], 1, 1, 2, 2, &mut output);
203 }
204
205 #[test]
206 #[should_panic(expected = "K dimension mismatch")]
207 fn test_attention_bad_k_dim() {
208 let mut output = vec![0.0f32; 2];
209 attention_scalar(&[1.0, 2.0], &[1.0], &[1.0, 2.0], 1, 1, 2, 2, &mut output);
210 }
211
212 #[test]
213 #[should_panic(expected = "V dimension mismatch")]
214 fn test_attention_bad_v_dim() {
215 let mut output = vec![0.0f32; 2];
216 attention_scalar(&[1.0, 2.0], &[1.0, 2.0], &[1.0], 1, 1, 2, 2, &mut output);
217 }
218
219 proptest! {
222 #[test]
223 fn prop_attention_output_bounded(
224 n in 1usize..4,
225 m in 1usize..4,
226 d_k in 1usize..4,
227 d_v in 1usize..4,
228 ) {
229 let q = sequential_floats(n*d_k, 0.1);
230 let k = sequential_floats(m*d_k, 0.1);
231 let v = sequential_floats(m*d_v, 0.1);
232 let mut output = vec![0.0f32; n * d_v];
233
234 attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
235
236 for j in 0..d_v {
239 let v_col_min = (0..m).map(|r| v[r * d_v + j]).fold(f32::INFINITY, f32::min);
240 let v_col_max = (0..m).map(|r| v[r * d_v + j]).fold(f32::NEG_INFINITY, f32::max);
241 for i in 0..n {
242 let val = output[i * d_v + j];
243 prop_assert!(
244 val >= v_col_min - 1e-5 && val <= v_col_max + 1e-5,
245 "output[{i},{j}] = {val} not in V column range [{v_col_min}, {v_col_max}]"
246 );
247 }
248 }
249 }
250
251 #[test]
252 fn prop_attention_softmax_rows_sum_to_one(
253 n in 1usize..3,
254 m in 1usize..5,
255 d_k in 1usize..4,
256 ) {
257 let d_v = 1; let q = sequential_floats(n*d_k, 0.1);
259 let k = sequential_floats(m*d_k, 0.1);
260 let v = vec![1.0f32; m * d_v];
262 let mut output = vec![0.0f32; n * d_v];
263
264 attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut output);
265
266 for i in 0..n {
267 prop_assert!(
268 (output[i] - 1.0).abs() < 1e-5,
269 "softmax row {i} should sum to 1.0, got {}",
270 output[i]
271 );
272 }
273 }
274 }
275
276 #[cfg(target_arch = "x86_64")]
279 #[test]
280 fn test_attention_avx2_parity() {
281 if !is_x86_feature_detected!("avx2") {
282 return;
283 }
284 let n = 3;
285 let m = 4;
286 let d_k = 5;
287 let d_v = 6;
288 let q = sequential_floats(n * d_k, 0.1);
289 let k = sequential_floats(m * d_k, 0.2);
290 let v = sequential_floats(m * d_v, 0.15);
291
292 let mut scalar_out = vec![0.0f32; n * d_v];
293 let mut avx2_out = vec![0.0f32; n * d_v];
294
295 attention_scalar(&q, &k, &v, n, m, d_k, d_v, &mut scalar_out);
296 unsafe { attention_avx2(&q, &k, &v, n, m, d_k, d_v, &mut avx2_out) };
297
298 assert_ulp_eq(&scalar_out, &avx2_out, 8);
300 }
301
302 #[test]
305 fn test_attention_ptx_structure() {
306 let ptx = attention_ptx();
307 assert!(ptx.contains(".version 8.5"), "missing PTX version");
308 assert!(ptx.contains(".target sm_90"), "missing PTX target");
309 assert!(
310 ptx.contains(".entry attention_kernel"),
311 "missing entry point"
312 );
313 assert!(ptx.contains("ret;"), "missing ret instruction");
314 assert!(ptx.contains(".shared"), "missing shared memory declaration");
315 assert!(ptx.contains("bar.sync"), "missing barrier synchronization");
316 assert!(ptx.contains("ex2.approx.f32"), "missing exp approximation");
317 assert!(ptx.contains("fma.rn.f32"), "missing FMA instruction");
318 let open = ptx.matches('{').count();
319 let close = ptx.matches('}').count();
320 assert_eq!(
321 open, close,
322 "unbalanced braces: {open} open vs {close} close"
323 );
324 }
325
326 #[test]
327 fn test_attention_ptx_nonempty() {
328 assert!(!attention_ptx().is_empty());
329 }
330
331 #[test]
334 fn test_softmax_row_uniform() {
335 let mut row = vec![1.0, 1.0, 1.0, 1.0];
336 ops::softmax_row(&mut row);
337 for &v in &row {
338 assert!(
339 (v - 0.25).abs() < 1e-6,
340 "uniform input should give 0.25, got {v}"
341 );
342 }
343 }
344
345 #[test]
346 fn test_softmax_row_single() {
347 let mut row = vec![42.0];
348 ops::softmax_row(&mut row);
349 assert!(
350 (row[0] - 1.0).abs() < 1e-6,
351 "single element softmax should be 1.0"
352 );
353 }
354
355 #[test]
356 fn test_softmax_row_sums_to_one() {
357 let mut row = vec![1.0, 2.0, 3.0, 4.0, 5.0];
358 ops::softmax_row(&mut row);
359 let sum: f32 = row.iter().sum();
360 assert!(
361 (sum - 1.0).abs() < 1e-6,
362 "softmax should sum to 1.0, got {sum}"
363 );
364 }
365
366 #[test]
367 fn test_softmax_row_monotonic() {
368 let mut row = vec![1.0, 2.0, 3.0];
369 ops::softmax_row(&mut row);
370 assert!(row[0] < row[1], "softmax should preserve order");
371 assert!(row[1] < row[2], "softmax should preserve order");
372 }
373}