Skip to main content

trueno/backends/
mod.rs

1//! Backend implementations for different SIMD instruction sets
2//!
3//! This module contains the actual SIMD implementations for each backend.
4//! All backends implement the same trait-based interface to ensure API consistency.
5//!
6//! # Safety
7//!
8//! All `unsafe` code is isolated within backend implementations. The public API
9//! remains 100% safe.
10//!
11//! # Backends
12//!
13//! - `scalar`: Portable baseline implementation (no SIMD)
14//! - `sse2`: x86_64 baseline SIMD (128-bit)
15//! - `avx2`: x86_64 advanced SIMD (256-bit with FMA)
16//! - `avx512`: x86_64 maximum SIMD (512-bit)
17//! - `neon`: ARM SIMD (128-bit)
18//! - `wasm`: WebAssembly SIMD128
19
20pub mod q4k;
21pub mod q6k;
22pub mod scalar;
23
24#[cfg(target_arch = "x86_64")]
25pub mod sse2;
26
27#[cfg(target_arch = "x86_64")]
28pub mod avx2;
29
30#[cfg(target_arch = "x86_64")]
31#[cfg(test)]
32mod avx2_tests;
33
34#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
35pub mod neon;
36
37#[cfg(target_arch = "wasm32")]
38pub mod wasm;
39
40// GPU module - always available for TensorView/PartitionView abstractions
41// Actual GPU compute requires "gpu" feature
42pub mod gpu;
43
44#[cfg(target_arch = "x86_64")]
45pub mod avx512;
46
47/// Backend trait defining common operations
48///
49/// All backend implementations must implement this trait to ensure
50/// consistent behavior across different SIMD instruction sets.
51///
52/// # Safety
53///
54/// Implementations may use unsafe SIMD intrinsics. Callers must ensure:
55/// - Input slices are valid
56/// - Result slice has sufficient capacity
57/// - Slices `a` and `b` have the same length
58pub trait VectorBackend {
59    /// Element-wise addition: a\[i\] + b\[i\]
60    ///
61    /// # Safety
62    ///
63    /// - `a` and `b` must have the same length
64    /// - `result` must have length >= `a.len()`
65    // SAFETY: Caller must satisfy the documented preconditions for slice validity
66    unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]);
67
68    /// Element-wise subtraction: a\[i\] - b\[i\]
69    ///
70    /// # Safety
71    ///
72    /// - `a` and `b` must have the same length
73    /// - `result` must have length >= `a.len()`
74    // SAFETY: Caller must satisfy the documented preconditions for slice validity
75    unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]);
76
77    /// Element-wise multiplication: a\[i\] * b\[i\]
78    ///
79    /// # Safety
80    ///
81    /// - `a` and `b` must have the same length
82    /// - `result` must have length >= `a.len()`
83    // SAFETY: Caller must satisfy the documented preconditions for slice validity
84    unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]);
85
86    /// Element-wise division: a\[i\] / b\[i\]
87    ///
88    /// # Safety
89    ///
90    /// - `a` and `b` must have the same length
91    /// - `result` must have length >= `a.len()`
92    // SAFETY: Caller must satisfy the documented preconditions for slice validity
93    unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]);
94
95    /// Dot product: sum(a\[i\] * b\[i\])
96    ///
97    /// # Safety
98    ///
99    /// - `a` and `b` must have the same length
100    // SAFETY: Caller must satisfy the documented preconditions for slice validity
101    unsafe fn dot(a: &[f32], b: &[f32]) -> f32;
102
103    /// Sum reduction: sum(a\[i\])
104    ///
105    /// # Safety
106    ///
107    /// - `a` must not be empty
108    // SAFETY: Caller must satisfy the documented preconditions for slice validity
109    unsafe fn sum(a: &[f32]) -> f32;
110
111    /// Max reduction: max(a\[i\])
112    ///
113    /// # Safety
114    ///
115    /// - `a` must not be empty
116    // SAFETY: Caller must satisfy the documented preconditions for slice validity
117    unsafe fn max(a: &[f32]) -> f32;
118
119    /// Min reduction: min(a\[i\])
120    ///
121    /// # Safety
122    ///
123    /// - `a` must not be empty
124    // SAFETY: Caller must satisfy the documented preconditions for slice validity
125    unsafe fn min(a: &[f32]) -> f32;
126
127    /// Argmax: index of maximum value
128    ///
129    /// Returns the index of the first occurrence of the maximum value.
130    ///
131    /// # Safety
132    ///
133    /// - `a` must not be empty
134    // SAFETY: Caller must satisfy the documented preconditions for slice validity
135    unsafe fn argmax(a: &[f32]) -> usize;
136
137    /// Argmin: index of minimum value
138    ///
139    /// Returns the index of the first occurrence of the minimum value.
140    ///
141    /// # Safety
142    ///
143    /// - `a` must not be empty
144    // SAFETY: Caller must satisfy the documented preconditions for slice validity
145    unsafe fn argmin(a: &[f32]) -> usize;
146
147    /// Kahan summation: numerically stable sum(a\[i\])
148    ///
149    /// Uses the Kahan summation algorithm to reduce floating-point rounding errors
150    /// when summing many numbers. Tracks a running compensation for lost low-order bits.
151    ///
152    /// # Safety
153    ///
154    /// - Can handle empty slice (returns 0.0)
155    // SAFETY: Caller must satisfy the documented preconditions for slice validity
156    unsafe fn sum_kahan(a: &[f32]) -> f32;
157
158    /// L2 norm (Euclidean norm): sqrt(sum(a\[i\]^2))
159    ///
160    /// Computes the Euclidean length of the vector. This is equivalent to sqrt(dot(a, a)).
161    ///
162    /// # Safety
163    ///
164    /// - Can handle empty slice (returns 0.0)
165    // SAFETY: Caller must satisfy the documented preconditions for slice validity
166    unsafe fn norm_l2(a: &[f32]) -> f32;
167
168    /// L1 norm (Manhattan norm): sum(|a\[i\]|)
169    ///
170    /// Computes the sum of absolute values of all elements.
171    /// Used in machine learning (L1 regularization), distance metrics, and sparse modeling.
172    ///
173    /// # Safety
174    ///
175    /// - Can handle empty slice (returns 0.0)
176    // SAFETY: Caller must satisfy the documented preconditions for slice validity
177    unsafe fn norm_l1(a: &[f32]) -> f32;
178
179    /// L-infinity norm (maximum absolute value): max(|a\[i\]|)
180    ///
181    /// Computes the maximum absolute value of all elements.
182    /// Used in optimization (constraint checking), numerical analysis, and error bounds.
183    ///
184    /// # Safety
185    ///
186    /// - Can handle empty slice (returns 0.0)
187    // SAFETY: Caller must satisfy the documented preconditions for slice validity
188    unsafe fn norm_linf(a: &[f32]) -> f32;
189
190    /// Scalar multiplication: result\[i\] = a\[i\] * scalar
191    ///
192    /// Multiplies all elements by a scalar value.
193    /// Used in vector scaling, normalization, and linear transformations.
194    ///
195    /// # Safety
196    ///
197    /// - `result` must have the same length as `a`
198    /// - Can handle empty slice
199    // SAFETY: Caller must satisfy the documented preconditions for slice validity
200    unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]);
201
202    /// Absolute value: result\[i\] = |a\[i\]|
203    ///
204    /// Computes the absolute value of each element.
205    /// Used in distance metrics (L1 norm), numerical stability, and signal processing.
206    ///
207    /// # Safety
208    ///
209    /// - `result` must have the same length as `a`
210    /// - Can handle empty slice
211    // SAFETY: Caller must satisfy the documented preconditions for slice validity
212    unsafe fn abs(a: &[f32], result: &mut [f32]);
213
214    /// Clamp elements to range [min_val, max_val]: result\[i\] = max(min_val, min(a\[i\], max_val))
215    ///
216    /// Constrains each element to the specified range.
217    /// Used in neural networks (gradient clipping), graphics (color clamping), and signal processing.
218    ///
219    /// # Safety
220    ///
221    /// - `result` must have the same length as `a`
222    /// - Can handle empty slice
223    /// - Assumes min_val <= max_val (caller must validate)
224    // SAFETY: Caller must satisfy the documented preconditions for slice validity
225    unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]);
226
227    /// Linear interpolation: result\[i\] = a\[i\] + t * (b\[i\] - a\[i\])
228    ///
229    /// Computes element-wise linear interpolation between two vectors.
230    /// When t=0, returns a; when t=1, returns b; values outside \[0,1\] extrapolate.
231    /// Used in graphics, animation, neural networks, and signal processing.
232    ///
233    /// # Safety
234    ///
235    /// - `a` and `b` must have the same length
236    /// - `result` must have the same length as `a`
237    /// - Can handle empty slices
238    // SAFETY: Caller must satisfy the documented preconditions for slice validity
239    unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]);
240
241    /// Fused multiply-add: result\[i\] = a\[i\] * b\[i\] + c\[i\]
242    ///
243    /// Computes element-wise fused multiply-add operation.
244    /// On hardware with FMA support, this is a single instruction with better performance
245    /// and numerical accuracy (no intermediate rounding).
246    /// Used in neural networks, matrix multiplication, and scientific computing.
247    ///
248    /// # Safety
249    ///
250    /// - `a`, `b`, and `c` must all have the same length
251    /// - `result` must have the same length as `a`
252    /// - Can handle empty slices
253    // SAFETY: Caller must satisfy the documented preconditions for slice validity
254    unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]);
255
256    /// ReLU activation: result\[i\] = max(0, a\[i\])
257    ///
258    /// Rectified Linear Unit - the most common activation function in neural networks.
259    /// Sets negative values to zero, passes positive values unchanged.
260    ///
261    /// # Safety
262    ///
263    /// - `result` must have the same length as `a`
264    /// - Can handle empty slices
265    // SAFETY: Caller must satisfy the documented preconditions for slice validity
266    unsafe fn relu(a: &[f32], result: &mut [f32]);
267
268    /// Exponential function: result\[i\] = exp(a\[i\])
269    ///
270    /// Computes e^x for each element using range reduction for numerical accuracy.
271    /// Foundation for sigmoid, softmax, GELU, and other activation functions.
272    ///
273    /// # Safety
274    ///
275    /// - `result` must have the same length as `a`
276    /// - Can handle empty slices
277    // SAFETY: Caller must satisfy the documented preconditions for slice validity
278    unsafe fn exp(a: &[f32], result: &mut [f32]);
279
280    /// Sigmoid activation: result\[i\] = 1 / (1 + exp(-a\[i\]))
281    ///
282    /// Logistic sigmoid function - maps inputs to (0, 1) range.
283    /// Used in binary classification and as gating mechanism.
284    ///
285    /// # Safety
286    ///
287    /// - `result` must have the same length as `a`
288    /// - Can handle empty slices
289    // SAFETY: Caller must satisfy the documented preconditions for slice validity
290    unsafe fn sigmoid(a: &[f32], result: &mut [f32]);
291
292    /// GELU activation: result\[i\] = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
293    ///
294    /// Gaussian Error Linear Unit - smooth non-monotonic activation.
295    /// Used in BERT, GPT, and modern transformers.
296    ///
297    /// # Safety
298    ///
299    /// - `result` must have the same length as `a`
300    /// - Can handle empty slices
301    // SAFETY: Caller must satisfy the documented preconditions for slice validity
302    unsafe fn gelu(a: &[f32], result: &mut [f32]);
303
304    /// Swish activation: result\[i\] = x * sigmoid(x) = x / (1 + exp(-x))
305    ///
306    /// Self-gated activation function (also called SiLU).
307    /// Used in EfficientNet, MobileNetV3.
308    ///
309    /// # Safety
310    ///
311    /// - `result` must have the same length as `a`
312    /// - Can handle empty slices
313    // SAFETY: Caller must satisfy the documented preconditions for slice validity
314    unsafe fn swish(a: &[f32], result: &mut [f32]);
315
316    /// Hyperbolic tangent activation: result\[i\] = tanh(a\[i\]) = (exp(2x) - 1) / (exp(2x) + 1)
317    ///
318    /// Hyperbolic tangent - maps inputs to (-1, 1) range.
319    /// Classic activation function from early neural networks.
320    /// Used in RNNs, LSTMs, and as smooth alternative to ReLU.
321    ///
322    /// # Safety
323    ///
324    /// - `result` must have the same length as `a`
325    /// - Can handle empty slices
326    // SAFETY: Caller must satisfy the documented preconditions for slice validity
327    unsafe fn tanh(a: &[f32], result: &mut [f32]);
328
329    /// Square root: result\[i\] = sqrt(a\[i\])
330    ///
331    /// # Safety
332    ///
333    /// - `result` must have the same length as `a`
334    /// - Can handle empty slices
335    // SAFETY: Caller must satisfy the documented preconditions for slice validity
336    unsafe fn sqrt(a: &[f32], result: &mut [f32]);
337
338    /// Reciprocal: result\[i\] = 1 / a\[i\]
339    ///
340    /// # Safety
341    ///
342    /// - `result` must have the same length as `a`
343    /// - Can handle empty slices
344    // SAFETY: Caller must satisfy the documented preconditions for slice validity
345    unsafe fn recip(a: &[f32], result: &mut [f32]);
346
347    /// Natural logarithm: result\[i\] = ln(a\[i\])
348    ///
349    /// # Safety
350    ///
351    /// - `result` must have the same length as `a`
352    /// - Can handle empty slices
353    // SAFETY: Caller must satisfy the documented preconditions for slice validity
354    unsafe fn ln(a: &[f32], result: &mut [f32]);
355
356    /// Base-2 logarithm: result\[i\] = log2(a\[i\])
357    ///
358    /// # Safety
359    ///
360    /// - `result` must have the same length as `a`
361    /// - Can handle empty slices
362    // SAFETY: Caller must satisfy the documented preconditions for slice validity
363    unsafe fn log2(a: &[f32], result: &mut [f32]);
364
365    /// Base-10 logarithm: result\[i\] = log10(a\[i\])
366    ///
367    /// # Safety
368    ///
369    /// - `result` must have the same length as `a`
370    /// - Can handle empty slices
371    // SAFETY: Caller must satisfy the documented preconditions for slice validity
372    unsafe fn log10(a: &[f32], result: &mut [f32]);
373
374    /// Sine: result\[i\] = sin(a\[i\])
375    ///
376    /// # Safety
377    ///
378    /// - `result` must have the same length as `a`
379    /// - Can handle empty slices
380    // SAFETY: Caller must satisfy the documented preconditions for slice validity
381    unsafe fn sin(a: &[f32], result: &mut [f32]);
382
383    /// Cosine: result\[i\] = cos(a\[i\])
384    ///
385    /// # Safety
386    ///
387    /// - `result` must have the same length as `a`
388    /// - Can handle empty slices
389    // SAFETY: Caller must satisfy the documented preconditions for slice validity
390    unsafe fn cos(a: &[f32], result: &mut [f32]);
391
392    /// Tangent: result\[i\] = tan(a\[i\])
393    ///
394    /// # Safety
395    ///
396    /// - `result` must have the same length as `a`
397    /// - Can handle empty slices
398    // SAFETY: Caller must satisfy the documented preconditions for slice validity
399    unsafe fn tan(a: &[f32], result: &mut [f32]);
400
401    /// Floor: result\[i\] = floor(a\[i\])
402    ///
403    /// # Safety
404    ///
405    /// - `result` must have the same length as `a`
406    /// - Can handle empty slices
407    // SAFETY: Caller must satisfy the documented preconditions for slice validity
408    unsafe fn floor(a: &[f32], result: &mut [f32]);
409
410    /// Ceiling: result\[i\] = ceil(a\[i\])
411    ///
412    /// # Safety
413    ///
414    /// - `result` must have the same length as `a`
415    /// - Can handle empty slices
416    // SAFETY: Caller must satisfy the documented preconditions for slice validity
417    unsafe fn ceil(a: &[f32], result: &mut [f32]);
418
419    /// Round: result\[i\] = round(a\[i\])
420    ///
421    /// # Safety
422    ///
423    /// - `result` must have the same length as `a`
424    /// - Can handle empty slices
425    // SAFETY: Caller must satisfy the documented preconditions for slice validity
426    unsafe fn round(a: &[f32], result: &mut [f32]);
427}