Skip to main content

entrenar/autograd/cuda_forward/
normalization.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10    BatchedRopeBackwardKernel, BatchedRopeKernel, BatchedVectorizedRmsNormKernel,
11    FusedResidualRmsNormKernel, Kernel, LayerNormKernel, PerHeadRmsNormKernel, RopeNeoxKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19/// Layer normalization forward pass on GPU
20///
21/// Computes: output = gamma * (input - mean) / sqrt(var + eps) + beta
22#[cfg(feature = "cuda")]
23pub fn layer_norm_forward(
24    input: &GpuBuffer<f32>,
25    gamma: &GpuBuffer<f32>,
26    beta: &GpuBuffer<f32>,
27    output: &mut GpuBuffer<f32>,
28    batch_size: u32,
29    hidden_size: u32,
30    stream: &CudaStream,
31) -> Result<()> {
32    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
33    let mut cache = cache.lock().map_err(|_err| {
34        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
35    })?;
36
37    let kernel = LayerNormKernel::new(hidden_size);
38    let kernel_name = kernel.name();
39
40    let key = format!("layer_norm_forward_{hidden_size}");
41    let module = match cache.get_cached(&key) {
42        Some(m) => m,
43        None => {
44            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
45            cache.get_or_compile(&key, &ptx)?
46        }
47    };
48
49    let config = LaunchConfig {
50        grid: (batch_size, 1, 1),
51        block: (256.min(hidden_size), 1, 1),
52        shared_mem: 0,
53    };
54
55    let input_ptr = input.as_ptr();
56    let gamma_ptr = gamma.as_ptr();
57    let beta_ptr = beta.as_ptr();
58    let output_ptr = output.as_ptr();
59
60    let mut args: [*mut std::ffi::c_void; 6] = [
61        &input_ptr as *const _ as *mut _,
62        &gamma_ptr as *const _ as *mut _,
63        &beta_ptr as *const _ as *mut _,
64        &output_ptr as *const _ as *mut _,
65        &batch_size as *const _ as *mut _,
66        &hidden_size as *const _ as *mut _,
67    ];
68
69    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
70    // matching sizes, and the kernel parameters match the expected PTX signature.
71    unsafe {
72        stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
73            CudaTensorError::KernelError(format!("LayerNorm forward launch failed: {e:?}"))
74        })?;
75    }
76
77    Ok(())
78}
79
80/// RMS normalization forward pass on GPU (LLaMA-style)
81///
82/// Computes: output = gamma * input / sqrt(mean(input^2) + eps)
83///
84/// Uses BatchedVectorizedRmsNormKernel: single kernel launch processes all
85/// batch_size rows in parallel via grid.y = batch_size, 256 threads per block.
86///
87/// ALB-076: Previously launched one 32-thread kernel per row (2048 launches for
88/// batch=4, seq=512). nsys profiling showed this was 97.1% of all GPU time.
89/// Single batched launch eliminates 100K+ kernel launches per step.
90#[cfg(feature = "cuda")]
91pub fn rms_norm_forward(
92    input: &GpuBuffer<f32>,
93    gamma: &GpuBuffer<f32>,
94    output: &mut GpuBuffer<f32>,
95    batch_size: u32,
96    hidden_size: u32,
97    stream: &CudaStream,
98) -> Result<()> {
99    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
100    let mut cache = cache.lock().map_err(|_err| {
101        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
102    })?;
103
104    let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size);
105
106    let key = format!("batched_rmsnorm_fwd_{hidden_size}");
107    let module = match cache.get_cached(&key) {
108        Some(m) => m,
109        None => {
110            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
111            cache.get_or_compile(&key, &ptx)?
112        }
113    };
114
115    // Grid: (1, batch_size, 1) — one block per row, all rows in parallel
116    // Block: (256, 1, 1) — 8 warps per block for parallel reduction
117    let config = LaunchConfig {
118        grid: (1, batch_size, 1),
119        block: (256, 1, 1),
120        shared_mem: 8 * 4, // 8 warp partial sums (f32)
121    };
122
123    let input_ptr = input.as_ptr();
124    let output_ptr = output.as_ptr();
125    let gamma_ptr = gamma.as_ptr();
126
127    let mut args: [*mut std::ffi::c_void; 3] = [
128        &input_ptr as *const _ as *mut _,
129        &output_ptr as *const _ as *mut _,
130        &gamma_ptr as *const _ as *mut _,
131    ];
132
133    // SAFETY: Kernel launch requires FFI. input has batch_size * hidden_size elements,
134    // output has batch_size * hidden_size elements, gamma has hidden_size elements.
135    // Parameters match PTX signature (u64 input_ptr, u64 output_ptr, u64 gamma_ptr).
136    unsafe {
137        stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
138            |e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
139        )?;
140    }
141
142    Ok(())
143}
144
145/// Per-head RMSNorm forward pass on GPU (ENT-270: QK-norm for Qwen3).
146///
147/// Applies RMSNorm independently to each attention head:
148///   output[h] = input[h] / sqrt(mean(input[h]^2) + eps) * gamma
149///
150/// Input layout: `[num_heads * head_dim]` (single sequence position, interleaved).
151/// Gamma: `[head_dim]` (shared across all heads).
152///
153/// For seq_len > 1, call once per position (loop in caller).
154#[cfg(feature = "cuda")]
155pub fn per_head_rmsnorm_forward(
156    input: &GpuBuffer<f32>,
157    gamma: &GpuBuffer<f32>,
158    output: &mut GpuBuffer<f32>,
159    num_heads: u32,
160    head_dim: u32,
161    pos_offset: usize,
162    stream: &CudaStream,
163) -> Result<()> {
164    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
165    let mut cache = cache.lock().map_err(|_err| {
166        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
167    })?;
168
169    let kernel = PerHeadRmsNormKernel::new(head_dim, num_heads);
170
171    let key = format!("per_head_rmsnorm_fwd_{head_dim}_{num_heads}");
172    let module = match cache.get_cached(&key) {
173        Some(m) => m,
174        None => {
175            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
176            cache.get_or_compile(&key, &ptx)?
177        }
178    };
179
180    // One block per head, one warp (32 threads) per block
181    let config = LaunchConfig { grid: (num_heads, 1, 1), block: (32, 1, 1), shared_mem: 0 };
182
183    // Offset into the buffer for this position
184    let stride = (num_heads * head_dim) as usize;
185    let input_offset = pos_offset * stride;
186    let output_offset = pos_offset * stride;
187
188    // CUdeviceptr is u64 — use arithmetic, not pointer .add()
189    let input_ptr = input.as_ptr() + (input_offset * std::mem::size_of::<f32>()) as u64;
190    let output_ptr = output.as_ptr() + (output_offset * std::mem::size_of::<f32>()) as u64;
191    let gamma_ptr = gamma.as_ptr();
192
193    let mut args: [*mut std::ffi::c_void; 3] = [
194        &input_ptr as *const _ as *mut _,
195        &output_ptr as *const _ as *mut _,
196        &gamma_ptr as *const _ as *mut _,
197    ];
198
199    unsafe {
200        stream.launch_kernel(module, "per_head_rmsnorm", &config, &mut args).map_err(|e| {
201            CudaTensorError::KernelError(format!("PerHeadRmsNorm forward failed: {e:?}"))
202        })?;
203    }
204
205    Ok(())
206}
207
208/// RoPE (NeoX/half-rotation) forward pass on GPU (ENT-270).
209///
210/// Applies rotary position embeddings with half-rotation layout:
211///   pairs at (i, i + half_dim) — required for Qwen/LLaMA models.
212///
213/// Input layout: `[num_heads * head_dim]` (single sequence position, interleaved).
214///
215/// For seq_len > 1, call once per position with the position index.
216#[cfg(feature = "cuda")]
217pub fn rope_neox_forward(
218    input: &GpuBuffer<f32>,
219    output: &mut GpuBuffer<f32>,
220    num_heads: u32,
221    head_dim: u32,
222    pos: u32,
223    pos_offset: usize,
224    theta: f32,
225    stream: &CudaStream,
226) -> Result<()> {
227    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
228    let mut cache = cache.lock().map_err(|_err| {
229        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
230    })?;
231
232    let kernel = RopeNeoxKernel::new(num_heads, head_dim, theta);
233
234    let key = format!("rope_neox_fwd_{num_heads}_{head_dim}");
235    let module = match cache.get_cached(&key) {
236        Some(m) => m,
237        None => {
238            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
239            cache.get_or_compile(&key, &ptx)?
240        }
241    };
242
243    // One block per head, half_dim threads per block
244    let config =
245        LaunchConfig { grid: (num_heads, 1, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
246
247    // Offset into buffer for this position
248    let stride = (num_heads * head_dim) as usize;
249    let byte_offset = pos_offset * stride * std::mem::size_of::<f32>();
250
251    // CUdeviceptr is u64 — use arithmetic, not pointer .add()
252    let input_ptr = input.as_ptr() + byte_offset as u64;
253    let output_ptr = output.as_ptr() + byte_offset as u64;
254
255    let mut args: [*mut std::ffi::c_void; 3] = [
256        &input_ptr as *const _ as *mut _,
257        &output_ptr as *const _ as *mut _,
258        &pos as *const _ as *mut _,
259    ];
260
261    unsafe {
262        stream.launch_kernel(module, "rope_neox", &config, &mut args).map_err(|e| {
263            CudaTensorError::KernelError(format!("RoPE NeoX forward failed: {e:?}"))
264        })?;
265    }
266
267    Ok(())
268}
269
270/// Batched RoPE NeoX forward — processes all seq_len positions in a single kernel launch.
271///
272/// Replaces per-position `rope_neox_forward` loop to avoid ~2048 kernel launches per block.
273/// Uses Grid(num_heads, seq_len, 1) with positions read from a GPU buffer.
274///
275/// Input layout: `[seq_len, num_heads * head_dim]` (interleaved).
276#[cfg(feature = "cuda")]
277pub fn batched_rope_neox_forward(
278    input: &GpuBuffer<f32>,
279    output: &mut GpuBuffer<f32>,
280    positions: &GpuBuffer<u32>,
281    num_heads: u32,
282    head_dim: u32,
283    seq_len: u32,
284    theta: f32,
285    stream: &CudaStream,
286) -> Result<()> {
287    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
288    let mut cache = cache.lock().map_err(|_err| {
289        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
290    })?;
291
292    let kernel = BatchedRopeKernel::new(num_heads, head_dim, seq_len, theta);
293
294    let key = format!("batched_rope_fwd_{num_heads}_{head_dim}");
295    let module = match cache.get_cached(&key) {
296        Some(m) => m,
297        None => {
298            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
299            cache.get_or_compile(&key, &ptx)?
300        }
301    };
302
303    let config =
304        LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
305
306    let input_ptr = input.as_ptr();
307    let output_ptr = output.as_ptr();
308    let positions_ptr = positions.as_ptr();
309
310    let mut args: [*mut std::ffi::c_void; 3] = [
311        &input_ptr as *const _ as *mut _,
312        &output_ptr as *const _ as *mut _,
313        &positions_ptr as *const _ as *mut _,
314    ];
315
316    unsafe {
317        stream.launch_kernel(module, "batched_rope", &config, &mut args).map_err(|e| {
318            CudaTensorError::KernelError(format!("Batched RoPE NeoX forward failed: {e:?}"))
319        })?;
320    }
321
322    Ok(())
323}
324
325/// Batched RoPE NeoX backward — inverse rotation for gradient flow.
326///
327/// Applies R^T(-θ) to gradients so Q/K projection backward receives
328/// correctly-framed gradients. Without this, dW_q and dW_k are computed
329/// in the rotated coordinate frame, producing incorrect weight updates.
330#[cfg(feature = "cuda")]
331pub fn batched_rope_neox_backward(
332    grad_input: &GpuBuffer<f32>,
333    grad_output: &mut GpuBuffer<f32>,
334    positions: &GpuBuffer<u32>,
335    num_heads: u32,
336    head_dim: u32,
337    seq_len: u32,
338    theta: f32,
339    stream: &CudaStream,
340) -> Result<()> {
341    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
342    let mut cache = cache.lock().map_err(|_err| {
343        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
344    })?;
345
346    let kernel = BatchedRopeBackwardKernel::new(num_heads, head_dim, seq_len, theta);
347
348    let key = format!("batched_rope_bwd_{num_heads}_{head_dim}");
349    let module = match cache.get_cached(&key) {
350        Some(m) => m,
351        None => {
352            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
353            cache.get_or_compile(&key, &ptx)?
354        }
355    };
356
357    let config =
358        LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
359
360    let input_ptr = grad_input.as_ptr();
361    let output_ptr = grad_output.as_ptr();
362    let positions_ptr = positions.as_ptr();
363
364    let mut args: [*mut std::ffi::c_void; 3] = [
365        &input_ptr as *const _ as *mut _,
366        &output_ptr as *const _ as *mut _,
367        &positions_ptr as *const _ as *mut _,
368    ];
369
370    unsafe {
371        stream.launch_kernel(module, "batched_rope_backward", &config, &mut args).map_err(|e| {
372            CudaTensorError::KernelError(format!("Batched RoPE NeoX backward failed: {e:?}"))
373        })?;
374    }
375
376    Ok(())
377}
378
379/// Fused residual add + RMSNorm forward: output = RMSNorm(residual + input, gamma)
380///
381/// Contract: entrenar#321 — eliminates NaN cascade in layers 24-27 by fusing
382/// the residual add with RMSNorm into a single kernel pass. The RMSNorm
383/// normalization prevents activation explosion through the residual chain.
384///
385/// Saves the un-normalized residual sum in `residual_out` for backward pass.
386///
387/// # Parameters
388/// - `residual`: Previous layer output (residual connection input)
389/// - `input`: Current block output to add
390/// - `residual_out`: Stores residual + input (for backward, can alias residual)
391/// - `output`: RMSNorm(residual + input) * gamma
392/// - `gamma`: Scale weights (hidden_size elements)
393/// - `batch_size`: Number of rows (seq_len)
394/// - `hidden_size`: Number of columns per row
395#[cfg(feature = "cuda")]
396pub fn fused_residual_rmsnorm_forward(
397    residual: &GpuBuffer<f32>,
398    input: &GpuBuffer<f32>,
399    residual_out: &mut GpuBuffer<f32>,
400    output: &mut GpuBuffer<f32>,
401    gamma: &GpuBuffer<f32>,
402    batch_size: u32,
403    hidden_size: u32,
404    stream: &CudaStream,
405) -> Result<()> {
406    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
407    let mut cache = cache.lock().map_err(|_err| {
408        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
409    })?;
410
411    let key = format!("fused_residual_rmsnorm_{hidden_size}");
412    let module = match cache.get_cached(&key) {
413        Some(m) => m,
414        None => {
415            let kernel = FusedResidualRmsNormKernel::new(hidden_size);
416            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
417            cache.get_or_compile(&key, &ptx)?
418        }
419    };
420
421    // Grid: (1, batch_size, 1) — one block per row
422    // Block: (32, 1, 1) — single warp for reduction
423    let config = LaunchConfig { grid: (1, batch_size, 1), block: (32, 1, 1), shared_mem: 0 };
424
425    let residual_ptr = residual.as_ptr();
426    let input_ptr = input.as_ptr();
427    let output_ptr = output.as_ptr();
428    let gamma_ptr = gamma.as_ptr();
429
430    let mut args: [*mut std::ffi::c_void; 4] = [
431        &residual_ptr as *const _ as *mut _,
432        &input_ptr as *const _ as *mut _,
433        &output_ptr as *const _ as *mut _,
434        &gamma_ptr as *const _ as *mut _,
435    ];
436
437    // Also store the un-normalized residual sum for backward pass
438    // The fused kernel writes residual+input to output before normalizing,
439    // so we need to save it separately if residual_out != output
440    if residual_out.as_ptr() != residual.as_ptr() {
441        // First do the residual add into residual_out
442        crate::autograd::cuda_forward::residual_add_forward(
443            residual,
444            input,
445            residual_out,
446            batch_size * hidden_size,
447            stream,
448        )?;
449    }
450
451    // Launch fused kernel: output = RMSNorm(residual + input) * gamma
452    unsafe {
453        stream.launch_kernel(module, "fused_residual_rmsnorm", &config, &mut args).map_err(
454            |e| {
455                CudaTensorError::KernelError(format!(
456                    "Fused residual+RMSNorm forward failed: {e:?}"
457                ))
458            },
459        )?;
460    }
461
462    Ok(())
463}