entrenar/autograd/cuda_backward/
elementwise.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::backward::{GeluBackwardKernel, ReluBackwardKernel, SiluBackwardKernel};
10#[cfg(feature = "cuda")]
11use trueno_gpu::kernels::Kernel;
12
13use super::super::cuda_tensor::{CudaTensorError, Result};
14#[cfg(feature = "cuda")]
15use super::cache::KERNEL_CACHE;
16
17#[cfg(feature = "cuda")]
27pub fn relu_backward(
28 input: &GpuBuffer<f32>,
29 grad_output: &GpuBuffer<f32>,
30 grad_input: &mut GpuBuffer<f32>,
31 stream: &CudaStream,
32) -> Result<()> {
33 let n = input.len() as u32;
34
35 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
36 let mut cache = cache.lock().map_err(|_err| {
37 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
38 })?;
39
40 let module = match cache.get_cached("relu_backward") {
41 Some(m) => m,
42 None => {
43 let kernel = ReluBackwardKernel::new(n);
44 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
45 cache.get_or_compile("relu_backward", &ptx)?
46 }
47 };
48
49 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
50
51 let input_ptr = input.as_ptr();
52 let grad_out_ptr = grad_output.as_ptr();
53 let grad_in_ptr = grad_input.as_ptr();
54
55 let mut args: [*mut std::ffi::c_void; 4] = [
56 &input_ptr as *const _ as *mut _,
57 &grad_out_ptr as *const _ as *mut _,
58 &grad_in_ptr as *const _ as *mut _,
59 &n as *const _ as *mut _,
60 ];
61
62 unsafe {
65 stream.launch_kernel(module, "relu_backward", &config, &mut args).map_err(|e| {
66 CudaTensorError::KernelError(format!("ReLU backward launch failed: {e:?}"))
67 })?;
68 }
69
70 Ok(())
71}
72
73#[cfg(feature = "cuda")]
77pub fn gelu_backward(
78 input: &GpuBuffer<f32>,
79 grad_output: &GpuBuffer<f32>,
80 grad_input: &mut GpuBuffer<f32>,
81 stream: &CudaStream,
82) -> Result<()> {
83 let n = input.len() as u32;
84
85 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
86 let mut cache = cache.lock().map_err(|_err| {
87 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
88 })?;
89
90 let module = match cache.get_cached("gelu_backward") {
91 Some(m) => m,
92 None => {
93 let kernel = GeluBackwardKernel::new(n);
94 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
95 cache.get_or_compile("gelu_backward", &ptx)?
96 }
97 };
98
99 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
100
101 let input_ptr = input.as_ptr();
102 let grad_out_ptr = grad_output.as_ptr();
103 let grad_in_ptr = grad_input.as_ptr();
104
105 let mut args: [*mut std::ffi::c_void; 4] = [
106 &input_ptr as *const _ as *mut _,
107 &grad_out_ptr as *const _ as *mut _,
108 &grad_in_ptr as *const _ as *mut _,
109 &n as *const _ as *mut _,
110 ];
111
112 unsafe {
115 stream.launch_kernel(module, "gelu_backward", &config, &mut args).map_err(|e| {
116 CudaTensorError::KernelError(format!("GELU backward launch failed: {e:?}"))
117 })?;
118 }
119
120 Ok(())
121}
122
123#[cfg(feature = "cuda")]
127pub fn silu_backward(
128 input: &GpuBuffer<f32>,
129 grad_output: &GpuBuffer<f32>,
130 grad_input: &mut GpuBuffer<f32>,
131 stream: &CudaStream,
132) -> Result<()> {
133 let n = input.len() as u32;
134
135 let cache = KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
136 let mut cache = cache.lock().map_err(|_err| {
137 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
138 })?;
139
140 let module = match cache.get_cached("silu_backward") {
141 Some(m) => m,
142 None => {
143 let kernel = SiluBackwardKernel::new(n);
144 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
145 cache.get_or_compile("silu_backward", &ptx)?
146 }
147 };
148
149 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
150
151 let input_ptr = input.as_ptr();
152 let grad_out_ptr = grad_output.as_ptr();
153 let grad_in_ptr = grad_input.as_ptr();
154
155 let mut args: [*mut std::ffi::c_void; 4] = [
156 &input_ptr as *const _ as *mut _,
157 &grad_out_ptr as *const _ as *mut _,
158 &grad_in_ptr as *const _ as *mut _,
159 &n as *const _ as *mut _,
160 ];
161
162 unsafe {
165 stream.launch_kernel(module, "silu_backward", &config, &mut args).map_err(|e| {
166 CudaTensorError::KernelError(format!("SiLU backward launch failed: {e:?}"))
167 })?;
168 }
169
170 Ok(())
171}