Skip to main content

trueno/backends/gpu/device/
activations.rs

1//! GPU activation functions and softmax operations
2//!
3//! Element-wise activation functions (ReLU, sigmoid, tanh, etc.) and
4//! multi-pass softmax/log_softmax implementations.
5
6#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
7use super::super::runtime;
8use super::super::shaders;
9use super::GpuDevice;
10
11impl GpuDevice {
12    /// Execute ReLU activation on GPU: result[i] = max(0, input[i]) (sync, native only)
13    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
14    pub fn relu(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
15        runtime::block_on(async {
16            self.execute_element_wise_op("ReLU", shaders::RELU_SHADER, input, result, None).await
17        })
18    }
19
20    /// Execute ReLU activation on GPU (async, works on all platforms)
21    pub async fn relu_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
22        self.execute_element_wise_op("ReLU", shaders::RELU_SHADER, input, result, None).await
23    }
24
25    /// Execute leaky ReLU activation on GPU (sync, native only)
26    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
27    pub fn leaky_relu(
28        &self,
29        input: &[f32],
30        result: &mut [f32],
31        negative_slope: f32,
32    ) -> Result<(), String> {
33        runtime::block_on(self.leaky_relu_async(input, result, negative_slope))
34    }
35
36    /// Execute leaky ReLU activation on GPU (async, works on all platforms)
37    pub async fn leaky_relu_async(
38        &self,
39        input: &[f32],
40        result: &mut [f32],
41        negative_slope: f32,
42    ) -> Result<(), String> {
43        #[repr(C)]
44        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
45        struct LeakyReluParams {
46            negative_slope: f32,
47        }
48
49        let params = LeakyReluParams { negative_slope };
50        let uniform_data = bytemuck::bytes_of(&params);
51
52        self.execute_element_wise_op(
53            "LeakyReLU",
54            shaders::LEAKY_RELU_SHADER,
55            input,
56            result,
57            Some(uniform_data),
58        )
59        .await
60    }
61
62    /// Execute ELU activation on GPU (sync, native only)
63    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
64    pub fn elu(&self, input: &[f32], result: &mut [f32], alpha: f32) -> Result<(), String> {
65        runtime::block_on(self.elu_async(input, result, alpha))
66    }
67
68    /// Execute ELU activation on GPU (async, works on all platforms)
69    pub async fn elu_async(
70        &self,
71        input: &[f32],
72        result: &mut [f32],
73        alpha: f32,
74    ) -> Result<(), String> {
75        #[repr(C)]
76        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
77        struct EluParams {
78            alpha: f32,
79        }
80
81        let params = EluParams { alpha };
82        let uniform_data = bytemuck::bytes_of(&params);
83
84        self.execute_element_wise_op("ELU", shaders::ELU_SHADER, input, result, Some(uniform_data))
85            .await
86    }
87
88    /// Execute sigmoid activation on GPU (sync, native only)
89    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
90    pub fn sigmoid(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
91        runtime::block_on(self.sigmoid_async(input, result))
92    }
93
94    /// Execute sigmoid activation on GPU (async, works on all platforms)
95    pub async fn sigmoid_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
96        self.execute_element_wise_op("Sigmoid", shaders::SIGMOID_SHADER, input, result, None).await
97    }
98
99    /// Execute tanh activation on GPU (sync, native only)
100    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
101    pub fn tanh(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
102        runtime::block_on(self.tanh_async(input, result))
103    }
104
105    /// Execute tanh activation on GPU (async, works on all platforms)
106    pub async fn tanh_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
107        self.execute_element_wise_op("Tanh", shaders::TANH_SHADER, input, result, None).await
108    }
109
110    /// Execute swish activation on GPU (sync, native only)
111    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
112    pub fn swish(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
113        runtime::block_on(self.swish_async(input, result))
114    }
115
116    /// Execute swish activation on GPU (async, works on all platforms)
117    pub async fn swish_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
118        self.execute_element_wise_op("Swish", shaders::SWISH_SHADER, input, result, None).await
119    }
120
121    /// Execute GELU activation on GPU (sync, native only)
122    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
123    pub fn gelu(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
124        runtime::block_on(self.gelu_async(input, result))
125    }
126
127    /// Execute GELU activation on GPU (async, works on all platforms)
128    pub async fn gelu_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
129        self.execute_element_wise_op("GELU", shaders::GELU_SHADER, input, result, None).await
130    }
131
132    /// Execute clip (clamp) operation on GPU (sync, native only)
133    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
134    pub fn clip(
135        &self,
136        input: &[f32],
137        result: &mut [f32],
138        min_val: f32,
139        max_val: f32,
140    ) -> Result<(), String> {
141        runtime::block_on(self.clip_async(input, result, min_val, max_val))
142    }
143
144    /// Execute clip (clamp) operation on GPU (async, works on all platforms)
145    pub async fn clip_async(
146        &self,
147        input: &[f32],
148        result: &mut [f32],
149        min_val: f32,
150        max_val: f32,
151    ) -> Result<(), String> {
152        #[repr(C)]
153        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
154        struct ClipParams {
155            min_val: f32,
156            max_val: f32,
157        }
158
159        let params = ClipParams { min_val, max_val };
160        let uniform_data = bytemuck::bytes_of(&params);
161
162        self.execute_element_wise_op(
163            "Clip",
164            shaders::CLIP_SHADER,
165            input,
166            result,
167            Some(uniform_data),
168        )
169        .await
170    }
171
172    /// Execute softmax on GPU (sync, native only)
173    ///
174    /// Multi-pass implementation:
175    /// 1. Find max value (parallel reduction)
176    /// 2. Compute exp(x - max) (element-wise)
177    /// 3. Sum exp values (parallel reduction)
178    /// 4. Normalize by sum (element-wise)
179    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
180    pub fn softmax(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
181        runtime::block_on(async { self.softmax_async(input, result).await })
182    }
183
184    /// Execute softmax on GPU (async, works on all platforms)
185    pub async fn softmax_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
186        // Pass 1: Find max value
187        let max_val = self.reduce_max(input).await?;
188
189        // Pass 2: Compute exp(x - max)
190        let exp_vals = self.compute_exp_subtract(input, max_val).await?;
191
192        // Pass 3: Sum exp values
193        let sum_exp = self.reduce_sum(&exp_vals).await?;
194
195        // Pass 4: Normalize by sum
196        self.normalize_by_sum(&exp_vals, result, sum_exp).await?;
197
198        Ok(())
199    }
200
201    /// Execute log_softmax on GPU (sync, native only)
202    ///
203    /// Multi-pass implementation:
204    /// 1. Find max value (parallel reduction)
205    /// 2. Compute exp(x - max) (element-wise)
206    /// 3. Sum exp values (parallel reduction)
207    /// 4. Compute log_softmax (element-wise)
208    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
209    pub fn log_softmax(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
210        runtime::block_on(async { self.log_softmax_async(input, result).await })
211    }
212
213    /// Execute log_softmax on GPU (async, works on all platforms)
214    pub async fn log_softmax_async(&self, input: &[f32], result: &mut [f32]) -> Result<(), String> {
215        // Pass 1: Find max value
216        let max_val = self.reduce_max(input).await?;
217
218        // Pass 2: Compute exp(x - max)
219        let exp_vals = self.compute_exp_subtract(input, max_val).await?;
220
221        // Pass 3: Sum exp values
222        let sum_exp = self.reduce_sum(&exp_vals).await?;
223
224        // Pass 4: Compute log_softmax = x - max - log(sum_exp)
225        let log_sum_exp = sum_exp.max(f32::EPSILON).ln();
226
227        #[repr(C)]
228        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
229        struct LogSoftmaxParams {
230            max_val: f32,
231            log_sum_exp: f32,
232        }
233
234        let params = LogSoftmaxParams { max_val, log_sum_exp };
235        let uniform_data = bytemuck::bytes_of(&params);
236
237        self.execute_element_wise_op(
238            "LogSoftmax",
239            shaders::LOG_SOFTMAX_SHADER,
240            input,
241            result,
242            Some(uniform_data),
243        )
244        .await?;
245
246        Ok(())
247    }
248
249    /// Helper: Compute exp(input[i] - max_val)
250    pub(super) async fn compute_exp_subtract(
251        &self,
252        input: &[f32],
253        max_val: f32,
254    ) -> Result<Vec<f32>, String> {
255        #[repr(C)]
256        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
257        struct MaxValue {
258            max_val: f32,
259        }
260
261        let params = MaxValue { max_val };
262        let uniform_data = bytemuck::bytes_of(&params);
263
264        let mut result = vec![0.0f32; input.len()];
265        self.execute_element_wise_op(
266            "SoftmaxExp",
267            shaders::SOFTMAX_EXP_SHADER,
268            input,
269            &mut result,
270            Some(uniform_data),
271        )
272        .await?;
273
274        Ok(result)
275    }
276
277    /// Helper: Normalize by sum
278    pub(super) async fn normalize_by_sum(
279        &self,
280        input: &[f32],
281        result: &mut [f32],
282        sum_val: f32,
283    ) -> Result<(), String> {
284        #[repr(C)]
285        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
286        struct SumValue {
287            sum_val: f32,
288        }
289
290        let params = SumValue { sum_val };
291        let uniform_data = bytemuck::bytes_of(&params);
292
293        self.execute_element_wise_op(
294            "SoftmaxNormalize",
295            shaders::SOFTMAX_NORMALIZE_SHADER,
296            input,
297            result,
298            Some(uniform_data),
299        )
300        .await?;
301
302        Ok(())
303    }
304}