Skip to main content

trueno/backends/scalar/
mod.rs

1//! Scalar (non-SIMD) backend implementation
2//!
3//! This is the portable baseline implementation that works on all platforms.
4//! It uses simple loops without any SIMD instructions.
5//!
6//! # Performance
7//!
8//! This backend provides correctness reference but no SIMD acceleration.
9//! Expected to be 8-32x slower than SIMD backends on operations with 1K+ elements.
10
11use super::VectorBackend;
12
13/// Scalar backend (portable, no SIMD)
14pub struct ScalarBackend;
15
16impl VectorBackend for ScalarBackend {
17    // SAFETY: This function is safe because:
18    // 1. All slice accesses are bounds-checked by Rust iterator/indexing
19    // 2. No raw pointer arithmetic is performed
20    // 3. Marked unsafe only to match VectorBackend trait interface
21    unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
22        for i in 0..a.len() {
23            result[i] = a[i] + b[i];
24        }
25    }
26
27    // SAFETY: This function is safe because:
28    // 1. All slice accesses are bounds-checked by Rust iterator/indexing
29    // 2. No raw pointer arithmetic is performed
30    // 3. Marked unsafe only to match VectorBackend trait interface
31    unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
32        for i in 0..a.len() {
33            result[i] = a[i] - b[i];
34        }
35    }
36
37    // SAFETY: This function is safe because:
38    // 1. All slice accesses are bounds-checked by Rust iterator/indexing
39    // 2. No raw pointer arithmetic is performed
40    // 3. Marked unsafe only to match VectorBackend trait interface
41    unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
42        for i in 0..a.len() {
43            result[i] = a[i] * b[i];
44        }
45    }
46
47    // SAFETY: This function is safe because:
48    // 1. All slice accesses are bounds-checked by Rust iterator/indexing
49    // 2. No raw pointer arithmetic is performed
50    // 3. Marked unsafe only to match VectorBackend trait interface
51    unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
52        for i in 0..a.len() {
53            result[i] = a[i] / b[i];
54        }
55    }
56
57    // SAFETY: This function is safe because:
58    // 1. All slice accesses are bounds-checked by Rust iterator/indexing
59    // 2. No raw pointer arithmetic is performed
60    // 3. Marked unsafe only to match VectorBackend trait interface
61    //
62    // OPTIMIZATION: 4× unrolling with mul_add for better ILP and auto-vectorization.
63    // This follows the cuda-tile pattern for improved throughput (spec: cuda-tile-behavior.md).
64    // Using f32::mul_add provides FMA semantics where available, improving accuracy.
65    #[inline(always)]
66    // SAFETY: caller ensures preconditions are met for this unsafe function
67    unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
68        contract_pre_dot_product!();
69        let len = a.len();
70        let chunks = len / 4;
71
72        // 4 independent accumulators for better ILP (cuda-tile inspired optimization)
73        let mut acc0 = 0.0f32;
74        let mut acc1 = 0.0f32;
75        let mut acc2 = 0.0f32;
76        let mut acc3 = 0.0f32;
77
78        // Process 4 elements at a time with independent accumulation chains
79        for i in 0..chunks {
80            let base = i * 4;
81            acc0 = a[base].mul_add(b[base], acc0);
82            acc1 = a[base + 1].mul_add(b[base + 1], acc1);
83            acc2 = a[base + 2].mul_add(b[base + 2], acc2);
84            acc3 = a[base + 3].mul_add(b[base + 3], acc3);
85        }
86
87        // Combine all 4 accumulators
88        let mut sum = (acc0 + acc1) + (acc2 + acc3);
89
90        // Handle remainder
91        for i in (chunks * 4)..len {
92            sum = a[i].mul_add(b[i], sum);
93        }
94
95        contract_post_dot_product_parity!(sum);
96        sum
97    }
98
99    // SAFETY: This function is safe because:
100    // 1. All slice accesses are bounds-checked by Rust iterator
101    // 2. No raw pointer arithmetic is performed
102    // 3. Marked unsafe only to match VectorBackend trait interface
103    unsafe fn sum(a: &[f32]) -> f32 {
104        let mut total = 0.0;
105        for &val in a {
106            total += val;
107        }
108        total
109    }
110
111    // SAFETY: This function is safe because:
112    // 1. All slice accesses are bounds-checked by Rust slicing/iteration
113    // 2. Caller must ensure slice is non-empty (a[0] access)
114    // 3. Marked unsafe only to match VectorBackend trait interface
115    unsafe fn max(a: &[f32]) -> f32 {
116        let mut maximum = a[0];
117        for &val in a.get(1..).unwrap_or(&[]) {
118            if val > maximum {
119                maximum = val;
120            }
121        }
122        maximum
123    }
124
125    // SAFETY: This function is safe because:
126    // 1. All slice accesses are bounds-checked by Rust slicing/iteration
127    // 2. Caller must ensure slice is non-empty (a[0] access)
128    // 3. Marked unsafe only to match VectorBackend trait interface
129    unsafe fn min(a: &[f32]) -> f32 {
130        let mut minimum = a[0];
131        for &val in a.get(1..).unwrap_or(&[]) {
132            if val < minimum {
133                minimum = val;
134            }
135        }
136        minimum
137    }
138
139    // SAFETY: This function is safe because:
140    // 1. All slice accesses are bounds-checked by Rust iterator
141    // 2. Caller must ensure slice is non-empty (a[0] access)
142    // 3. Marked unsafe only to match VectorBackend trait interface
143    unsafe fn argmax(a: &[f32]) -> usize {
144        let mut max_value = a[0];
145        let mut max_index = 0;
146        for (i, &val) in a.iter().enumerate() {
147            if val > max_value {
148                max_value = val;
149                max_index = i;
150            }
151        }
152        max_index
153    }
154
155    // SAFETY: This function is safe because:
156    // 1. All slice accesses are bounds-checked by Rust iterator
157    // 2. Caller must ensure slice is non-empty (a[0] access)
158    // 3. Marked unsafe only to match VectorBackend trait interface
159    unsafe fn argmin(a: &[f32]) -> usize {
160        let mut min_value = a[0];
161        let mut min_index = 0;
162        for (i, &val) in a.iter().enumerate() {
163            if val < min_value {
164                min_value = val;
165                min_index = i;
166            }
167        }
168        min_index
169    }
170
171    // SAFETY: This function is safe because:
172    // 1. All slice accesses are bounds-checked by Rust iterator
173    // 2. Kahan summation uses only safe floating-point arithmetic
174    // 3. Marked unsafe only to match VectorBackend trait interface
175    unsafe fn sum_kahan(a: &[f32]) -> f32 {
176        let mut sum = 0.0;
177        let mut c = 0.0; // Compensation for lost low-order bits
178
179        for &value in a {
180            let y = value - c; // Subtract the compensation
181            let t = sum + y; // Add to sum
182            c = (t - sum) - y; // Update compensation
183            sum = t; // Update sum
184        }
185
186        sum
187    }
188
189    // SAFETY: This function is safe because:
190    // 1. All slice accesses are bounds-checked by Rust iterator
191    // 2. Empty check prevents undefined behavior
192    // 3. Marked unsafe only to match VectorBackend trait interface
193    unsafe fn norm_l2(a: &[f32]) -> f32 {
194        if a.is_empty() {
195            return 0.0;
196        }
197
198        let mut sum_of_squares = 0.0;
199        for &val in a {
200            sum_of_squares += val * val;
201        }
202        sum_of_squares.sqrt()
203    }
204
205    // SAFETY: This function is safe because:
206    // 1. All slice accesses are bounds-checked by Rust iterator
207    // 2. Empty check prevents undefined behavior
208    // 3. Marked unsafe only to match VectorBackend trait interface
209    unsafe fn norm_l1(a: &[f32]) -> f32 {
210        if a.is_empty() {
211            return 0.0;
212        }
213
214        let mut sum = 0.0;
215        for &val in a {
216            sum += val.abs();
217        }
218        sum
219    }
220
221    // SAFETY: This function is safe because:
222    // 1. All slice accesses are bounds-checked by Rust iterator
223    // 2. Empty check prevents undefined behavior
224    // 3. Marked unsafe only to match VectorBackend trait interface
225    unsafe fn norm_linf(a: &[f32]) -> f32 {
226        if a.is_empty() {
227            return 0.0;
228        }
229
230        let mut max_val = 0.0_f32;
231        for &val in a {
232            let abs_val = val.abs();
233            if abs_val > max_val {
234                max_val = abs_val;
235            }
236        }
237        max_val
238    }
239
240    // SAFETY: This function is safe because:
241    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
242    // 2. No raw pointer arithmetic is performed
243    // 3. Marked unsafe only to match VectorBackend trait interface
244    unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
245        for (i, &val) in a.iter().enumerate() {
246            result[i] = val * scalar;
247        }
248    }
249
250    // SAFETY: This function is safe because:
251    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
252    // 2. No raw pointer arithmetic is performed
253    // 3. Marked unsafe only to match VectorBackend trait interface
254    unsafe fn abs(a: &[f32], result: &mut [f32]) {
255        for (i, &val) in a.iter().enumerate() {
256            result[i] = val.abs();
257        }
258    }
259
260    // SAFETY: This function is safe because:
261    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
262    // 2. No raw pointer arithmetic is performed
263    // 3. Marked unsafe only to match VectorBackend trait interface
264    unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
265        for (i, &val) in a.iter().enumerate() {
266            result[i] = val.max(min_val).min(max_val);
267        }
268    }
269
270    // SAFETY: This function is safe because:
271    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate/zip
272    // 2. No raw pointer arithmetic is performed
273    // 3. Marked unsafe only to match VectorBackend trait interface
274    unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
275        for (i, (&a_val, &b_val)) in a.iter().zip(b.iter()).enumerate() {
276            // result = a + t * (b - a)
277            result[i] = a_val + t * (b_val - a_val);
278        }
279    }
280
281    // SAFETY: This function is safe because:
282    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate/zip
283    // 2. No raw pointer arithmetic is performed
284    // 3. Marked unsafe only to match VectorBackend trait interface
285    unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
286        for (i, ((&a_val, &b_val), &c_val)) in a.iter().zip(b.iter()).zip(c.iter()).enumerate() {
287            // result = a * b + c
288            result[i] = a_val * b_val + c_val;
289        }
290    }
291
292    // SAFETY: This function is safe because:
293    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
294    // 2. No raw pointer arithmetic is performed
295    // 3. Marked unsafe only to match VectorBackend trait interface
296    unsafe fn relu(a: &[f32], result: &mut [f32]) {
297        for (i, &val) in a.iter().enumerate() {
298            result[i] = if val > 0.0 { val } else { 0.0 };
299        }
300    }
301
302    // SAFETY: This function is safe because:
303    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
304    // 2. No raw pointer arithmetic is performed
305    // 3. Marked unsafe only to match VectorBackend trait interface
306    unsafe fn exp(a: &[f32], result: &mut [f32]) {
307        for (i, &val) in a.iter().enumerate() {
308            result[i] = val.exp();
309        }
310    }
311
312    // SAFETY: This function is safe because:
313    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
314    // 2. Clamping prevents exp() overflow
315    // 3. Marked unsafe only to match VectorBackend trait interface
316    unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
317        contract_pre_sigmoid!(a);
318        for (i, &val) in a.iter().enumerate() {
319            // Handle extreme values for numerical stability
320            result[i] = if val < -50.0 {
321                0.0 // exp(-x) would overflow, but sigmoid approaches 0
322            } else if val > 50.0 {
323                1.0 // exp(-x) underflows to 0, sigmoid approaches 1
324            } else {
325                1.0 / (1.0 + (-val).exp())
326            };
327        }
328    }
329
330    // SAFETY: This function is safe because:
331    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
332    // 2. No raw pointer arithmetic is performed
333    // 3. Marked unsafe only to match VectorBackend trait interface
334    unsafe fn gelu(a: &[f32], result: &mut [f32]) {
335        // GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
336        contract_pre_gelu!(a);
337        const SQRT_2_OVER_PI: f32 = 0.797_884_6;
338        const COEFF: f32 = 0.044715;
339
340        for (i, &x) in a.iter().enumerate() {
341            let x3 = x * x * x;
342            let inner = SQRT_2_OVER_PI * (x + COEFF * x3);
343            result[i] = 0.5 * x * (1.0 + inner.tanh());
344        }
345        contract_post_gelu!(result);
346    }
347
348    // SAFETY: This function is safe because:
349    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
350    // 2. Clamping prevents exp() overflow
351    // 3. Marked unsafe only to match VectorBackend trait interface
352    unsafe fn swish(a: &[f32], result: &mut [f32]) {
353        contract_pre_silu!();
354        // Swish: x * sigmoid(x) = x / (1 + exp(-x))
355        for (i, &x) in a.iter().enumerate() {
356            if x < -50.0 {
357                result[i] = 0.0; // x * 0 = 0
358            } else if x > 50.0 {
359                result[i] = x; // x * 1 = x
360            } else {
361                let sigmoid = 1.0 / (1.0 + (-x).exp());
362                result[i] = x * sigmoid;
363            }
364        }
365        contract_post_silu!(result);
366    }
367
368    // SAFETY: This function is safe because:
369    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
370    // 2. No raw pointer arithmetic is performed
371    // 3. Marked unsafe only to match VectorBackend trait interface
372    unsafe fn tanh(a: &[f32], result: &mut [f32]) {
373        // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
374        for (i, &x) in a.iter().enumerate() {
375            result[i] = x.tanh();
376        }
377    }
378
379    // SAFETY: This function is safe because:
380    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
381    // 2. No raw pointer arithmetic is performed
382    // 3. Marked unsafe only to match VectorBackend trait interface
383    unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
384        for (i, &val) in a.iter().enumerate() {
385            result[i] = val.sqrt();
386        }
387    }
388
389    // SAFETY: This function is safe because:
390    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
391    // 2. No raw pointer arithmetic is performed
392    // 3. Marked unsafe only to match VectorBackend trait interface
393    unsafe fn recip(a: &[f32], result: &mut [f32]) {
394        for (i, &val) in a.iter().enumerate() {
395            result[i] = val.recip();
396        }
397    }
398
399    // SAFETY: This function is safe because:
400    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
401    // 2. No raw pointer arithmetic is performed
402    // 3. Marked unsafe only to match VectorBackend trait interface
403    unsafe fn ln(a: &[f32], result: &mut [f32]) {
404        for (i, &val) in a.iter().enumerate() {
405            result[i] = val.ln();
406        }
407    }
408
409    // SAFETY: This function is safe because:
410    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
411    // 2. No raw pointer arithmetic is performed
412    // 3. Marked unsafe only to match VectorBackend trait interface
413    unsafe fn log2(a: &[f32], result: &mut [f32]) {
414        for (i, &val) in a.iter().enumerate() {
415            result[i] = val.log2();
416        }
417    }
418
419    // SAFETY: This function is safe because:
420    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
421    // 2. No raw pointer arithmetic is performed
422    // 3. Marked unsafe only to match VectorBackend trait interface
423    unsafe fn log10(a: &[f32], result: &mut [f32]) {
424        for (i, &val) in a.iter().enumerate() {
425            result[i] = val.log10();
426        }
427    }
428
429    // SAFETY: This function is safe because:
430    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
431    // 2. No raw pointer arithmetic is performed
432    // 3. Marked unsafe only to match VectorBackend trait interface
433    unsafe fn sin(a: &[f32], result: &mut [f32]) {
434        for (i, &val) in a.iter().enumerate() {
435            result[i] = val.sin();
436        }
437    }
438
439    // SAFETY: This function is safe because:
440    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
441    // 2. No raw pointer arithmetic is performed
442    // 3. Marked unsafe only to match VectorBackend trait interface
443    unsafe fn cos(a: &[f32], result: &mut [f32]) {
444        for (i, &val) in a.iter().enumerate() {
445            result[i] = val.cos();
446        }
447    }
448
449    // SAFETY: This function is safe because:
450    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
451    // 2. No raw pointer arithmetic is performed
452    // 3. Marked unsafe only to match VectorBackend trait interface
453    unsafe fn tan(a: &[f32], result: &mut [f32]) {
454        for (i, &val) in a.iter().enumerate() {
455            result[i] = val.tan();
456        }
457    }
458
459    // SAFETY: This function is safe because:
460    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
461    // 2. No raw pointer arithmetic is performed
462    // 3. Marked unsafe only to match VectorBackend trait interface
463    unsafe fn floor(a: &[f32], result: &mut [f32]) {
464        for (i, &val) in a.iter().enumerate() {
465            result[i] = val.floor();
466        }
467    }
468
469    // SAFETY: This function is safe because:
470    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
471    // 2. No raw pointer arithmetic is performed
472    // 3. Marked unsafe only to match VectorBackend trait interface
473    unsafe fn ceil(a: &[f32], result: &mut [f32]) {
474        for (i, &val) in a.iter().enumerate() {
475            result[i] = val.ceil();
476        }
477    }
478
479    // SAFETY: This function is safe because:
480    // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
481    // 2. No raw pointer arithmetic is performed
482    // 3. Marked unsafe only to match VectorBackend trait interface
483    unsafe fn round(a: &[f32], result: &mut [f32]) {
484        for (i, &val) in a.iter().enumerate() {
485            result[i] = val.round();
486        }
487    }
488}
489
490#[cfg(test)]
491mod tests;