trueno/backends/mod.rs
1//! Backend implementations for different SIMD instruction sets
2//!
3//! This module contains the actual SIMD implementations for each backend.
4//! All backends implement the same trait-based interface to ensure API consistency.
5//!
6//! # Safety
7//!
8//! All `unsafe` code is isolated within backend implementations. The public API
9//! remains 100% safe.
10//!
11//! # Backends
12//!
13//! - `scalar`: Portable baseline implementation (no SIMD)
14//! - `sse2`: x86_64 baseline SIMD (128-bit)
15//! - `avx2`: x86_64 advanced SIMD (256-bit with FMA)
16//! - `avx512`: x86_64 maximum SIMD (512-bit)
17//! - `neon`: ARM SIMD (128-bit)
18//! - `wasm`: WebAssembly SIMD128
19
20pub mod q4k;
21pub mod q6k;
22pub mod scalar;
23
24#[cfg(target_arch = "x86_64")]
25pub mod sse2;
26
27#[cfg(target_arch = "x86_64")]
28pub mod avx2;
29
30#[cfg(target_arch = "x86_64")]
31#[cfg(test)]
32mod avx2_tests;
33
34#[cfg(any(target_arch = "aarch64", target_arch = "arm"))]
35pub mod neon;
36
37#[cfg(target_arch = "wasm32")]
38pub mod wasm;
39
40// GPU module - always available for TensorView/PartitionView abstractions
41// Actual GPU compute requires "gpu" feature
42pub mod gpu;
43
44#[cfg(target_arch = "x86_64")]
45pub mod avx512;
46
47/// Backend trait defining common operations
48///
49/// All backend implementations must implement this trait to ensure
50/// consistent behavior across different SIMD instruction sets.
51///
52/// # Safety
53///
54/// Implementations may use unsafe SIMD intrinsics. Callers must ensure:
55/// - Input slices are valid
56/// - Result slice has sufficient capacity
57/// - Slices `a` and `b` have the same length
58pub trait VectorBackend {
59 /// Element-wise addition: a\[i\] + b\[i\]
60 ///
61 /// # Safety
62 ///
63 /// - `a` and `b` must have the same length
64 /// - `result` must have length >= `a.len()`
65 // SAFETY: Caller must satisfy the documented preconditions for slice validity
66 unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]);
67
68 /// Element-wise subtraction: a\[i\] - b\[i\]
69 ///
70 /// # Safety
71 ///
72 /// - `a` and `b` must have the same length
73 /// - `result` must have length >= `a.len()`
74 // SAFETY: Caller must satisfy the documented preconditions for slice validity
75 unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]);
76
77 /// Element-wise multiplication: a\[i\] * b\[i\]
78 ///
79 /// # Safety
80 ///
81 /// - `a` and `b` must have the same length
82 /// - `result` must have length >= `a.len()`
83 // SAFETY: Caller must satisfy the documented preconditions for slice validity
84 unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]);
85
86 /// Element-wise division: a\[i\] / b\[i\]
87 ///
88 /// # Safety
89 ///
90 /// - `a` and `b` must have the same length
91 /// - `result` must have length >= `a.len()`
92 // SAFETY: Caller must satisfy the documented preconditions for slice validity
93 unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]);
94
95 /// Dot product: sum(a\[i\] * b\[i\])
96 ///
97 /// # Safety
98 ///
99 /// - `a` and `b` must have the same length
100 // SAFETY: Caller must satisfy the documented preconditions for slice validity
101 unsafe fn dot(a: &[f32], b: &[f32]) -> f32;
102
103 /// Sum reduction: sum(a\[i\])
104 ///
105 /// # Safety
106 ///
107 /// - `a` must not be empty
108 // SAFETY: Caller must satisfy the documented preconditions for slice validity
109 unsafe fn sum(a: &[f32]) -> f32;
110
111 /// Max reduction: max(a\[i\])
112 ///
113 /// # Safety
114 ///
115 /// - `a` must not be empty
116 // SAFETY: Caller must satisfy the documented preconditions for slice validity
117 unsafe fn max(a: &[f32]) -> f32;
118
119 /// Min reduction: min(a\[i\])
120 ///
121 /// # Safety
122 ///
123 /// - `a` must not be empty
124 // SAFETY: Caller must satisfy the documented preconditions for slice validity
125 unsafe fn min(a: &[f32]) -> f32;
126
127 /// Argmax: index of maximum value
128 ///
129 /// Returns the index of the first occurrence of the maximum value.
130 ///
131 /// # Safety
132 ///
133 /// - `a` must not be empty
134 // SAFETY: Caller must satisfy the documented preconditions for slice validity
135 unsafe fn argmax(a: &[f32]) -> usize;
136
137 /// Argmin: index of minimum value
138 ///
139 /// Returns the index of the first occurrence of the minimum value.
140 ///
141 /// # Safety
142 ///
143 /// - `a` must not be empty
144 // SAFETY: Caller must satisfy the documented preconditions for slice validity
145 unsafe fn argmin(a: &[f32]) -> usize;
146
147 /// Kahan summation: numerically stable sum(a\[i\])
148 ///
149 /// Uses the Kahan summation algorithm to reduce floating-point rounding errors
150 /// when summing many numbers. Tracks a running compensation for lost low-order bits.
151 ///
152 /// # Safety
153 ///
154 /// - Can handle empty slice (returns 0.0)
155 // SAFETY: Caller must satisfy the documented preconditions for slice validity
156 unsafe fn sum_kahan(a: &[f32]) -> f32;
157
158 /// L2 norm (Euclidean norm): sqrt(sum(a\[i\]^2))
159 ///
160 /// Computes the Euclidean length of the vector. This is equivalent to sqrt(dot(a, a)).
161 ///
162 /// # Safety
163 ///
164 /// - Can handle empty slice (returns 0.0)
165 // SAFETY: Caller must satisfy the documented preconditions for slice validity
166 unsafe fn norm_l2(a: &[f32]) -> f32;
167
168 /// L1 norm (Manhattan norm): sum(|a\[i\]|)
169 ///
170 /// Computes the sum of absolute values of all elements.
171 /// Used in machine learning (L1 regularization), distance metrics, and sparse modeling.
172 ///
173 /// # Safety
174 ///
175 /// - Can handle empty slice (returns 0.0)
176 // SAFETY: Caller must satisfy the documented preconditions for slice validity
177 unsafe fn norm_l1(a: &[f32]) -> f32;
178
179 /// L-infinity norm (maximum absolute value): max(|a\[i\]|)
180 ///
181 /// Computes the maximum absolute value of all elements.
182 /// Used in optimization (constraint checking), numerical analysis, and error bounds.
183 ///
184 /// # Safety
185 ///
186 /// - Can handle empty slice (returns 0.0)
187 // SAFETY: Caller must satisfy the documented preconditions for slice validity
188 unsafe fn norm_linf(a: &[f32]) -> f32;
189
190 /// Scalar multiplication: result\[i\] = a\[i\] * scalar
191 ///
192 /// Multiplies all elements by a scalar value.
193 /// Used in vector scaling, normalization, and linear transformations.
194 ///
195 /// # Safety
196 ///
197 /// - `result` must have the same length as `a`
198 /// - Can handle empty slice
199 // SAFETY: Caller must satisfy the documented preconditions for slice validity
200 unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]);
201
202 /// Absolute value: result\[i\] = |a\[i\]|
203 ///
204 /// Computes the absolute value of each element.
205 /// Used in distance metrics (L1 norm), numerical stability, and signal processing.
206 ///
207 /// # Safety
208 ///
209 /// - `result` must have the same length as `a`
210 /// - Can handle empty slice
211 // SAFETY: Caller must satisfy the documented preconditions for slice validity
212 unsafe fn abs(a: &[f32], result: &mut [f32]);
213
214 /// Clamp elements to range [min_val, max_val]: result\[i\] = max(min_val, min(a\[i\], max_val))
215 ///
216 /// Constrains each element to the specified range.
217 /// Used in neural networks (gradient clipping), graphics (color clamping), and signal processing.
218 ///
219 /// # Safety
220 ///
221 /// - `result` must have the same length as `a`
222 /// - Can handle empty slice
223 /// - Assumes min_val <= max_val (caller must validate)
224 // SAFETY: Caller must satisfy the documented preconditions for slice validity
225 unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]);
226
227 /// Linear interpolation: result\[i\] = a\[i\] + t * (b\[i\] - a\[i\])
228 ///
229 /// Computes element-wise linear interpolation between two vectors.
230 /// When t=0, returns a; when t=1, returns b; values outside \[0,1\] extrapolate.
231 /// Used in graphics, animation, neural networks, and signal processing.
232 ///
233 /// # Safety
234 ///
235 /// - `a` and `b` must have the same length
236 /// - `result` must have the same length as `a`
237 /// - Can handle empty slices
238 // SAFETY: Caller must satisfy the documented preconditions for slice validity
239 unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]);
240
241 /// Fused multiply-add: result\[i\] = a\[i\] * b\[i\] + c\[i\]
242 ///
243 /// Computes element-wise fused multiply-add operation.
244 /// On hardware with FMA support, this is a single instruction with better performance
245 /// and numerical accuracy (no intermediate rounding).
246 /// Used in neural networks, matrix multiplication, and scientific computing.
247 ///
248 /// # Safety
249 ///
250 /// - `a`, `b`, and `c` must all have the same length
251 /// - `result` must have the same length as `a`
252 /// - Can handle empty slices
253 // SAFETY: Caller must satisfy the documented preconditions for slice validity
254 unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]);
255
256 /// ReLU activation: result\[i\] = max(0, a\[i\])
257 ///
258 /// Rectified Linear Unit - the most common activation function in neural networks.
259 /// Sets negative values to zero, passes positive values unchanged.
260 ///
261 /// # Safety
262 ///
263 /// - `result` must have the same length as `a`
264 /// - Can handle empty slices
265 // SAFETY: Caller must satisfy the documented preconditions for slice validity
266 unsafe fn relu(a: &[f32], result: &mut [f32]);
267
268 /// Exponential function: result\[i\] = exp(a\[i\])
269 ///
270 /// Computes e^x for each element using range reduction for numerical accuracy.
271 /// Foundation for sigmoid, softmax, GELU, and other activation functions.
272 ///
273 /// # Safety
274 ///
275 /// - `result` must have the same length as `a`
276 /// - Can handle empty slices
277 // SAFETY: Caller must satisfy the documented preconditions for slice validity
278 unsafe fn exp(a: &[f32], result: &mut [f32]);
279
280 /// Sigmoid activation: result\[i\] = 1 / (1 + exp(-a\[i\]))
281 ///
282 /// Logistic sigmoid function - maps inputs to (0, 1) range.
283 /// Used in binary classification and as gating mechanism.
284 ///
285 /// # Safety
286 ///
287 /// - `result` must have the same length as `a`
288 /// - Can handle empty slices
289 // SAFETY: Caller must satisfy the documented preconditions for slice validity
290 unsafe fn sigmoid(a: &[f32], result: &mut [f32]);
291
292 /// GELU activation: result\[i\] = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
293 ///
294 /// Gaussian Error Linear Unit - smooth non-monotonic activation.
295 /// Used in BERT, GPT, and modern transformers.
296 ///
297 /// # Safety
298 ///
299 /// - `result` must have the same length as `a`
300 /// - Can handle empty slices
301 // SAFETY: Caller must satisfy the documented preconditions for slice validity
302 unsafe fn gelu(a: &[f32], result: &mut [f32]);
303
304 /// Swish activation: result\[i\] = x * sigmoid(x) = x / (1 + exp(-x))
305 ///
306 /// Self-gated activation function (also called SiLU).
307 /// Used in EfficientNet, MobileNetV3.
308 ///
309 /// # Safety
310 ///
311 /// - `result` must have the same length as `a`
312 /// - Can handle empty slices
313 // SAFETY: Caller must satisfy the documented preconditions for slice validity
314 unsafe fn swish(a: &[f32], result: &mut [f32]);
315
316 /// Hyperbolic tangent activation: result\[i\] = tanh(a\[i\]) = (exp(2x) - 1) / (exp(2x) + 1)
317 ///
318 /// Hyperbolic tangent - maps inputs to (-1, 1) range.
319 /// Classic activation function from early neural networks.
320 /// Used in RNNs, LSTMs, and as smooth alternative to ReLU.
321 ///
322 /// # Safety
323 ///
324 /// - `result` must have the same length as `a`
325 /// - Can handle empty slices
326 // SAFETY: Caller must satisfy the documented preconditions for slice validity
327 unsafe fn tanh(a: &[f32], result: &mut [f32]);
328
329 /// Square root: result\[i\] = sqrt(a\[i\])
330 ///
331 /// # Safety
332 ///
333 /// - `result` must have the same length as `a`
334 /// - Can handle empty slices
335 // SAFETY: Caller must satisfy the documented preconditions for slice validity
336 unsafe fn sqrt(a: &[f32], result: &mut [f32]);
337
338 /// Reciprocal: result\[i\] = 1 / a\[i\]
339 ///
340 /// # Safety
341 ///
342 /// - `result` must have the same length as `a`
343 /// - Can handle empty slices
344 // SAFETY: Caller must satisfy the documented preconditions for slice validity
345 unsafe fn recip(a: &[f32], result: &mut [f32]);
346
347 /// Natural logarithm: result\[i\] = ln(a\[i\])
348 ///
349 /// # Safety
350 ///
351 /// - `result` must have the same length as `a`
352 /// - Can handle empty slices
353 // SAFETY: Caller must satisfy the documented preconditions for slice validity
354 unsafe fn ln(a: &[f32], result: &mut [f32]);
355
356 /// Base-2 logarithm: result\[i\] = log2(a\[i\])
357 ///
358 /// # Safety
359 ///
360 /// - `result` must have the same length as `a`
361 /// - Can handle empty slices
362 // SAFETY: Caller must satisfy the documented preconditions for slice validity
363 unsafe fn log2(a: &[f32], result: &mut [f32]);
364
365 /// Base-10 logarithm: result\[i\] = log10(a\[i\])
366 ///
367 /// # Safety
368 ///
369 /// - `result` must have the same length as `a`
370 /// - Can handle empty slices
371 // SAFETY: Caller must satisfy the documented preconditions for slice validity
372 unsafe fn log10(a: &[f32], result: &mut [f32]);
373
374 /// Sine: result\[i\] = sin(a\[i\])
375 ///
376 /// # Safety
377 ///
378 /// - `result` must have the same length as `a`
379 /// - Can handle empty slices
380 // SAFETY: Caller must satisfy the documented preconditions for slice validity
381 unsafe fn sin(a: &[f32], result: &mut [f32]);
382
383 /// Cosine: result\[i\] = cos(a\[i\])
384 ///
385 /// # Safety
386 ///
387 /// - `result` must have the same length as `a`
388 /// - Can handle empty slices
389 // SAFETY: Caller must satisfy the documented preconditions for slice validity
390 unsafe fn cos(a: &[f32], result: &mut [f32]);
391
392 /// Tangent: result\[i\] = tan(a\[i\])
393 ///
394 /// # Safety
395 ///
396 /// - `result` must have the same length as `a`
397 /// - Can handle empty slices
398 // SAFETY: Caller must satisfy the documented preconditions for slice validity
399 unsafe fn tan(a: &[f32], result: &mut [f32]);
400
401 /// Floor: result\[i\] = floor(a\[i\])
402 ///
403 /// # Safety
404 ///
405 /// - `result` must have the same length as `a`
406 /// - Can handle empty slices
407 // SAFETY: Caller must satisfy the documented preconditions for slice validity
408 unsafe fn floor(a: &[f32], result: &mut [f32]);
409
410 /// Ceiling: result\[i\] = ceil(a\[i\])
411 ///
412 /// # Safety
413 ///
414 /// - `result` must have the same length as `a`
415 /// - Can handle empty slices
416 // SAFETY: Caller must satisfy the documented preconditions for slice validity
417 unsafe fn ceil(a: &[f32], result: &mut [f32]);
418
419 /// Round: result\[i\] = round(a\[i\])
420 ///
421 /// # Safety
422 ///
423 /// - `result` must have the same length as `a`
424 /// - Can handle empty slices
425 // SAFETY: Caller must satisfy the documented preconditions for slice validity
426 unsafe fn round(a: &[f32], result: &mut [f32]);
427}