Skip to main content

trueno/backends/gpu/
backend_ops.rs

1//! GPU backend operation implementations
2//!
3//! Contains all compute operations for [`GpuBackend`] including:
4//! - Vector operations (add, dot product)
5//! - Activation functions (ReLU, sigmoid, tanh, swish, GELU, softmax, etc.)
6//! - Matrix operations (matmul, convolve2d, eigendecomposition)
7//! - Tiled reductions (sum, max, min)
8
9use super::GpuBackend;
10
11#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
12impl GpuBackend {
13    /// Vector addition on GPU: c = a + b
14    ///
15    /// # Arguments
16    ///
17    /// * `a` - Vector a
18    /// * `b` - Vector b
19    ///
20    /// # Returns
21    ///
22    /// Vector c (element-wise sum)
23    pub fn vec_add(&mut self, a: &[f32], b: &[f32]) -> Result<Vec<f32>, String> {
24        if a.len() != b.len() {
25            return Err(format!("Vector length mismatch: {} != {}", a.len(), b.len()));
26        }
27
28        // wgpu doesn't allow zero-sized buffers
29        if a.is_empty() {
30            return Err("Cannot perform GPU operation on empty vectors".to_string());
31        }
32
33        let device = self.ensure_device()?;
34
35        // Create output buffer
36        let mut result = vec![0.0f32; a.len()];
37
38        // Execute GPU compute
39        device.vec_add(a, b, &mut result)?;
40
41        Ok(result)
42    }
43
44    /// Dot product on GPU: result = sum(a[i] * b[i])
45    ///
46    /// # Arguments
47    ///
48    /// * `a` - Vector a
49    /// * `b` - Vector b
50    ///
51    /// # Returns
52    ///
53    /// Scalar dot product result
54    pub fn dot(&mut self, a: &[f32], b: &[f32]) -> Result<f32, String> {
55        if a.len() != b.len() {
56            return Err(format!("Vector length mismatch: {} != {}", a.len(), b.len()));
57        }
58
59        let device = self.ensure_device()?;
60
61        // Execute GPU compute
62        device.dot(a, b)
63    }
64
65    /// ReLU activation on GPU: result[i] = max(0, input[i])
66    ///
67    /// # Arguments
68    ///
69    /// * `input` - Input vector
70    ///
71    /// # Returns
72    ///
73    /// Vector with ReLU applied element-wise
74    pub fn relu(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
75        let device = self.ensure_device()?;
76
77        // Create output buffer
78        let mut result = vec![0.0f32; input.len()];
79
80        // Execute GPU compute
81        device.relu(input, &mut result)?;
82
83        Ok(result)
84    }
85
86    /// Leaky ReLU activation on GPU: result[i] = max(negative_slope * input[i], input[i])
87    ///
88    /// # Arguments
89    ///
90    /// * `input` - Input vector
91    /// * `negative_slope` - Slope for negative values (typically 0.01)
92    ///
93    /// # Returns
94    ///
95    /// Vector with leaky ReLU applied element-wise
96    pub fn leaky_relu(&mut self, input: &[f32], negative_slope: f32) -> Result<Vec<f32>, String> {
97        let device = self.ensure_device()?;
98
99        // Create output buffer
100        let mut result = vec![0.0f32; input.len()];
101
102        // Execute GPU compute
103        device.leaky_relu(input, &mut result, negative_slope)?;
104
105        Ok(result)
106    }
107
108    /// ELU activation on GPU: result[i] = x if x > 0, else alpha * (exp(x) - 1)
109    ///
110    /// # Arguments
111    ///
112    /// * `input` - Input vector
113    /// * `alpha` - Scaling factor for negative values (typically 1.0)
114    ///
115    /// # Returns
116    ///
117    /// Vector with ELU applied element-wise
118    pub fn elu(&mut self, input: &[f32], alpha: f32) -> Result<Vec<f32>, String> {
119        let device = self.ensure_device()?;
120
121        // Create output buffer
122        let mut result = vec![0.0f32; input.len()];
123
124        // Execute GPU compute
125        device.elu(input, &mut result, alpha)?;
126
127        Ok(result)
128    }
129
130    /// Clip (clamp) operation on GPU: result[i] = clamp(input[i], min_val, max_val)
131    ///
132    /// # Arguments
133    ///
134    /// * `input` - Input vector
135    /// * `min_val` - Minimum value
136    /// * `max_val` - Maximum value
137    ///
138    /// # Returns
139    ///
140    /// Vector with clip applied element-wise
141    pub fn clip(&mut self, input: &[f32], min_val: f32, max_val: f32) -> Result<Vec<f32>, String> {
142        let device = self.ensure_device()?;
143
144        // Create output buffer
145        let mut result = vec![0.0f32; input.len()];
146
147        // Execute GPU compute
148        device.clip(input, &mut result, min_val, max_val)?;
149
150        Ok(result)
151    }
152
153    /// Sigmoid activation on GPU: result[i] = 1 / (1 + exp(-input[i]))
154    ///
155    /// # Arguments
156    ///
157    /// * `input` - Input vector
158    ///
159    /// # Returns
160    ///
161    /// Vector with sigmoid applied element-wise
162    pub fn sigmoid(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
163        let device = self.ensure_device()?;
164
165        // Create output buffer
166        let mut result = vec![0.0f32; input.len()];
167
168        // Execute GPU compute
169        device.sigmoid(input, &mut result)?;
170
171        Ok(result)
172    }
173
174    /// Tanh activation on GPU: result[i] = tanh(input[i])
175    ///
176    /// # Arguments
177    ///
178    /// * `input` - Input vector
179    ///
180    /// # Returns
181    ///
182    /// Vector with tanh applied element-wise
183    pub fn tanh(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
184        let device = self.ensure_device()?;
185
186        // Create output buffer
187        let mut result = vec![0.0f32; input.len()];
188
189        // Execute GPU compute
190        device.tanh(input, &mut result)?;
191
192        Ok(result)
193    }
194
195    /// Swish activation on GPU: result[i] = input[i] / (1 + exp(-input[i]))
196    ///
197    /// # Arguments
198    ///
199    /// * `input` - Input vector
200    ///
201    /// # Returns
202    ///
203    /// Vector with swish applied element-wise
204    pub fn swish(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
205        let device = self.ensure_device()?;
206
207        // Create output buffer
208        let mut result = vec![0.0f32; input.len()];
209
210        // Execute GPU compute
211        device.swish(input, &mut result)?;
212
213        Ok(result)
214    }
215
216    /// GELU activation on GPU: result[i] = 0.5 * input[i] * (1 + tanh(...))
217    ///
218    /// # Arguments
219    ///
220    /// * `input` - Input vector
221    ///
222    /// # Returns
223    ///
224    /// Vector with GELU applied element-wise
225    pub fn gelu(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
226        let device = self.ensure_device()?;
227
228        // Create output buffer
229        let mut result = vec![0.0f32; input.len()];
230
231        // Execute GPU compute
232        device.gelu(input, &mut result)?;
233
234        Ok(result)
235    }
236
237    /// Softmax activation on GPU: result[i] = exp(input[i] - max) / sum(exp(input - max))
238    ///
239    /// Uses multi-pass reduction for numerical stability:
240    /// - Pass 1: Max reduction (parallel)
241    /// - Pass 2: Exp-subtract (element-wise)
242    /// - Pass 3: Sum reduction (parallel)
243    /// - Pass 4: Normalize (element-wise)
244    ///
245    /// # Arguments
246    ///
247    /// * `input` - Input vector
248    ///
249    /// # Returns
250    ///
251    /// Vector with softmax applied element-wise
252    pub fn softmax(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
253        contract_pre_softmax!(input);
254        let device = self.ensure_device()?;
255
256        // Create output buffer
257        let mut result = vec![0.0f32; input.len()];
258
259        // Execute GPU compute
260        device.softmax(input, &mut result)?;
261
262        contract_post_softmax!(&result);
263        Ok(result)
264    }
265
266    /// Log-softmax activation on GPU: result[i] = log(softmax(input)[i])
267    ///
268    /// Uses multi-pass reduction for numerical stability:
269    /// - Pass 1: Max reduction (parallel)
270    /// - Pass 2: Exp-subtract (element-wise)
271    /// - Pass 3: Sum reduction (parallel)
272    /// - Pass 4: Log-normalize (element-wise)
273    ///
274    /// # Arguments
275    ///
276    /// * `input` - Input vector
277    ///
278    /// # Returns
279    ///
280    /// Vector with log-softmax applied element-wise
281    pub fn log_softmax(&mut self, input: &[f32]) -> Result<Vec<f32>, String> {
282        contract_pre_log_softmax!(input);
283        let device = self.ensure_device()?;
284
285        // Create output buffer
286        let mut result = vec![0.0f32; input.len()];
287
288        // Execute GPU compute
289        device.log_softmax(input, &mut result)?;
290
291        Ok(result)
292    }
293
294    /// 2D Convolution on GPU: output = input (convolved with) kernel
295    ///
296    /// # Arguments
297    ///
298    /// * `input` - Input matrix (flattened row-major)
299    /// * `kernel` - Convolution kernel (flattened row-major)
300    /// * `input_rows` - Number of rows in input
301    /// * `input_cols` - Number of columns in input
302    /// * `kernel_rows` - Number of rows in kernel
303    /// * `kernel_cols` - Number of columns in kernel
304    ///
305    /// # Returns
306    ///
307    /// Output matrix (flattened row-major, "valid" convolution)
308    /// - output_rows = input_rows - kernel_rows + 1
309    /// - output_cols = input_cols - kernel_cols + 1
310    pub fn convolve2d(
311        &mut self,
312        input: &[f32],
313        kernel: &[f32],
314        input_rows: usize,
315        input_cols: usize,
316        kernel_rows: usize,
317        kernel_cols: usize,
318    ) -> Result<Vec<f32>, String> {
319        let device = self.ensure_device()?;
320
321        // Calculate output dimensions
322        let output_rows = input_rows.saturating_sub(kernel_rows).saturating_add(1);
323        let output_cols = input_cols.saturating_sub(kernel_cols).saturating_add(1);
324
325        // Create output buffer
326        let mut result = vec![0.0f32; output_rows * output_cols];
327
328        // Execute GPU compute
329        device.convolve2d(
330            input,
331            kernel,
332            &mut result,
333            input_rows,
334            input_cols,
335            kernel_rows,
336            kernel_cols,
337        )?;
338
339        Ok(result)
340    }
341
342    /// Matrix multiplication on GPU: C = A x B
343    ///
344    /// # Arguments
345    ///
346    /// * `a` - Matrix A (m x k) in row-major order
347    /// * `b` - Matrix B (k x n) in row-major order
348    /// * `m` - Rows of A and C
349    /// * `k` - Cols of A, rows of B
350    /// * `n` - Cols of B and C
351    ///
352    /// # Returns
353    ///
354    /// Matrix C (m x n) in row-major order
355    pub fn matmul(
356        &mut self,
357        a: &[f32],
358        b: &[f32],
359        m: usize,
360        k: usize,
361        n: usize,
362    ) -> Result<Vec<f32>, String> {
363        let device = self.ensure_device()?;
364
365        // Create output buffer
366        let mut result = vec![0.0f32; m * n];
367
368        // Execute GPU compute
369        device.matmul(a, b, &mut result, m, k, n)?;
370
371        Ok(result)
372    }
373
374    /// Symmetric eigendecomposition on GPU
375    ///
376    /// Computes eigenvalues and eigenvectors using Jacobi algorithm with
377    /// GPU-accelerated Givens rotations.
378    ///
379    /// # Arguments
380    ///
381    /// * `matrix` - Symmetric matrix data (row-major, n x n)
382    /// * `n` - Matrix dimension
383    ///
384    /// # Returns
385    ///
386    /// Tuple of (eigenvalues, eigenvector_data) where eigenvector_data is row-major
387    pub fn symmetric_eigen(
388        &mut self,
389        matrix: &[f32],
390        n: usize,
391    ) -> Result<(Vec<f32>, Vec<f32>), String> {
392        let device = self.ensure_device()?;
393        device.symmetric_eigen(matrix, n)
394    }
395
396    /// 2D Tiled Sum Reduction on GPU
397    ///
398    /// Uses 16x16 workgroups for efficient parallel reduction with
399    /// optimal memory coalescing.
400    ///
401    /// # Arguments
402    ///
403    /// * `data` - Input 2D data in row-major order
404    /// * `width` - Number of columns
405    /// * `height` - Number of rows
406    ///
407    /// # Returns
408    ///
409    /// Sum of all elements
410    pub fn tiled_sum_2d_gpu(
411        &mut self,
412        data: &[f32],
413        width: usize,
414        height: usize,
415    ) -> Result<f32, String> {
416        let device = self.ensure_device()?;
417        device.tiled_sum_2d(data, width, height)
418    }
419
420    /// 2D Tiled Max Reduction on GPU
421    ///
422    /// Uses 16x16 workgroups for efficient parallel max reduction.
423    pub fn tiled_max_2d_gpu(
424        &mut self,
425        data: &[f32],
426        width: usize,
427        height: usize,
428    ) -> Result<f32, String> {
429        let device = self.ensure_device()?;
430        device.tiled_max_2d(data, width, height)
431    }
432
433    /// 2D Tiled Min Reduction on GPU
434    ///
435    /// Uses 16x16 workgroups for efficient parallel min reduction.
436    pub fn tiled_min_2d_gpu(
437        &mut self,
438        data: &[f32],
439        width: usize,
440        height: usize,
441    ) -> Result<f32, String> {
442        let device = self.ensure_device()?;
443        device.tiled_min_2d(data, width, height)
444    }
445}