#![allow(unsafe_code)]
#![allow(trivial_casts)]
#![allow(clippy::borrow_as_ptr)]
#![allow(clippy::ref_as_ptr)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::backward::{
BatchedRmsNormBackwardKernel, BatchedSoftmaxBackwardKernel, LayerNormBackwardKernel,
SoftmaxBackwardKernel,
};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::BatchedVectorizedRmsNormKernel;
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::Kernel;
use super::super::cuda_tensor::{CudaTensorError, Result};
#[cfg(feature = "cuda")]
use super::cache::KERNEL_CACHE;
#[cfg(feature = "cuda")]
use provable_contracts_macros::requires;
#[cfg(feature = "cuda")]
#[requires(batch_size > 0 && seq_len > 0)]
pub fn softmax_backward(
softmax_output: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
batch_size: u32,
seq_len: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = format!("softmax_backward_{batch_size}_{seq_len}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = SoftmaxBackwardKernel::new(batch_size, seq_len);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (batch_size, 1, 1),
block: (32.min(seq_len), 1, 1), shared_mem: 0,
};
let output_ptr = softmax_output.as_ptr();
let grad_out_ptr = grad_output.as_ptr();
let grad_in_ptr = grad_input.as_ptr();
let mut args: [*mut std::ffi::c_void; 5] = [
&output_ptr as *const _ as *mut _,
&grad_out_ptr as *const _ as *mut _,
&grad_in_ptr as *const _ as *mut _,
&batch_size as *const _ as *mut _,
&seq_len as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "softmax_backward", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("Softmax backward launch failed: {e:?}"))
})?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn batched_softmax_backward(
softmax_output: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
total_rows: u32,
row_size: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "batched_softmax_backward";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let kernel = BatchedSoftmaxBackwardKernel::new(total_rows, row_size);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(key, &ptx)?
}
};
let config =
LaunchConfig { grid: (total_rows, 1, 1), block: (32.min(row_size), 1, 1), shared_mem: 0 };
let output_ptr = softmax_output.as_ptr();
let grad_out_ptr = grad_output.as_ptr();
let grad_in_ptr = grad_input.as_ptr();
let mut args: [*mut std::ffi::c_void; 5] = [
&output_ptr as *const _ as *mut _,
&grad_out_ptr as *const _ as *mut _,
&grad_in_ptr as *const _ as *mut _,
&total_rows as *const _ as *mut _,
&row_size as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "batched_softmax_backward", &config, &mut args).map_err(
|e| {
CudaTensorError::KernelError(format!(
"Batched softmax backward launch failed: {e:?}"
))
},
)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn rms_norm_backward(
input: &GpuBuffer<f32>,
gamma: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
grad_gamma: &mut GpuBuffer<f32>,
batch_size: u32,
hidden_size: u32,
eps: f32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = "batched_rms_norm_backward";
let module = match cache.get_cached(key) {
Some(m) => m,
None => {
let kernel = BatchedRmsNormBackwardKernel::new(batch_size, hidden_size, eps);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(key, &ptx)?
}
};
let config = LaunchConfig {
grid: (batch_size, 1, 1),
block: (32.min(hidden_size), 1, 1),
shared_mem: 0,
};
let input_ptr = input.as_ptr();
let gamma_ptr = gamma.as_ptr();
let grad_out_ptr = grad_output.as_ptr();
let grad_in_ptr = grad_input.as_ptr();
let grad_gamma_ptr = grad_gamma.as_ptr();
let mut args: [*mut std::ffi::c_void; 8] = [
&input_ptr as *const _ as *mut _,
&gamma_ptr as *const _ as *mut _,
&grad_out_ptr as *const _ as *mut _,
&grad_in_ptr as *const _ as *mut _,
&grad_gamma_ptr as *const _ as *mut _,
&batch_size as *const _ as *mut _,
&hidden_size as *const _ as *mut _,
&eps as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "batched_rms_norm_backward", &config, &mut args).map_err(
|e| CudaTensorError::KernelError(format!("RMSNorm backward launch failed: {e:?}")),
)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn rms_norm_forward(
input: &GpuBuffer<f32>,
gamma: &GpuBuffer<f32>,
output: &mut GpuBuffer<f32>,
batch_size: u32,
hidden_size: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = format!("batched_rmsnorm_fwd_{hidden_size}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = BatchedVectorizedRmsNormKernel::new(hidden_size, batch_size);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (1, batch_size, 1),
block: (256, 1, 1),
shared_mem: 8 * 4, };
let input_ptr = input.as_ptr();
let output_ptr = output.as_ptr();
let gamma_ptr = gamma.as_ptr();
let mut args: [*mut std::ffi::c_void; 3] = [
&input_ptr as *const _ as *mut _,
&output_ptr as *const _ as *mut _,
&gamma_ptr as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "batched_rmsnorm_vectorized", &config, &mut args).map_err(
|e| CudaTensorError::KernelError(format!("RMSNorm forward launch failed: {e:?}")),
)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
pub fn layer_norm_backward(
input: &GpuBuffer<f32>,
gamma: &GpuBuffer<f32>,
grad_output: &GpuBuffer<f32>,
grad_input: &mut GpuBuffer<f32>,
grad_gamma: &mut GpuBuffer<f32>,
grad_beta: &mut GpuBuffer<f32>,
batch_size: u32,
hidden_size: u32,
stream: &CudaStream,
) -> Result<()> {
let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
let mut cache = cache.lock().map_err(|_err| {
CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
})?;
let key = format!("layer_norm_backward_{batch_size}_{hidden_size}");
let module = match cache.get_cached(&key) {
Some(m) => m,
None => {
let kernel = LayerNormBackwardKernel::new(batch_size, hidden_size);
let ptx = kernel.emit_ptx_for_target(cache.sm_target());
cache.get_or_compile(&key, &ptx)?
}
};
let config = LaunchConfig {
grid: (batch_size, 1, 1),
block: (256.min(hidden_size), 1, 1),
shared_mem: 0,
};
let input_ptr = input.as_ptr();
let gamma_ptr = gamma.as_ptr();
let grad_out_ptr = grad_output.as_ptr();
let grad_in_ptr = grad_input.as_ptr();
let grad_gamma_ptr = grad_gamma.as_ptr();
let grad_beta_ptr = grad_beta.as_ptr();
let mut args: [*mut std::ffi::c_void; 8] = [
&input_ptr as *const _ as *mut _,
&gamma_ptr as *const _ as *mut _,
&grad_out_ptr as *const _ as *mut _,
&grad_in_ptr as *const _ as *mut _,
&grad_gamma_ptr as *const _ as *mut _,
&grad_beta_ptr as *const _ as *mut _,
&batch_size as *const _ as *mut _,
&hidden_size as *const _ as *mut _,
];
unsafe {
stream.launch_kernel(module, "layer_norm_backward", &config, &mut args).map_err(|e| {
CudaTensorError::KernelError(format!("LayerNorm backward launch failed: {e:?}"))
})?;
}
Ok(())
}