trueno/backends/scalar/mod.rs
1//! Scalar (non-SIMD) backend implementation
2//!
3//! This is the portable baseline implementation that works on all platforms.
4//! It uses simple loops without any SIMD instructions.
5//!
6//! # Performance
7//!
8//! This backend provides correctness reference but no SIMD acceleration.
9//! Expected to be 8-32x slower than SIMD backends on operations with 1K+ elements.
10
11use super::VectorBackend;
12
13/// Scalar backend (portable, no SIMD)
14pub struct ScalarBackend;
15
16impl VectorBackend for ScalarBackend {
17 // SAFETY: This function is safe because:
18 // 1. All slice accesses are bounds-checked by Rust iterator/indexing
19 // 2. No raw pointer arithmetic is performed
20 // 3. Marked unsafe only to match VectorBackend trait interface
21 unsafe fn add(a: &[f32], b: &[f32], result: &mut [f32]) {
22 for i in 0..a.len() {
23 result[i] = a[i] + b[i];
24 }
25 }
26
27 // SAFETY: This function is safe because:
28 // 1. All slice accesses are bounds-checked by Rust iterator/indexing
29 // 2. No raw pointer arithmetic is performed
30 // 3. Marked unsafe only to match VectorBackend trait interface
31 unsafe fn sub(a: &[f32], b: &[f32], result: &mut [f32]) {
32 for i in 0..a.len() {
33 result[i] = a[i] - b[i];
34 }
35 }
36
37 // SAFETY: This function is safe because:
38 // 1. All slice accesses are bounds-checked by Rust iterator/indexing
39 // 2. No raw pointer arithmetic is performed
40 // 3. Marked unsafe only to match VectorBackend trait interface
41 unsafe fn mul(a: &[f32], b: &[f32], result: &mut [f32]) {
42 for i in 0..a.len() {
43 result[i] = a[i] * b[i];
44 }
45 }
46
47 // SAFETY: This function is safe because:
48 // 1. All slice accesses are bounds-checked by Rust iterator/indexing
49 // 2. No raw pointer arithmetic is performed
50 // 3. Marked unsafe only to match VectorBackend trait interface
51 unsafe fn div(a: &[f32], b: &[f32], result: &mut [f32]) {
52 for i in 0..a.len() {
53 result[i] = a[i] / b[i];
54 }
55 }
56
57 // SAFETY: This function is safe because:
58 // 1. All slice accesses are bounds-checked by Rust iterator/indexing
59 // 2. No raw pointer arithmetic is performed
60 // 3. Marked unsafe only to match VectorBackend trait interface
61 //
62 // OPTIMIZATION: 4× unrolling with mul_add for better ILP and auto-vectorization.
63 // This follows the cuda-tile pattern for improved throughput (spec: cuda-tile-behavior.md).
64 // Using f32::mul_add provides FMA semantics where available, improving accuracy.
65 #[inline(always)]
66 // SAFETY: caller ensures preconditions are met for this unsafe function
67 unsafe fn dot(a: &[f32], b: &[f32]) -> f32 {
68 contract_pre_dot_product!();
69 let len = a.len();
70 let chunks = len / 4;
71
72 // 4 independent accumulators for better ILP (cuda-tile inspired optimization)
73 let mut acc0 = 0.0f32;
74 let mut acc1 = 0.0f32;
75 let mut acc2 = 0.0f32;
76 let mut acc3 = 0.0f32;
77
78 // Process 4 elements at a time with independent accumulation chains
79 for i in 0..chunks {
80 let base = i * 4;
81 acc0 = a[base].mul_add(b[base], acc0);
82 acc1 = a[base + 1].mul_add(b[base + 1], acc1);
83 acc2 = a[base + 2].mul_add(b[base + 2], acc2);
84 acc3 = a[base + 3].mul_add(b[base + 3], acc3);
85 }
86
87 // Combine all 4 accumulators
88 let mut sum = (acc0 + acc1) + (acc2 + acc3);
89
90 // Handle remainder
91 for i in (chunks * 4)..len {
92 sum = a[i].mul_add(b[i], sum);
93 }
94
95 contract_post_dot_product_parity!(sum);
96 sum
97 }
98
99 // SAFETY: This function is safe because:
100 // 1. All slice accesses are bounds-checked by Rust iterator
101 // 2. No raw pointer arithmetic is performed
102 // 3. Marked unsafe only to match VectorBackend trait interface
103 unsafe fn sum(a: &[f32]) -> f32 {
104 let mut total = 0.0;
105 for &val in a {
106 total += val;
107 }
108 total
109 }
110
111 // SAFETY: This function is safe because:
112 // 1. All slice accesses are bounds-checked by Rust slicing/iteration
113 // 2. Caller must ensure slice is non-empty (a[0] access)
114 // 3. Marked unsafe only to match VectorBackend trait interface
115 unsafe fn max(a: &[f32]) -> f32 {
116 let mut maximum = a[0];
117 for &val in a.get(1..).unwrap_or(&[]) {
118 if val > maximum {
119 maximum = val;
120 }
121 }
122 maximum
123 }
124
125 // SAFETY: This function is safe because:
126 // 1. All slice accesses are bounds-checked by Rust slicing/iteration
127 // 2. Caller must ensure slice is non-empty (a[0] access)
128 // 3. Marked unsafe only to match VectorBackend trait interface
129 unsafe fn min(a: &[f32]) -> f32 {
130 let mut minimum = a[0];
131 for &val in a.get(1..).unwrap_or(&[]) {
132 if val < minimum {
133 minimum = val;
134 }
135 }
136 minimum
137 }
138
139 // SAFETY: This function is safe because:
140 // 1. All slice accesses are bounds-checked by Rust iterator
141 // 2. Caller must ensure slice is non-empty (a[0] access)
142 // 3. Marked unsafe only to match VectorBackend trait interface
143 unsafe fn argmax(a: &[f32]) -> usize {
144 let mut max_value = a[0];
145 let mut max_index = 0;
146 for (i, &val) in a.iter().enumerate() {
147 if val > max_value {
148 max_value = val;
149 max_index = i;
150 }
151 }
152 max_index
153 }
154
155 // SAFETY: This function is safe because:
156 // 1. All slice accesses are bounds-checked by Rust iterator
157 // 2. Caller must ensure slice is non-empty (a[0] access)
158 // 3. Marked unsafe only to match VectorBackend trait interface
159 unsafe fn argmin(a: &[f32]) -> usize {
160 let mut min_value = a[0];
161 let mut min_index = 0;
162 for (i, &val) in a.iter().enumerate() {
163 if val < min_value {
164 min_value = val;
165 min_index = i;
166 }
167 }
168 min_index
169 }
170
171 // SAFETY: This function is safe because:
172 // 1. All slice accesses are bounds-checked by Rust iterator
173 // 2. Kahan summation uses only safe floating-point arithmetic
174 // 3. Marked unsafe only to match VectorBackend trait interface
175 unsafe fn sum_kahan(a: &[f32]) -> f32 {
176 let mut sum = 0.0;
177 let mut c = 0.0; // Compensation for lost low-order bits
178
179 for &value in a {
180 let y = value - c; // Subtract the compensation
181 let t = sum + y; // Add to sum
182 c = (t - sum) - y; // Update compensation
183 sum = t; // Update sum
184 }
185
186 sum
187 }
188
189 // SAFETY: This function is safe because:
190 // 1. All slice accesses are bounds-checked by Rust iterator
191 // 2. Empty check prevents undefined behavior
192 // 3. Marked unsafe only to match VectorBackend trait interface
193 unsafe fn norm_l2(a: &[f32]) -> f32 {
194 if a.is_empty() {
195 return 0.0;
196 }
197
198 let mut sum_of_squares = 0.0;
199 for &val in a {
200 sum_of_squares += val * val;
201 }
202 sum_of_squares.sqrt()
203 }
204
205 // SAFETY: This function is safe because:
206 // 1. All slice accesses are bounds-checked by Rust iterator
207 // 2. Empty check prevents undefined behavior
208 // 3. Marked unsafe only to match VectorBackend trait interface
209 unsafe fn norm_l1(a: &[f32]) -> f32 {
210 if a.is_empty() {
211 return 0.0;
212 }
213
214 let mut sum = 0.0;
215 for &val in a {
216 sum += val.abs();
217 }
218 sum
219 }
220
221 // SAFETY: This function is safe because:
222 // 1. All slice accesses are bounds-checked by Rust iterator
223 // 2. Empty check prevents undefined behavior
224 // 3. Marked unsafe only to match VectorBackend trait interface
225 unsafe fn norm_linf(a: &[f32]) -> f32 {
226 if a.is_empty() {
227 return 0.0;
228 }
229
230 let mut max_val = 0.0_f32;
231 for &val in a {
232 let abs_val = val.abs();
233 if abs_val > max_val {
234 max_val = abs_val;
235 }
236 }
237 max_val
238 }
239
240 // SAFETY: This function is safe because:
241 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
242 // 2. No raw pointer arithmetic is performed
243 // 3. Marked unsafe only to match VectorBackend trait interface
244 unsafe fn scale(a: &[f32], scalar: f32, result: &mut [f32]) {
245 for (i, &val) in a.iter().enumerate() {
246 result[i] = val * scalar;
247 }
248 }
249
250 // SAFETY: This function is safe because:
251 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
252 // 2. No raw pointer arithmetic is performed
253 // 3. Marked unsafe only to match VectorBackend trait interface
254 unsafe fn abs(a: &[f32], result: &mut [f32]) {
255 for (i, &val) in a.iter().enumerate() {
256 result[i] = val.abs();
257 }
258 }
259
260 // SAFETY: This function is safe because:
261 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
262 // 2. No raw pointer arithmetic is performed
263 // 3. Marked unsafe only to match VectorBackend trait interface
264 unsafe fn clamp(a: &[f32], min_val: f32, max_val: f32, result: &mut [f32]) {
265 for (i, &val) in a.iter().enumerate() {
266 result[i] = val.max(min_val).min(max_val);
267 }
268 }
269
270 // SAFETY: This function is safe because:
271 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate/zip
272 // 2. No raw pointer arithmetic is performed
273 // 3. Marked unsafe only to match VectorBackend trait interface
274 unsafe fn lerp(a: &[f32], b: &[f32], t: f32, result: &mut [f32]) {
275 for (i, (&a_val, &b_val)) in a.iter().zip(b.iter()).enumerate() {
276 // result = a + t * (b - a)
277 result[i] = a_val + t * (b_val - a_val);
278 }
279 }
280
281 // SAFETY: This function is safe because:
282 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate/zip
283 // 2. No raw pointer arithmetic is performed
284 // 3. Marked unsafe only to match VectorBackend trait interface
285 unsafe fn fma(a: &[f32], b: &[f32], c: &[f32], result: &mut [f32]) {
286 for (i, ((&a_val, &b_val), &c_val)) in a.iter().zip(b.iter()).zip(c.iter()).enumerate() {
287 // result = a * b + c
288 result[i] = a_val * b_val + c_val;
289 }
290 }
291
292 // SAFETY: This function is safe because:
293 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
294 // 2. No raw pointer arithmetic is performed
295 // 3. Marked unsafe only to match VectorBackend trait interface
296 unsafe fn relu(a: &[f32], result: &mut [f32]) {
297 for (i, &val) in a.iter().enumerate() {
298 result[i] = if val > 0.0 { val } else { 0.0 };
299 }
300 }
301
302 // SAFETY: This function is safe because:
303 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
304 // 2. No raw pointer arithmetic is performed
305 // 3. Marked unsafe only to match VectorBackend trait interface
306 unsafe fn exp(a: &[f32], result: &mut [f32]) {
307 for (i, &val) in a.iter().enumerate() {
308 result[i] = val.exp();
309 }
310 }
311
312 // SAFETY: This function is safe because:
313 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
314 // 2. Clamping prevents exp() overflow
315 // 3. Marked unsafe only to match VectorBackend trait interface
316 unsafe fn sigmoid(a: &[f32], result: &mut [f32]) {
317 contract_pre_sigmoid!(a);
318 for (i, &val) in a.iter().enumerate() {
319 // Handle extreme values for numerical stability
320 result[i] = if val < -50.0 {
321 0.0 // exp(-x) would overflow, but sigmoid approaches 0
322 } else if val > 50.0 {
323 1.0 // exp(-x) underflows to 0, sigmoid approaches 1
324 } else {
325 1.0 / (1.0 + (-val).exp())
326 };
327 }
328 }
329
330 // SAFETY: This function is safe because:
331 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
332 // 2. No raw pointer arithmetic is performed
333 // 3. Marked unsafe only to match VectorBackend trait interface
334 unsafe fn gelu(a: &[f32], result: &mut [f32]) {
335 // GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³)))
336 contract_pre_gelu!(a);
337 const SQRT_2_OVER_PI: f32 = 0.797_884_6;
338 const COEFF: f32 = 0.044715;
339
340 for (i, &x) in a.iter().enumerate() {
341 let x3 = x * x * x;
342 let inner = SQRT_2_OVER_PI * (x + COEFF * x3);
343 result[i] = 0.5 * x * (1.0 + inner.tanh());
344 }
345 contract_post_gelu!(result);
346 }
347
348 // SAFETY: This function is safe because:
349 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
350 // 2. Clamping prevents exp() overflow
351 // 3. Marked unsafe only to match VectorBackend trait interface
352 unsafe fn swish(a: &[f32], result: &mut [f32]) {
353 contract_pre_silu!();
354 // Swish: x * sigmoid(x) = x / (1 + exp(-x))
355 for (i, &x) in a.iter().enumerate() {
356 if x < -50.0 {
357 result[i] = 0.0; // x * 0 = 0
358 } else if x > 50.0 {
359 result[i] = x; // x * 1 = x
360 } else {
361 let sigmoid = 1.0 / (1.0 + (-x).exp());
362 result[i] = x * sigmoid;
363 }
364 }
365 contract_post_silu!(result);
366 }
367
368 // SAFETY: This function is safe because:
369 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
370 // 2. No raw pointer arithmetic is performed
371 // 3. Marked unsafe only to match VectorBackend trait interface
372 unsafe fn tanh(a: &[f32], result: &mut [f32]) {
373 // tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)
374 for (i, &x) in a.iter().enumerate() {
375 result[i] = x.tanh();
376 }
377 }
378
379 // SAFETY: This function is safe because:
380 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
381 // 2. No raw pointer arithmetic is performed
382 // 3. Marked unsafe only to match VectorBackend trait interface
383 unsafe fn sqrt(a: &[f32], result: &mut [f32]) {
384 for (i, &val) in a.iter().enumerate() {
385 result[i] = val.sqrt();
386 }
387 }
388
389 // SAFETY: This function is safe because:
390 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
391 // 2. No raw pointer arithmetic is performed
392 // 3. Marked unsafe only to match VectorBackend trait interface
393 unsafe fn recip(a: &[f32], result: &mut [f32]) {
394 for (i, &val) in a.iter().enumerate() {
395 result[i] = val.recip();
396 }
397 }
398
399 // SAFETY: This function is safe because:
400 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
401 // 2. No raw pointer arithmetic is performed
402 // 3. Marked unsafe only to match VectorBackend trait interface
403 unsafe fn ln(a: &[f32], result: &mut [f32]) {
404 for (i, &val) in a.iter().enumerate() {
405 result[i] = val.ln();
406 }
407 }
408
409 // SAFETY: This function is safe because:
410 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
411 // 2. No raw pointer arithmetic is performed
412 // 3. Marked unsafe only to match VectorBackend trait interface
413 unsafe fn log2(a: &[f32], result: &mut [f32]) {
414 for (i, &val) in a.iter().enumerate() {
415 result[i] = val.log2();
416 }
417 }
418
419 // SAFETY: This function is safe because:
420 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
421 // 2. No raw pointer arithmetic is performed
422 // 3. Marked unsafe only to match VectorBackend trait interface
423 unsafe fn log10(a: &[f32], result: &mut [f32]) {
424 for (i, &val) in a.iter().enumerate() {
425 result[i] = val.log10();
426 }
427 }
428
429 // SAFETY: This function is safe because:
430 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
431 // 2. No raw pointer arithmetic is performed
432 // 3. Marked unsafe only to match VectorBackend trait interface
433 unsafe fn sin(a: &[f32], result: &mut [f32]) {
434 for (i, &val) in a.iter().enumerate() {
435 result[i] = val.sin();
436 }
437 }
438
439 // SAFETY: This function is safe because:
440 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
441 // 2. No raw pointer arithmetic is performed
442 // 3. Marked unsafe only to match VectorBackend trait interface
443 unsafe fn cos(a: &[f32], result: &mut [f32]) {
444 for (i, &val) in a.iter().enumerate() {
445 result[i] = val.cos();
446 }
447 }
448
449 // SAFETY: This function is safe because:
450 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
451 // 2. No raw pointer arithmetic is performed
452 // 3. Marked unsafe only to match VectorBackend trait interface
453 unsafe fn tan(a: &[f32], result: &mut [f32]) {
454 for (i, &val) in a.iter().enumerate() {
455 result[i] = val.tan();
456 }
457 }
458
459 // SAFETY: This function is safe because:
460 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
461 // 2. No raw pointer arithmetic is performed
462 // 3. Marked unsafe only to match VectorBackend trait interface
463 unsafe fn floor(a: &[f32], result: &mut [f32]) {
464 for (i, &val) in a.iter().enumerate() {
465 result[i] = val.floor();
466 }
467 }
468
469 // SAFETY: This function is safe because:
470 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
471 // 2. No raw pointer arithmetic is performed
472 // 3. Marked unsafe only to match VectorBackend trait interface
473 unsafe fn ceil(a: &[f32], result: &mut [f32]) {
474 for (i, &val) in a.iter().enumerate() {
475 result[i] = val.ceil();
476 }
477 }
478
479 // SAFETY: This function is safe because:
480 // 1. All slice accesses are bounds-checked by Rust iterator/enumerate
481 // 2. No raw pointer arithmetic is performed
482 // 3. Marked unsafe only to match VectorBackend trait interface
483 unsafe fn round(a: &[f32], result: &mut [f32]) {
484 for (i, &val) in a.iter().enumerate() {
485 result[i] = val.round();
486 }
487 }
488}
489
490#[cfg(test)]
491mod tests;