Skip to main content

any_gpu/ops/
attention.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Softmax and scaled dot-product attention.
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9// --- Softmax (two-pass) ---
10
11#[repr(C)]
12#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
13struct SoftmaxParams {
14    rows: u32,
15    cols: u32,
16    _pad: [u32; 2],
17}
18
19// Pass 1: one thread per row. Compute max and sum(exp(x - max)).
20const SHADER_SOFTMAX_STATS: &str = "
21struct P { rows: u32, cols: u32, _p0: u32, _p1: u32, }
22@group(0) @binding(0) var<uniform> p: P;
23@group(0) @binding(1) var<storage, read> input: array<f32>;
24@group(0) @binding(2) var<storage, read_write> stats: array<f32>;
25@compute @workgroup_size(256)
26fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
27    let row = gid.x;
28    if row >= p.rows { return; }
29
30    let base = row * p.cols;
31    var mx: f32 = input[base];
32    for (var i: u32 = 1u; i < p.cols; i++) {
33        mx = max(mx, input[base + i]);
34    }
35    var sum: f32 = 0.0;
36    for (var i: u32 = 0u; i < p.cols; i++) {
37        sum += exp(input[base + i] - mx);
38    }
39    stats[row * 2u] = mx;
40    stats[row * 2u + 1u] = sum;
41}
42";
43
44// Pass 2: one thread per element. exp(x - max) / sum.
45const SHADER_SOFTMAX_APPLY: &str = "
46struct P { rows: u32, cols: u32, _p0: u32, _p1: u32, }
47@group(0) @binding(0) var<uniform> p: P;
48@group(0) @binding(1) var<storage, read> input: array<f32>;
49@group(0) @binding(2) var<storage, read> stats: array<f32>;
50@group(0) @binding(3) var<storage, read_write> out: array<f32>;
51@compute @workgroup_size(256)
52fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
53    let idx = gid.x + gid.y * 65535u * 256u;
54    if idx >= p.rows * p.cols { return; }
55
56    let row = idx / p.cols;
57    let mx = stats[row * 2u];
58    let sum = stats[row * 2u + 1u];
59    out[idx] = exp(input[idx] - mx) / sum;
60}
61";
62
63// --- MSE Loss ---
64
65#[repr(C)]
66#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
67struct ReduceParams {
68    n: u32,
69    _pad: [u32; 3],
70}
71
72const SHADER_MSE_SUM: &str = "
73struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
74@group(0) @binding(0) var<uniform> p: P;
75@group(0) @binding(1) var<storage, read> pred: array<f32>;
76@group(0) @binding(2) var<storage, read> tgt: array<f32>;
77@group(0) @binding(3) var<storage, read_write> out: array<f32>;
78@compute @workgroup_size(1)
79fn main() {
80    var sum: f32 = 0.0;
81    for (var i: u32 = 0u; i < p.n; i++) {
82        let d = pred[i] - tgt[i];
83        sum += d * d;
84    }
85    out[0] = sum / f32(p.n);
86}
87";
88
89impl GpuDevice {
90    /// Softmax along the last dimension. Input shape: [rows, cols].
91    pub fn softmax(&self, input: &GpuBuffer, rows: u32, cols: u32) -> Result<GpuBuffer> {
92        ensure!(input.len == (rows * cols) as usize);
93
94        let params = SoftmaxParams { rows, cols, _pad: [0; 2] };
95
96        // Pass 1: per-row max and sum
97        let stats = self.alloc((rows * 2) as usize);
98        self.dispatch_shader(
99            SHADER_SOFTMAX_STATS, Some("softmax_stats"),
100            &params, &[input], &stats,
101            super::dispatch_1d(rows),
102        );
103
104        // Pass 2: normalize
105        let total = rows * cols;
106        let out = self.alloc(total as usize);
107
108        let params_buf = self.upload_uniform(&params);
109        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
110            label: Some("softmax_apply"),
111            source: wgpu::ShaderSource::Wgsl(SHADER_SOFTMAX_APPLY.into()),
112        });
113        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
114            label: Some("softmax_apply"),
115            layout: None,
116            module: &shader,
117            entry_point: Some("main"),
118            compilation_options: Default::default(),
119            cache: None,
120        });
121        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
122            label: None,
123            layout: &pipeline.get_bind_group_layout(0),
124            entries: &[
125                wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
126                wgpu::BindGroupEntry { binding: 1, resource: input.buffer.as_entire_binding() },
127                wgpu::BindGroupEntry { binding: 2, resource: stats.buffer.as_entire_binding() },
128                wgpu::BindGroupEntry { binding: 3, resource: out.buffer.as_entire_binding() },
129            ],
130        });
131        let (wx, wy, wz) = super::dispatch_1d(total);
132        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
133        {
134            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
135                label: Some("softmax_apply"),
136                timestamp_writes: None,
137            });
138            pass.set_pipeline(&pipeline);
139            pass.set_bind_group(0, &bind_group, &[]);
140            pass.dispatch_workgroups(wx, wy, wz);
141        }
142        self.queue.submit(Some(encoder.finish()));
143
144        Ok(out)
145    }
146
147    /// Scaled dot-product attention: softmax(Q @ K^T / sqrt(d_k)) @ V.
148    /// Q,K,V: [batch_heads, seq_len, d_k]. Returns [batch_heads, seq_len, d_k].
149    pub fn scaled_dot_product_attention(
150        &self,
151        q: &GpuBuffer, k: &GpuBuffer, v: &GpuBuffer,
152        batch_heads: u32, seq_len: u32, d_k: u32,
153    ) -> Result<GpuBuffer> {
154        // 1. K^T: [batch_heads, d_k, seq_len]
155        let kt = self.transpose(k, batch_heads, seq_len, d_k, 1)?;
156
157        // 2. scores = Q @ K^T: [batch_heads, seq_len, seq_len]
158        let scores = self.batch_matmul(q, &kt, batch_heads, seq_len, seq_len, d_k)?;
159
160        // 3. Scale by 1/sqrt(d_k)
161        let scale = 1.0 / (d_k as f32).sqrt();
162        let scaled = self.scale(&scores, scale)?;
163
164        // 4. Softmax over last dim (each row of seq_len)
165        let attn = self.softmax(&scaled, batch_heads * seq_len, seq_len)?;
166
167        // 5. attn @ V: [batch_heads, seq_len, d_k]
168        self.batch_matmul(&attn, v, batch_heads, seq_len, d_k, seq_len)
169    }
170
171    /// MSE loss: mean((pred - target)^2). Returns a 1-element buffer.
172    pub fn mse_loss(&self, pred: &GpuBuffer, target: &GpuBuffer) -> Result<GpuBuffer> {
173        ensure!(pred.len == target.len, "mse: length mismatch");
174        let out = self.alloc(1);
175        let params = ReduceParams { n: pred.len as u32, _pad: [0; 3] };
176        self.dispatch_shader(
177            SHADER_MSE_SUM, Some("mse"),
178            &params, &[pred, target], &out,
179            (1, 1, 1),
180        );
181        Ok(out)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188    use crate::ops::assert_approx;
189
190    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
191
192    // CPU reference softmax
193    fn cpu_softmax(input: &[f32], rows: usize, cols: usize) -> Vec<f32> {
194        let mut out = vec![0.0f32; input.len()];
195        for r in 0..rows {
196            let row = &input[r * cols..(r + 1) * cols];
197            let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
198            let sum: f32 = row.iter().map(|&x| (x - mx).exp()).sum();
199            for c in 0..cols {
200                out[r * cols + c] = (row[c] - mx).exp() / sum;
201            }
202        }
203        out
204    }
205
206    // CPU reference attention
207    fn cpu_attention(q: &[f32], k: &[f32], v: &[f32], seq: usize, dk: usize) -> Vec<f32> {
208        let scale = 1.0 / (dk as f32).sqrt();
209        // scores = Q @ K^T * scale: [seq, seq]
210        let mut scores = vec![0.0f32; seq * seq];
211        for i in 0..seq {
212            for j in 0..seq {
213                let mut s = 0.0;
214                for d in 0..dk { s += q[i * dk + d] * k[j * dk + d]; }
215                scores[i * seq + j] = s * scale;
216            }
217        }
218        let attn = cpu_softmax(&scores, seq, seq);
219        // out = attn @ V: [seq, dk]
220        let mut out = vec![0.0f32; seq * dk];
221        for i in 0..seq {
222            for d in 0..dk {
223                let mut s = 0.0;
224                for j in 0..seq { s += attn[i * seq + j] * v[j * dk + d]; }
225                out[i * dk + d] = s;
226            }
227        }
228        out
229    }
230
231    #[test]
232    fn test_softmax_vs_cpu() {
233        let data: Vec<f32> = vec![1.0, 2.0, 3.0, -1.0, 0.0, 1.0, 5.0, 5.0, 5.0];
234        let expected = cpu_softmax(&data, 3, 3);
235        let result = dev().read(&dev().softmax(&dev().upload(&data), 3, 3).unwrap()).unwrap();
236        assert_approx(&result, &expected, 1e-4);
237        // Verify each row sums to 1.0
238        for r in 0..3 {
239            let sum: f32 = result[r*3..(r+1)*3].iter().sum();
240            assert!((sum - 1.0).abs() < 1e-4, "row {r} sum = {sum}");
241        }
242    }
243
244    #[test]
245    fn test_softmax_large_values() {
246        // Numerical stability: large values should not overflow
247        let data = vec![1000.0, 1001.0, 1002.0];
248        let expected = cpu_softmax(&data, 1, 3);
249        let result = dev().read(&dev().softmax(&dev().upload(&data), 1, 3).unwrap()).unwrap();
250        assert_approx(&result, &expected, 1e-4);
251        let sum: f32 = result.iter().sum();
252        assert!((sum - 1.0).abs() < 1e-4, "sum = {sum}");
253    }
254
255    #[test]
256    fn test_softmax_single_element() {
257        let result = dev().read(&dev().softmax(&dev().upload(&[42.0]), 1, 1).unwrap()).unwrap();
258        assert_approx(&result, &[1.0], 1e-5);
259    }
260
261    #[test]
262    fn test_attention_vs_cpu() {
263        // 1 head, seq=3, d_k=4 — fully verified against CPU reference
264        let q: Vec<f32> = (0..12).map(|i| (i as f32) * 0.1 - 0.3).collect();
265        let k: Vec<f32> = (0..12).map(|i| (i as f32) * 0.05 + 0.1).collect();
266        let v: Vec<f32> = (0..12).map(|i| (i as f32) * 0.2 - 0.5).collect();
267        let expected = cpu_attention(&q, &k, &v, 3, 4);
268        let result = dev().read(&dev().scaled_dot_product_attention(
269            &dev().upload(&q), &dev().upload(&k), &dev().upload(&v), 1, 3, 4
270        ).unwrap()).unwrap();
271        assert_approx(&result, &expected, 1e-3);
272    }
273
274    #[test]
275    fn test_attention_uniform_qk() {
276        // When Q and K are identical uniform vectors, attention is uniform -> output = mean(V) per position
277        let q = vec![1.0, 1.0, 1.0, 1.0]; // seq=2, dk=2, both rows identical
278        let k = q.clone();
279        let v = vec![0.0, 10.0, 20.0, 30.0]; // seq=2, dk=2
280        let expected = cpu_attention(&q, &k, &v, 2, 2);
281        let result = dev().read(&dev().scaled_dot_product_attention(
282            &dev().upload(&q), &dev().upload(&k), &dev().upload(&v), 1, 2, 2
283        ).unwrap()).unwrap();
284        assert_approx(&result, &expected, 1e-3);
285    }
286
287    #[test]
288    fn test_mse_loss() {
289        let pred = dev().upload(&[1.0, 2.0, 3.0]);
290        let target = dev().upload(&[1.5, 2.5, 3.5]);
291        let result = dev().read(&dev().mse_loss(&pred, &target).unwrap()).unwrap();
292        // MSE = ((0.5)^2 * 3) / 3 = 0.25
293        assert_approx(&result, &[0.25], 1e-5);
294    }
295
296    #[test]
297    fn test_mse_loss_zero() {
298        let a = dev().upload(&[1.0, 2.0, 3.0]);
299        let result = dev().read(&dev().mse_loss(&a, &a).unwrap()).unwrap();
300        assert_approx(&result, &[0.0], 1e-6);
301    }
302
303    #[test]
304    fn test_mse_loss_known_value() {
305        let pred = dev().upload(&[0.0, 0.0, 0.0]);
306        let target = dev().upload(&[1.0, 2.0, 3.0]);
307        let result = dev().read(&dev().mse_loss(&pred, &target).unwrap()).unwrap();
308        assert_approx(&result, &[14.0 / 3.0], 1e-5);
309    }
310
311    #[test]
312    fn test_softmax_size_mismatch() {
313        let input = dev().upload(&[1.0, 2.0, 3.0]); // 3 elements
314        assert!(dev().softmax(&input, 2, 3).is_err()); // expects 6
315    }
316
317    #[test]
318    fn test_mse_loss_length_mismatch() {
319        let pred = dev().upload(&[1.0, 2.0]);
320        let target = dev().upload(&[1.0, 2.0, 3.0]);
321        assert!(dev().mse_loss(&pred, &target).is_err());
322    }
323
324    #[test]
325    fn test_mse_loss_single_element() {
326        let result = dev().read(&dev().mse_loss(&dev().upload(&[5.0]), &dev().upload(&[3.0])).unwrap()).unwrap();
327        assert_approx(&result, &[4.0], 1e-5); // (5-3)^2 / 1 = 4
328    }
329
330    #[test]
331    fn test_mse_loss_negative_values() {
332        let pred = dev().upload(&[-1.0, -2.0]);
333        let target = dev().upload(&[1.0, 2.0]);
334        let result = dev().read(&dev().mse_loss(&pred, &target).unwrap()).unwrap();
335        // ((-1-1)^2 + (-2-2)^2) / 2 = (4 + 16) / 2 = 10
336        assert_approx(&result, &[10.0], 1e-5);
337    }
338}