Skip to main content

entrenar/autograd/cuda_backward/
elementwise.rs

1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::backward::{GeluBackwardKernel, ReluBackwardKernel, SiluBackwardKernel};
10#[cfg(feature = "cuda")]
11use trueno_gpu::kernels::Kernel;
12
13use super::super::cuda_tensor::{CudaTensorError, Result};
14#[cfg(feature = "cuda")]
15use super::cache::KERNEL_CACHE;
16
17/// ReLU backward pass on GPU
18///
19/// Computes: grad_input = grad_output * (input > 0 ? 1 : 0)
20///
21/// # Arguments
22/// * `input` - Original input to forward pass
23/// * `grad_output` - Gradient from upstream
24/// * `grad_input` - Output buffer for computed gradient
25/// * `stream` - CUDA stream for async execution
26#[cfg(feature = "cuda")]
27pub fn relu_backward(
28    input: &GpuBuffer<f32>,
29    grad_output: &GpuBuffer<f32>,
30    grad_input: &mut GpuBuffer<f32>,
31    stream: &CudaStream,
32) -> Result<()> {
33    let n = input.len() as u32;
34
35    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
36    let mut cache = cache.lock().map_err(|_err| {
37        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
38    })?;
39
40    let module = match cache.get_cached("relu_backward") {
41        Some(m) => m,
42        None => {
43            let kernel = ReluBackwardKernel::new(n);
44            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
45            cache.get_or_compile("relu_backward", &ptx)?
46        }
47    };
48
49    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
50
51    let input_ptr = input.as_ptr();
52    let grad_out_ptr = grad_output.as_ptr();
53    let grad_in_ptr = grad_input.as_ptr();
54
55    let mut args: [*mut std::ffi::c_void; 4] = [
56        &input_ptr as *const _ as *mut _,
57        &grad_out_ptr as *const _ as *mut _,
58        &grad_in_ptr as *const _ as *mut _,
59        &n as *const _ as *mut _,
60    ];
61
62    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
63    // matching sizes, and the kernel parameters match the expected PTX signature.
64    unsafe {
65        stream.launch_kernel(module, "relu_backward", &config, &mut args).map_err(|e| {
66            CudaTensorError::KernelError(format!("ReLU backward launch failed: {e:?}"))
67        })?;
68    }
69
70    Ok(())
71}
72
73/// GELU backward pass on GPU
74///
75/// Computes gradient using tanh approximation derivative
76#[cfg(feature = "cuda")]
77pub fn gelu_backward(
78    input: &GpuBuffer<f32>,
79    grad_output: &GpuBuffer<f32>,
80    grad_input: &mut GpuBuffer<f32>,
81    stream: &CudaStream,
82) -> Result<()> {
83    let n = input.len() as u32;
84
85    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
86    let mut cache = cache.lock().map_err(|_err| {
87        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
88    })?;
89
90    let module = match cache.get_cached("gelu_backward") {
91        Some(m) => m,
92        None => {
93            let kernel = GeluBackwardKernel::new(n);
94            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
95            cache.get_or_compile("gelu_backward", &ptx)?
96        }
97    };
98
99    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
100
101    let input_ptr = input.as_ptr();
102    let grad_out_ptr = grad_output.as_ptr();
103    let grad_in_ptr = grad_input.as_ptr();
104
105    let mut args: [*mut std::ffi::c_void; 4] = [
106        &input_ptr as *const _ as *mut _,
107        &grad_out_ptr as *const _ as *mut _,
108        &grad_in_ptr as *const _ as *mut _,
109        &n as *const _ as *mut _,
110    ];
111
112    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
113    // matching sizes, and the kernel parameters match the expected PTX signature.
114    unsafe {
115        stream.launch_kernel(module, "gelu_backward", &config, &mut args).map_err(|e| {
116            CudaTensorError::KernelError(format!("GELU backward launch failed: {e:?}"))
117        })?;
118    }
119
120    Ok(())
121}
122
123/// SiLU/Swish backward pass on GPU
124///
125/// Computes: grad_input = grad_output * (sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x)))
126#[cfg(feature = "cuda")]
127pub fn silu_backward(
128    input: &GpuBuffer<f32>,
129    grad_output: &GpuBuffer<f32>,
130    grad_input: &mut GpuBuffer<f32>,
131    stream: &CudaStream,
132) -> Result<()> {
133    let n = input.len() as u32;
134
135    let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
136    let mut cache = cache.lock().map_err(|_err| {
137        CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
138    })?;
139
140    let module = match cache.get_cached("silu_backward") {
141        Some(m) => m,
142        None => {
143            let kernel = SiluBackwardKernel::new(n);
144            let ptx = kernel.emit_ptx_for_target(cache.sm_target());
145            cache.get_or_compile("silu_backward", &ptx)?
146        }
147    };
148
149    let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
150
151    let input_ptr = input.as_ptr();
152    let grad_out_ptr = grad_output.as_ptr();
153    let grad_in_ptr = grad_input.as_ptr();
154
155    let mut args: [*mut std::ffi::c_void; 4] = [
156        &input_ptr as *const _ as *mut _,
157        &grad_out_ptr as *const _ as *mut _,
158        &grad_in_ptr as *const _ as *mut _,
159        &n as *const _ as *mut _,
160    ];
161
162    // SAFETY: Kernel launch requires FFI. All buffers are valid GPU allocations with
163    // matching sizes, and the kernel parameters match the expected PTX signature.
164    unsafe {
165        stream.launch_kernel(module, "silu_backward", &config, &mut args).map_err(|e| {
166            CudaTensorError::KernelError(format!("SiLU backward launch failed: {e:?}"))
167        })?;
168    }
169
170    Ok(())
171}