1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
12#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
13struct SoftmaxParams {
14 rows: u32,
15 cols: u32,
16 _pad: [u32; 2],
17}
18
19const SHADER_SOFTMAX_STATS: &str = "
21struct P { rows: u32, cols: u32, _p0: u32, _p1: u32, }
22@group(0) @binding(0) var<uniform> p: P;
23@group(0) @binding(1) var<storage, read> input: array<f32>;
24@group(0) @binding(2) var<storage, read_write> stats: array<f32>;
25@compute @workgroup_size(256)
26fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
27 let row = gid.x;
28 if row >= p.rows { return; }
29
30 let base = row * p.cols;
31 var mx: f32 = input[base];
32 for (var i: u32 = 1u; i < p.cols; i++) {
33 mx = max(mx, input[base + i]);
34 }
35 var sum: f32 = 0.0;
36 for (var i: u32 = 0u; i < p.cols; i++) {
37 sum += exp(input[base + i] - mx);
38 }
39 stats[row * 2u] = mx;
40 stats[row * 2u + 1u] = sum;
41}
42";
43
44const SHADER_SOFTMAX_APPLY: &str = "
46struct P { rows: u32, cols: u32, _p0: u32, _p1: u32, }
47@group(0) @binding(0) var<uniform> p: P;
48@group(0) @binding(1) var<storage, read> input: array<f32>;
49@group(0) @binding(2) var<storage, read> stats: array<f32>;
50@group(0) @binding(3) var<storage, read_write> out: array<f32>;
51@compute @workgroup_size(256)
52fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
53 let idx = gid.x + gid.y * 65535u * 256u;
54 if idx >= p.rows * p.cols { return; }
55
56 let row = idx / p.cols;
57 let mx = stats[row * 2u];
58 let sum = stats[row * 2u + 1u];
59 out[idx] = exp(input[idx] - mx) / sum;
60}
61";
62
63#[repr(C)]
66#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
67struct ReduceParams {
68 n: u32,
69 _pad: [u32; 3],
70}
71
72const SHADER_MSE_SUM: &str = "
73struct P { n: u32, _p0: u32, _p1: u32, _p2: u32, }
74@group(0) @binding(0) var<uniform> p: P;
75@group(0) @binding(1) var<storage, read> pred: array<f32>;
76@group(0) @binding(2) var<storage, read> tgt: array<f32>;
77@group(0) @binding(3) var<storage, read_write> out: array<f32>;
78@compute @workgroup_size(1)
79fn main() {
80 var sum: f32 = 0.0;
81 for (var i: u32 = 0u; i < p.n; i++) {
82 let d = pred[i] - tgt[i];
83 sum += d * d;
84 }
85 out[0] = sum / f32(p.n);
86}
87";
88
89impl GpuDevice {
90 pub fn softmax(&self, input: &GpuBuffer, rows: u32, cols: u32) -> Result<GpuBuffer> {
92 ensure!(input.len == (rows * cols) as usize);
93
94 let params = SoftmaxParams { rows, cols, _pad: [0; 2] };
95
96 let stats = self.alloc((rows * 2) as usize);
98 self.dispatch_shader(
99 SHADER_SOFTMAX_STATS, Some("softmax_stats"),
100 ¶ms, &[input], &stats,
101 super::dispatch_1d(rows),
102 );
103
104 let total = rows * cols;
106 let out = self.alloc(total as usize);
107
108 let params_buf = self.upload_uniform(¶ms);
109 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
110 label: Some("softmax_apply"),
111 source: wgpu::ShaderSource::Wgsl(SHADER_SOFTMAX_APPLY.into()),
112 });
113 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
114 label: Some("softmax_apply"),
115 layout: None,
116 module: &shader,
117 entry_point: Some("main"),
118 compilation_options: Default::default(),
119 cache: None,
120 });
121 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
122 label: None,
123 layout: &pipeline.get_bind_group_layout(0),
124 entries: &[
125 wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
126 wgpu::BindGroupEntry { binding: 1, resource: input.buffer.as_entire_binding() },
127 wgpu::BindGroupEntry { binding: 2, resource: stats.buffer.as_entire_binding() },
128 wgpu::BindGroupEntry { binding: 3, resource: out.buffer.as_entire_binding() },
129 ],
130 });
131 let (wx, wy, wz) = super::dispatch_1d(total);
132 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
133 {
134 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
135 label: Some("softmax_apply"),
136 timestamp_writes: None,
137 });
138 pass.set_pipeline(&pipeline);
139 pass.set_bind_group(0, &bind_group, &[]);
140 pass.dispatch_workgroups(wx, wy, wz);
141 }
142 self.queue.submit(Some(encoder.finish()));
143
144 Ok(out)
145 }
146
147 pub fn scaled_dot_product_attention(
150 &self,
151 q: &GpuBuffer, k: &GpuBuffer, v: &GpuBuffer,
152 batch_heads: u32, seq_len: u32, d_k: u32,
153 ) -> Result<GpuBuffer> {
154 let kt = self.transpose(k, batch_heads, seq_len, d_k, 1)?;
156
157 let scores = self.batch_matmul(q, &kt, batch_heads, seq_len, seq_len, d_k)?;
159
160 let scale = 1.0 / (d_k as f32).sqrt();
162 let scaled = self.scale(&scores, scale)?;
163
164 let attn = self.softmax(&scaled, batch_heads * seq_len, seq_len)?;
166
167 self.batch_matmul(&attn, v, batch_heads, seq_len, d_k, seq_len)
169 }
170
171 pub fn mse_loss(&self, pred: &GpuBuffer, target: &GpuBuffer) -> Result<GpuBuffer> {
173 ensure!(pred.len == target.len, "mse: length mismatch");
174 let out = self.alloc(1);
175 let params = ReduceParams { n: pred.len as u32, _pad: [0; 3] };
176 self.dispatch_shader(
177 SHADER_MSE_SUM, Some("mse"),
178 ¶ms, &[pred, target], &out,
179 (1, 1, 1),
180 );
181 Ok(out)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use crate::ops::assert_approx;
189
190 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
191
192 fn cpu_softmax(input: &[f32], rows: usize, cols: usize) -> Vec<f32> {
194 let mut out = vec![0.0f32; input.len()];
195 for r in 0..rows {
196 let row = &input[r * cols..(r + 1) * cols];
197 let mx = row.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
198 let sum: f32 = row.iter().map(|&x| (x - mx).exp()).sum();
199 for c in 0..cols {
200 out[r * cols + c] = (row[c] - mx).exp() / sum;
201 }
202 }
203 out
204 }
205
206 fn cpu_attention(q: &[f32], k: &[f32], v: &[f32], seq: usize, dk: usize) -> Vec<f32> {
208 let scale = 1.0 / (dk as f32).sqrt();
209 let mut scores = vec![0.0f32; seq * seq];
211 for i in 0..seq {
212 for j in 0..seq {
213 let mut s = 0.0;
214 for d in 0..dk { s += q[i * dk + d] * k[j * dk + d]; }
215 scores[i * seq + j] = s * scale;
216 }
217 }
218 let attn = cpu_softmax(&scores, seq, seq);
219 let mut out = vec![0.0f32; seq * dk];
221 for i in 0..seq {
222 for d in 0..dk {
223 let mut s = 0.0;
224 for j in 0..seq { s += attn[i * seq + j] * v[j * dk + d]; }
225 out[i * dk + d] = s;
226 }
227 }
228 out
229 }
230
231 #[test]
232 fn test_softmax_vs_cpu() {
233 let data: Vec<f32> = vec![1.0, 2.0, 3.0, -1.0, 0.0, 1.0, 5.0, 5.0, 5.0];
234 let expected = cpu_softmax(&data, 3, 3);
235 let result = dev().read(&dev().softmax(&dev().upload(&data), 3, 3).unwrap()).unwrap();
236 assert_approx(&result, &expected, 1e-4);
237 for r in 0..3 {
239 let sum: f32 = result[r*3..(r+1)*3].iter().sum();
240 assert!((sum - 1.0).abs() < 1e-4, "row {r} sum = {sum}");
241 }
242 }
243
244 #[test]
245 fn test_softmax_large_values() {
246 let data = vec![1000.0, 1001.0, 1002.0];
248 let expected = cpu_softmax(&data, 1, 3);
249 let result = dev().read(&dev().softmax(&dev().upload(&data), 1, 3).unwrap()).unwrap();
250 assert_approx(&result, &expected, 1e-4);
251 let sum: f32 = result.iter().sum();
252 assert!((sum - 1.0).abs() < 1e-4, "sum = {sum}");
253 }
254
255 #[test]
256 fn test_softmax_single_element() {
257 let result = dev().read(&dev().softmax(&dev().upload(&[42.0]), 1, 1).unwrap()).unwrap();
258 assert_approx(&result, &[1.0], 1e-5);
259 }
260
261 #[test]
262 fn test_attention_vs_cpu() {
263 let q: Vec<f32> = (0..12).map(|i| (i as f32) * 0.1 - 0.3).collect();
265 let k: Vec<f32> = (0..12).map(|i| (i as f32) * 0.05 + 0.1).collect();
266 let v: Vec<f32> = (0..12).map(|i| (i as f32) * 0.2 - 0.5).collect();
267 let expected = cpu_attention(&q, &k, &v, 3, 4);
268 let result = dev().read(&dev().scaled_dot_product_attention(
269 &dev().upload(&q), &dev().upload(&k), &dev().upload(&v), 1, 3, 4
270 ).unwrap()).unwrap();
271 assert_approx(&result, &expected, 1e-3);
272 }
273
274 #[test]
275 fn test_attention_uniform_qk() {
276 let q = vec![1.0, 1.0, 1.0, 1.0]; let k = q.clone();
279 let v = vec![0.0, 10.0, 20.0, 30.0]; let expected = cpu_attention(&q, &k, &v, 2, 2);
281 let result = dev().read(&dev().scaled_dot_product_attention(
282 &dev().upload(&q), &dev().upload(&k), &dev().upload(&v), 1, 2, 2
283 ).unwrap()).unwrap();
284 assert_approx(&result, &expected, 1e-3);
285 }
286
287 #[test]
288 fn test_mse_loss() {
289 let pred = dev().upload(&[1.0, 2.0, 3.0]);
290 let target = dev().upload(&[1.5, 2.5, 3.5]);
291 let result = dev().read(&dev().mse_loss(&pred, &target).unwrap()).unwrap();
292 assert_approx(&result, &[0.25], 1e-5);
294 }
295
296 #[test]
297 fn test_mse_loss_zero() {
298 let a = dev().upload(&[1.0, 2.0, 3.0]);
299 let result = dev().read(&dev().mse_loss(&a, &a).unwrap()).unwrap();
300 assert_approx(&result, &[0.0], 1e-6);
301 }
302
303 #[test]
304 fn test_mse_loss_known_value() {
305 let pred = dev().upload(&[0.0, 0.0, 0.0]);
306 let target = dev().upload(&[1.0, 2.0, 3.0]);
307 let result = dev().read(&dev().mse_loss(&pred, &target).unwrap()).unwrap();
308 assert_approx(&result, &[14.0 / 3.0], 1e-5);
309 }
310
311 #[test]
312 fn test_softmax_size_mismatch() {
313 let input = dev().upload(&[1.0, 2.0, 3.0]); assert!(dev().softmax(&input, 2, 3).is_err()); }
316
317 #[test]
318 fn test_mse_loss_length_mismatch() {
319 let pred = dev().upload(&[1.0, 2.0]);
320 let target = dev().upload(&[1.0, 2.0, 3.0]);
321 assert!(dev().mse_loss(&pred, &target).is_err());
322 }
323
324 #[test]
325 fn test_mse_loss_single_element() {
326 let result = dev().read(&dev().mse_loss(&dev().upload(&[5.0]), &dev().upload(&[3.0])).unwrap()).unwrap();
327 assert_approx(&result, &[4.0], 1e-5); }
329
330 #[test]
331 fn test_mse_loss_negative_values() {
332 let pred = dev().upload(&[-1.0, -2.0]);
333 let target = dev().upload(&[1.0, 2.0]);
334 let result = dev().read(&dev().mse_loss(&pred, &target).unwrap()).unwrap();
335 assert_approx(&result, &[10.0], 1e-5);
337 }
338}