entrenar/autograd/cuda_forward/
normalization.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10 BatchedRopeBackwardKernel, BatchedRopeKernel, BatchedVectorizedRmsNormKernel,
11 FusedResidualRmsNormKernel, Kernel, LayerNormKernel, PerHeadRmsNormKernel, RopeNeoxKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19#[cfg(feature = "cuda")]
23pub fn layer_norm_forward(
24 input: &GpuBuffer<f32>,
25 gamma: &GpuBuffer<f32>,
26 beta: &GpuBuffer<f32>,
27 output: &mut GpuBuffer<f32>,
28 batch_size: u32,
29 hidden_size: u32,
30 stream: &CudaStream,
31) -> Result<()> {
32 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
33 let mut cache = cache.lock().map_err(|_err| {
34 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
35 })?;
36
37 let kernel = LayerNormKernel::new(hidden_size);
38 let kernel_name = kernel.name();
39
40 let key = format!("layer_norm_forward_{hidden_size}");
41 let module = match cache.get_cached(&key) {
42 Some(m) => m,
43 None => {
44 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
45 cache.get_or_compile(&key, &ptx)?
46 }
47 };
48
49 let config = LaunchConfig {
50 grid: (batch_size, 1, 1),
51 block: (256.min(hidden_size), 1, 1),
52 shared_mem: 0,
53 };
54
55 let input_ptr = input.as_ptr();
56 let gamma_ptr = gamma.as_ptr();
57 let beta_ptr = beta.as_ptr();
58 let output_ptr = output.as_ptr();
59
60 let mut args: [*mut std::ffi::c_void; 6] = [
61 &input_ptr as *const _ as *mut _,
62 &gamma_ptr as *const _ as *mut _,
63 &beta_ptr as *const _ as *mut _,
64 &output_ptr as *const _ as *mut _,
65 &batch_size as *const _ as *mut _,
66 &hidden_size as *const _ as *mut _,
67 ];
68
69 unsafe {
72 stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
73 CudaTensorError::KernelError(format!("LayerNorm forward launch failed: {e:?}"))
74 })?;
75 }
76
77 Ok(())
78}
79
80#[cfg(feature = "cuda")]
91pub fn rms_norm_forward(
92 input: &GpuBuffer<f32>,
93 gamma: &GpuBuffer<f32>,
94 output: &mut GpuBuffer<f32>,
95 batch_size: u32,
96 hidden_size: u32,
97 stream: &CudaStream,
98) -> Result<()> {
99 rms_norm_forward_with_eps(input, gamma, output, batch_size, hidden_size, 1e-5, stream)
104}
105
106#[cfg(feature = "cuda")]
120pub fn rms_norm_forward_with_eps(
121 input: &GpuBuffer<f32>,
122 gamma: &GpuBuffer<f32>,
123 output: &mut GpuBuffer<f32>,
124 batch_size: u32,
125 hidden_size: u32,
126 eps: f32,
127 stream: &CudaStream,
128) -> Result<()> {
129 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
130 let mut cache = cache.lock().map_err(|_err| {
131 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
132 })?;
133
134 let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size).with_epsilon(eps);
135
136 let eps_bits = eps.to_bits();
139 let key = format!("batched_rmsnorm_fwd_{hidden_size}_eps{eps_bits:08x}");
140 let module = match cache.get_cached(&key) {
141 Some(m) => m,
142 None => {
143 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
144 cache.get_or_compile(&key, &ptx)?
145 }
146 };
147
148 let config = LaunchConfig {
151 grid: (1, batch_size, 1),
152 block: (256, 1, 1),
153 shared_mem: 8 * 4, };
155
156 let input_ptr = input.as_ptr();
157 let output_ptr = output.as_ptr();
158 let gamma_ptr = gamma.as_ptr();
159
160 let mut args: [*mut std::ffi::c_void; 3] = [
161 &input_ptr as *const _ as *mut _,
162 &output_ptr as *const _ as *mut _,
163 &gamma_ptr as *const _ as *mut _,
164 ];
165
166 unsafe {
170 stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
171 |e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
172 )?;
173 }
174
175 Ok(())
176}
177
178#[cfg(feature = "cuda")]
188pub fn per_head_rmsnorm_forward(
189 input: &GpuBuffer<f32>,
190 gamma: &GpuBuffer<f32>,
191 output: &mut GpuBuffer<f32>,
192 num_heads: u32,
193 head_dim: u32,
194 pos_offset: usize,
195 stream: &CudaStream,
196) -> Result<()> {
197 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
198 let mut cache = cache.lock().map_err(|_err| {
199 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
200 })?;
201
202 let kernel = PerHeadRmsNormKernel::new(head_dim, num_heads);
203
204 let key = format!("per_head_rmsnorm_fwd_{head_dim}_{num_heads}");
205 let module = match cache.get_cached(&key) {
206 Some(m) => m,
207 None => {
208 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
209 cache.get_or_compile(&key, &ptx)?
210 }
211 };
212
213 let config = LaunchConfig { grid: (num_heads, 1, 1), block: (32, 1, 1), shared_mem: 0 };
215
216 let stride = (num_heads * head_dim) as usize;
218 let input_offset = pos_offset * stride;
219 let output_offset = pos_offset * stride;
220
221 let input_ptr = input.as_ptr() + (input_offset * std::mem::size_of::<f32>()) as u64;
223 let output_ptr = output.as_ptr() + (output_offset * std::mem::size_of::<f32>()) as u64;
224 let gamma_ptr = gamma.as_ptr();
225
226 let mut args: [*mut std::ffi::c_void; 3] = [
227 &input_ptr as *const _ as *mut _,
228 &output_ptr as *const _ as *mut _,
229 &gamma_ptr as *const _ as *mut _,
230 ];
231
232 unsafe {
233 stream.launch_kernel(module, "per_head_rmsnorm", &config, &mut args).map_err(|e| {
234 CudaTensorError::KernelError(format!("PerHeadRmsNorm forward failed: {e:?}"))
235 })?;
236 }
237
238 Ok(())
239}
240
241#[cfg(feature = "cuda")]
250pub fn rope_neox_forward(
251 input: &GpuBuffer<f32>,
252 output: &mut GpuBuffer<f32>,
253 num_heads: u32,
254 head_dim: u32,
255 pos: u32,
256 pos_offset: usize,
257 theta: f32,
258 stream: &CudaStream,
259) -> Result<()> {
260 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
261 let mut cache = cache.lock().map_err(|_err| {
262 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
263 })?;
264
265 let kernel = RopeNeoxKernel::new(num_heads, head_dim, theta);
266
267 let theta_bits = theta.to_bits();
275 let key = format!("rope_neox_fwd_{num_heads}_{head_dim}_th{theta_bits:08x}");
276 let module = match cache.get_cached(&key) {
277 Some(m) => m,
278 None => {
279 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
280 cache.get_or_compile(&key, &ptx)?
281 }
282 };
283
284 let config =
286 LaunchConfig { grid: (num_heads, 1, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
287
288 let stride = (num_heads * head_dim) as usize;
290 let byte_offset = pos_offset * stride * std::mem::size_of::<f32>();
291
292 let input_ptr = input.as_ptr() + byte_offset as u64;
294 let output_ptr = output.as_ptr() + byte_offset as u64;
295
296 let mut args: [*mut std::ffi::c_void; 3] = [
297 &input_ptr as *const _ as *mut _,
298 &output_ptr as *const _ as *mut _,
299 &pos as *const _ as *mut _,
300 ];
301
302 unsafe {
303 stream.launch_kernel(module, "rope_neox", &config, &mut args).map_err(|e| {
304 CudaTensorError::KernelError(format!("RoPE NeoX forward failed: {e:?}"))
305 })?;
306 }
307
308 Ok(())
309}
310
311#[cfg(feature = "cuda")]
318pub fn batched_rope_neox_forward(
319 input: &GpuBuffer<f32>,
320 output: &mut GpuBuffer<f32>,
321 positions: &GpuBuffer<u32>,
322 num_heads: u32,
323 head_dim: u32,
324 seq_len: u32,
325 theta: f32,
326 stream: &CudaStream,
327) -> Result<()> {
328 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
329 let mut cache = cache.lock().map_err(|_err| {
330 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
331 })?;
332
333 let kernel = BatchedRopeKernel::new(num_heads, head_dim, seq_len, theta);
334
335 let theta_bits = theta.to_bits();
339 let key = format!("batched_rope_fwd_{num_heads}_{head_dim}_{seq_len}_th{theta_bits:08x}");
340 let module = match cache.get_cached(&key) {
341 Some(m) => m,
342 None => {
343 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
344 cache.get_or_compile(&key, &ptx)?
345 }
346 };
347
348 let config =
349 LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
350
351 let input_ptr = input.as_ptr();
352 let output_ptr = output.as_ptr();
353 let positions_ptr = positions.as_ptr();
354
355 let mut args: [*mut std::ffi::c_void; 3] = [
356 &input_ptr as *const _ as *mut _,
357 &output_ptr as *const _ as *mut _,
358 &positions_ptr as *const _ as *mut _,
359 ];
360
361 unsafe {
362 stream.launch_kernel(module, "batched_rope", &config, &mut args).map_err(|e| {
363 CudaTensorError::KernelError(format!("Batched RoPE NeoX forward failed: {e:?}"))
364 })?;
365 }
366
367 Ok(())
368}
369
370#[cfg(feature = "cuda")]
376pub fn batched_rope_neox_backward(
377 grad_input: &GpuBuffer<f32>,
378 grad_output: &mut GpuBuffer<f32>,
379 positions: &GpuBuffer<u32>,
380 num_heads: u32,
381 head_dim: u32,
382 seq_len: u32,
383 theta: f32,
384 stream: &CudaStream,
385) -> Result<()> {
386 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
387 let mut cache = cache.lock().map_err(|_err| {
388 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
389 })?;
390
391 let kernel = BatchedRopeBackwardKernel::new(num_heads, head_dim, seq_len, theta);
392
393 let theta_bits = theta.to_bits();
396 let key = format!("batched_rope_bwd_{num_heads}_{head_dim}_{seq_len}_th{theta_bits:08x}");
397 let module = match cache.get_cached(&key) {
398 Some(m) => m,
399 None => {
400 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
401 cache.get_or_compile(&key, &ptx)?
402 }
403 };
404
405 let config =
406 LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
407
408 let input_ptr = grad_input.as_ptr();
409 let output_ptr = grad_output.as_ptr();
410 let positions_ptr = positions.as_ptr();
411
412 let mut args: [*mut std::ffi::c_void; 3] = [
413 &input_ptr as *const _ as *mut _,
414 &output_ptr as *const _ as *mut _,
415 &positions_ptr as *const _ as *mut _,
416 ];
417
418 unsafe {
419 stream.launch_kernel(module, "batched_rope_backward", &config, &mut args).map_err(|e| {
420 CudaTensorError::KernelError(format!("Batched RoPE NeoX backward failed: {e:?}"))
421 })?;
422 }
423
424 Ok(())
425}
426
427#[cfg(feature = "cuda")]
444pub fn fused_residual_rmsnorm_forward(
445 residual: &GpuBuffer<f32>,
446 input: &GpuBuffer<f32>,
447 residual_out: &mut GpuBuffer<f32>,
448 output: &mut GpuBuffer<f32>,
449 gamma: &GpuBuffer<f32>,
450 batch_size: u32,
451 hidden_size: u32,
452 stream: &CudaStream,
453) -> Result<()> {
454 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
455 let mut cache = cache.lock().map_err(|_err| {
456 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
457 })?;
458
459 let key = format!("fused_residual_rmsnorm_{hidden_size}");
460 let module = match cache.get_cached(&key) {
461 Some(m) => m,
462 None => {
463 let kernel = FusedResidualRmsNormKernel::new(hidden_size);
464 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
465 cache.get_or_compile(&key, &ptx)?
466 }
467 };
468
469 let config = LaunchConfig { grid: (1, batch_size, 1), block: (32, 1, 1), shared_mem: 0 };
472
473 let residual_ptr = residual.as_ptr();
474 let input_ptr = input.as_ptr();
475 let output_ptr = output.as_ptr();
476 let gamma_ptr = gamma.as_ptr();
477
478 let mut args: [*mut std::ffi::c_void; 4] = [
479 &residual_ptr as *const _ as *mut _,
480 &input_ptr as *const _ as *mut _,
481 &output_ptr as *const _ as *mut _,
482 &gamma_ptr as *const _ as *mut _,
483 ];
484
485 if residual_out.as_ptr() != residual.as_ptr() {
489 crate::autograd::cuda_forward::residual_add_forward(
491 residual,
492 input,
493 residual_out,
494 batch_size * hidden_size,
495 stream,
496 )?;
497 }
498
499 unsafe {
501 stream.launch_kernel(module, "fused_residual_rmsnorm", &config, &mut args).map_err(
502 |e| {
503 CudaTensorError::KernelError(format!(
504 "Fused residual+RMSNorm forward failed: {e:?}"
505 ))
506 },
507 )?;
508 }
509
510 Ok(())
511}
512
513#[cfg(all(test, feature = "cuda"))]
514mod tests {
515 use super::*;
516 use crate::autograd::cuda_forward::cache::init_forward_kernel_cache;
517 use crate::autograd::cuda_tensor::CudaDevice;
518 use trueno_gpu::driver::GpuBuffer;
519
520 fn cpu_rmsnorm_reference(input: &[f32], gamma: &[f32], eps: f32) -> Vec<f32> {
523 let n = input.len() as f32;
524 let mean_sq: f32 = input.iter().map(|v| v * v).sum::<f32>() / n;
525 let rms = (mean_sq + eps).sqrt();
526 input.iter().zip(gamma.iter()).map(|(&x, &g)| (x / rms) * g).collect()
527 }
528
529 #[test]
541 fn falsify_cuda_rmsnorm_eps_parity_qwen_1e_minus_6() {
542 let device = match CudaDevice::default_device() {
543 Ok(d) => d,
544 Err(e) => {
545 eprintln!("[falsify-cuda-rmsnorm-eps-parity-001] skipping (no CUDA host): {e}");
546 return;
547 }
548 };
549 let ctx = device.context().clone();
550 let stream = device.stream();
551 if let Err(e) = init_forward_kernel_cache(ctx.clone()) {
552 eprintln!("[falsify-cuda-rmsnorm-eps-parity-001] kernel cache init failed: {e}");
553 return;
554 }
555
556 let hidden_size = 896usize;
561 let batch_size = 4u32;
562 let total = batch_size as usize * hidden_size;
563 let input_data: Vec<f32> =
564 (0..total).map(|i| (((i as f32) * 0.013).sin()) * 0.02).collect();
565 let gamma_data: Vec<f32> =
566 (0..hidden_size).map(|i| 1.0 + ((i as f32) * 0.005).cos() * 0.1).collect();
567
568 let mut cpu_out = Vec::with_capacity(total);
570 for b in 0..batch_size as usize {
571 let row = &input_data[b * hidden_size..(b + 1) * hidden_size];
572 cpu_out.extend(cpu_rmsnorm_reference(row, &gamma_data, 1e-6));
573 }
574
575 let input_gpu = GpuBuffer::from_host(&ctx, &input_data).expect("input");
576 let gamma_gpu = GpuBuffer::from_host(&ctx, &gamma_data).expect("gamma");
577 let mut output_gpu = GpuBuffer::<f32>::new(&ctx, total).expect("output alloc");
578
579 rms_norm_forward_with_eps(
580 &input_gpu,
581 &gamma_gpu,
582 &mut output_gpu,
583 batch_size,
584 hidden_size as u32,
585 1e-6,
586 stream,
587 )
588 .expect("kernel launch");
589 stream.synchronize().expect("sync");
590
591 let mut gpu_out = vec![0.0f32; total];
592 output_gpu.copy_to_host(&mut gpu_out).expect("download");
593
594 let max_diff =
595 cpu_out.iter().zip(gpu_out.iter()).map(|(c, g)| (c - g).abs()).fold(0.0f32, f32::max);
596
597 eprintln!("[falsify-cuda-rmsnorm-eps-parity-001] max_diff={max_diff} (Qwen eps=1e-6)");
598 assert!(
599 max_diff < 1e-4,
600 "FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: max_diff={max_diff} >= 1e-4. \
601 CUDA RMSNorm kernel disagrees with CPU reference at Qwen eps=1e-6. \
602 Pre-fix root cause: BatchedVectorizedRmsNormKernel::new hardcodes \
603 epsilon=1e-5 (Llama default) so calling `rms_norm_forward` for \
604 Qwen2 silently uses the wrong eps. Fix: \
605 `rms_norm_forward_with_eps(.., eps, ..)` threads `config.rms_norm_eps` \
606 into the kernel and the cache key includes eps bits to avoid stale \
607 PTX shadowing. See contract apr-pretrain-cuda-rmsnorm-eps-parity-v1.yaml."
608 );
609 }
610}