1#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
9use super::super::runtime;
10use super::super::shaders;
11use super::GpuDevice;
12
13impl GpuDevice {
14 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
21 pub fn silu_backward(
22 &self,
23 input: &[f32],
24 grad_output: &[f32],
25 grad_input: &mut [f32],
26 ) -> Result<(), String> {
27 runtime::block_on(self.silu_backward_async(input, grad_output, grad_input))
28 }
29
30 pub async fn silu_backward_async(
32 &self,
33 input: &[f32],
34 grad_output: &[f32],
35 grad_input: &mut [f32],
36 ) -> Result<(), String> {
37 let n = input.len();
38 if grad_output.len() != n || grad_input.len() != n {
39 return Err(format!(
40 "SiLU backward: length mismatch: input={}, grad_output={}, grad_input={}",
41 n,
42 grad_output.len(),
43 grad_input.len()
44 ));
45 }
46
47 self.execute_backward_elementwise(
48 "SiLU Backward",
49 shaders::backward::SILU_BACKWARD_SHADER,
50 input,
51 grad_output,
52 grad_input,
53 n as u32,
54 )
55 .await
56 }
57
58 async fn execute_backward_elementwise(
62 &self,
63 op_name: &str,
64 shader_source: &str,
65 input: &[f32],
66 grad_output: &[f32],
67 grad_input: &mut [f32],
68 n: u32,
69 ) -> Result<(), String> {
70 use wgpu;
71
72 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
73 label: Some(&format!("{op_name} Shader")),
74 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
75 });
76
77 let input_buf = self.create_storage_buffer(&format!("{op_name} input"), input, true);
79 let grad_out_buf =
80 self.create_storage_buffer(&format!("{op_name} grad_output"), grad_output, true);
81 let grad_in_buf = self.create_rw_storage_buffer(
82 &format!("{op_name} grad_input"),
83 (grad_input.len() * 4) as u64,
84 );
85
86 let uniform_data: [u32; 4] = [n, 0, 0, 0];
88 let uniform_buf = self.create_uniform_buffer(
89 &format!("{op_name} uniform"),
90 bytemuck::cast_slice(&uniform_data),
91 );
92
93 let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
95 label: Some(&format!("{op_name} BGL")),
96 entries: &[
97 storage_entry(0, true),
98 storage_entry(1, true),
99 storage_entry(2, false),
100 uniform_entry(3),
101 ],
102 });
103
104 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
105 label: Some(&format!("{op_name} BG")),
106 layout: &bgl,
107 entries: &[
108 wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
109 wgpu::BindGroupEntry { binding: 1, resource: grad_out_buf.as_entire_binding() },
110 wgpu::BindGroupEntry { binding: 2, resource: grad_in_buf.as_entire_binding() },
111 wgpu::BindGroupEntry { binding: 3, resource: uniform_buf.as_entire_binding() },
112 ],
113 });
114
115 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
116 label: Some(&format!("{op_name} PL")),
117 bind_group_layouts: &[&bgl],
118 push_constant_ranges: &[],
119 });
120
121 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
122 label: Some(&format!("{op_name} Pipeline")),
123 layout: Some(&pipeline_layout),
124 module: &shader,
125 entry_point: Some("main"),
126 compilation_options: Default::default(),
127 cache: None,
128 });
129
130 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
132 label: Some(&format!("{op_name} Staging")),
133 size: (grad_input.len() * 4) as u64,
134 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
135 mapped_at_creation: false,
136 });
137
138 let mut encoder =
140 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
141 {
142 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
143 pass.set_pipeline(&pipeline);
144 pass.set_bind_group(0, &bg, &[]);
145 let total_wg = n.div_ceil(256);
147 pass.dispatch_workgroups(total_wg.min(65535), total_wg.div_ceil(65535), 1);
148 }
149 encoder.copy_buffer_to_buffer(&grad_in_buf, 0, &staging, 0, (grad_input.len() * 4) as u64);
150 self.queue.submit(Some(encoder.finish()));
151
152 let slice = staging.slice(..);
154 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
155 slice.map_async(wgpu::MapMode::Read, move |r| {
156 sender.send(r).ok();
157 });
158 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
159 receiver
160 .receive()
161 .await
162 .ok_or_else(|| format!("{op_name}: map_async cancelled"))?
163 .map_err(|e| format!("{op_name}: map_async failed: {e}"))?;
164
165 let data = slice.get_mapped_range();
166 grad_input.copy_from_slice(bytemuck::cast_slice(&data));
167 drop(data);
168 staging.unmap();
169
170 Ok(())
171 }
172
173 fn create_storage_buffer(&self, label: &str, data: &[f32], read_only: bool) -> wgpu::Buffer {
176 let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
177 label: Some(label),
178 size: (data.len() * 4) as u64,
179 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
180 mapped_at_creation: false,
181 });
182 self.queue.write_buffer(&buf, 0, bytemuck::cast_slice(data));
183 let _ = read_only; buf
185 }
186
187 fn create_rw_storage_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
188 self.device.create_buffer(&wgpu::BufferDescriptor {
189 label: Some(label),
190 size,
191 usage: wgpu::BufferUsages::STORAGE
192 | wgpu::BufferUsages::COPY_SRC
193 | wgpu::BufferUsages::COPY_DST,
194 mapped_at_creation: false,
195 })
196 }
197
198 fn create_uniform_buffer(&self, label: &str, data: &[u8]) -> wgpu::Buffer {
199 let buf = self.device.create_buffer(&wgpu::BufferDescriptor {
200 label: Some(label),
201 size: data.len() as u64,
202 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
203 mapped_at_creation: false,
204 });
205 self.queue.write_buffer(&buf, 0, data);
206 buf
207 }
208
209 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
216 pub fn gemm_backward_a(
217 &self,
218 grad_c: &[f32],
219 b: &[f32],
220 grad_a: &mut [f32],
221 m: u32,
222 k: u32,
223 n: u32,
224 ) -> Result<(), String> {
225 runtime::block_on(self.gemm_backward_a_async(grad_c, b, grad_a, m, k, n))
226 }
227
228 pub async fn gemm_backward_a_async(
230 &self,
231 grad_c: &[f32],
232 b: &[f32],
233 grad_a: &mut [f32],
234 m: u32,
235 k: u32,
236 n: u32,
237 ) -> Result<(), String> {
238 self.execute_backward_gemm(
239 "GEMM Backward A",
240 shaders::backward::GEMM_BACKWARD_A_SHADER,
241 grad_c,
242 b,
243 grad_a,
244 m,
245 k,
246 n,
247 )
248 .await
249 }
250
251 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
253 pub fn gemm_backward_b(
254 &self,
255 a: &[f32],
256 grad_c: &[f32],
257 grad_b: &mut [f32],
258 m: u32,
259 k: u32,
260 n: u32,
261 ) -> Result<(), String> {
262 runtime::block_on(self.gemm_backward_b_async(a, grad_c, grad_b, m, k, n))
263 }
264
265 pub async fn gemm_backward_b_async(
267 &self,
268 a: &[f32],
269 grad_c: &[f32],
270 grad_b: &mut [f32],
271 m: u32,
272 k: u32,
273 n: u32,
274 ) -> Result<(), String> {
275 self.execute_backward_gemm(
276 "GEMM Backward B",
277 shaders::backward::GEMM_BACKWARD_B_SHADER,
278 a,
279 grad_c,
280 grad_b,
281 m,
282 k,
283 n,
284 )
285 .await
286 }
287
288 async fn execute_backward_gemm(
292 &self,
293 op_name: &str,
294 shader_source: &str,
295 buf_a: &[f32],
296 buf_b: &[f32],
297 output: &mut [f32],
298 m: u32,
299 k: u32,
300 n: u32,
301 ) -> Result<(), String> {
302 use wgpu;
303
304 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
305 label: Some(&format!("{op_name} Shader")),
306 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
307 });
308
309 let a_buf = self.create_storage_buffer(&format!("{op_name} A"), buf_a, true);
310 let b_buf = self.create_storage_buffer(&format!("{op_name} B"), buf_b, true);
311 let out_buf =
312 self.create_rw_storage_buffer(&format!("{op_name} Output"), (output.len() * 4) as u64);
313
314 let dims: [u32; 4] = [m, k, n, 0];
316 let uniform_buf =
317 self.create_uniform_buffer(&format!("{op_name} Dims"), bytemuck::cast_slice(&dims));
318
319 let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
320 label: None,
321 entries: &[
322 storage_entry(0, true),
323 storage_entry(1, true),
324 storage_entry(2, false),
325 uniform_entry(3),
326 ],
327 });
328
329 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
330 label: None,
331 layout: &bgl,
332 entries: &[
333 wgpu::BindGroupEntry { binding: 0, resource: a_buf.as_entire_binding() },
334 wgpu::BindGroupEntry { binding: 1, resource: b_buf.as_entire_binding() },
335 wgpu::BindGroupEntry { binding: 2, resource: out_buf.as_entire_binding() },
336 wgpu::BindGroupEntry { binding: 3, resource: uniform_buf.as_entire_binding() },
337 ],
338 });
339
340 let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
341 label: None,
342 bind_group_layouts: &[&bgl],
343 push_constant_ranges: &[],
344 });
345
346 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
347 label: Some(&format!("{op_name} Pipeline")),
348 layout: Some(&pl),
349 module: &shader,
350 entry_point: Some("main"),
351 compilation_options: Default::default(),
352 cache: None,
353 });
354
355 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
356 label: None,
357 size: (output.len() * 4) as u64,
358 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
359 mapped_at_creation: false,
360 });
361
362 let mut encoder =
363 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
364 {
365 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
366 pass.set_pipeline(&pipeline);
367 pass.set_bind_group(0, &bg, &[]);
368
369 let out_rows = if op_name.contains("A") { m } else { k };
373 let out_cols = if op_name.contains("A") { k } else { n };
374 pass.dispatch_workgroups(out_rows.div_ceil(16), out_cols.div_ceil(16), 1);
375 }
376 encoder.copy_buffer_to_buffer(&out_buf, 0, &staging, 0, (output.len() * 4) as u64);
377 self.queue.submit(Some(encoder.finish()));
378
379 let slice = staging.slice(..);
380 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
381 slice.map_async(wgpu::MapMode::Read, move |r| {
382 sender.send(r).ok();
383 });
384 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
385 receiver
386 .receive()
387 .await
388 .ok_or_else(|| format!("{op_name}: map cancelled"))?
389 .map_err(|e| format!("{op_name}: map failed: {e}"))?;
390
391 let data = slice.get_mapped_range();
392 output.copy_from_slice(bytemuck::cast_slice(&data));
393 drop(data);
394 staging.unmap();
395
396 Ok(())
397 }
398
399 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
403 pub fn rope_backward(
404 &self,
405 grad_output: &[f32],
406 grad_input: &mut [f32],
407 num_heads: u32,
408 head_dim: u32,
409 seq_len: u32,
410 theta: f32,
411 ) -> Result<(), String> {
412 runtime::block_on(self.rope_backward_async(
413 grad_output,
414 grad_input,
415 num_heads,
416 head_dim,
417 seq_len,
418 theta,
419 ))
420 }
421
422 pub async fn rope_backward_async(
424 &self,
425 grad_output: &[f32],
426 grad_input: &mut [f32],
427 num_heads: u32,
428 head_dim: u32,
429 seq_len: u32,
430 theta: f32,
431 ) -> Result<(), String> {
432 use wgpu;
433
434 let n = grad_output.len();
435 let total_pairs = num_heads * seq_len * (head_dim / 2);
436
437 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
438 label: Some("RoPE Backward Shader"),
439 source: wgpu::ShaderSource::Wgsl(shaders::backward::ROPE_BACKWARD_SHADER.into()),
440 });
441
442 let go_buf = self.create_storage_buffer("rope_bwd grad_out", grad_output, true);
443 let gi_buf = self.create_rw_storage_buffer("rope_bwd grad_in", (n * 4) as u64);
444
445 let params: [u32; 4] = [num_heads, head_dim, seq_len, theta.log2().to_bits()];
447 let uniform_buf =
448 self.create_uniform_buffer("rope_bwd params", bytemuck::cast_slice(¶ms));
449
450 let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
451 label: None,
452 entries: &[storage_entry(0, true), storage_entry(1, false), uniform_entry(2)],
453 });
454 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
455 label: None,
456 layout: &bgl,
457 entries: &[
458 wgpu::BindGroupEntry { binding: 0, resource: go_buf.as_entire_binding() },
459 wgpu::BindGroupEntry { binding: 1, resource: gi_buf.as_entire_binding() },
460 wgpu::BindGroupEntry { binding: 2, resource: uniform_buf.as_entire_binding() },
461 ],
462 });
463
464 let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
465 label: None,
466 bind_group_layouts: &[&bgl],
467 push_constant_ranges: &[],
468 });
469 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
470 label: Some("RoPE Backward"),
471 layout: Some(&pl),
472 module: &shader,
473 entry_point: Some("main"),
474 compilation_options: Default::default(),
475 cache: None,
476 });
477
478 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
479 label: None,
480 size: (n * 4) as u64,
481 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
482 mapped_at_creation: false,
483 });
484
485 let mut encoder =
486 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
487 {
488 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
489 pass.set_pipeline(&pipeline);
490 pass.set_bind_group(0, &bg, &[]);
491 let total_wg = total_pairs.div_ceil(256);
492 pass.dispatch_workgroups(total_wg.min(65535), total_wg.div_ceil(65535), 1);
493 }
494 encoder.copy_buffer_to_buffer(&gi_buf, 0, &staging, 0, (n * 4) as u64);
495 self.queue.submit(Some(encoder.finish()));
496
497 let slice = staging.slice(..);
498 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
499 slice.map_async(wgpu::MapMode::Read, move |r| {
500 sender.send(r).ok();
501 });
502 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
503 receiver
504 .receive()
505 .await
506 .ok_or("RoPE backward: cancelled".to_string())?
507 .map_err(|e| format!("RoPE backward: {e}"))?;
508 let data = slice.get_mapped_range();
509 grad_input.copy_from_slice(bytemuck::cast_slice(&data));
510 drop(data);
511 staging.unmap();
512 Ok(())
513 }
514
515 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
517 pub fn adamw_step(
518 &self,
519 params: &mut [f32],
520 grads: &[f32],
521 m: &mut [f32],
522 v: &mut [f32],
523 lr: f32,
524 beta1: f32,
525 beta2: f32,
526 eps: f32,
527 weight_decay: f32,
528 step: u32,
529 ) -> Result<(), String> {
530 runtime::block_on(self.adamw_step_async(
531 params,
532 grads,
533 m,
534 v,
535 lr,
536 beta1,
537 beta2,
538 eps,
539 weight_decay,
540 step,
541 ))
542 }
543
544 pub async fn adamw_step_async(
546 &self,
547 params: &mut [f32],
548 grads: &[f32],
549 m: &mut [f32],
550 v: &mut [f32],
551 lr: f32,
552 beta1: f32,
553 beta2: f32,
554 eps: f32,
555 weight_decay: f32,
556 step: u32,
557 ) -> Result<(), String> {
558 use wgpu;
559
560 let n = params.len() as u32;
561 let bc1 = 1.0 - beta1.powi(step as i32);
562 let bc2 = 1.0 - beta2.powi(step as i32);
563
564 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
565 label: Some("AdamW Step"),
566 source: wgpu::ShaderSource::Wgsl(shaders::backward::ADAMW_STEP_SHADER.into()),
567 });
568
569 let params_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
571 label: Some("adamw params"),
572 size: (params.len() * 4) as u64,
573 usage: wgpu::BufferUsages::STORAGE
574 | wgpu::BufferUsages::COPY_DST
575 | wgpu::BufferUsages::COPY_SRC,
576 mapped_at_creation: false,
577 });
578 self.queue.write_buffer(¶ms_buf, 0, bytemuck::cast_slice(params));
579
580 let grads_buf = self.create_storage_buffer("adamw grads", grads, true);
581
582 let m_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
583 label: Some("adamw m"),
584 size: (m.len() * 4) as u64,
585 usage: wgpu::BufferUsages::STORAGE
586 | wgpu::BufferUsages::COPY_DST
587 | wgpu::BufferUsages::COPY_SRC,
588 mapped_at_creation: false,
589 });
590 self.queue.write_buffer(&m_buf, 0, bytemuck::cast_slice(m));
591
592 let v_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
593 label: Some("adamw v"),
594 size: (v.len() * 4) as u64,
595 usage: wgpu::BufferUsages::STORAGE
596 | wgpu::BufferUsages::COPY_DST
597 | wgpu::BufferUsages::COPY_SRC,
598 mapped_at_creation: false,
599 });
600 self.queue.write_buffer(&v_buf, 0, bytemuck::cast_slice(v));
601
602 let hp: [u32; 8] = [
605 n,
606 lr.to_bits(),
607 beta1.to_bits(),
608 beta2.to_bits(),
609 eps.to_bits(),
610 weight_decay.to_bits(),
611 bc1.to_bits(),
612 bc2.to_bits(),
613 ];
614 let uniform_buf = self.create_uniform_buffer("adamw hp", bytemuck::cast_slice(&hp));
615
616 let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
617 label: None,
618 entries: &[
619 storage_entry(0, false), storage_entry(1, true), storage_entry(2, false), storage_entry(3, false), uniform_entry(4),
624 ],
625 });
626 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
627 label: None,
628 layout: &bgl,
629 entries: &[
630 wgpu::BindGroupEntry { binding: 0, resource: params_buf.as_entire_binding() },
631 wgpu::BindGroupEntry { binding: 1, resource: grads_buf.as_entire_binding() },
632 wgpu::BindGroupEntry { binding: 2, resource: m_buf.as_entire_binding() },
633 wgpu::BindGroupEntry { binding: 3, resource: v_buf.as_entire_binding() },
634 wgpu::BindGroupEntry { binding: 4, resource: uniform_buf.as_entire_binding() },
635 ],
636 });
637
638 let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
639 label: None,
640 bind_group_layouts: &[&bgl],
641 push_constant_ranges: &[],
642 });
643 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
644 label: Some("AdamW"),
645 layout: Some(&pl),
646 module: &shader,
647 entry_point: Some("main"),
648 compilation_options: Default::default(),
649 cache: None,
650 });
651
652 let params_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
654 label: None,
655 size: (params.len() * 4) as u64,
656 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
657 mapped_at_creation: false,
658 });
659 let m_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
660 label: None,
661 size: (m.len() * 4) as u64,
662 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
663 mapped_at_creation: false,
664 });
665 let v_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
666 label: None,
667 size: (v.len() * 4) as u64,
668 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
669 mapped_at_creation: false,
670 });
671
672 let mut encoder =
673 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
674 {
675 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
676 pass.set_pipeline(&pipeline);
677 pass.set_bind_group(0, &bg, &[]);
678 let total_wg = n.div_ceil(256);
680 pass.dispatch_workgroups(total_wg.min(65535), total_wg.div_ceil(65535), 1);
681 }
682 encoder.copy_buffer_to_buffer(
683 ¶ms_buf,
684 0,
685 ¶ms_staging,
686 0,
687 (params.len() * 4) as u64,
688 );
689 encoder.copy_buffer_to_buffer(&m_buf, 0, &m_staging, 0, (m.len() * 4) as u64);
690 encoder.copy_buffer_to_buffer(&v_buf, 0, &v_staging, 0, (v.len() * 4) as u64);
691 self.queue.submit(Some(encoder.finish()));
692
693 let read_buf = |staging: &wgpu::Buffer, out: &mut [f32]| -> Result<(), String> {
695 let slice = staging.slice(..);
696 let (tx, rx) = std::sync::mpsc::channel();
697 slice.map_async(wgpu::MapMode::Read, move |r| {
698 tx.send(r).ok();
699 });
700 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
701 rx.recv()
702 .map_err(|e| format!("AdamW readback: {e}"))?
703 .map_err(|e| format!("AdamW map: {e}"))?;
704 let data = slice.get_mapped_range();
705 out.copy_from_slice(bytemuck::cast_slice(&data));
706 drop(data);
707 staging.unmap();
708 Ok(())
709 };
710 read_buf(¶ms_staging, params)?;
711 read_buf(&m_staging, m)?;
712 read_buf(&v_staging, v)?;
713
714 Ok(())
715 }
716
717 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
724 pub fn rmsnorm_backward(
725 &self,
726 input: &[f32],
727 gamma: &[f32],
728 grad_output: &[f32],
729 grad_input: &mut [f32],
730 grad_gamma: &mut [f32],
731 num_rows: u32,
732 hidden_dim: u32,
733 eps: f32,
734 ) -> Result<(), String> {
735 runtime::block_on(self.rmsnorm_backward_async(
736 input,
737 gamma,
738 grad_output,
739 grad_input,
740 grad_gamma,
741 num_rows,
742 hidden_dim,
743 eps,
744 ))
745 }
746
747 pub async fn rmsnorm_backward_async(
749 &self,
750 input: &[f32],
751 gamma: &[f32],
752 grad_output: &[f32],
753 grad_input: &mut [f32],
754 grad_gamma: &mut [f32],
755 num_rows: u32,
756 hidden_dim: u32,
757 eps: f32,
758 ) -> Result<(), String> {
759 use wgpu;
760
761 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
762 label: Some("RMSNorm Backward"),
763 source: wgpu::ShaderSource::Wgsl(shaders::backward::RMSNORM_BACKWARD_SHADER.into()),
764 });
765
766 let input_buf = self.create_storage_buffer("rms_bwd input", input, true);
767 let gamma_buf = self.create_storage_buffer("rms_bwd gamma", gamma, true);
768 let grad_out_buf = self.create_storage_buffer("rms_bwd grad_out", grad_output, true);
769 let grad_in_buf =
770 self.create_rw_storage_buffer("rms_bwd grad_in", (grad_input.len() * 4) as u64);
771
772 let grad_gamma_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
774 label: Some("rms_bwd grad_gamma"),
775 size: (hidden_dim as usize * 4) as u64,
776 usage: wgpu::BufferUsages::STORAGE
777 | wgpu::BufferUsages::COPY_DST
778 | wgpu::BufferUsages::COPY_SRC,
779 mapped_at_creation: false,
780 });
781 let zeros = vec![0u8; hidden_dim as usize * 4];
783 self.queue.write_buffer(&grad_gamma_buf, 0, &zeros);
784
785 let params: [u32; 4] = [num_rows, hidden_dim, eps.to_bits(), 0];
787 let uniform_buf =
788 self.create_uniform_buffer("rms_bwd params", bytemuck::cast_slice(¶ms));
789
790 let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
791 label: None,
792 entries: &[
793 storage_entry(0, true), storage_entry(1, true), storage_entry(2, true), storage_entry(3, false), storage_entry(4, false), uniform_entry(5),
799 ],
800 });
801 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
802 label: None,
803 layout: &bgl,
804 entries: &[
805 wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
806 wgpu::BindGroupEntry { binding: 1, resource: gamma_buf.as_entire_binding() },
807 wgpu::BindGroupEntry { binding: 2, resource: grad_out_buf.as_entire_binding() },
808 wgpu::BindGroupEntry { binding: 3, resource: grad_in_buf.as_entire_binding() },
809 wgpu::BindGroupEntry { binding: 4, resource: grad_gamma_buf.as_entire_binding() },
810 wgpu::BindGroupEntry { binding: 5, resource: uniform_buf.as_entire_binding() },
811 ],
812 });
813
814 let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
815 label: None,
816 bind_group_layouts: &[&bgl],
817 push_constant_ranges: &[],
818 });
819 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
820 label: Some("RMSNorm Backward"),
821 layout: Some(&pl),
822 module: &shader,
823 entry_point: Some("main"),
824 compilation_options: Default::default(),
825 cache: None,
826 });
827
828 let gi_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
830 label: None,
831 size: (grad_input.len() * 4) as u64,
832 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
833 mapped_at_creation: false,
834 });
835 let gg_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
836 label: None,
837 size: (hidden_dim as usize * 4) as u64,
838 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
839 mapped_at_creation: false,
840 });
841
842 let mut encoder =
843 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
844 {
845 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
846 pass.set_pipeline(&pipeline);
847 pass.set_bind_group(0, &bg, &[]);
848 pass.dispatch_workgroups(num_rows, 1, 1);
850 }
851 encoder.copy_buffer_to_buffer(
852 &grad_in_buf,
853 0,
854 &gi_staging,
855 0,
856 (grad_input.len() * 4) as u64,
857 );
858 encoder.copy_buffer_to_buffer(
859 &grad_gamma_buf,
860 0,
861 &gg_staging,
862 0,
863 (hidden_dim as usize * 4) as u64,
864 );
865 self.queue.submit(Some(encoder.finish()));
866
867 {
869 let slice = gi_staging.slice(..);
870 let (tx, rx) = std::sync::mpsc::channel();
871 slice.map_async(wgpu::MapMode::Read, move |r| {
872 tx.send(r).ok();
873 });
874 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
875 rx.recv()
876 .map_err(|e| format!("RMSNorm bwd gi: {e}"))?
877 .map_err(|e| format!("RMSNorm bwd gi map: {e}"))?;
878 let data = slice.get_mapped_range();
879 grad_input.copy_from_slice(bytemuck::cast_slice(&data));
880 drop(data);
881 gi_staging.unmap();
882 }
883 {
885 let slice = gg_staging.slice(..);
886 let (tx, rx) = std::sync::mpsc::channel();
887 slice.map_async(wgpu::MapMode::Read, move |r| {
888 tx.send(r).ok();
889 });
890 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
891 rx.recv()
892 .map_err(|e| format!("RMSNorm bwd gg: {e}"))?
893 .map_err(|e| format!("RMSNorm bwd gg map: {e}"))?;
894 let data = slice.get_mapped_range();
895 let raw: &[u32] = bytemuck::cast_slice(&data);
897 for (i, &bits) in raw.iter().enumerate() {
898 grad_gamma[i] = f32::from_bits(bits);
899 }
900 drop(data);
901 gg_staging.unmap();
902 }
903
904 Ok(())
905 }
906
907 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
913 pub fn nf4_dequant(
914 &self,
915 packed: &[u32],
916 scales: &[f32],
917 output: &mut [f32],
918 n: u32,
919 block_size: u32,
920 ) -> Result<(), String> {
921 runtime::block_on(self.nf4_dequant_async(packed, scales, output, n, block_size))
922 }
923
924 pub async fn nf4_dequant_async(
926 &self,
927 packed: &[u32],
928 scales: &[f32],
929 output: &mut [f32],
930 n: u32,
931 block_size: u32,
932 ) -> Result<(), String> {
933 use wgpu;
934
935 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
936 label: Some("NF4 Dequant"),
937 source: wgpu::ShaderSource::Wgsl(shaders::backward::NF4_DEQUANT_SHADER.into()),
938 });
939
940 let packed_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
941 label: Some("nf4 packed"),
942 size: (packed.len() * 4) as u64,
943 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
944 mapped_at_creation: false,
945 });
946 self.queue.write_buffer(&packed_buf, 0, bytemuck::cast_slice(packed));
947
948 let scales_buf = self.create_storage_buffer("nf4 scales", scales, true);
949 let output_buf = self.create_rw_storage_buffer("nf4 output", (output.len() * 4) as u64);
950
951 let params: [u32; 4] = [n, block_size, 0, 0];
952 let uniform_buf = self.create_uniform_buffer("nf4 params", bytemuck::cast_slice(¶ms));
953
954 let bgl = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
955 label: None,
956 entries: &[
957 storage_entry(0, true), storage_entry(1, true), storage_entry(2, false), uniform_entry(3),
961 ],
962 });
963 let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
964 label: None,
965 layout: &bgl,
966 entries: &[
967 wgpu::BindGroupEntry { binding: 0, resource: packed_buf.as_entire_binding() },
968 wgpu::BindGroupEntry { binding: 1, resource: scales_buf.as_entire_binding() },
969 wgpu::BindGroupEntry { binding: 2, resource: output_buf.as_entire_binding() },
970 wgpu::BindGroupEntry { binding: 3, resource: uniform_buf.as_entire_binding() },
971 ],
972 });
973
974 let pl = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
975 label: None,
976 bind_group_layouts: &[&bgl],
977 push_constant_ranges: &[],
978 });
979 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
980 label: Some("NF4 Dequant"),
981 layout: Some(&pl),
982 module: &shader,
983 entry_point: Some("main"),
984 compilation_options: Default::default(),
985 cache: None,
986 });
987
988 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
989 label: None,
990 size: (output.len() * 4) as u64,
991 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
992 mapped_at_creation: false,
993 });
994
995 let mut encoder =
996 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
997 {
998 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
999 pass.set_pipeline(&pipeline);
1000 pass.set_bind_group(0, &bg, &[]);
1001 let total_wg = n.div_ceil(256);
1005 let x = total_wg.min(65535);
1006 let y = total_wg.div_ceil(65535);
1007 pass.dispatch_workgroups(x, y, 1);
1008 }
1009 encoder.copy_buffer_to_buffer(&output_buf, 0, &staging, 0, (output.len() * 4) as u64);
1010 self.queue.submit(Some(encoder.finish()));
1011
1012 let slice = staging.slice(..);
1013 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
1014 slice.map_async(wgpu::MapMode::Read, move |r| {
1015 sender.send(r).ok();
1016 });
1017 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
1018 receiver
1019 .receive()
1020 .await
1021 .ok_or("NF4 dequant: cancelled".to_string())?
1022 .map_err(|e| format!("NF4 dequant: {e}"))?;
1023 let data = slice.get_mapped_range();
1024 output.copy_from_slice(bytemuck::cast_slice(&data));
1025 drop(data);
1026 staging.unmap();
1027
1028 Ok(())
1029 }
1030}
1031
1032fn storage_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
1033 wgpu::BindGroupLayoutEntry {
1034 binding,
1035 visibility: wgpu::ShaderStages::COMPUTE,
1036 ty: wgpu::BindingType::Buffer {
1037 ty: wgpu::BufferBindingType::Storage { read_only },
1038 has_dynamic_offset: false,
1039 min_binding_size: None,
1040 },
1041 count: None,
1042 }
1043}
1044
1045fn uniform_entry(binding: u32) -> wgpu::BindGroupLayoutEntry {
1046 wgpu::BindGroupLayoutEntry {
1047 binding,
1048 visibility: wgpu::ShaderStages::COMPUTE,
1049 ty: wgpu::BindingType::Buffer {
1050 ty: wgpu::BufferBindingType::Uniform,
1051 has_dynamic_offset: false,
1052 min_binding_size: None,
1053 },
1054 count: None,
1055 }
1056}
1057
1058#[cfg(all(test, feature = "gpu"))]
1059mod tests {
1060 use super::*;
1061
1062 fn silu_backward_cpu(input: &[f32], grad_output: &[f32]) -> Vec<f32> {
1064 input
1065 .iter()
1066 .zip(grad_output.iter())
1067 .map(|(&x, &dy)| {
1068 let sigmoid = 1.0 / (1.0 + (-x).exp());
1069 let y = x * sigmoid;
1070 let silu_prime = sigmoid * (1.0 + x - y);
1071 dy * silu_prime
1072 })
1073 .collect()
1074 }
1075
1076 #[test]
1078 fn test_falsify_wgpu_001_silu_backward_parity() {
1079 let device = GpuDevice::new().expect("GPU device");
1080
1081 let input: Vec<f32> = (-50..50).map(|i| i as f32 * 0.1).collect();
1082 let grad_output: Vec<f32> = (0..100).map(|i| (i as f32 - 50.0) * 0.01).collect();
1083 let expected = silu_backward_cpu(&input, &grad_output);
1084
1085 let mut grad_input = vec![0.0f32; 100];
1086 device.silu_backward(&input, &grad_output, &mut grad_input).expect("silu_backward");
1087
1088 let max_diff = grad_input
1089 .iter()
1090 .zip(expected.iter())
1091 .map(|(a, b)| (a - b).abs())
1092 .fold(0.0f32, f32::max);
1093
1094 assert!(
1095 max_diff < 1e-4,
1096 "FALSIFY-WGPU-001: SiLU backward max diff = {max_diff} (threshold: 1e-4)"
1097 );
1098 }
1099
1100 #[test]
1102 fn test_silu_backward_at_zero() {
1103 let device = GpuDevice::new().expect("GPU device");
1104
1105 let input = vec![0.0f32; 4];
1106 let grad_output = vec![1.0f32; 4];
1107 let mut grad_input = vec![0.0f32; 4];
1108
1109 device.silu_backward(&input, &grad_output, &mut grad_input).expect("silu_backward");
1110
1111 for &g in &grad_input {
1113 assert!((g - 0.5).abs() < 1e-5, "silu'(0) should be 0.5, got {g}");
1114 }
1115 }
1116
1117 #[test]
1119 fn test_silu_backward_length_mismatch() {
1120 let device = GpuDevice::new().expect("GPU device");
1121
1122 let input = vec![1.0f32; 10];
1123 let grad_output = vec![1.0f32; 5]; let mut grad_input = vec![0.0f32; 10];
1125
1126 let result = device.silu_backward(&input, &grad_output, &mut grad_input);
1127 assert!(result.is_err());
1128 }
1129
1130 fn matmul_cpu(a: &[f32], b: &[f32], m: usize, k: usize, n: usize) -> Vec<f32> {
1132 let mut c = vec![0.0f32; m * n];
1133 for i in 0..m {
1134 for j in 0..n {
1135 let mut sum = 0.0f32;
1136 for p in 0..k {
1137 sum += a[i * k + p] * b[p * n + j];
1138 }
1139 c[i * n + j] = sum;
1140 }
1141 }
1142 c
1143 }
1144
1145 #[test]
1150 fn test_falsify_wgpu_001_gemm_backward_a_parity() {
1151 let device = GpuDevice::new().expect("GPU device");
1152
1153 let (m, k, n) = (4, 8, 6);
1154
1155 let grad_c: Vec<f32> = (0..m * n).map(|i| (i as f32 - 12.0) * 0.1).collect();
1157 let b: Vec<f32> = (0..k * n).map(|i| (i as f32 - 24.0) * 0.05).collect();
1158
1159 let mut b_t = vec![0.0f32; n * k];
1162 for i in 0..k {
1163 for j in 0..n {
1164 b_t[j * k + i] = b[i * n + j];
1165 }
1166 }
1167 let expected = matmul_cpu(&grad_c, &b_t, m, n, k);
1168
1169 let mut grad_a = vec![0.0f32; m * k];
1170 device
1171 .gemm_backward_a(&grad_c, &b, &mut grad_a, m as u32, k as u32, n as u32)
1172 .expect("gemm_backward_a");
1173
1174 let max_diff =
1175 grad_a.iter().zip(expected.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1176
1177 assert!(
1178 max_diff < 1e-3,
1179 "FALSIFY-WGPU-001: GEMM backward A max diff = {max_diff} (threshold: 1e-3)"
1180 );
1181 }
1182
1183 #[test]
1187 fn test_falsify_wgpu_001_gemm_backward_b_parity() {
1188 let device = GpuDevice::new().expect("GPU device");
1189
1190 let (m, k, n) = (4, 8, 6);
1191
1192 let a: Vec<f32> = (0..m * k).map(|i| (i as f32 - 16.0) * 0.1).collect();
1193 let grad_c: Vec<f32> = (0..m * n).map(|i| (i as f32 - 12.0) * 0.05).collect();
1194
1195 let mut a_t = vec![0.0f32; k * m];
1197 for i in 0..m {
1198 for j in 0..k {
1199 a_t[j * m + i] = a[i * k + j];
1200 }
1201 }
1202 let expected = matmul_cpu(&a_t, &grad_c, k, m, n);
1203
1204 let mut grad_b = vec![0.0f32; k * n];
1205 device
1206 .gemm_backward_b(&a, &grad_c, &mut grad_b, m as u32, k as u32, n as u32)
1207 .expect("gemm_backward_b");
1208
1209 let max_diff =
1210 grad_b.iter().zip(expected.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1211
1212 assert!(
1213 max_diff < 1e-3,
1214 "FALSIFY-WGPU-001: GEMM backward B max diff = {max_diff} (threshold: 1e-3)"
1215 );
1216 }
1217
1218 #[test]
1220 fn test_falsify_wgpu_001_rope_backward_parity() {
1221 let device = GpuDevice::new().expect("GPU device");
1222
1223 let (num_heads, head_dim, seq_len) = (2, 4, 3);
1224 let theta = 10000.0f32;
1225 let n = num_heads * head_dim * seq_len;
1226
1227 let grad_output: Vec<f32> = (0..n).map(|i| (i as f32 - 12.0) * 0.1).collect();
1228
1229 let half_dim = head_dim / 2;
1231 let mut expected = vec![0.0f32; n];
1232 for h in 0..num_heads {
1233 for s in 0..seq_len {
1234 for p in 0..half_dim {
1235 let freq_exp = -((2 * p) as f32) / head_dim as f32 * theta.log2();
1236 let inv_freq = 2.0f32.powf(freq_exp);
1237 let angle = s as f32 * inv_freq;
1238 let (sin_a, cos_a) = angle.sin_cos();
1239
1240 let base = h * seq_len * head_dim + s * head_dim;
1241 let even = base + 2 * p;
1242 let odd = base + 2 * p + 1;
1243
1244 let dy_even = grad_output[even];
1245 let dy_odd = grad_output[odd];
1246
1247 expected[even] = dy_even * cos_a + dy_odd * sin_a;
1249 expected[odd] = -dy_even * sin_a + dy_odd * cos_a;
1250 }
1251 }
1252 }
1253
1254 let mut grad_input = vec![0.0f32; n];
1255 device
1256 .rope_backward(
1257 &grad_output,
1258 &mut grad_input,
1259 num_heads as u32,
1260 head_dim as u32,
1261 seq_len as u32,
1262 theta,
1263 )
1264 .expect("rope_backward");
1265
1266 let max_diff = grad_input
1267 .iter()
1268 .zip(expected.iter())
1269 .map(|(a, b)| (a - b).abs())
1270 .fold(0.0f32, f32::max);
1271
1272 assert!(
1273 max_diff < 1e-4,
1274 "FALSIFY-WGPU-001: RoPE backward max diff = {max_diff} (threshold: 1e-4)"
1275 );
1276 }
1277
1278 #[test]
1280 fn test_falsify_wgpu_001_adamw_step_parity() {
1281 let device = GpuDevice::new().expect("GPU device");
1282
1283 let n = 16;
1284 let mut params: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
1285 let grads: Vec<f32> = (0..n).map(|i| (i as f32 - 8.0) * 0.01).collect();
1286 let mut m_state = vec![0.0f32; n];
1287 let mut v_state = vec![0.0f32; n];
1288
1289 let lr: f32 = 1e-3;
1290 let beta1: f32 = 0.9;
1291 let beta2: f32 = 0.999;
1292 let eps: f32 = 1e-8;
1293 let wd: f32 = 0.01;
1294 let step = 1u32;
1295
1296 let bc1: f32 = 1.0 - beta1.powi(step as i32);
1298 let bc2: f32 = 1.0 - beta2.powi(step as i32);
1299 let mut cpu_params = params.clone();
1300 let mut cpu_m = m_state.clone();
1301 let mut cpu_v = v_state.clone();
1302 for i in 0..n {
1303 cpu_m[i] = beta1 * cpu_m[i] + (1.0 - beta1) * grads[i];
1304 cpu_v[i] = beta2 * cpu_v[i] + (1.0 - beta2) * grads[i] * grads[i];
1305 let m_hat = cpu_m[i] / bc1;
1306 let v_hat = cpu_v[i] / bc2;
1307 cpu_params[i] -= lr * (m_hat / (v_hat.sqrt() + eps) + wd * cpu_params[i]);
1308 }
1309
1310 device
1311 .adamw_step(
1312 &mut params,
1313 &grads,
1314 &mut m_state,
1315 &mut v_state,
1316 lr as f32,
1317 beta1 as f32,
1318 beta2 as f32,
1319 eps as f32,
1320 wd as f32,
1321 step,
1322 )
1323 .expect("adamw_step");
1324
1325 let max_diff =
1326 params.iter().zip(cpu_params.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1327
1328 assert!(
1329 max_diff < 1e-4,
1330 "FALSIFY-WGPU-001: AdamW step max diff = {max_diff} (threshold: 1e-4)"
1331 );
1332 }
1333
1334 #[test]
1336 fn test_falsify_wgpu_001_rmsnorm_backward_parity() {
1337 let device = GpuDevice::new().expect("GPU device");
1338
1339 let (num_rows, hidden_dim) = (3, 8);
1340 let eps: f32 = 1e-5;
1341 let n = num_rows * hidden_dim;
1342
1343 let input: Vec<f32> = (0..n).map(|i| (i as f32 - 12.0) * 0.1).collect();
1344 let gamma: Vec<f32> = (0..hidden_dim).map(|i| 1.0 + i as f32 * 0.1).collect();
1345 let grad_output: Vec<f32> = (0..n).map(|i| (i as f32 - 12.0) * 0.05).collect();
1346
1347 let mut cpu_grad_input = vec![0.0f32; n];
1349 let mut cpu_grad_gamma = vec![0.0f32; hidden_dim];
1350 for r in 0..num_rows {
1351 let row = &input[r * hidden_dim..(r + 1) * hidden_dim];
1352 let grow = &grad_output[r * hidden_dim..(r + 1) * hidden_dim];
1353
1354 let sum_x2: f32 = row.iter().map(|x| x * x).sum();
1355 let mean_x2 = sum_x2 / hidden_dim as f32;
1356 let var_eps = mean_x2 + eps;
1357 let rms = var_eps.sqrt();
1358 let inv_rms = 1.0 / rms;
1359
1360 let sum_xgg: f32 = row
1361 .iter()
1362 .zip(grow.iter())
1363 .zip(gamma.iter())
1364 .map(|((&x, &gy), &g)| x * gy * g)
1365 .sum();
1366 let mean_xgg = sum_xgg / hidden_dim as f32;
1367
1368 for i in 0..hidden_dim {
1369 let x = row[i];
1370 let gy = grow[i];
1371 let g = gamma[i];
1372 let gamma_gy = g * gy;
1373 let correction = (x / var_eps) * mean_xgg;
1374 cpu_grad_input[r * hidden_dim + i] = inv_rms * (gamma_gy - correction);
1375 cpu_grad_gamma[i] += gy * x * inv_rms;
1376 }
1377 }
1378
1379 let mut grad_input = vec![0.0f32; n];
1380 let mut grad_gamma = vec![0.0f32; hidden_dim];
1381
1382 device
1383 .rmsnorm_backward(
1384 &input,
1385 &gamma,
1386 &grad_output,
1387 &mut grad_input,
1388 &mut grad_gamma,
1389 num_rows as u32,
1390 hidden_dim as u32,
1391 eps,
1392 )
1393 .expect("rmsnorm_backward");
1394
1395 let gi_max_diff = grad_input
1396 .iter()
1397 .zip(cpu_grad_input.iter())
1398 .map(|(a, b)| (a - b).abs())
1399 .fold(0.0f32, f32::max);
1400
1401 let gg_max_diff = grad_gamma
1402 .iter()
1403 .zip(cpu_grad_gamma.iter())
1404 .map(|(a, b)| (a - b).abs())
1405 .fold(0.0f32, f32::max);
1406
1407 assert!(
1408 gi_max_diff < 1e-3,
1409 "FALSIFY-WGPU-001: RMSNorm grad_input max diff = {gi_max_diff}"
1410 );
1411 assert!(
1412 gg_max_diff < 1e-2,
1413 "FALSIFY-WGPU-001: RMSNorm grad_gamma max diff = {gg_max_diff} (atomic CAS accumulation)"
1414 );
1415 }
1416
1417 #[test]
1419 fn test_falsify_wgpu_003_nf4_dequant_parity() {
1420 let device = GpuDevice::new().expect("GPU device");
1421
1422 let nf4_lut: [f32; 16] = [
1424 -1.0,
1425 -0.6961928,
1426 -0.5250731,
1427 -0.39491749,
1428 -0.28444138,
1429 -0.18477343,
1430 -0.09105004,
1431 0.0,
1432 0.0795803,
1433 0.1609302,
1434 0.24611230,
1435 0.33791524,
1436 0.44070983,
1437 0.5626170,
1438 0.7229568,
1439 1.0,
1440 ];
1441
1442 let block_size = 4u32; let n = 8u32; let packed: Vec<u32> = vec![0x90F5_1C73_u32];
1457
1458 let scales: Vec<f32> = vec![2.0, 0.5]; let indices = [3, 7, 12, 1, 5, 15, 0, 9];
1460
1461 let mut expected = vec![0.0f32; n as usize];
1463 for i in 0..n as usize {
1464 let scale = scales[i / block_size as usize];
1465 expected[i] = nf4_lut[indices[i]] * scale;
1466 }
1467
1468 let mut output = vec![0.0f32; n as usize];
1469 device.nf4_dequant(&packed, &scales, &mut output, n, block_size).expect("nf4_dequant");
1470
1471 let max_diff =
1472 output.iter().zip(expected.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
1473
1474 assert!(
1475 max_diff < 1e-6,
1476 "FALSIFY-WGPU-003: NF4 dequant max diff = {max_diff} (threshold: 1e-6)"
1477 );
1478 }
1479}