Skip to main content

entrenar/autograd/cuda_forward/
normalization.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10    BatchedRopeBackwardKernel, BatchedRopeKernel, BatchedVectorizedRmsNormKernel,
11    FusedResidualRmsNormKernel, Kernel, LayerNormKernel, PerHeadRmsNormKernel, RopeNeoxKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19/// Layer normalization forward pass on GPU
20///
21/// Computes: output = gamma * (input - mean) / sqrt(var + eps) + beta
22#[cfg(feature = "cuda")]
23pub fn layer_norm_forward(
24    input: &GpuBuffer<f32>,
25    gamma: &GpuBuffer<f32>,
26    beta: &GpuBuffer<f32>,
27    output: &mut GpuBuffer<f32>,
28    batch_size: u32,
29    hidden_size: u32,
30    stream: &CudaStream,
31) -> Result<()> {
32    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
33    let mut cache = cache.lock().map_err(|_err| {
34        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
35    })?;
36
37    let kernel = LayerNormKernel::new(hidden_size);
38    let kernel_name = kernel.name();
39
40    let key = format!("layer_norm_forward_{hidden_size}");
41    let module = match cache.get_cached(&key) {
42        Some(m) => m,
43        None => {
44            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
45            cache.get_or_compile(&key, &ptx)?
46        }
47    };
48
49    let config = LaunchConfig {
50        grid: (batch_size, 1, 1),
51        block: (256.min(hidden_size), 1, 1),
52        shared_mem: 0,
53    };
54
55    let input_ptr = input.as_ptr();
56    let gamma_ptr = gamma.as_ptr();
57    let beta_ptr = beta.as_ptr();
58    let output_ptr = output.as_ptr();
59
60    let mut args: [*mut std::ffi::c_void; 6] = [
61        &input_ptr as *const _ as *mut _,
62        &gamma_ptr as *const _ as *mut _,
63        &beta_ptr as *const _ as *mut _,
64        &output_ptr as *const _ as *mut _,
65        &batch_size as *const _ as *mut _,
66        &hidden_size as *const _ as *mut _,
67    ];
68
69    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
70    // matching sizes, and the kernel parameters match the expected PTX signature.
71    unsafe {
72        stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
73            CudaTensorError::KernelError(format!("LayerNorm forward launch failed: {e:?}"))
74        })?;
75    }
76
77    Ok(())
78}
79
80/// RMS normalization forward pass on GPU (LLaMA-style)
81///
82/// Computes: output = gamma * input / sqrt(mean(input^2) + eps)
83///
84/// Uses BatchedVectorizedRmsNormKernel: single kernel launch processes all
85/// batch_size rows in parallel via grid.y = batch_size, 256 threads per block.
86///
87/// ALB-076: Previously launched one 32-thread kernel per row (2048 launches for
88/// batch=4, seq=512). nsys profiling showed this was 97.1% of all GPU time.
89/// Single batched launch eliminates 100K+ kernel launches per step.
90#[cfg(feature = "cuda")]
91pub fn rms_norm_forward(
92    input: &GpuBuffer<f32>,
93    gamma: &GpuBuffer<f32>,
94    output: &mut GpuBuffer<f32>,
95    batch_size: u32,
96    hidden_size: u32,
97    stream: &CudaStream,
98) -> Result<()> {
99    // Backwards-compatible default for legacy callers (Llama default).
100    // Production callers in transformer/cuda_block.rs should call
101    // rms_norm_forward_with_eps directly with config.rms_norm_eps so
102    // Qwen2 / Qwen2.5 (rms_norm_eps=1e-6) gets the right epsilon.
103    rms_norm_forward_with_eps(input, gamma, output, batch_size, hidden_size, 1e-5, stream)
104}
105
106/// FALSIFY-CUDA-RMSNORM-EPS-PARITY-001 (eps-aware variant): batched RMSNorm
107/// forward that honours `config.rms_norm_eps` instead of hardcoding 1e-5.
108///
109/// Pre-fix: `rms_norm_forward` constructed `BatchedVectorizedRmsNormKernel::new`
110/// (eps=1e-5, the Llama default) regardless of model. Qwen2 / Qwen2.5
111/// uses `rms_norm_eps=1e-6` per its config.json. The 9e-6 absolute eps
112/// difference compounds over 24 layers × 2 RMSNorms per block = 48 calls,
113/// and is one of the residual contributors to CUDA-CPU forward divergence
114/// surfaced by `apr-pretrain-cuda-forward-parity-v1.yaml`.
115///
116/// Cache key includes eps bits so two different epsilons compile to two
117/// different PTX modules; otherwise a stale cached module would silently
118/// shadow the new eps.
119#[cfg(feature = "cuda")]
120pub fn rms_norm_forward_with_eps(
121    input: &GpuBuffer<f32>,
122    gamma: &GpuBuffer<f32>,
123    output: &mut GpuBuffer<f32>,
124    batch_size: u32,
125    hidden_size: u32,
126    eps: f32,
127    stream: &CudaStream,
128) -> Result<()> {
129    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
130    let mut cache = cache.lock().map_err(|_err| {
131        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
132    })?;
133
134    let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size).with_epsilon(eps);
135
136    // Cache key MUST include eps bits — different eps values compile to
137    // different PTX (the constant is baked into `mov.f32`).
138    let eps_bits = eps.to_bits();
139    let key = format!("batched_rmsnorm_fwd_{hidden_size}_eps{eps_bits:08x}");
140    let module = match cache.get_cached(&key) {
141        Some(m) => m,
142        None => {
143            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
144            cache.get_or_compile(&key, &ptx)?
145        }
146    };
147
148    // Grid: (1, batch_size, 1) — one block per row, all rows in parallel
149    // Block: (256, 1, 1) — 8 warps per block for parallel reduction
150    let config = LaunchConfig {
151        grid: (1, batch_size, 1),
152        block: (256, 1, 1),
153        shared_mem: 8 * 4, // 8 warp partial sums (f32)
154    };
155
156    let input_ptr = input.as_ptr();
157    let output_ptr = output.as_ptr();
158    let gamma_ptr = gamma.as_ptr();
159
160    let mut args: [*mut std::ffi::c_void; 3] = [
161        &input_ptr as *const _ as *mut _,
162        &output_ptr as *const _ as *mut _,
163        &gamma_ptr as *const _ as *mut _,
164    ];
165
166    // SAFETY: Kernel launch requires FFI. input has batch_size * hidden_size elements,
167    // output has batch_size * hidden_size elements, gamma has hidden_size elements.
168    // Parameters match PTX signature (u64 input_ptr, u64 output_ptr, u64 gamma_ptr).
169    unsafe {
170        stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
171            |e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
172        )?;
173    }
174
175    Ok(())
176}
177
178/// Per-head RMSNorm forward pass on GPU (ENT-270: QK-norm for Qwen3).
179///
180/// Applies RMSNorm independently to each attention head:
181///   output[h] = input[h] / sqrt(mean(input[h]^2) + eps) * gamma
182///
183/// Input layout: `[num_heads * head_dim]` (single sequence position, interleaved).
184/// Gamma: `[head_dim]` (shared across all heads).
185///
186/// For seq_len > 1, call once per position (loop in caller).
187#[cfg(feature = "cuda")]
188pub fn per_head_rmsnorm_forward(
189    input: &GpuBuffer<f32>,
190    gamma: &GpuBuffer<f32>,
191    output: &mut GpuBuffer<f32>,
192    num_heads: u32,
193    head_dim: u32,
194    pos_offset: usize,
195    stream: &CudaStream,
196) -> Result<()> {
197    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
198    let mut cache = cache.lock().map_err(|_err| {
199        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
200    })?;
201
202    let kernel = PerHeadRmsNormKernel::new(head_dim, num_heads);
203
204    let key = format!("per_head_rmsnorm_fwd_{head_dim}_{num_heads}");
205    let module = match cache.get_cached(&key) {
206        Some(m) => m,
207        None => {
208            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
209            cache.get_or_compile(&key, &ptx)?
210        }
211    };
212
213    // One block per head, one warp (32 threads) per block
214    let config = LaunchConfig { grid: (num_heads, 1, 1), block: (32, 1, 1), shared_mem: 0 };
215
216    // Offset into the buffer for this position
217    let stride = (num_heads * head_dim) as usize;
218    let input_offset = pos_offset * stride;
219    let output_offset = pos_offset * stride;
220
221    // CUdeviceptr is u64 — use arithmetic, not pointer .add()
222    let input_ptr = input.as_ptr() + (input_offset * std::mem::size_of::<f32>()) as u64;
223    let output_ptr = output.as_ptr() + (output_offset * std::mem::size_of::<f32>()) as u64;
224    let gamma_ptr = gamma.as_ptr();
225
226    let mut args: [*mut std::ffi::c_void; 3] = [
227        &input_ptr as *const _ as *mut _,
228        &output_ptr as *const _ as *mut _,
229        &gamma_ptr as *const _ as *mut _,
230    ];
231
232    unsafe {
233        stream.launch_kernel(module, "per_head_rmsnorm", &config, &mut args).map_err(|e| {
234            CudaTensorError::KernelError(format!("PerHeadRmsNorm forward failed: {e:?}"))
235        })?;
236    }
237
238    Ok(())
239}
240
241/// RoPE (NeoX/half-rotation) forward pass on GPU (ENT-270).
242///
243/// Applies rotary position embeddings with half-rotation layout:
244///   pairs at (i, i + half_dim) — required for Qwen/LLaMA models.
245///
246/// Input layout: `[num_heads * head_dim]` (single sequence position, interleaved).
247///
248/// For seq_len > 1, call once per position with the position index.
249#[cfg(feature = "cuda")]
250pub fn rope_neox_forward(
251    input: &GpuBuffer<f32>,
252    output: &mut GpuBuffer<f32>,
253    num_heads: u32,
254    head_dim: u32,
255    pos: u32,
256    pos_offset: usize,
257    theta: f32,
258    stream: &CudaStream,
259) -> Result<()> {
260    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
261    let mut cache = cache.lock().map_err(|_err| {
262        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
263    })?;
264
265    let kernel = RopeNeoxKernel::new(num_heads, head_dim, theta);
266
267    // FALSIFY-CUDA-ROPE-THETA-CACHE-KEY-001: theta is baked into the
268    // PTX at build_ptx time (RopeNeoxKernel::build_ptx captures
269    // self.theta into the closure as `mov.f32 imm`). Two calls with
270    // different theta values produce different PTX, so the cache key
271    // MUST include theta_bits — otherwise the first theta to populate
272    // the cache wins and subsequent calls silently use the wrong theta
273    // (e.g. Llama 1e4 caches first → Qwen 1e6 calls reuse 1e4 PTX).
274    let theta_bits = theta.to_bits();
275    let key = format!("rope_neox_fwd_{num_heads}_{head_dim}_th{theta_bits:08x}");
276    let module = match cache.get_cached(&key) {
277        Some(m) => m,
278        None => {
279            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
280            cache.get_or_compile(&key, &ptx)?
281        }
282    };
283
284    // One block per head, half_dim threads per block
285    let config =
286        LaunchConfig { grid: (num_heads, 1, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
287
288    // Offset into buffer for this position
289    let stride = (num_heads * head_dim) as usize;
290    let byte_offset = pos_offset * stride * std::mem::size_of::<f32>();
291
292    // CUdeviceptr is u64 — use arithmetic, not pointer .add()
293    let input_ptr = input.as_ptr() + byte_offset as u64;
294    let output_ptr = output.as_ptr() + byte_offset as u64;
295
296    let mut args: [*mut std::ffi::c_void; 3] = [
297        &input_ptr as *const _ as *mut _,
298        &output_ptr as *const _ as *mut _,
299        &pos as *const _ as *mut _,
300    ];
301
302    unsafe {
303        stream.launch_kernel(module, "rope_neox", &config, &mut args).map_err(|e| {
304            CudaTensorError::KernelError(format!("RoPE NeoX forward failed: {e:?}"))
305        })?;
306    }
307
308    Ok(())
309}
310
311/// Batched RoPE NeoX forward — processes all seq_len positions in a single kernel launch.
312///
313/// Replaces per-position `rope_neox_forward` loop to avoid ~2048 kernel launches per block.
314/// Uses Grid(num_heads, seq_len, 1) with positions read from a GPU buffer.
315///
316/// Input layout: `[seq_len, num_heads * head_dim]` (interleaved).
317#[cfg(feature = "cuda")]
318pub fn batched_rope_neox_forward(
319    input: &GpuBuffer<f32>,
320    output: &mut GpuBuffer<f32>,
321    positions: &GpuBuffer<u32>,
322    num_heads: u32,
323    head_dim: u32,
324    seq_len: u32,
325    theta: f32,
326    stream: &CudaStream,
327) -> Result<()> {
328    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
329    let mut cache = cache.lock().map_err(|_err| {
330        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
331    })?;
332
333    let kernel = BatchedRopeKernel::new(num_heads, head_dim, seq_len, theta);
334
335    // FALSIFY-CUDA-ROPE-THETA-CACHE-KEY-001: cache key MUST include
336    // theta_bits (and seq_len, which is also baked in via grid sizing).
337    // See `rope_neox_forward` rationale.
338    let theta_bits = theta.to_bits();
339    let key = format!("batched_rope_fwd_{num_heads}_{head_dim}_{seq_len}_th{theta_bits:08x}");
340    let module = match cache.get_cached(&key) {
341        Some(m) => m,
342        None => {
343            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
344            cache.get_or_compile(&key, &ptx)?
345        }
346    };
347
348    let config =
349        LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
350
351    let input_ptr = input.as_ptr();
352    let output_ptr = output.as_ptr();
353    let positions_ptr = positions.as_ptr();
354
355    let mut args: [*mut std::ffi::c_void; 3] = [
356        &input_ptr as *const _ as *mut _,
357        &output_ptr as *const _ as *mut _,
358        &positions_ptr as *const _ as *mut _,
359    ];
360
361    unsafe {
362        stream.launch_kernel(module, "batched_rope", &config, &mut args).map_err(|e| {
363            CudaTensorError::KernelError(format!("Batched RoPE NeoX forward failed: {e:?}"))
364        })?;
365    }
366
367    Ok(())
368}
369
370/// Batched RoPE NeoX backward — inverse rotation for gradient flow.
371///
372/// Applies R^T(-θ) to gradients so Q/K projection backward receives
373/// correctly-framed gradients. Without this, dW_q and dW_k are computed
374/// in the rotated coordinate frame, producing incorrect weight updates.
375#[cfg(feature = "cuda")]
376pub fn batched_rope_neox_backward(
377    grad_input: &GpuBuffer<f32>,
378    grad_output: &mut GpuBuffer<f32>,
379    positions: &GpuBuffer<u32>,
380    num_heads: u32,
381    head_dim: u32,
382    seq_len: u32,
383    theta: f32,
384    stream: &CudaStream,
385) -> Result<()> {
386    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
387    let mut cache = cache.lock().map_err(|_err| {
388        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
389    })?;
390
391    let kernel = BatchedRopeBackwardKernel::new(num_heads, head_dim, seq_len, theta);
392
393    // FALSIFY-CUDA-ROPE-THETA-CACHE-KEY-001: cache key MUST include
394    // theta_bits. See `rope_neox_forward` rationale.
395    let theta_bits = theta.to_bits();
396    let key = format!("batched_rope_bwd_{num_heads}_{head_dim}_{seq_len}_th{theta_bits:08x}");
397    let module = match cache.get_cached(&key) {
398        Some(m) => m,
399        None => {
400            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
401            cache.get_or_compile(&key, &ptx)?
402        }
403    };
404
405    let config =
406        LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
407
408    let input_ptr = grad_input.as_ptr();
409    let output_ptr = grad_output.as_ptr();
410    let positions_ptr = positions.as_ptr();
411
412    let mut args: [*mut std::ffi::c_void; 3] = [
413        &input_ptr as *const _ as *mut _,
414        &output_ptr as *const _ as *mut _,
415        &positions_ptr as *const _ as *mut _,
416    ];
417
418    unsafe {
419        stream.launch_kernel(module, "batched_rope_backward", &config, &mut args).map_err(|e| {
420            CudaTensorError::KernelError(format!("Batched RoPE NeoX backward failed: {e:?}"))
421        })?;
422    }
423
424    Ok(())
425}
426
427/// Fused residual add + RMSNorm forward: output = RMSNorm(residual + input, gamma)
428///
429/// Contract: entrenar#321 — eliminates NaN cascade in layers 24-27 by fusing
430/// the residual add with RMSNorm into a single kernel pass. The RMSNorm
431/// normalization prevents activation explosion through the residual chain.
432///
433/// Saves the un-normalized residual sum in `residual_out` for backward pass.
434///
435/// # Parameters
436/// - `residual`: Previous layer output (residual connection input)
437/// - `input`: Current block output to add
438/// - `residual_out`: Stores residual + input (for backward, can alias residual)
439/// - `output`: RMSNorm(residual + input) * gamma
440/// - `gamma`: Scale weights (hidden_size elements)
441/// - `batch_size`: Number of rows (seq_len)
442/// - `hidden_size`: Number of columns per row
443#[cfg(feature = "cuda")]
444pub fn fused_residual_rmsnorm_forward(
445    residual: &GpuBuffer<f32>,
446    input: &GpuBuffer<f32>,
447    residual_out: &mut GpuBuffer<f32>,
448    output: &mut GpuBuffer<f32>,
449    gamma: &GpuBuffer<f32>,
450    batch_size: u32,
451    hidden_size: u32,
452    stream: &CudaStream,
453) -> Result<()> {
454    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
455    let mut cache = cache.lock().map_err(|_err| {
456        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
457    })?;
458
459    let key = format!("fused_residual_rmsnorm_{hidden_size}");
460    let module = match cache.get_cached(&key) {
461        Some(m) => m,
462        None => {
463            let kernel = FusedResidualRmsNormKernel::new(hidden_size);
464            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
465            cache.get_or_compile(&key, &ptx)?
466        }
467    };
468
469    // Grid: (1, batch_size, 1) — one block per row
470    // Block: (32, 1, 1) — single warp for reduction
471    let config = LaunchConfig { grid: (1, batch_size, 1), block: (32, 1, 1), shared_mem: 0 };
472
473    let residual_ptr = residual.as_ptr();
474    let input_ptr = input.as_ptr();
475    let output_ptr = output.as_ptr();
476    let gamma_ptr = gamma.as_ptr();
477
478    let mut args: [*mut std::ffi::c_void; 4] = [
479        &residual_ptr as *const _ as *mut _,
480        &input_ptr as *const _ as *mut _,
481        &output_ptr as *const _ as *mut _,
482        &gamma_ptr as *const _ as *mut _,
483    ];
484
485    // Also store the un-normalized residual sum for backward pass
486    // The fused kernel writes residual+input to output before normalizing,
487    // so we need to save it separately if residual_out != output
488    if residual_out.as_ptr() != residual.as_ptr() {
489        // First do the residual add into residual_out
490        crate::autograd::cuda_forward::residual_add_forward(
491            residual,
492            input,
493            residual_out,
494            batch_size * hidden_size,
495            stream,
496        )?;
497    }
498
499    // Launch fused kernel: output = RMSNorm(residual + input) * gamma
500    unsafe {
501        stream.launch_kernel(module, "fused_residual_rmsnorm", &config, &mut args).map_err(
502            |e| {
503                CudaTensorError::KernelError(format!(
504                    "Fused residual+RMSNorm forward failed: {e:?}"
505                ))
506            },
507        )?;
508    }
509
510    Ok(())
511}
512
513#[cfg(all(test, feature = "cuda"))]
514mod tests {
515    use super::*;
516    use crate::autograd::cuda_forward::cache::init_forward_kernel_cache;
517    use crate::autograd::cuda_tensor::CudaDevice;
518    use trueno_gpu::driver::GpuBuffer;
519
520    /// Reference CPU RMSNorm matching the kernel's exact arithmetic order:
521    /// rms = sqrt(mean(x^2) + eps); y = (x / rms) * gamma.
522    fn cpu_rmsnorm_reference(input: &[f32], gamma: &[f32], eps: f32) -> Vec<f32> {
523        let n = input.len() as f32;
524        let mean_sq: f32 = input.iter().map(|v| v * v).sum::<f32>() / n;
525        let rms = (mean_sq + eps).sqrt();
526        input.iter().zip(gamma.iter()).map(|(&x, &g)| (x / rms) * g).collect()
527    }
528
529    /// FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: With Qwen's eps=1e-6 the
530    /// CUDA `rms_norm_forward_with_eps` MUST match the CPU reference to
531    /// within 1e-5 absolute. The legacy `rms_norm_forward` (eps=1e-5
532    /// hardcoded) cannot meet this bound on Qwen because the eps in the
533    /// kernel disagrees with the reference's eps.
534    ///
535    /// On main pre-fix this test FAILS for `rms_norm_forward` (legacy)
536    /// because the kernel uses eps=1e-5 while the CPU ref uses 1e-6.
537    /// Post-fix `rms_norm_forward_with_eps(eps=1e-6)` passes by
538    /// construction — the kernel compiles with the same eps the
539    /// reference uses, so diffs are bounded by f32 round-off only.
540    #[test]
541    fn falsify_cuda_rmsnorm_eps_parity_qwen_1e_minus_6() {
542        let device = match CudaDevice::default_device() {
543            Ok(d) => d,
544            Err(e) => {
545                eprintln!("[falsify-cuda-rmsnorm-eps-parity-001] skipping (no CUDA host): {e}");
546                return;
547            }
548        };
549        let ctx = device.context().clone();
550        let stream = device.stream();
551        if let Err(e) = init_forward_kernel_cache(ctx.clone()) {
552            eprintln!("[falsify-cuda-rmsnorm-eps-parity-001] kernel cache init failed: {e}");
553            return;
554        }
555
556        // Qwen 0.5B hidden size; values intentionally small so mean_sq is
557        // small enough that the eps difference 1e-5 vs 1e-6 actually
558        // moves the rms denominator measurably. Real Qwen activations
559        // post-embedding have std~0.02 so this is realistic.
560        let hidden_size = 896usize;
561        let batch_size = 4u32;
562        let total = batch_size as usize * hidden_size;
563        let input_data: Vec<f32> =
564            (0..total).map(|i| (((i as f32) * 0.013).sin()) * 0.02).collect();
565        let gamma_data: Vec<f32> =
566            (0..hidden_size).map(|i| 1.0 + ((i as f32) * 0.005).cos() * 0.1).collect();
567
568        // Build CPU reference once; it's the same per-row.
569        let mut cpu_out = Vec::with_capacity(total);
570        for b in 0..batch_size as usize {
571            let row = &input_data[b * hidden_size..(b + 1) * hidden_size];
572            cpu_out.extend(cpu_rmsnorm_reference(row, &gamma_data, 1e-6));
573        }
574
575        let input_gpu = GpuBuffer::from_host(&ctx, &input_data).expect("input");
576        let gamma_gpu = GpuBuffer::from_host(&ctx, &gamma_data).expect("gamma");
577        let mut output_gpu = GpuBuffer::<f32>::new(&ctx, total).expect("output alloc");
578
579        rms_norm_forward_with_eps(
580            &input_gpu,
581            &gamma_gpu,
582            &mut output_gpu,
583            batch_size,
584            hidden_size as u32,
585            1e-6,
586            stream,
587        )
588        .expect("kernel launch");
589        stream.synchronize().expect("sync");
590
591        let mut gpu_out = vec![0.0f32; total];
592        output_gpu.copy_to_host(&mut gpu_out).expect("download");
593
594        let max_diff =
595            cpu_out.iter().zip(gpu_out.iter()).map(|(c, g)| (c - g).abs()).fold(0.0f32, f32::max);
596
597        eprintln!("[falsify-cuda-rmsnorm-eps-parity-001] max_diff={max_diff} (Qwen eps=1e-6)");
598        assert!(
599            max_diff < 1e-4,
600            "FALSIFY-CUDA-RMSNORM-EPS-PARITY-001: max_diff={max_diff} >= 1e-4. \
601             CUDA RMSNorm kernel disagrees with CPU reference at Qwen eps=1e-6. \
602             Pre-fix root cause: BatchedVectorizedRmsNormKernel::new hardcodes \
603             epsilon=1e-5 (Llama default) so calling `rms_norm_forward` for \
604             Qwen2 silently uses the wrong eps. Fix: \
605             `rms_norm_forward_with_eps(.., eps, ..)` threads `config.rms_norm_eps` \
606             into the kernel and the cache key includes eps bits to avoid stale \
607             PTX shadowing. See contract apr-pretrain-cuda-rmsnorm-eps-parity-v1.yaml."
608        );
609    }
610}