Skip to main content

entrenar/autograd/cuda_forward/
elementwise.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10    BatchedToInterleavedKernel, BatchedTransposeKernel, ElementwiseMulKernel,
11    InterleavedToBatchedKernel, Kernel, ResidualAddKernel, ScaleKernel,
12};
13
14use crate::autograd::cuda_tensor::{CudaTensorError, Result};
15
16#[cfg(feature = "cuda")]
17use super::cache::FORWARD_KERNEL_CACHE;
18
19/// Residual addition forward pass on GPU
20///
21/// Computes: output[i] = a[i] + b[i] for i in [0, n)
22///
23/// # Contract (C-RESADD-001)
24///
25/// - **Precondition**: a.len() == b.len() == output.len() >= n, n > 0
26/// - **Postcondition**: output[i] == a[i] + b[i] for all i in [0, n)
27/// - **Invariant**: Zero CPU-side data transfers (no gpu_to_vec / vec_to_gpu)
28#[cfg(feature = "cuda")]
29pub fn residual_add_forward(
30    a: &GpuBuffer<f32>,
31    b: &GpuBuffer<f32>,
32    output: &mut GpuBuffer<f32>,
33    n: u32,
34    stream: &CudaStream,
35) -> Result<()> {
36    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
37    let mut cache = cache.lock().map_err(|_err| {
38        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
39    })?;
40
41    let key = "residual_add_forward".to_string(); // PTX is n-independent (trueno#184)
42    let module = match cache.get_cached(&key) {
43        Some(m) => m,
44        None => {
45            let kernel = ResidualAddKernel::new(n);
46            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
47            cache.get_or_compile(&key, &ptx)?
48        }
49    };
50
51    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
52
53    let a_ptr = a.as_ptr();
54    let b_ptr = b.as_ptr();
55    let output_ptr = output.as_ptr();
56
57    let mut args: [*mut std::ffi::c_void; 4] = [
58        &a_ptr as *const _ as *mut _,
59        &b_ptr as *const _ as *mut _,
60        &output_ptr as *const _ as *mut _,
61        &n as *const _ as *mut _,
62    ];
63
64    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
65    // matching sizes, and the kernel parameters match the expected PTX signature.
66    unsafe {
67        stream.launch_kernel(module, "residual_add", &config, &mut args).map_err(|e| {
68            CudaTensorError::KernelError(format!("Residual add forward launch failed: {e:?}"))
69        })?;
70    }
71
72    Ok(())
73}
74
75/// In-place GPU buffer addition: dst[i] += src[i]
76///
77/// # Contract (C-IPADD-001)
78///
79/// - **Precondition**: dst.len() >= n, src.len() >= n, n > 0
80/// - **Postcondition**: dst[i] == old_dst[i] + src[i] for all i in [0, n)
81/// - **Invariant**: Zero CPU-side data transfers, zero stream synchronization
82#[cfg(feature = "cuda")]
83pub fn inplace_add_gpu(
84    dst: &mut GpuBuffer<f32>,
85    src: &GpuBuffer<f32>,
86    n: u32,
87    stream: &CudaStream,
88) -> Result<()> {
89    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
90    let mut cache = cache.lock().map_err(|_err| {
91        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
92    })?;
93
94    let key = "inplace_add".to_string(); // PTX is n-independent (trueno#184)
95    let module = match cache.get_cached(&key) {
96        Some(m) => m,
97        None => {
98            let kernel = ResidualAddKernel::new(n);
99            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
100            cache.get_or_compile(&key, &ptx)?
101        }
102    };
103
104    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
105
106    // Use dst pointer for both input `a` and `output` — in-place accumulation.
107    // SAFETY: ResidualAddKernel computes output[i] = a[i] + b[i] per thread.
108    // With a == output (aliased), each thread reads dst[i], adds src[i], writes dst[i].
109    // No inter-thread data dependency — safe for in-place operation.
110    let dst_ptr = dst.as_ptr();
111    let src_ptr = src.as_ptr();
112
113    let mut args: [*mut std::ffi::c_void; 4] = [
114        &dst_ptr as *const _ as *mut _,
115        &src_ptr as *const _ as *mut _,
116        &dst_ptr as *const _ as *mut _,
117        &n as *const _ as *mut _,
118    ];
119
120    // SAFETY: kernel launch with pre-validated device pointers and grid config;
121    // both src and dst are valid CudaTensor buffers with length >= n elements.
122    unsafe {
123        stream.launch_kernel(module, "residual_add", &config, &mut args).map_err(|e| {
124            CudaTensorError::KernelError(format!("In-place add launch failed: {e:?}"))
125        })?;
126    }
127
128    Ok(())
129}
130
131/// Element-wise multiplication forward pass on GPU
132///
133/// Computes: output[i] = a[i] * b[i] for i in [0, n)
134///
135/// # Contract (C-ELMUL-001)
136///
137/// - **Precondition**: a.len() == b.len() == output.len() >= n, n > 0
138/// - **Postcondition**: output[i] == a[i] * b[i] for all i in [0, n)
139/// - **Invariant**: Zero CPU-side data transfers
140#[cfg(feature = "cuda")]
141pub fn elementwise_mul_forward(
142    a: &GpuBuffer<f32>,
143    b: &GpuBuffer<f32>,
144    output: &mut GpuBuffer<f32>,
145    n: u32,
146    stream: &CudaStream,
147) -> Result<()> {
148    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
149    let mut cache = cache.lock().map_err(|_err| {
150        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
151    })?;
152
153    let key = "elementwise_mul_forward".to_string(); // PTX is n-independent (trueno#184)
154    let module = match cache.get_cached(&key) {
155        Some(m) => m,
156        None => {
157            let kernel = ElementwiseMulKernel::new(n);
158            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
159            cache.get_or_compile(&key, &ptx)?
160        }
161    };
162
163    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
164
165    let a_ptr = a.as_ptr();
166    let b_ptr = b.as_ptr();
167    let output_ptr = output.as_ptr();
168
169    let mut args: [*mut std::ffi::c_void; 4] = [
170        &a_ptr as *const _ as *mut _,
171        &b_ptr as *const _ as *mut _,
172        &output_ptr as *const _ as *mut _,
173        &n as *const _ as *mut _,
174    ];
175
176    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
177    // matching sizes, and the kernel parameters match the expected PTX signature.
178    unsafe {
179        stream.launch_kernel(module, "elementwise_mul", &config, &mut args).map_err(|e| {
180            CudaTensorError::KernelError(format!("Elementwise mul forward launch failed: {e:?}"))
181        })?;
182    }
183
184    Ok(())
185}
186
187/// Scale forward pass on GPU
188///
189/// Computes: output[i] = input[i] * scale for i in [0, n)
190///
191/// # Contract (C-SCALE-001)
192///
193/// - **Precondition**: input.len() == output.len() >= n, n > 0
194/// - **Postcondition**: output[i] == input[i] * scale for all i in [0, n)
195/// - **Invariant**: Zero CPU-side data transfers; in-place aliasing allowed (output may == input)
196#[cfg(feature = "cuda")]
197pub fn scale_forward(
198    input: &GpuBuffer<f32>,
199    output: &mut GpuBuffer<f32>,
200    scale: f32,
201    n: u32,
202    stream: &CudaStream,
203) -> Result<()> {
204    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
205    let mut cache = cache.lock().map_err(|_err| {
206        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
207    })?;
208
209    let key = "scale_forward".to_string(); // PTX is n-independent (trueno#184)
210    let module = match cache.get_cached(&key) {
211        Some(m) => m,
212        None => {
213            let kernel = ScaleKernel::new(n);
214            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
215            cache.get_or_compile(&key, &ptx)?
216        }
217    };
218
219    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
220
221    let input_ptr = input.as_ptr();
222    let output_ptr = output.as_ptr();
223
224    let mut args: [*mut std::ffi::c_void; 4] = [
225        &input_ptr as *const _ as *mut _,
226        &output_ptr as *const _ as *mut _,
227        &scale as *const _ as *mut _,
228        &n as *const _ as *mut _,
229    ];
230
231    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
232    // matching sizes, and the kernel parameters match the expected PTX signature.
233    unsafe {
234        stream.launch_kernel(module, "scale", &config, &mut args).map_err(|e| {
235            CudaTensorError::KernelError(format!("Scale forward launch failed: {e:?}"))
236        })?;
237    }
238
239    Ok(())
240}
241
242/// Convert interleaved to batched layout on GPU
243///
244/// Transforms: [seq_len, n_heads * head_dim] → [n_heads, seq_len, head_dim]
245///
246/// Used to prepare Q/K/V for batched multi-head attention GEMM.
247///
248/// # Contract (C-I2B-001)
249///
250/// - **Precondition**: input.len() >= seq_len * n_heads * head_dim, output.len() >= same
251/// - **Postcondition**: output[h, s, d] = input[s, h * head_dim + d]
252/// - **Invariant**: Zero CPU-side data transfers; total element count preserved
253#[cfg(feature = "cuda")]
254pub fn interleaved_to_batched_forward(
255    input: &GpuBuffer<f32>,
256    output: &mut GpuBuffer<f32>,
257    seq_len: u32,
258    n_heads: u32,
259    head_dim: u32,
260    stream: &CudaStream,
261) -> Result<()> {
262    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
263    let mut cache = cache.lock().map_err(|_err| {
264        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
265    })?;
266
267    let total = seq_len * n_heads * head_dim;
268    // Contract: dimension-independent-kernels-v1.yaml (FALSIFY-DIM-004)
269    // Use generic cache key — PTX is dimension-independent, one module handles all dims.
270    let key = "interleaved_to_batched";
271    let module = match cache.get_cached(key) {
272        Some(m) => m,
273        None => {
274            // Constructor args don't matter — PTX is identical for any dimensions
275            let kernel = InterleavedToBatchedKernel::new(seq_len, n_heads, head_dim);
276            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
277            cache.get_or_compile(key, &ptx)?
278        }
279    };
280
281    let config =
282        LaunchConfig { grid: (total.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
283
284    let input_ptr = input.as_ptr();
285    let output_ptr = output.as_ptr();
286
287    // Dimension-independent kernel: pass dims as runtime params
288    let mut args: [*mut std::ffi::c_void; 6] = [
289        &input_ptr as *const _ as *mut _,
290        &output_ptr as *const _ as *mut _,
291        &seq_len as *const _ as *mut _,
292        &n_heads as *const _ as *mut _,
293        &head_dim as *const _ as *mut _,
294        &total as *const _ as *mut _,
295    ];
296
297    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations.
298    unsafe {
299        stream.launch_kernel(module, "interleaved_to_batched", &config, &mut args).map_err(
300            |e| {
301                CudaTensorError::KernelError(format!("Interleaved-to-batched launch failed: {e:?}"))
302            },
303        )?;
304    }
305
306    Ok(())
307}
308
309/// Batched transpose on GPU
310///
311/// Transforms: [batch, rows, cols] → [batch, cols, rows]
312///
313/// Used for K^T in attention: [n_heads, seq_len, head_dim] → [n_heads, head_dim, seq_len]
314///
315/// # Contract (C-BTRANS-001)
316///
317/// - **Precondition**: input.len() >= batch * rows * cols, output.len() >= same
318/// - **Postcondition**: output[b, j, i] = input[b, i, j]
319/// - **Invariant**: Zero CPU-side data transfers; total element count preserved
320#[cfg(feature = "cuda")]
321pub fn batched_transpose_forward(
322    input: &GpuBuffer<f32>,
323    output: &mut GpuBuffer<f32>,
324    batch: u32,
325    rows: u32,
326    cols: u32,
327    stream: &CudaStream,
328) -> Result<()> {
329    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
330    let mut cache = cache.lock().map_err(|_err| {
331        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
332    })?;
333
334    let total_per_batch = rows * cols;
335    // Contract: dimension-independent-kernels-v1.yaml (FALSIFY-DIM-004)
336    let key = "batched_transpose";
337    let module = match cache.get_cached(key) {
338        Some(m) => m,
339        None => {
340            let kernel = BatchedTransposeKernel::new(batch, rows, cols);
341            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
342            cache.get_or_compile(key, &ptx)?
343        }
344    };
345
346    // Grid: (ceil(total_per_batch/256), 1, batch)
347    let config = LaunchConfig {
348        grid: (total_per_batch.div_ceil(256), 1, batch),
349        block: (256, 1, 1),
350        shared_mem: 0,
351    };
352
353    let input_ptr = input.as_ptr();
354    let output_ptr = output.as_ptr();
355
356    // Dimension-independent kernel: pass dims as runtime params
357    let mut args: [*mut std::ffi::c_void; 6] = [
358        &input_ptr as *const _ as *mut _,
359        &output_ptr as *const _ as *mut _,
360        &batch as *const _ as *mut _,
361        &rows as *const _ as *mut _,
362        &cols as *const _ as *mut _,
363        &total_per_batch as *const _ as *mut _,
364    ];
365
366    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations.
367    unsafe {
368        stream.launch_kernel(module, "batched_transpose", &config, &mut args).map_err(|e| {
369            CudaTensorError::KernelError(format!("Batched transpose launch failed: {e:?}"))
370        })?;
371    }
372
373    Ok(())
374}
375
376/// Convert batched to interleaved layout on GPU
377///
378/// Transforms: [n_heads, seq_len, head_dim] → [seq_len, n_heads * head_dim]
379///
380/// Used to convert attention output back to interleaved layout for output projection.
381///
382/// # Contract (C-B2I-001)
383///
384/// - **Precondition**: input.len() >= n_heads * seq_len * head_dim, output.len() >= same
385/// - **Postcondition**: output[s, h * head_dim + d] = input[h, s, d]
386/// - **Invariant**: Zero CPU-side data transfers; total element count preserved
387#[cfg(feature = "cuda")]
388pub fn batched_to_interleaved_forward(
389    input: &GpuBuffer<f32>,
390    output: &mut GpuBuffer<f32>,
391    seq_len: u32,
392    n_heads: u32,
393    head_dim: u32,
394    stream: &CudaStream,
395) -> Result<()> {
396    let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
397    let mut cache = cache.lock().map_err(|_err| {
398        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
399    })?;
400
401    let total = seq_len * n_heads * head_dim;
402    // Contract: dimension-independent-kernels-v1.yaml (FALSIFY-DIM-004)
403    let key = "batched_to_interleaved";
404    let module = match cache.get_cached(key) {
405        Some(m) => m,
406        None => {
407            let kernel = BatchedToInterleavedKernel::new(seq_len, n_heads, head_dim);
408            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
409            cache.get_or_compile(key, &ptx)?
410        }
411    };
412
413    let config =
414        LaunchConfig { grid: (total.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
415
416    let input_ptr = input.as_ptr();
417    let output_ptr = output.as_ptr();
418
419    // Dimension-independent kernel: pass dims as runtime params
420    let mut args: [*mut std::ffi::c_void; 6] = [
421        &input_ptr as *const _ as *mut _,
422        &output_ptr as *const _ as *mut _,
423        &seq_len as *const _ as *mut _,
424        &n_heads as *const _ as *mut _,
425        &head_dim as *const _ as *mut _,
426        &total as *const _ as *mut _,
427    ];
428
429    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations.
430    unsafe {
431        stream.launch_kernel(module, "batched_to_interleaved", &config, &mut args).map_err(
432            |e| {
433                CudaTensorError::KernelError(format!("Batched-to-interleaved launch failed: {e:?}"))
434            },
435        )?;
436    }
437
438    Ok(())
439}
440
441/// Expand KV heads for grouped-query attention (GQA) on GPU
442///
443/// Replicates each KV head `heads_per_kv` times using D2D copies.
444/// Transforms: [num_kv_heads, seq_len, head_dim] → [num_heads, seq_len, head_dim]
445///
446/// # Contract (C-GQAEXP-001)
447///
448/// - **Precondition**: src has at least num_kv_heads * elems_per_head elements,
449///   dst has at least num_kv_heads * heads_per_kv * elems_per_head elements
450/// - **Postcondition**: dst[h, :, :] = src[h / heads_per_kv, :, :] for all h in [0, num_heads)
451/// - **Invariant**: Zero CPU-side data transfers (D2D only)
452#[cfg(feature = "cuda")]
453pub fn expand_kv_heads(
454    src: &GpuBuffer<f32>,
455    dst: &mut GpuBuffer<f32>,
456    num_kv_heads: usize,
457    heads_per_kv: usize,
458    elems_per_head: usize,
459    stream: &CudaStream,
460) -> Result<()> {
461    for kv_h in 0..num_kv_heads {
462        let src_offset = kv_h * elems_per_head;
463        for rep in 0..heads_per_kv {
464            let dst_offset = (kv_h * heads_per_kv + rep) * elems_per_head;
465            // SAFETY: Both buffers are valid GPU allocations with sufficient size.
466            // The async D2D copy is ordered on the stream with prior kernel launches.
467            unsafe {
468                dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
469                    .map_err(|e| {
470                        CudaTensorError::TransferFailed(format!(
471                            "GQA head expansion D2D copy failed: {e}"
472                        ))
473                    })?;
474            }
475        }
476    }
477    Ok(())
478}