entrenar/autograd/cuda_forward/
activations.rs1#![allow(unsafe_code)]
2#![allow(trivial_casts)]
3#![allow(clippy::borrow_as_ptr)]
4#![allow(clippy::ref_as_ptr)]
5
6#[cfg(feature = "cuda")]
7use trueno_gpu::driver::{CudaStream, GpuBuffer, LaunchConfig};
8#[cfg(feature = "cuda")]
9use trueno_gpu::kernels::{
10 BatchedSoftmaxKernel, GeluKernel, Kernel, ReluKernel, SiluKernel, SoftmaxKernel,
11};
12
13use crate::autograd::cuda_tensor::{CudaTensorError, Result};
14
15#[cfg(feature = "cuda")]
16use super::cache::FORWARD_KERNEL_CACHE;
17
18#[cfg(feature = "cuda")]
22pub fn relu_forward(
23 input: &GpuBuffer<f32>,
24 output: &mut GpuBuffer<f32>,
25 n: u32,
26 stream: &CudaStream,
27) -> Result<()> {
28 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
29 let mut cache = cache.lock().map_err(|_err| {
30 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
31 })?;
32
33 let key = "relu_forward".to_string(); let module = match cache.get_cached(&key) {
35 Some(m) => m,
36 None => {
37 let kernel = ReluKernel::new(n);
38 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
39 cache.get_or_compile(&key, &ptx)?
40 }
41 };
42
43 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
44
45 let input_ptr = input.as_ptr();
46 let output_ptr = output.as_ptr();
47
48 let mut args: [*mut std::ffi::c_void; 3] = [
49 &input_ptr as *const _ as *mut _,
50 &output_ptr as *const _ as *mut _,
51 &n as *const _ as *mut _,
52 ];
53
54 unsafe {
57 stream.launch_kernel(module, "relu", &config, &mut args).map_err(|e| {
58 CudaTensorError::KernelError(format!("ReLU forward launch failed: {e:?}"))
59 })?;
60 }
61
62 Ok(())
63}
64
65#[cfg(feature = "cuda")]
71pub fn softmax_forward(
72 input: &GpuBuffer<f32>,
73 output: &mut GpuBuffer<f32>,
74 length: u32,
75 stream: &CudaStream,
76) -> Result<()> {
77 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
78 let mut cache = cache.lock().map_err(|_err| {
79 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
80 })?;
81
82 let kernel = SoftmaxKernel::new(length);
83 let kernel_name = kernel.name();
84
85 let key = "softmax_forward".to_string(); let module = match cache.get_cached(&key) {
87 Some(m) => m,
88 None => {
89 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
90 cache.get_or_compile(&key, &ptx)?
91 }
92 };
93
94 let config = LaunchConfig { grid: (1, 1, 1), block: (32.min(length), 1, 1), shared_mem: 0 };
95
96 let input_ptr = input.as_ptr();
97 let output_ptr = output.as_ptr();
98
99 let mut args: [*mut std::ffi::c_void; 3] = [
100 &input_ptr as *const _ as *mut _,
101 &output_ptr as *const _ as *mut _,
102 &length as *const _ as *mut _,
103 ];
104
105 unsafe {
108 stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
109 CudaTensorError::KernelError(format!("Softmax forward launch failed: {e:?}"))
110 })?;
111 }
112
113 Ok(())
114}
115
116#[cfg(feature = "cuda")]
120pub fn gelu_forward(
121 input: &GpuBuffer<f32>,
122 output: &mut GpuBuffer<f32>,
123 n: u32,
124 stream: &CudaStream,
125) -> Result<()> {
126 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
127 let mut cache = cache.lock().map_err(|_err| {
128 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
129 })?;
130
131 let key = "gelu_forward".to_string(); let module = match cache.get_cached(&key) {
133 Some(m) => m,
134 None => {
135 let kernel = GeluKernel::new(n);
136 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
137 cache.get_or_compile(&key, &ptx)?
138 }
139 };
140
141 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
142
143 let input_ptr = input.as_ptr();
144 let output_ptr = output.as_ptr();
145
146 let mut args: [*mut std::ffi::c_void; 3] = [
147 &input_ptr as *const _ as *mut _,
148 &output_ptr as *const _ as *mut _,
149 &n as *const _ as *mut _,
150 ];
151
152 unsafe {
155 stream.launch_kernel(module, "gelu", &config, &mut args).map_err(|e| {
156 CudaTensorError::KernelError(format!("GELU forward launch failed: {e:?}"))
157 })?;
158 }
159
160 Ok(())
161}
162
163#[cfg(feature = "cuda")]
167pub fn silu_forward(
168 input: &GpuBuffer<f32>,
169 output: &mut GpuBuffer<f32>,
170 n: u32,
171 stream: &CudaStream,
172) -> Result<()> {
173 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
174 let mut cache = cache.lock().map_err(|_err| {
175 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
176 })?;
177
178 let key = "silu_forward".to_string(); let module = match cache.get_cached(&key) {
180 Some(m) => m,
181 None => {
182 let kernel = SiluKernel::new(n);
183 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
184 cache.get_or_compile(&key, &ptx)?
185 }
186 };
187
188 let config = LaunchConfig { grid: (n.div_ceil(256), 1, 1), block: (256, 1, 1), shared_mem: 0 };
189
190 let input_ptr = input.as_ptr();
191 let output_ptr = output.as_ptr();
192
193 let mut args: [*mut std::ffi::c_void; 3] = [
194 &input_ptr as *const _ as *mut _,
195 &output_ptr as *const _ as *mut _,
196 &n as *const _ as *mut _,
197 ];
198
199 unsafe {
202 stream.launch_kernel(module, "silu", &config, &mut args).map_err(|e| {
203 CudaTensorError::KernelError(format!("SiLU forward launch failed: {e:?}"))
204 })?;
205 }
206
207 Ok(())
208}
209
210#[cfg(feature = "cuda")]
224pub fn batched_softmax_forward(
225 input: &GpuBuffer<f32>,
226 output: &mut GpuBuffer<f32>,
227 total_rows: u32,
228 row_size: u32,
229 stream: &CudaStream,
230) -> Result<()> {
231 let cache = FORWARD_KERNEL_CACHE.get().ok_or(CudaTensorError::DeviceNotInitialized)?;
232 let mut cache = cache.lock().map_err(|_err| {
233 CudaTensorError::KernelError("Failed to acquire kernel cache lock".to_string())
234 })?;
235
236 let kernel = BatchedSoftmaxKernel::new(total_rows, row_size);
237 let kernel_name = kernel.name();
238
239 let key = "batched_softmax_forward";
241 let module = match cache.get_cached(key) {
242 Some(m) => m,
243 None => {
244 let ptx = kernel.emit_ptx_for_target(cache.sm_target());
245 cache.get_or_compile(key, &ptx)?
246 }
247 };
248
249 let config =
251 LaunchConfig { grid: (total_rows, 1, 1), block: (32.min(row_size), 1, 1), shared_mem: 72 };
252
253 let input_ptr = input.as_ptr();
254 let output_ptr = output.as_ptr();
255
256 let mut args: [*mut std::ffi::c_void; 4] = [
257 &input_ptr as *const _ as *mut _,
258 &output_ptr as *const _ as *mut _,
259 &total_rows as *const _ as *mut _,
260 &row_size as *const _ as *mut _,
261 ];
262
263 unsafe {
266 stream.launch_kernel(module, kernel_name, &config, &mut args).map_err(|e| {
267 CudaTensorError::KernelError(format!("Batched softmax forward launch failed: {e:?}"))
268 })?;
269 }
270
271 Ok(())
272}