Skip to main content

trueno/brick/ops/
mod.rs

1#![allow(missing_docs)]
2//! Built-in Compute Operations
3//!
4//! Pre-defined operations that implement the ComputeOp trait:
5//! - DotOp: Vector dot product
6//! - AddOp: Element-wise vector addition
7//! - MatmulOp: Matrix multiplication (SIMD-optimized)
8//! - SoftmaxOp: Softmax with SIMD exp approximation (SIMD-EXP)
9
10use super::{Backend, ComputeOp};
11use crate::error::TruenoError;
12
13// ============================================================================
14// DotOp: Dot Product
15// ============================================================================
16
17/// Dot product operation.
18#[derive(Debug, Clone)]
19pub struct DotOp {
20    /// Expected vector length
21    pub len: usize,
22}
23
24impl DotOp {
25    pub fn new(len: usize) -> Self {
26        Self { len }
27    }
28}
29
30impl ComputeOp for DotOp {
31    type Input = (Vec<f32>, Vec<f32>);
32    type Output = f32;
33
34    fn name(&self) -> &'static str {
35        "dot"
36    }
37
38    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
39        let (a, b) = input;
40        if a.len() != b.len() {
41            return Err(TruenoError::SizeMismatch { expected: a.len(), actual: b.len() });
42        }
43        // Simple scalar implementation for now
44        let sum: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
45        Ok(sum)
46    }
47
48    fn tokens(&self, input: &Self::Input) -> usize {
49        // Each element pair is roughly 1 "token" of work
50        input.0.len()
51    }
52}
53
54// ============================================================================
55// AddOp: Element-wise Addition
56// ============================================================================
57
58/// Element-wise add operation.
59#[derive(Debug, Clone)]
60pub struct AddOp {
61    /// Expected vector length
62    pub len: usize,
63}
64
65impl AddOp {
66    pub fn new(len: usize) -> Self {
67        Self { len }
68    }
69}
70
71impl ComputeOp for AddOp {
72    type Input = (Vec<f32>, Vec<f32>);
73    type Output = Vec<f32>;
74
75    fn name(&self) -> &'static str {
76        "add"
77    }
78
79    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
80        let (a, b) = input;
81        if a.len() != b.len() {
82            return Err(TruenoError::SizeMismatch { expected: a.len(), actual: b.len() });
83        }
84        Ok(a.iter().zip(b.iter()).map(|(x, y)| x + y).collect())
85    }
86
87    fn tokens(&self, input: &Self::Input) -> usize {
88        input.0.len()
89    }
90}
91
92// ============================================================================
93// MatmulOp: Matrix Multiplication
94// ============================================================================
95
96/// Matrix multiplication operation.
97#[derive(Debug, Clone)]
98pub struct MatmulOp {
99    /// M dimension (rows of A)
100    pub m: usize,
101    /// K dimension (cols of A = rows of B)
102    pub k: usize,
103    /// N dimension (cols of B)
104    pub n: usize,
105}
106
107impl MatmulOp {
108    pub fn new(m: usize, k: usize, n: usize) -> Self {
109        Self { m, k, n }
110    }
111}
112
113impl ComputeOp for MatmulOp {
114    type Input = (Vec<f32>, Vec<f32>);
115    type Output = Vec<f32>;
116
117    fn name(&self) -> &'static str {
118        "matmul"
119    }
120
121    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
122        let (a, b) = input;
123        let expected_a = self.m * self.k;
124        let expected_b = self.k * self.n;
125
126        if a.len() != expected_a {
127            return Err(TruenoError::SizeMismatch { expected: expected_a, actual: a.len() });
128        }
129        if b.len() != expected_b {
130            return Err(TruenoError::SizeMismatch { expected: expected_b, actual: b.len() });
131        }
132
133        // SIMD-optimized matrix multiplication via Matrix type
134        // Uses AVX2/AVX-512 with cache blocking for ~10-50x speedup
135        let simd_backend = crate::Backend::select_best();
136        let mat_a = crate::Matrix::from_vec_with_backend(self.m, self.k, a, simd_backend);
137        let mat_b = crate::Matrix::from_vec_with_backend(self.k, self.n, b, simd_backend);
138
139        let result = mat_a.matmul(&mat_b)?;
140        // Take ownership of the data Vec directly — avoids redundant copy.
141        Ok(result.data)
142    }
143
144    fn tokens(&self, _input: &Self::Input) -> usize {
145        // For matmul, "tokens" = number of output elements
146        // Each output requires K multiply-adds
147        self.m * self.n
148    }
149}
150
151// ============================================================================
152// SoftmaxOp: Softmax with SIMD Exp (SIMD-EXP)
153// ============================================================================
154
155/// Softmax operation.
156#[derive(Debug, Clone)]
157pub struct SoftmaxOp {
158    /// Expected vector length
159    pub len: usize,
160}
161
162impl SoftmaxOp {
163    pub fn new(len: usize) -> Self {
164        Self { len }
165    }
166}
167
168impl ComputeOp for SoftmaxOp {
169    type Input = Vec<f32>;
170    type Output = Vec<f32>;
171
172    fn name(&self) -> &'static str {
173        "softmax"
174    }
175
176    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
177        if input.is_empty() {
178            return Ok(vec![]);
179        }
180
181        // CGP-DBUF: delegate to blis::softmax which has AVX2 fused exp+sum
182        // (3-pass: max, fused exp+sum, normalize). This eliminates 3 intermediate
183        // allocations (shifted, exp_vals, result) and uses polynomial fast_exp.
184        Ok(crate::blis::softmax::softmax_1d_alloc(&input))
185    }
186
187    fn tokens(&self, input: &Self::Input) -> usize {
188        input.len()
189    }
190}
191
192impl SoftmaxOp {
193    /// Check if backend supports SIMD acceleration
194    #[inline]
195    pub fn is_simd_backend(backend: Backend) -> bool {
196        matches!(
197            backend,
198            Backend::Avx2 | Backend::Avx512 | Backend::Sse2 | Backend::Neon | Backend::Auto
199        )
200    }
201    // CGP-DBUF: SIMD helper methods (simd_max, simd_exp, simd_sum, simd_scale,
202    // avx2_max, avx2_exp, avx2_sum, avx2_scale) removed — execute() now delegates
203    // to blis::softmax::softmax_1d_alloc which has fused AVX2 fast_exp path.
204}
205
206#[cfg(test)]
207mod tests;