#![allow(unsafe_code)]
#![allow(trivial_casts)]
#![allow(clippy::borrow_as_ptr)]
#![allow(clippy::ref_as_ptr)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::{
BatchedToInterleavedKernel, BatchedTransposeKernel, ElementwiseMulKernel,
InterleavedToBatchedKernel, Kernel, ResidualAddKernel, ScaleKernel,
};
use crate::autograd::cuda_tensor::{CudaTensorError, Result};
#[cfg(feature = "cuda")]
use super::cache::FORWARD_KERNEL_CACHE;
#[cfg(feature = "cuda")]
pub fn residual_add_forward(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "residual_add_forward".to_string(); let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = ResidualAddKernel::new(n);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 4] = [
&a_ptr as *const _ as *mut _,
&b_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&n as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "residual_add", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Residual add forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn inplace_add_gpu(
dst: &mut GpuBuffer<f32>,
src: &GpuBuffer<f32>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "inplace_add".to_string(); let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = ResidualAddKernel::new(n);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let dst_ptr = dst.as_ptr();
let src_ptr = src.as_ptr();
let mut args: [*mut std::ffi::c_void; 4] = [
&dst_ptr as *const _ as *mut _,
&src_ptr as *const _ as *mut _,
&dst_ptr as *const _ as *mut _,
&n as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "residual_add", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("In-place add launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn elementwise_mul_forward(
a: &GpuBuffer<f32>,
b: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "elementwise_mul_forward".to_string(); let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = ElementwiseMulKernel::new(n);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let a_ptr = a.as_ptr();
let b_ptr = b.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 4] = [
&a_ptr as *const _ as *mut _,
&b_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&n as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "elementwise_mul", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Elementwise mul forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn scale_forward(
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
scale: f32,
n: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "scale_forward".to_string(); let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = ScaleKernel::new(n);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let input_ptr = input.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 4] = [
&input_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&scale as *const _ as *mut _,
&n as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "scale", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Scale forward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn interleaved_to_batched_forward(
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
seq_len: u32,
n_heads: u32,
head_dim: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let total = seq_len * n_heads * head_dim;
let key = "interleaved_to_batched";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let kernel = InterleavedToBatchedKernel::new(seq_len, n_heads, head_dim);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(key, &ptx)?
}
};
let config =
LaunchConfig { grid: (total.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let input_ptr = input.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&input_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&seq_len as *const _ as *mut _,
&n_heads as *const _ as *mut _,
&head_dim as *const _ as *mut _,
&total as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "interleaved_to_batched", &config, &mut args).map_err(
|e| {
CudaTensorError::KernelError(format!("Interleaved-to-batched launch failed: {e:?}"))
},
)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn batched_transpose_forward(
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
batch: u32,
rows: u32,
cols: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let total_per_batch = rows * cols;
let key = "batched_transpose";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let kernel = BatchedTransposeKernel::new(batch, rows, cols);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(key, &ptx)?
}
};
let config = LaunchConfig {
grid: (total_per_batch.div_ceil(256), 1, batch),
block: (256, 1, 1),
shared_mem: 0,
};
let input_ptr = input.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&input_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&batch as *const _ as *mut _,
&rows as *const _ as *mut _,
&cols as *const _ as *mut _,
&total_per_batch as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "batched_transpose", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Batched transpose launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn batched_to_interleaved_forward(
input: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
seq_len: u32,
n_heads: u32,
head_dim: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let total = seq_len * n_heads * head_dim;
let key = "batched_to_interleaved";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let kernel = BatchedToInterleavedKernel::new(seq_len, n_heads, head_dim);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(key, &ptx)?
}
};
let config =
LaunchConfig { grid: (total.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
let input_ptr = input.as_ptr();
let output_ptr = output.as_ptr();
let mut args: [*mut std::ffi::c_void; 6] = [
&input_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&seq_len as *const _ as *mut _,
&n_heads as *const _ as *mut _,
&head_dim as *const _ as *mut _,
&total as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "batched_to_interleaved", &config, &mut args).map_err(
|e| {
CudaTensorError::KernelError(format!("Batched-to-interleaved launch failed: {e:?}"))
},
)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn expand_kv_heads(
src: &GpuBuffer<f32>,
dst: &mut GpuBuffer<f32>,
num_kv_heads: usize,
heads_per_kv: usize,
elems_per_head: usize,
stream: &CudaStream,
) -> Result<()> {
for kv_h in 0..num_kv_heads {
let src_offset = kv_h * elems_per_head;
for rep in 0..heads_per_kv {
let dst_offset = (kv_h * heads_per_kv + rep) * elems_per_head;
unsafe {
dst.copy_from_buffer_at_async(src, dst_offset, src_offset, elems_per_head, stream)
.map_err(|e| {
CudaTensorError::TransferFailed(format!(
"GQA head expansion D2D copy failed: {e}"
))
})?;
}
}
}
Ok(())
}