Skip to main content

trueno/backends/gpu/device/
backward.rs

1//! GPU backward (gradient) operations for training
2//!
3//! Contract: wgpu-training-v1.yaml (FALSIFY-WGPU-001)
4//!
5//! Dispatches WGSL backward shaders to compute gradients on GPU.
6//! All operations match CPU reference within ε < 1e-4.
7
8#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
9use super::super::runtime;
10use super::super::shaders;
11use super::GpuDevice;
12
13impl GpuDevice {
14    /// SiLU backward on GPU: grad_input[i] = grad_output[i] * silu'(input[i])
15    ///
16    /// # Contract (FALSIFY-WGPU-001)
17    ///
18    /// - **Precondition**: input.len() == grad_output.len() == grad_input.len()
19    /// - **Postcondition**: max|grad_input_gpu - grad_input_cpu| < 1e-4
20    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
21    pub fn silu_backward(
22        &self,
23        input: &[f32],
24        grad_output: &[f32],
25        grad_input: &mut [f32],
26    ) -> Result<(), String> {
27        runtime::block_on(self.silu_backward_async(input, grad_output, grad_input))
28    }
29
30    /// SiLU backward on GPU (async)
31    pub async fn silu_backward_async(
32        &self,
33        input: &[f32],
34        grad_output: &[f32],
35        grad_input: &mut [f32],
36    ) -> Result<(), String> {
37        let n = input.len();
38        if grad_output.len() != n || grad_input.len() != n {
39            return Err(format!(
40                "SiLU backward: length mismatch: input={}, grad_output={}, grad_input={}",
41                n,
42                grad_output.len(),
43                grad_input.len()
44            ));
45        }
46
47        self.execute_backward_elementwise(
48            "SiLU Backward",
49            shaders::backward::SILU_BACKWARD_SHADER,
50            input,
51            grad_output,
52            grad_input,
53            n as u32,
54        )
55        .await
56    }
57
58    /// Generic dispatch for element-wise backward shaders (3 buffers + uniform)
59    ///
60    /// Binding layout: 0=input(read), 1=grad_output(read), 2=grad_input(write), 3=uniform{n}
61    async fn execute_backward_elementwise(
62        &self,
63        op_name: &str,
64        shader_source: &str,
65        input: &[f32],
66        grad_output: &[f32],
67        grad_input: &mut [f32],
68        n: u32,
69    ) -> Result<(), String> {
70        use wgpu;
71
72        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
73            label: Some(&format!("{op_name} Shader")),
74            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
75        });
76
77        // Create buffers
78        let input_buf = self.create_storage_buffer(&format!("{op_name} input"), input, true);
79        let grad_out_buf =
80            self.create_storage_buffer(&format!("{op_name} grad_output"), grad_output, true);
81        let grad_in_buf = self.create_rw_storage_buffer(
82            &format!("{op_name} grad_input"),
83            (grad_input.len() * 4) as u64,
84        );
85
86        // Uniform: { n: u32 } padded to 16 bytes (WGSL alignment)
87        let uniform_data: [u32; 4] = [n, 0, 0, 0];
88        let uniform_buf = self.create_uniform_buffer(
89            &format!("{op_name} uniform"),
90            bytemuck::cast_slice(&uniform_data),
91        );
92
93        // Bind group layout: 3 storage + 1 uniform
94        let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
95            label: Some(&format!("{op_name} BGL")),
96            entries: &[
97                storage_entry(0, true),
98                storage_entry(1, true),
99                storage_entry(2, false),
100                uniform_entry(3),
101            ],
102        });
103
104        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
105            label: Some(&format!("{op_name} BG")),
106            layout: &bgl,
107            entries: &[
108                wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
109                wgpu::BindGroupEntry { binding: 1, resource: grad_out_buf.as_entire_binding() },
110                wgpu::BindGroupEntry { binding: 2, resource: grad_in_buf.as_entire_binding() },
111                wgpu::BindGroupEntry { binding: 3, resource: uniform_buf.as_entire_binding() },
112            ],
113        });
114
115        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
116            label: Some(&format!("{op_name} PL")),
117            bind_group_layouts: &[&bgl],
118            push_constant_ranges: &[],
119        });
120
121        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
122            label: Some(&format!("{op_name} Pipeline")),
123            layout: Some(&pipeline_layout),
124            module: &shader,
125            entry_point: Some("main"),
126            compilation_options: Default::default(),
127            cache: None,
128        });
129
130        // Staging buffer for readback
131        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
132            label: Some(&format!("{op_name} Staging")),
133            size: (grad_input.len() * 4) as u64,
134            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
135            mapped_at_creation: false,
136        });
137
138        // Dispatch
139        let mut encoder =
140            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
141        {
142            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
143            pass.set_pipeline(&pipeline);
144            pass.set_bind_group(0, &bg, &[]);
145            // 2D dispatch for large tensors (>16M elements)
146            let total_wg = n.div_ceil(256);
147            pass.dispatch_workgroups(total_wg.min(65535), total_wg.div_ceil(65535), 1);
148        }
149        encoder.copy_buffer_to_buffer(&grad_in_buf, 0, &staging, 0, (grad_input.len() * 4) as u64);
150        self.queue.submit(Some(encoder.finish()));
151
152        // Read back
153        let slice = staging.slice(..);
154        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
155        slice.map_async(wgpu::MapMode::Read, move |r| {
156            sender.send(r).ok();
157        });
158        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
159        receiver
160            .receive()
161            .await
162            .ok_or_else(|| format!("{op_name}: map_async cancelled"))?
163            .map_err(|e| format!("{op_name}: map_async failed: {e}"))?;
164
165        let data = slice.get_mapped_range();
166        grad_input.copy_from_slice(bytemuck::cast_slice(&data));
167        drop(data);
168        staging.unmap();
169
170        Ok(())
171    }
172
173    // --- Buffer helpers ---
174
175    fn create_storage_buffer(&self, label: &str, data: &[f32], read_only: bool) -> wgpu::Buffer {
176        let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
177            label: Some(label),
178            size: (data.len() * 4) as u64,
179            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
180            mapped_at_creation: false,
181        });
182        self.queue.write_buffer(&buf, 0, bytemuck::cast_slice(data));
183        let _ = read_only; // usage flags are same; read_only is in the shader
184        buf
185    }
186
187    fn create_rw_storage_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
188        self.device.create_buffer(&wgpu::BufferDescriptor {
189            label: Some(label),
190            size,
191            usage: wgpu::BufferUsages::STORAGE
192                | wgpu::BufferUsages::COPY_SRC
193                | wgpu::BufferUsages::COPY_DST,
194            mapped_at_creation: false,
195        })
196    }
197
198    fn create_uniform_buffer(&self, label: &str, data: &[u8]) -> wgpu::Buffer {
199        let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
200            label: Some(label),
201            size: data.len() as u64,
202            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
203            mapped_at_creation: false,
204        });
205        self.queue.write_buffer(&buf, 0, data);
206        buf
207    }
208
209    /// GEMM backward for A: grad_a[M,K] = grad_c[M,N] @ B^T[N,K]
210    ///
211    /// # Contract (FALSIFY-WGPU-001)
212    ///
213    /// - **Precondition**: grad_c.len() == m*n, b.len() == k*n, grad_a.len() == m*k
214    /// - **Postcondition**: max|grad_a_gpu - grad_a_cpu| < 1e-4
215    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
216    pub fn gemm_backward_a(
217        &self,
218        grad_c: &[f32],
219        b: &[f32],
220        grad_a: &mut [f32],
221        m: u32,
222        k: u32,
223        n: u32,
224    ) -> Result<(), String> {
225        runtime::block_on(self.gemm_backward_a_async(grad_c, b, grad_a, m, k, n))
226    }
227
228    /// GEMM backward for A (async): grad_a = grad_c @ B^T
229    pub async fn gemm_backward_a_async(
230        &self,
231        grad_c: &[f32],
232        b: &[f32],
233        grad_a: &mut [f32],
234        m: u32,
235        k: u32,
236        n: u32,
237    ) -> Result<(), String> {
238        self.execute_backward_gemm(
239            "GEMM Backward A",
240            shaders::backward::GEMM_BACKWARD_A_SHADER,
241            grad_c,
242            b,
243            grad_a,
244            m,
245            k,
246            n,
247        )
248        .await
249    }
250
251    /// GEMM backward for B: grad_b[K,N] = A^T[K,M] @ grad_c[M,N]
252    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
253    pub fn gemm_backward_b(
254        &self,
255        a: &[f32],
256        grad_c: &[f32],
257        grad_b: &mut [f32],
258        m: u32,
259        k: u32,
260        n: u32,
261    ) -> Result<(), String> {
262        runtime::block_on(self.gemm_backward_b_async(a, grad_c, grad_b, m, k, n))
263    }
264
265    /// GEMM backward for B (async): grad_b = A^T @ grad_c
266    pub async fn gemm_backward_b_async(
267        &self,
268        a: &[f32],
269        grad_c: &[f32],
270        grad_b: &mut [f32],
271        m: u32,
272        k: u32,
273        n: u32,
274    ) -> Result<(), String> {
275        self.execute_backward_gemm(
276            "GEMM Backward B",
277            shaders::backward::GEMM_BACKWARD_B_SHADER,
278            a,
279            grad_c,
280            grad_b,
281            m,
282            k,
283            n,
284        )
285        .await
286    }
287
288    /// Generic dispatch for GEMM backward shaders (tiled 16×16)
289    ///
290    /// Binding: 0=buf_a(read), 1=buf_b(read), 2=output(write), 3=uniform{M,K,N}
291    async fn execute_backward_gemm(
292        &self,
293        op_name: &str,
294        shader_source: &str,
295        buf_a: &[f32],
296        buf_b: &[f32],
297        output: &mut [f32],
298        m: u32,
299        k: u32,
300        n: u32,
301    ) -> Result<(), String> {
302        use wgpu;
303
304        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
305            label: Some(&format!("{op_name} Shader")),
306            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
307        });
308
309        let a_buf = self.create_storage_buffer(&format!("{op_name} A"), buf_a, true);
310        let b_buf = self.create_storage_buffer(&format!("{op_name} B"), buf_b, true);
311        let out_buf =
312            self.create_rw_storage_buffer(&format!("{op_name} Output"), (output.len() * 4) as u64);
313
314        // Uniform: { M, K, N, pad }
315        let dims: [u32; 4] = [m, k, n, 0];
316        let uniform_buf =
317            self.create_uniform_buffer(&format!("{op_name} Dims"), bytemuck::cast_slice(&dims));
318
319        let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
320            label: None,
321            entries: &[
322                storage_entry(0, true),
323                storage_entry(1, true),
324                storage_entry(2, false),
325                uniform_entry(3),
326            ],
327        });
328
329        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
330            label: None,
331            layout: &bgl,
332            entries: &[
333                wgpu::BindGroupEntry { binding: 0, resource: a_buf.as_entire_binding() },
334                wgpu::BindGroupEntry { binding: 1, resource: b_buf.as_entire_binding() },
335                wgpu::BindGroupEntry { binding: 2, resource: out_buf.as_entire_binding() },
336                wgpu::BindGroupEntry { binding: 3, resource: uniform_buf.as_entire_binding() },
337            ],
338        });
339
340        let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
341            label: None,
342            bind_group_layouts: &[&bgl],
343            push_constant_ranges: &[],
344        });
345
346        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
347            label: Some(&format!("{op_name} Pipeline")),
348            layout: Some(&pl),
349            module: &shader,
350            entry_point: Some("main"),
351            compilation_options: Default::default(),
352            cache: None,
353        });
354
355        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
356            label: None,
357            size: (output.len() * 4) as u64,
358            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
359            mapped_at_creation: false,
360        });
361
362        let mut encoder =
363            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
364        {
365            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
366            pass.set_pipeline(&pipeline);
367            pass.set_bind_group(0, &bg, &[]);
368
369            // For GEMM backward A: output is [M,K], dispatch ceil(M/16) × ceil(K/16)
370            // For GEMM backward B: output is [K,N], dispatch ceil(K/16) × ceil(N/16)
371            // The output dimensions are encoded in the first two dims of the output buffer.
372            let out_rows = if op_name.contains("A") { m } else { k };
373            let out_cols = if op_name.contains("A") { k } else { n };
374            pass.dispatch_workgroups(out_rows.div_ceil(16), out_cols.div_ceil(16), 1);
375        }
376        encoder.copy_buffer_to_buffer(&out_buf, 0, &staging, 0, (output.len() * 4) as u64);
377        self.queue.submit(Some(encoder.finish()));
378
379        let slice = staging.slice(..);
380        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
381        slice.map_async(wgpu::MapMode::Read, move |r| {
382            sender.send(r).ok();
383        });
384        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
385        receiver
386            .receive()
387            .await
388            .ok_or_else(|| format!("{op_name}: map cancelled"))?
389            .map_err(|e| format!("{op_name}: map failed: {e}"))?;
390
391        let data = slice.get_mapped_range();
392        output.copy_from_slice(bytemuck::cast_slice(&data));
393        drop(data);
394        staging.unmap();
395
396        Ok(())
397    }
398
399    /// RoPE backward on GPU: transpose rotation (negated sin)
400    ///
401    /// # Contract (FALSIFY-WGPU-001)
402    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
403    pub fn rope_backward(
404        &self,
405        grad_output: &[f32],
406        grad_input: &mut [f32],
407        num_heads: u32,
408        head_dim: u32,
409        seq_len: u32,
410        theta: f32,
411    ) -> Result<(), String> {
412        runtime::block_on(self.rope_backward_async(
413            grad_output,
414            grad_input,
415            num_heads,
416            head_dim,
417            seq_len,
418            theta,
419        ))
420    }
421
422    /// RoPE backward (async)
423    pub async fn rope_backward_async(
424        &self,
425        grad_output: &[f32],
426        grad_input: &mut [f32],
427        num_heads: u32,
428        head_dim: u32,
429        seq_len: u32,
430        theta: f32,
431    ) -> Result<(), String> {
432        use wgpu;
433
434        let n = grad_output.len();
435        let total_pairs = num_heads * seq_len * (head_dim / 2);
436
437        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
438            label: Some("RoPE Backward Shader"),
439            source: wgpu::ShaderSource::Wgsl(shaders::backward::ROPE_BACKWARD_SHADER.into()),
440        });
441
442        let go_buf = self.create_storage_buffer("rope_bwd grad_out", grad_output, true);
443        let gi_buf = self.create_rw_storage_buffer("rope_bwd grad_in", (n * 4) as u64);
444
445        // Uniform: { num_heads, head_dim, seq_len, theta_log2 }
446        let params: [u32; 4] = [num_heads, head_dim, seq_len, theta.log2().to_bits()];
447        let uniform_buf =
448            self.create_uniform_buffer("rope_bwd params", bytemuck::cast_slice(&params));
449
450        let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
451            label: None,
452            entries: &[storage_entry(0, true), storage_entry(1, false), uniform_entry(2)],
453        });
454        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
455            label: None,
456            layout: &bgl,
457            entries: &[
458                wgpu::BindGroupEntry { binding: 0, resource: go_buf.as_entire_binding() },
459                wgpu::BindGroupEntry { binding: 1, resource: gi_buf.as_entire_binding() },
460                wgpu::BindGroupEntry { binding: 2, resource: uniform_buf.as_entire_binding() },
461            ],
462        });
463
464        let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
465            label: None,
466            bind_group_layouts: &[&bgl],
467            push_constant_ranges: &[],
468        });
469        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
470            label: Some("RoPE Backward"),
471            layout: Some(&pl),
472            module: &shader,
473            entry_point: Some("main"),
474            compilation_options: Default::default(),
475            cache: None,
476        });
477
478        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
479            label: None,
480            size: (n * 4) as u64,
481            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
482            mapped_at_creation: false,
483        });
484
485        let mut encoder =
486            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
487        {
488            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
489            pass.set_pipeline(&pipeline);
490            pass.set_bind_group(0, &bg, &[]);
491            let total_wg = total_pairs.div_ceil(256);
492            pass.dispatch_workgroups(total_wg.min(65535), total_wg.div_ceil(65535), 1);
493        }
494        encoder.copy_buffer_to_buffer(&gi_buf, 0, &staging, 0, (n * 4) as u64);
495        self.queue.submit(Some(encoder.finish()));
496
497        let slice = staging.slice(..);
498        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
499        slice.map_async(wgpu::MapMode::Read, move |r| {
500            sender.send(r).ok();
501        });
502        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
503        receiver
504            .receive()
505            .await
506            .ok_or("RoPE backward: cancelled".to_string())?
507            .map_err(|e| format!("RoPE backward: {e}"))?;
508        let data = slice.get_mapped_range();
509        grad_input.copy_from_slice(bytemuck::cast_slice(&data));
510        drop(data);
511        staging.unmap();
512        Ok(())
513    }
514
515    /// AdamW optimizer step on GPU
516    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
517    pub fn adamw_step(
518        &self,
519        params: &mut [f32],
520        grads: &[f32],
521        m: &mut [f32],
522        v: &mut [f32],
523        lr: f32,
524        beta1: f32,
525        beta2: f32,
526        eps: f32,
527        weight_decay: f32,
528        step: u32,
529    ) -> Result<(), String> {
530        runtime::block_on(self.adamw_step_async(
531            params,
532            grads,
533            m,
534            v,
535            lr,
536            beta1,
537            beta2,
538            eps,
539            weight_decay,
540            step,
541        ))
542    }
543
544    /// AdamW step (async)
545    pub async fn adamw_step_async(
546        &self,
547        params: &mut [f32],
548        grads: &[f32],
549        m: &mut [f32],
550        v: &mut [f32],
551        lr: f32,
552        beta1: f32,
553        beta2: f32,
554        eps: f32,
555        weight_decay: f32,
556        step: u32,
557    ) -> Result<(), String> {
558        use wgpu;
559
560        let n = params.len() as u32;
561        let bc1 = 1.0 - beta1.powi(step as i32);
562        let bc2 = 1.0 - beta2.powi(step as i32);
563
564        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
565            label: Some("AdamW Step"),
566            source: wgpu::ShaderSource::Wgsl(shaders::backward::ADAMW_STEP_SHADER.into()),
567        });
568
569        // Params buffer is read-write (updated in-place)
570        let params_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
571            label: Some("adamw params"),
572            size: (params.len() * 4) as u64,
573            usage: wgpu::BufferUsages::STORAGE
574                | wgpu::BufferUsages::COPY_DST
575                | wgpu::BufferUsages::COPY_SRC,
576            mapped_at_creation: false,
577        });
578        self.queue.write_buffer(&params_buf, 0, bytemuck::cast_slice(params));
579
580        let grads_buf = self.create_storage_buffer("adamw grads", grads, true);
581
582        let m_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
583            label: Some("adamw m"),
584            size: (m.len() * 4) as u64,
585            usage: wgpu::BufferUsages::STORAGE
586                | wgpu::BufferUsages::COPY_DST
587                | wgpu::BufferUsages::COPY_SRC,
588            mapped_at_creation: false,
589        });
590        self.queue.write_buffer(&m_buf, 0, bytemuck::cast_slice(m));
591
592        let v_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
593            label: Some("adamw v"),
594            size: (v.len() * 4) as u64,
595            usage: wgpu::BufferUsages::STORAGE
596                | wgpu::BufferUsages::COPY_DST
597                | wgpu::BufferUsages::COPY_SRC,
598            mapped_at_creation: false,
599        });
600        self.queue.write_buffer(&v_buf, 0, bytemuck::cast_slice(v));
601
602        // Uniform: { n: u32, lr: f32, beta1: f32, beta2: f32, eps: f32, wd: f32, bc1: f32, bc2: f32 }
603        // Pack as raw u32 bytes to handle the mixed u32/f32 layout
604        let hp: [u32; 8] = [
605            n,
606            lr.to_bits(),
607            beta1.to_bits(),
608            beta2.to_bits(),
609            eps.to_bits(),
610            weight_decay.to_bits(),
611            bc1.to_bits(),
612            bc2.to_bits(),
613        ];
614        let uniform_buf = self.create_uniform_buffer("adamw hp", bytemuck::cast_slice(&hp));
615
616        let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
617            label: None,
618            entries: &[
619                storage_entry(0, false), // params (read-write)
620                storage_entry(1, true),  // grads
621                storage_entry(2, false), // m
622                storage_entry(3, false), // v
623                uniform_entry(4),
624            ],
625        });
626        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
627            label: None,
628            layout: &bgl,
629            entries: &[
630                wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
631                wgpu::BindGroupEntry { binding: 1, resource: grads_buf.as_entire_binding() },
632                wgpu::BindGroupEntry { binding: 2, resource: m_buf.as_entire_binding() },
633                wgpu::BindGroupEntry { binding: 3, resource: v_buf.as_entire_binding() },
634                wgpu::BindGroupEntry { binding: 4, resource: uniform_buf.as_entire_binding() },
635            ],
636        });
637
638        let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
639            label: None,
640            bind_group_layouts: &[&bgl],
641            push_constant_ranges: &[],
642        });
643        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
644            label: Some("AdamW"),
645            layout: Some(&pl),
646            module: &shader,
647            entry_point: Some("main"),
648            compilation_options: Default::default(),
649            cache: None,
650        });
651
652        // Staging buffers for readback
653        let params_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
654            label: None,
655            size: (params.len() * 4) as u64,
656            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
657            mapped_at_creation: false,
658        });
659        let m_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
660            label: None,
661            size: (m.len() * 4) as u64,
662            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
663            mapped_at_creation: false,
664        });
665        let v_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
666            label: None,
667            size: (v.len() * 4) as u64,
668            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
669            mapped_at_creation: false,
670        });
671
672        let mut encoder =
673            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
674        {
675            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
676            pass.set_pipeline(&pipeline);
677            pass.set_bind_group(0, &bg, &[]);
678            // 2D dispatch for large tensors (>16M elements)
679            let total_wg = n.div_ceil(256);
680            pass.dispatch_workgroups(total_wg.min(65535), total_wg.div_ceil(65535), 1);
681        }
682        encoder.copy_buffer_to_buffer(
683            &params_buf,
684            0,
685            &params_staging,
686            0,
687            (params.len() * 4) as u64,
688        );
689        encoder.copy_buffer_to_buffer(&m_buf, 0, &m_staging, 0, (m.len() * 4) as u64);
690        encoder.copy_buffer_to_buffer(&v_buf, 0, &v_staging, 0, (v.len() * 4) as u64);
691        self.queue.submit(Some(encoder.finish()));
692
693        // Read back all three buffers
694        let read_buf = |staging: &wgpu::Buffer, out: &mut [f32]| -> Result<(), String> {
695            let slice = staging.slice(..);
696            let (tx, rx) = std::sync::mpsc::channel();
697            slice.map_async(wgpu::MapMode::Read, move |r| {
698                tx.send(r).ok();
699            });
700            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
701            rx.recv()
702                .map_err(|e| format!("AdamW readback: {e}"))?
703                .map_err(|e| format!("AdamW map: {e}"))?;
704            let data = slice.get_mapped_range();
705            out.copy_from_slice(bytemuck::cast_slice(&data));
706            drop(data);
707            staging.unmap();
708            Ok(())
709        };
710        read_buf(&params_staging, params)?;
711        read_buf(&m_staging, m)?;
712        read_buf(&v_staging, v)?;
713
714        Ok(())
715    }
716
717    /// RMSNorm backward on GPU
718    ///
719    /// Computes grad_input and accumulates grad_gamma via atomic CAS.
720    /// One workgroup (256 threads) per row.
721    ///
722    /// # Contract (FALSIFY-WGPU-001)
723    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
724    pub fn rmsnorm_backward(
725        &self,
726        input: &[f32],
727        gamma: &[f32],
728        grad_output: &[f32],
729        grad_input: &mut [f32],
730        grad_gamma: &mut [f32],
731        num_rows: u32,
732        hidden_dim: u32,
733        eps: f32,
734    ) -> Result<(), String> {
735        runtime::block_on(self.rmsnorm_backward_async(
736            input,
737            gamma,
738            grad_output,
739            grad_input,
740            grad_gamma,
741            num_rows,
742            hidden_dim,
743            eps,
744        ))
745    }
746
747    /// RMSNorm backward (async)
748    pub async fn rmsnorm_backward_async(
749        &self,
750        input: &[f32],
751        gamma: &[f32],
752        grad_output: &[f32],
753        grad_input: &mut [f32],
754        grad_gamma: &mut [f32],
755        num_rows: u32,
756        hidden_dim: u32,
757        eps: f32,
758    ) -> Result<(), String> {
759        use wgpu;
760
761        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
762            label: Some("RMSNorm Backward"),
763            source: wgpu::ShaderSource::Wgsl(shaders::backward::RMSNORM_BACKWARD_SHADER.into()),
764        });
765
766        let input_buf = self.create_storage_buffer("rms_bwd input", input, true);
767        let gamma_buf = self.create_storage_buffer("rms_bwd gamma", gamma, true);
768        let grad_out_buf = self.create_storage_buffer("rms_bwd grad_out", grad_output, true);
769        let grad_in_buf =
770            self.create_rw_storage_buffer("rms_bwd grad_in", (grad_input.len() * 4) as u64);
771
772        // grad_gamma: init to zero, accumulated via atomic CAS
773        let grad_gamma_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
774            label: Some("rms_bwd grad_gamma"),
775            size: (hidden_dim as usize * 4) as u64,
776            usage: wgpu::BufferUsages::STORAGE
777                | wgpu::BufferUsages::COPY_DST
778                | wgpu::BufferUsages::COPY_SRC,
779            mapped_at_creation: false,
780        });
781        // Zero-init grad_gamma
782        let zeros = vec![0u8; hidden_dim as usize * 4];
783        self.queue.write_buffer(&grad_gamma_buf, 0, &zeros);
784
785        // Uniform: { num_rows, hidden_dim, eps_bits, pad }
786        let params: [u32; 4] = [num_rows, hidden_dim, eps.to_bits(), 0];
787        let uniform_buf =
788            self.create_uniform_buffer("rms_bwd params", bytemuck::cast_slice(&params));
789
790        let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
791            label: None,
792            entries: &[
793                storage_entry(0, true),  // input
794                storage_entry(1, true),  // gamma
795                storage_entry(2, true),  // grad_output
796                storage_entry(3, false), // grad_input
797                storage_entry(4, false), // grad_gamma (atomic)
798                uniform_entry(5),
799            ],
800        });
801        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
802            label: None,
803            layout: &bgl,
804            entries: &[
805                wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
806                wgpu::BindGroupEntry { binding: 1, resource: gamma_buf.as_entire_binding() },
807                wgpu::BindGroupEntry { binding: 2, resource: grad_out_buf.as_entire_binding() },
808                wgpu::BindGroupEntry { binding: 3, resource: grad_in_buf.as_entire_binding() },
809                wgpu::BindGroupEntry { binding: 4, resource: grad_gamma_buf.as_entire_binding() },
810                wgpu::BindGroupEntry { binding: 5, resource: uniform_buf.as_entire_binding() },
811            ],
812        });
813
814        let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
815            label: None,
816            bind_group_layouts: &[&bgl],
817            push_constant_ranges: &[],
818        });
819        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
820            label: Some("RMSNorm Backward"),
821            layout: Some(&pl),
822            module: &shader,
823            entry_point: Some("main"),
824            compilation_options: Default::default(),
825            cache: None,
826        });
827
828        // Staging for grad_input and grad_gamma
829        let gi_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
830            label: None,
831            size: (grad_input.len() * 4) as u64,
832            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
833            mapped_at_creation: false,
834        });
835        let gg_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
836            label: None,
837            size: (hidden_dim as usize * 4) as u64,
838            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
839            mapped_at_creation: false,
840        });
841
842        let mut encoder =
843            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
844        {
845            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
846            pass.set_pipeline(&pipeline);
847            pass.set_bind_group(0, &bg, &[]);
848            // One workgroup (256 threads) per row
849            pass.dispatch_workgroups(num_rows, 1, 1);
850        }
851        encoder.copy_buffer_to_buffer(
852            &grad_in_buf,
853            0,
854            &gi_staging,
855            0,
856            (grad_input.len() * 4) as u64,
857        );
858        encoder.copy_buffer_to_buffer(
859            &grad_gamma_buf,
860            0,
861            &gg_staging,
862            0,
863            (hidden_dim as usize * 4) as u64,
864        );
865        self.queue.submit(Some(encoder.finish()));
866
867        // Read back grad_input
868        {
869            let slice = gi_staging.slice(..);
870            let (tx, rx) = std::sync::mpsc::channel();
871            slice.map_async(wgpu::MapMode::Read, move |r| {
872                tx.send(r).ok();
873            });
874            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
875            rx.recv()
876                .map_err(|e| format!("RMSNorm bwd gi: {e}"))?
877                .map_err(|e| format!("RMSNorm bwd gi map: {e}"))?;
878            let data = slice.get_mapped_range();
879            grad_input.copy_from_slice(bytemuck::cast_slice(&data));
880            drop(data);
881            gi_staging.unmap();
882        }
883        // Read back grad_gamma (stored as atomic<u32> = bitcast f32)
884        {
885            let slice = gg_staging.slice(..);
886            let (tx, rx) = std::sync::mpsc::channel();
887            slice.map_async(wgpu::MapMode::Read, move |r| {
888                tx.send(r).ok();
889            });
890            self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
891            rx.recv()
892                .map_err(|e| format!("RMSNorm bwd gg: {e}"))?
893                .map_err(|e| format!("RMSNorm bwd gg map: {e}"))?;
894            let data = slice.get_mapped_range();
895            // atomic<u32> stores are bit-identical to f32 after CAS
896            let raw: &[u32] = bytemuck::cast_slice(&data);
897            for (i, &bits) in raw.iter().enumerate() {
898                grad_gamma[i] = f32::from_bits(bits);
899            }
900            drop(data);
901            gg_staging.unmap();
902        }
903
904        Ok(())
905    }
906
907    /// NF4 dequantization on GPU
908    ///
909    /// Converts 4-bit NormalFloat packed weights to fp32 using codebook lookup.
910    ///
911    /// # Contract (FALSIFY-WGPU-003)
912    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
913    pub fn nf4_dequant(
914        &self,
915        packed: &[u32],
916        scales: &[f32],
917        output: &mut [f32],
918        n: u32,
919        block_size: u32,
920    ) -> Result<(), String> {
921        runtime::block_on(self.nf4_dequant_async(packed, scales, output, n, block_size))
922    }
923
924    /// NF4 dequant (async)
925    pub async fn nf4_dequant_async(
926        &self,
927        packed: &[u32],
928        scales: &[f32],
929        output: &mut [f32],
930        n: u32,
931        block_size: u32,
932    ) -> Result<(), String> {
933        use wgpu;
934
935        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
936            label: Some("NF4 Dequant"),
937            source: wgpu::ShaderSource::Wgsl(shaders::backward::NF4_DEQUANT_SHADER.into()),
938        });
939
940        let packed_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
941            label: Some("nf4 packed"),
942            size: (packed.len() * 4) as u64,
943            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
944            mapped_at_creation: false,
945        });
946        self.queue.write_buffer(&packed_buf, 0, bytemuck::cast_slice(packed));
947
948        let scales_buf = self.create_storage_buffer("nf4 scales", scales, true);
949        let output_buf = self.create_rw_storage_buffer("nf4 output", (output.len() * 4) as u64);
950
951        let params: [u32; 4] = [n, block_size, 0, 0];
952        let uniform_buf = self.create_uniform_buffer("nf4 params", bytemuck::cast_slice(&params));
953
954        let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
955            label: None,
956            entries: &[
957                storage_entry(0, true),  // packed
958                storage_entry(1, true),  // scales
959                storage_entry(2, false), // output
960                uniform_entry(3),
961            ],
962        });
963        let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
964            label: None,
965            layout: &bgl,
966            entries: &[
967                wgpu::BindGroupEntry { binding: 0, resource: packed_buf.as_entire_binding() },
968                wgpu::BindGroupEntry { binding: 1, resource: scales_buf.as_entire_binding() },
969                wgpu::BindGroupEntry { binding: 2, resource: output_buf.as_entire_binding() },
970                wgpu::BindGroupEntry { binding: 3, resource: uniform_buf.as_entire_binding() },
971            ],
972        });
973
974        let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
975            label: None,
976            bind_group_layouts: &[&bgl],
977            push_constant_ranges: &[],
978        });
979        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
980            label: Some("NF4 Dequant"),
981            layout: Some(&pl),
982            module: &shader,
983            entry_point: Some("main"),
984            compilation_options: Default::default(),
985            cache: None,
986        });
987
988        let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
989            label: None,
990            size: (output.len() * 4) as u64,
991            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
992            mapped_at_creation: false,
993        });
994
995        let mut encoder =
996            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
997        {
998            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
999            pass.set_pipeline(&pipeline);
1000            pass.set_bind_group(0, &bg, &[]);
1001            // Use 2D dispatch to handle >65535 workgroups
1002            // Each workgroup has 256 threads. Total threads = n.
1003            // X = min(ceil(n/256), 65535), Y = ceil(ceil(n/256) / 65535)
1004            let total_wg = n.div_ceil(256);
1005            let x = total_wg.min(65535);
1006            let y = total_wg.div_ceil(65535);
1007            pass.dispatch_workgroups(x, y, 1);
1008        }
1009        encoder.copy_buffer_to_buffer(&output_buf, 0, &staging, 0, (output.len() * 4) as u64);
1010        self.queue.submit(Some(encoder.finish()));
1011
1012        let slice = staging.slice(..);
1013        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
1014        slice.map_async(wgpu::MapMode::Read, move |r| {
1015            sender.send(r).ok();
1016        });
1017        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
1018        receiver
1019            .receive()
1020            .await
1021            .ok_or("NF4 dequant: cancelled".to_string())?
1022            .map_err(|e| format!("NF4 dequant: {e}"))?;
1023        let data = slice.get_mapped_range();
1024        output.copy_from_slice(bytemuck::cast_slice(&data));
1025        drop(data);
1026        staging.unmap();
1027
1028        Ok(())
1029    }
1030}
1031
1032fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
1033    wgpu::BindGroupLayoutEntry {
1034        binding,
1035        visibility: wgpu::ShaderStages::COMPUTE,
1036        ty: wgpu::BindingType::Buffer {
1037            ty: wgpu::BufferBindingType::Storage { read_only },
1038            has_dynamic_offset: false,
1039            min_binding_size: None,
1040        },
1041        count: None,
1042    }
1043}
1044
1045fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
1046    wgpu::BindGroupLayoutEntry {
1047        binding,
1048        visibility: wgpu::ShaderStages::COMPUTE,
1049        ty: wgpu::BindingType::Buffer {
1050            ty: wgpu::BufferBindingType::Uniform,
1051            has_dynamic_offset: false,
1052            min_binding_size: None,
1053        },
1054        count: None,
1055    }
1056}
1057
1058#[cfg(all(test, feature = "gpu"))]
1059mod tests {
1060    use super::*;
1061
1062    /// CPU reference: SiLU backward
1063    fn silu_backward_cpu(input: &[f32], grad_output: &[f32]) -> Vec<f32> {
1064        input
1065            .iter()
1066            .zip(grad_output.iter())
1067            .map(|(&x, &dy)| {
1068                let sigmoid = 1.0 / (1.0 + (-x).exp());
1069                let y = x * sigmoid;
1070                let silu_prime = sigmoid * (1.0 + x - y);
1071                dy * silu_prime
1072            })
1073            .collect()
1074    }
1075
1076    /// FALSIFY-WGPU-001: SiLU backward matches CPU within ε < 1e-4
1077    #[test]
1078    fn test_falsify_wgpu_001_silu_backward_parity() {
1079        let device = GpuDevice::new().expect("GPU device");
1080
1081        let input: Vec<f32> = (-50..50).map(|i| i as f32 * 0.1).collect();
1082        let grad_output: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) * 0.01).collect();
1083        let expected = silu_backward_cpu(&input, &grad_output);
1084
1085        let mut grad_input = vec![0.0f32; 100];
1086        device.silu_backward(&input, &grad_output, &mut grad_input).expect("silu_backward");
1087
1088        let max_diff = grad_input
1089            .iter()
1090            .zip(expected.iter())
1091            .map(|(a, b)| (a - b).abs())
1092            .fold(0.0f32, f32::max);
1093
1094        assert!(
1095            max_diff < 1e-4,
1096            "FALSIFY-WGPU-001: SiLU backward max diff = {max_diff} (threshold: 1e-4)"
1097        );
1098    }
1099
1100    /// SiLU backward at x=0 (sigmoid=0.5, silu'=0.5)
1101    #[test]
1102    fn test_silu_backward_at_zero() {
1103        let device = GpuDevice::new().expect("GPU device");
1104
1105        let input = vec![0.0f32; 4];
1106        let grad_output = vec![1.0f32; 4];
1107        let mut grad_input = vec![0.0f32; 4];
1108
1109        device.silu_backward(&input, &grad_output, &mut grad_input).expect("silu_backward");
1110
1111        // At x=0: sigmoid(0)=0.5, silu'(0) = 0.5 * (1 + 0 - 0) = 0.5
1112        for &g in &grad_input {
1113            assert!((g - 0.5).abs() < 1e-5, "silu'(0) should be 0.5, got {g}");
1114        }
1115    }
1116
1117    /// SiLU backward length mismatch error
1118    #[test]
1119    fn test_silu_backward_length_mismatch() {
1120        let device = GpuDevice::new().expect("GPU device");
1121
1122        let input = vec![1.0f32; 10];
1123        let grad_output = vec![1.0f32; 5]; // wrong length
1124        let mut grad_input = vec![0.0f32; 10];
1125
1126        let result = device.silu_backward(&input, &grad_output, &mut grad_input);
1127        assert!(result.is_err());
1128    }
1129
1130    /// CPU reference: matmul C = A[M,K] @ B[K,N]
1131    fn matmul_cpu(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
1132        let mut c = vec![0.0f32; m * n];
1133        for i in 0..m {
1134            for j in 0..n {
1135                let mut sum = 0.0f32;
1136                for p in 0..k {
1137                    sum += a[i * k + p] * b[p * n + j];
1138                }
1139                c[i * n + j] = sum;
1140            }
1141        }
1142        c
1143    }
1144
1145    /// FALSIFY-WGPU-001: GEMM backward A matches CPU within ε < 1e-3
1146    ///
1147    /// grad_a[M,K] = grad_c[M,N] @ B^T[N,K]
1148    /// Which is matmul(grad_c, B^T, M, N, K) but our shader handles the transpose internally.
1149    #[test]
1150    fn test_falsify_wgpu_001_gemm_backward_a_parity() {
1151        let device = GpuDevice::new().expect("GPU device");
1152
1153        let (m, k, n) = (4, 8, 6);
1154
1155        // Random-ish test data
1156        let grad_c: Vec<f32> = (0..m * n).map(|i| (i as f32 - 12.0) * 0.1).collect();
1157        let b: Vec<f32> = (0..k * n).map(|i| (i as f32 - 24.0) * 0.05).collect();
1158
1159        // CPU reference: grad_a = grad_c @ B^T
1160        // B^T[N,K] means we need to transpose B[K,N] → B^T[N,K]
1161        let mut b_t = vec![0.0f32; n * k];
1162        for i in 0..k {
1163            for j in 0..n {
1164                b_t[j * k + i] = b[i * n + j];
1165            }
1166        }
1167        let expected = matmul_cpu(&grad_c, &b_t, m, n, k);
1168
1169        let mut grad_a = vec![0.0f32; m * k];
1170        device
1171            .gemm_backward_a(&grad_c, &b, &mut grad_a, m as u32, k as u32, n as u32)
1172            .expect("gemm_backward_a");
1173
1174        let max_diff =
1175            grad_a.iter().zip(expected.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1176
1177        assert!(
1178            max_diff < 1e-3,
1179            "FALSIFY-WGPU-001: GEMM backward A max diff = {max_diff} (threshold: 1e-3)"
1180        );
1181    }
1182
1183    /// FALSIFY-WGPU-001: GEMM backward B matches CPU within ε < 1e-3
1184    ///
1185    /// grad_b[K,N] = A^T[K,M] @ grad_c[M,N]
1186    #[test]
1187    fn test_falsify_wgpu_001_gemm_backward_b_parity() {
1188        let device = GpuDevice::new().expect("GPU device");
1189
1190        let (m, k, n) = (4, 8, 6);
1191
1192        let a: Vec<f32> = (0..m * k).map(|i| (i as f32 - 16.0) * 0.1).collect();
1193        let grad_c: Vec<f32> = (0..m * n).map(|i| (i as f32 - 12.0) * 0.05).collect();
1194
1195        // CPU reference: grad_b = A^T @ grad_c
1196        let mut a_t = vec![0.0f32; k * m];
1197        for i in 0..m {
1198            for j in 0..k {
1199                a_t[j * m + i] = a[i * k + j];
1200            }
1201        }
1202        let expected = matmul_cpu(&a_t, &grad_c, k, m, n);
1203
1204        let mut grad_b = vec![0.0f32; k * n];
1205        device
1206            .gemm_backward_b(&a, &grad_c, &mut grad_b, m as u32, k as u32, n as u32)
1207            .expect("gemm_backward_b");
1208
1209        let max_diff =
1210            grad_b.iter().zip(expected.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1211
1212        assert!(
1213            max_diff < 1e-3,
1214            "FALSIFY-WGPU-001: GEMM backward B max diff = {max_diff} (threshold: 1e-3)"
1215        );
1216    }
1217
1218    /// FALSIFY-WGPU-001: RoPE backward matches CPU
1219    #[test]
1220    fn test_falsify_wgpu_001_rope_backward_parity() {
1221        let device = GpuDevice::new().expect("GPU device");
1222
1223        let (num_heads, head_dim, seq_len) = (2, 4, 3);
1224        let theta = 10000.0f32;
1225        let n = num_heads * head_dim * seq_len;
1226
1227        let grad_output: Vec<f32> = (0..n).map(|i| (i as f32 - 12.0) * 0.1).collect();
1228
1229        // CPU reference: RoPE backward = transpose rotation
1230        let half_dim = head_dim / 2;
1231        let mut expected = vec![0.0f32; n];
1232        for h in 0..num_heads {
1233            for s in 0..seq_len {
1234                for p in 0..half_dim {
1235                    let freq_exp = -((2 * p) as f32) / head_dim as f32 * theta.log2();
1236                    let inv_freq = 2.0f32.powf(freq_exp);
1237                    let angle = s as f32 * inv_freq;
1238                    let (sin_a, cos_a) = angle.sin_cos();
1239
1240                    let base = h * seq_len * head_dim + s * head_dim;
1241                    let even = base + 2 * p;
1242                    let odd = base + 2 * p + 1;
1243
1244                    let dy_even = grad_output[even];
1245                    let dy_odd = grad_output[odd];
1246
1247                    // Backward: transpose of rotation matrix
1248                    expected[even] = dy_even * cos_a + dy_odd * sin_a;
1249                    expected[odd] = -dy_even * sin_a + dy_odd * cos_a;
1250                }
1251            }
1252        }
1253
1254        let mut grad_input = vec![0.0f32; n];
1255        device
1256            .rope_backward(
1257                &grad_output,
1258                &mut grad_input,
1259                num_heads as u32,
1260                head_dim as u32,
1261                seq_len as u32,
1262                theta,
1263            )
1264            .expect("rope_backward");
1265
1266        let max_diff = grad_input
1267            .iter()
1268            .zip(expected.iter())
1269            .map(|(a, b)| (a - b).abs())
1270            .fold(0.0f32, f32::max);
1271
1272        assert!(
1273            max_diff < 1e-4,
1274            "FALSIFY-WGPU-001: RoPE backward max diff = {max_diff} (threshold: 1e-4)"
1275        );
1276    }
1277
1278    /// FALSIFY-WGPU-001: AdamW step matches CPU
1279    #[test]
1280    fn test_falsify_wgpu_001_adamw_step_parity() {
1281        let device = GpuDevice::new().expect("GPU device");
1282
1283        let n = 16;
1284        let mut params: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
1285        let grads: Vec<f32> = (0..n).map(|i| (i as f32 - 8.0) * 0.01).collect();
1286        let mut m_state = vec![0.0f32; n];
1287        let mut v_state = vec![0.0f32; n];
1288
1289        let lr: f32 = 1e-3;
1290        let beta1: f32 = 0.9;
1291        let beta2: f32 = 0.999;
1292        let eps: f32 = 1e-8;
1293        let wd: f32 = 0.01;
1294        let step = 1u32;
1295
1296        // CPU reference
1297        let bc1: f32 = 1.0 - beta1.powi(step as i32);
1298        let bc2: f32 = 1.0 - beta2.powi(step as i32);
1299        let mut cpu_params = params.clone();
1300        let mut cpu_m = m_state.clone();
1301        let mut cpu_v = v_state.clone();
1302        for i in 0..n {
1303            cpu_m[i] = beta1 * cpu_m[i] + (1.0 - beta1) * grads[i];
1304            cpu_v[i] = beta2 * cpu_v[i] + (1.0 - beta2) * grads[i] * grads[i];
1305            let m_hat = cpu_m[i] / bc1;
1306            let v_hat = cpu_v[i] / bc2;
1307            cpu_params[i] -= lr * (m_hat / (v_hat.sqrt() + eps) + wd * cpu_params[i]);
1308        }
1309
1310        device
1311            .adamw_step(
1312                &mut params,
1313                &grads,
1314                &mut m_state,
1315                &mut v_state,
1316                lr as f32,
1317                beta1 as f32,
1318                beta2 as f32,
1319                eps as f32,
1320                wd as f32,
1321                step,
1322            )
1323            .expect("adamw_step");
1324
1325        let max_diff =
1326            params.iter().zip(cpu_params.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1327
1328        assert!(
1329            max_diff < 1e-4,
1330            "FALSIFY-WGPU-001: AdamW step max diff = {max_diff} (threshold: 1e-4)"
1331        );
1332    }
1333
1334    /// FALSIFY-WGPU-001: RMSNorm backward matches CPU
1335    #[test]
1336    fn test_falsify_wgpu_001_rmsnorm_backward_parity() {
1337        let device = GpuDevice::new().expect("GPU device");
1338
1339        let (num_rows, hidden_dim) = (3, 8);
1340        let eps: f32 = 1e-5;
1341        let n = num_rows * hidden_dim;
1342
1343        let input: Vec<f32> = (0..n).map(|i| (i as f32 - 12.0) * 0.1).collect();
1344        let gamma: Vec<f32> = (0..hidden_dim).map(|i| 1.0 + i as f32 * 0.1).collect();
1345        let grad_output: Vec<f32> = (0..n).map(|i| (i as f32 - 12.0) * 0.05).collect();
1346
1347        // CPU reference
1348        let mut cpu_grad_input = vec![0.0f32; n];
1349        let mut cpu_grad_gamma = vec![0.0f32; hidden_dim];
1350        for r in 0..num_rows {
1351            let row = &input[r * hidden_dim..(r + 1) * hidden_dim];
1352            let grow = &grad_output[r * hidden_dim..(r + 1) * hidden_dim];
1353
1354            let sum_x2: f32 = row.iter().map(|x| x * x).sum();
1355            let mean_x2 = sum_x2 / hidden_dim as f32;
1356            let var_eps = mean_x2 + eps;
1357            let rms = var_eps.sqrt();
1358            let inv_rms = 1.0 / rms;
1359
1360            let sum_xgg: f32 = row
1361                .iter()
1362                .zip(grow.iter())
1363                .zip(gamma.iter())
1364                .map(|((&x, &gy), &g)| x * gy * g)
1365                .sum();
1366            let mean_xgg = sum_xgg / hidden_dim as f32;
1367
1368            for i in 0..hidden_dim {
1369                let x = row[i];
1370                let gy = grow[i];
1371                let g = gamma[i];
1372                let gamma_gy = g * gy;
1373                let correction = (x / var_eps) * mean_xgg;
1374                cpu_grad_input[r * hidden_dim + i] = inv_rms * (gamma_gy - correction);
1375                cpu_grad_gamma[i] += gy * x * inv_rms;
1376            }
1377        }
1378
1379        let mut grad_input = vec![0.0f32; n];
1380        let mut grad_gamma = vec![0.0f32; hidden_dim];
1381
1382        device
1383            .rmsnorm_backward(
1384                &input,
1385                &gamma,
1386                &grad_output,
1387                &mut grad_input,
1388                &mut grad_gamma,
1389                num_rows as u32,
1390                hidden_dim as u32,
1391                eps,
1392            )
1393            .expect("rmsnorm_backward");
1394
1395        let gi_max_diff = grad_input
1396            .iter()
1397            .zip(cpu_grad_input.iter())
1398            .map(|(a, b)| (a - b).abs())
1399            .fold(0.0f32, f32::max);
1400
1401        let gg_max_diff = grad_gamma
1402            .iter()
1403            .zip(cpu_grad_gamma.iter())
1404            .map(|(a, b)| (a - b).abs())
1405            .fold(0.0f32, f32::max);
1406
1407        assert!(
1408            gi_max_diff < 1e-3,
1409            "FALSIFY-WGPU-001: RMSNorm grad_input max diff = {gi_max_diff}"
1410        );
1411        assert!(
1412            gg_max_diff < 1e-2,
1413            "FALSIFY-WGPU-001: RMSNorm grad_gamma max diff = {gg_max_diff} (atomic CAS accumulation)"
1414        );
1415    }
1416
1417    /// FALSIFY-WGPU-003: NF4 dequant matches CPU
1418    #[test]
1419    fn test_falsify_wgpu_003_nf4_dequant_parity() {
1420        let device = GpuDevice::new().expect("GPU device");
1421
1422        // NF4 codebook
1423        let nf4_lut: [f32; 16] = [
1424            -1.0,
1425            -0.6961928,
1426            -0.5250731,
1427            -0.39491749,
1428            -0.28444138,
1429            -0.18477343,
1430            -0.09105004,
1431            0.0,
1432            0.0795803,
1433            0.1609302,
1434            0.24611230,
1435            0.33791524,
1436            0.44070983,
1437            0.5626170,
1438            0.7229568,
1439            1.0,
1440        ];
1441
1442        let block_size = 4u32; // small for testing
1443        let n = 8u32; // 8 elements = 2 blocks of 4
1444
1445        // Pack: each byte has 2 nibbles (low=even, high=odd)
1446        // Elements: indices [3, 7, 12, 1, 5, 15, 0, 9]
1447        // Byte 0: low=3, high=7 → 0x73
1448        // Byte 1: low=12, high=1 → 0x1C
1449        // Byte 2: low=5, high=15 → 0xF5
1450        // Byte 3: low=0, high=9 → 0x90
1451        // 8 elements = 4 bytes = 1 u32 (each byte has 2 nibbles)
1452        // Byte 0: elem[0]=3,elem[1]=7 → 0x73
1453        // Byte 1: elem[2]=12,elem[3]=1 → 0x1C
1454        // Byte 2: elem[4]=5,elem[5]=15 → 0xF5
1455        // Byte 3: elem[6]=0,elem[7]=9 → 0x90
1456        let packed: Vec<u32> = vec![0x90F5_1C73_u32];
1457
1458        let scales: Vec<f32> = vec![2.0, 0.5]; // 2 blocks
1459        let indices = [3, 7, 12, 1, 5, 15, 0, 9];
1460
1461        // CPU reference
1462        let mut expected = vec![0.0f32; n as usize];
1463        for i in 0..n as usize {
1464            let scale = scales[i / block_size as usize];
1465            expected[i] = nf4_lut[indices[i]] * scale;
1466        }
1467
1468        let mut output = vec![0.0f32; n as usize];
1469        device.nf4_dequant(&packed, &scales, &mut output, n, block_size).expect("nf4_dequant");
1470
1471        let max_diff =
1472            output.iter().zip(expected.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1473
1474        assert!(
1475            max_diff < 1e-6,
1476            "FALSIFY-WGPU-003: NF4 dequant max diff = {max_diff} (threshold: 1e-6)"
1477        );
1478    }
1479}