entrenar/autograd/cuda_forward/
normalization.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10 BatchedRopeBackwardKernel, BatchedRopeKernel, BatchedVectorizedRmsNormKernel,
11 FusedResidualRmsNormKernel, Kernel, LayerNormKernel, PerHeadRmsNormKernel, RopeNeoxKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19#[cfg(feature = "cuda")]
23pub fn layer_norm_forward(
24 input: &GpuBuffer<f32>,
25 gamma: &GpuBuffer<f32>,
26 beta: &GpuBuffer<f32>,
27 output: &mut GpuBuffer<f32>,
28 batch_size: u32,
29 hidden_size: u32,
30 stream: &CudaStream,
31) -> Result<()> {
32 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
33 let mut cache = cache.lock().map_err(|_err| {
34 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
35 })?;
36
37 let kernel = LayerNormKernel::new(hidden_size);
38 let kernel_name = kernel.name();
39
40 let key = format!("layer_norm_forward_{hidden_size}");
41 let module = match cache.get_cached(&key) {
42 Some(m) => m,
43 None => {
44 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
45 cache.get_or_compile(&key, &ptx)?
46 }
47 };
48
49 let config = LaunchConfig {
50 grid: (batch_size, 1, 1),
51 block: (256.min(hidden_size), 1, 1),
52 shared_mem: 0,
53 };
54
55 let input_ptr = input.as_ptr();
56 let gamma_ptr = gamma.as_ptr();
57 let beta_ptr = beta.as_ptr();
58 let output_ptr = output.as_ptr();
59
60 let mut args: [*mut std::ffi::c_void; 6] = [
61 &input_ptr as *const _ as *mut _,
62 &gamma_ptr as *const _ as *mut _,
63 &beta_ptr as *const _ as *mut _,
64 &output_ptr as *const _ as *mut _,
65 &batch_size as *const _ as *mut _,
66 &hidden_size as *const _ as *mut _,
67 ];
68
69 unsafe {
72 stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
73 CudaTensorError::KernelError(format!("LayerNorm forward launch failed: {e:?}"))
74 })?;
75 }
76
77 Ok(())
78}
79
80#[cfg(feature = "cuda")]
91pub fn rms_norm_forward(
92 input: &GpuBuffer<f32>,
93 gamma: &GpuBuffer<f32>,
94 output: &mut GpuBuffer<f32>,
95 batch_size: u32,
96 hidden_size: u32,
97 stream: &CudaStream,
98) -> Result<()> {
99 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
100 let mut cache = cache.lock().map_err(|_err| {
101 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
102 })?;
103
104 let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size);
105
106 let key = format!("batched_rmsnorm_fwd_{hidden_size}");
107 let module = match cache.get_cached(&key) {
108 Some(m) => m,
109 None => {
110 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
111 cache.get_or_compile(&key, &ptx)?
112 }
113 };
114
115 let config = LaunchConfig {
118 grid: (1, batch_size, 1),
119 block: (256, 1, 1),
120 shared_mem: 8 * 4, };
122
123 let input_ptr = input.as_ptr();
124 let output_ptr = output.as_ptr();
125 let gamma_ptr = gamma.as_ptr();
126
127 let mut args: [*mut std::ffi::c_void; 3] = [
128 &input_ptr as *const _ as *mut _,
129 &output_ptr as *const _ as *mut _,
130 &gamma_ptr as *const _ as *mut _,
131 ];
132
133 unsafe {
137 stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
138 |e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
139 )?;
140 }
141
142 Ok(())
143}
144
145#[cfg(feature = "cuda")]
155pub fn per_head_rmsnorm_forward(
156 input: &GpuBuffer<f32>,
157 gamma: &GpuBuffer<f32>,
158 output: &mut GpuBuffer<f32>,
159 num_heads: u32,
160 head_dim: u32,
161 pos_offset: usize,
162 stream: &CudaStream,
163) -> Result<()> {
164 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
165 let mut cache = cache.lock().map_err(|_err| {
166 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
167 })?;
168
169 let kernel = PerHeadRmsNormKernel::new(head_dim, num_heads);
170
171 let key = format!("per_head_rmsnorm_fwd_{head_dim}_{num_heads}");
172 let module = match cache.get_cached(&key) {
173 Some(m) => m,
174 None => {
175 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
176 cache.get_or_compile(&key, &ptx)?
177 }
178 };
179
180 let config = LaunchConfig { grid: (num_heads, 1, 1), block: (32, 1, 1), shared_mem: 0 };
182
183 let stride = (num_heads * head_dim) as usize;
185 let input_offset = pos_offset * stride;
186 let output_offset = pos_offset * stride;
187
188 let input_ptr = input.as_ptr() + (input_offset * std::mem::size_of::<f32>()) as u64;
190 let output_ptr = output.as_ptr() + (output_offset * std::mem::size_of::<f32>()) as u64;
191 let gamma_ptr = gamma.as_ptr();
192
193 let mut args: [*mut std::ffi::c_void; 3] = [
194 &input_ptr as *const _ as *mut _,
195 &output_ptr as *const _ as *mut _,
196 &gamma_ptr as *const _ as *mut _,
197 ];
198
199 unsafe {
200 stream.launch_kernel(module, "per_head_rmsnorm", &config, &mut args).map_err(|e| {
201 CudaTensorError::KernelError(format!("PerHeadRmsNorm forward failed: {e:?}"))
202 })?;
203 }
204
205 Ok(())
206}
207
208#[cfg(feature = "cuda")]
217pub fn rope_neox_forward(
218 input: &GpuBuffer<f32>,
219 output: &mut GpuBuffer<f32>,
220 num_heads: u32,
221 head_dim: u32,
222 pos: u32,
223 pos_offset: usize,
224 theta: f32,
225 stream: &CudaStream,
226) -> Result<()> {
227 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
228 let mut cache = cache.lock().map_err(|_err| {
229 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
230 })?;
231
232 let kernel = RopeNeoxKernel::new(num_heads, head_dim, theta);
233
234 let key = format!("rope_neox_fwd_{num_heads}_{head_dim}");
235 let module = match cache.get_cached(&key) {
236 Some(m) => m,
237 None => {
238 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
239 cache.get_or_compile(&key, &ptx)?
240 }
241 };
242
243 let config =
245 LaunchConfig { grid: (num_heads, 1, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
246
247 let stride = (num_heads * head_dim) as usize;
249 let byte_offset = pos_offset * stride * std::mem::size_of::<f32>();
250
251 let input_ptr = input.as_ptr() + byte_offset as u64;
253 let output_ptr = output.as_ptr() + byte_offset as u64;
254
255 let mut args: [*mut std::ffi::c_void; 3] = [
256 &input_ptr as *const _ as *mut _,
257 &output_ptr as *const _ as *mut _,
258 &pos as *const _ as *mut _,
259 ];
260
261 unsafe {
262 stream.launch_kernel(module, "rope_neox", &config, &mut args).map_err(|e| {
263 CudaTensorError::KernelError(format!("RoPE NeoX forward failed: {e:?}"))
264 })?;
265 }
266
267 Ok(())
268}
269
270#[cfg(feature = "cuda")]
277pub fn batched_rope_neox_forward(
278 input: &GpuBuffer<f32>,
279 output: &mut GpuBuffer<f32>,
280 positions: &GpuBuffer<u32>,
281 num_heads: u32,
282 head_dim: u32,
283 seq_len: u32,
284 theta: f32,
285 stream: &CudaStream,
286) -> Result<()> {
287 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
288 let mut cache = cache.lock().map_err(|_err| {
289 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
290 })?;
291
292 let kernel = BatchedRopeKernel::new(num_heads, head_dim, seq_len, theta);
293
294 let key = format!("batched_rope_fwd_{num_heads}_{head_dim}");
295 let module = match cache.get_cached(&key) {
296 Some(m) => m,
297 None => {
298 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
299 cache.get_or_compile(&key, &ptx)?
300 }
301 };
302
303 let config =
304 LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
305
306 let input_ptr = input.as_ptr();
307 let output_ptr = output.as_ptr();
308 let positions_ptr = positions.as_ptr();
309
310 let mut args: [*mut std::ffi::c_void; 3] = [
311 &input_ptr as *const _ as *mut _,
312 &output_ptr as *const _ as *mut _,
313 &positions_ptr as *const _ as *mut _,
314 ];
315
316 unsafe {
317 stream.launch_kernel(module, "batched_rope", &config, &mut args).map_err(|e| {
318 CudaTensorError::KernelError(format!("Batched RoPE NeoX forward failed: {e:?}"))
319 })?;
320 }
321
322 Ok(())
323}
324
325#[cfg(feature = "cuda")]
331pub fn batched_rope_neox_backward(
332 grad_input: &GpuBuffer<f32>,
333 grad_output: &mut GpuBuffer<f32>,
334 positions: &GpuBuffer<u32>,
335 num_heads: u32,
336 head_dim: u32,
337 seq_len: u32,
338 theta: f32,
339 stream: &CudaStream,
340) -> Result<()> {
341 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
342 let mut cache = cache.lock().map_err(|_err| {
343 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
344 })?;
345
346 let kernel = BatchedRopeBackwardKernel::new(num_heads, head_dim, seq_len, theta);
347
348 let key = format!("batched_rope_bwd_{num_heads}_{head_dim}");
349 let module = match cache.get_cached(&key) {
350 Some(m) => m,
351 None => {
352 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
353 cache.get_or_compile(&key, &ptx)?
354 }
355 };
356
357 let config =
358 LaunchConfig { grid: (num_heads, seq_len, 1), block: (head_dim / 2, 1, 1), shared_mem: 0 };
359
360 let input_ptr = grad_input.as_ptr();
361 let output_ptr = grad_output.as_ptr();
362 let positions_ptr = positions.as_ptr();
363
364 let mut args: [*mut std::ffi::c_void; 3] = [
365 &input_ptr as *const _ as *mut _,
366 &output_ptr as *const _ as *mut _,
367 &positions_ptr as *const _ as *mut _,
368 ];
369
370 unsafe {
371 stream.launch_kernel(module, "batched_rope_backward", &config, &mut args).map_err(|e| {
372 CudaTensorError::KernelError(format!("Batched RoPE NeoX backward failed: {e:?}"))
373 })?;
374 }
375
376 Ok(())
377}
378
379#[cfg(feature = "cuda")]
396pub fn fused_residual_rmsnorm_forward(
397 residual: &GpuBuffer<f32>,
398 input: &GpuBuffer<f32>,
399 residual_out: &mut GpuBuffer<f32>,
400 output: &mut GpuBuffer<f32>,
401 gamma: &GpuBuffer<f32>,
402 batch_size: u32,
403 hidden_size: u32,
404 stream: &CudaStream,
405) -> Result<()> {
406 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
407 let mut cache = cache.lock().map_err(|_err| {
408 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
409 })?;
410
411 let key = format!("fused_residual_rmsnorm_{hidden_size}");
412 let module = match cache.get_cached(&key) {
413 Some(m) => m,
414 None => {
415 let kernel = FusedResidualRmsNormKernel::new(hidden_size);
416 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
417 cache.get_or_compile(&key, &ptx)?
418 }
419 };
420
421 let config = LaunchConfig { grid: (1, batch_size, 1), block: (32, 1, 1), shared_mem: 0 };
424
425 let residual_ptr = residual.as_ptr();
426 let input_ptr = input.as_ptr();
427 let output_ptr = output.as_ptr();
428 let gamma_ptr = gamma.as_ptr();
429
430 let mut args: [*mut std::ffi::c_void; 4] = [
431 &residual_ptr as *const _ as *mut _,
432 &input_ptr as *const _ as *mut _,
433 &output_ptr as *const _ as *mut _,
434 &gamma_ptr as *const _ as *mut _,
435 ];
436
437 if residual_out.as_ptr() != residual.as_ptr() {
441 crate::autograd::cuda_forward::residual_add_forward(
443 residual,
444 input,
445 residual_out,
446 batch_size * hidden_size,
447 stream,
448 )?;
449 }
450
451 unsafe {
453 stream.launch_kernel(module, "fused_residual_rmsnorm", &config, &mut args).map_err(
454 |e| {
455 CudaTensorError::KernelError(format!(
456 "Fused residual+RMSNorm forward failed: {e:?}"
457 ))
458 },
459 )?;
460 }
461
462 Ok(())
463}