ringkernel_cpu/
simd.rs

1//! SIMD-accelerated operations for CPU backend.
2//!
3//! This module provides high-performance implementations of common GPU-like
4//! operations using SIMD (Single Instruction, Multiple Data) instructions.
5//!
6//! # Operations
7//!
8//! - **Vector Operations**: SAXPY, dot product, element-wise operations
9//! - **Reductions**: Sum, min, max, mean
10//! - **Stencil Operations**: 2D/3D Laplacian for FDTD simulations
11//! - **Array Operations**: Fill, copy, compare
12//!
13//! # Example
14//!
15//! ```
16//! use ringkernel_cpu::simd::SimdOps;
17//!
18//! // SAXPY: y = a * x + y
19//! let x = vec![1.0f32; 1024];
20//! let mut y = vec![2.0f32; 1024];
21//! SimdOps::saxpy(2.0, &x, &mut y);
22//!
23//! // Reduction
24//! let sum = SimdOps::sum_f32(&y);
25//! ```
26
27use rayon::prelude::*;
28use wide::{f32x8, f64x4, i32x8};
29
30/// SIMD-accelerated operations.
31pub struct SimdOps;
32
33// ============================================================================
34// VECTOR OPERATIONS
35// ============================================================================
36
37impl SimdOps {
38    /// SAXPY: y = a * x + y (f32)
39    ///
40    /// Single-precision A*X Plus Y operation, fundamental to linear algebra.
41    #[inline]
42    pub fn saxpy(a: f32, x: &[f32], y: &mut [f32]) {
43        let n = x.len().min(y.len());
44        let a_vec = f32x8::splat(a);
45
46        // Process 8 elements at a time
47        let chunks = n / 8;
48        let remainder = n % 8;
49
50        for i in 0..chunks {
51            let offset = i * 8;
52            let x_vec = f32x8::new([
53                x[offset],
54                x[offset + 1],
55                x[offset + 2],
56                x[offset + 3],
57                x[offset + 4],
58                x[offset + 5],
59                x[offset + 6],
60                x[offset + 7],
61            ]);
62            let y_vec = f32x8::new([
63                y[offset],
64                y[offset + 1],
65                y[offset + 2],
66                y[offset + 3],
67                y[offset + 4],
68                y[offset + 5],
69                y[offset + 6],
70                y[offset + 7],
71            ]);
72
73            let result = a_vec * x_vec + y_vec;
74            let arr: [f32; 8] = result.into();
75            y[offset..offset + 8].copy_from_slice(&arr);
76        }
77
78        // Handle remainder
79        let tail_start = chunks * 8;
80        for i in 0..remainder {
81            y[tail_start + i] += a * x[tail_start + i];
82        }
83    }
84
85    /// DAXPY: y = a * x + y (f64)
86    ///
87    /// Double-precision A*X Plus Y operation.
88    #[inline]
89    pub fn daxpy(a: f64, x: &[f64], y: &mut [f64]) {
90        let n = x.len().min(y.len());
91        let a_vec = f64x4::splat(a);
92
93        // Process 4 elements at a time
94        let chunks = n / 4;
95        let remainder = n % 4;
96
97        for i in 0..chunks {
98            let offset = i * 4;
99            let x_vec = f64x4::new([x[offset], x[offset + 1], x[offset + 2], x[offset + 3]]);
100            let y_vec = f64x4::new([y[offset], y[offset + 1], y[offset + 2], y[offset + 3]]);
101
102            let result = a_vec * x_vec + y_vec;
103            let arr: [f64; 4] = result.into();
104            y[offset..offset + 4].copy_from_slice(&arr);
105        }
106
107        // Handle remainder
108        let tail_start = chunks * 4;
109        for i in 0..remainder {
110            y[tail_start + i] += a * x[tail_start + i];
111        }
112    }
113
114    /// Element-wise addition: z = x + y
115    #[inline]
116    pub fn add_f32(x: &[f32], y: &[f32], z: &mut [f32]) {
117        let n = x.len().min(y.len()).min(z.len());
118        let chunks = n / 8;
119        let remainder = n % 8;
120
121        for i in 0..chunks {
122            let offset = i * 8;
123            let x_vec = f32x8::new([
124                x[offset],
125                x[offset + 1],
126                x[offset + 2],
127                x[offset + 3],
128                x[offset + 4],
129                x[offset + 5],
130                x[offset + 6],
131                x[offset + 7],
132            ]);
133            let y_vec = f32x8::new([
134                y[offset],
135                y[offset + 1],
136                y[offset + 2],
137                y[offset + 3],
138                y[offset + 4],
139                y[offset + 5],
140                y[offset + 6],
141                y[offset + 7],
142            ]);
143
144            let result = x_vec + y_vec;
145            let arr: [f32; 8] = result.into();
146            z[offset..offset + 8].copy_from_slice(&arr);
147        }
148
149        let tail_start = chunks * 8;
150        for i in 0..remainder {
151            z[tail_start + i] = x[tail_start + i] + y[tail_start + i];
152        }
153    }
154
155    /// Element-wise subtraction: z = x - y
156    #[inline]
157    pub fn sub_f32(x: &[f32], y: &[f32], z: &mut [f32]) {
158        let n = x.len().min(y.len()).min(z.len());
159        let chunks = n / 8;
160        let remainder = n % 8;
161
162        for i in 0..chunks {
163            let offset = i * 8;
164            let x_vec = f32x8::new([
165                x[offset],
166                x[offset + 1],
167                x[offset + 2],
168                x[offset + 3],
169                x[offset + 4],
170                x[offset + 5],
171                x[offset + 6],
172                x[offset + 7],
173            ]);
174            let y_vec = f32x8::new([
175                y[offset],
176                y[offset + 1],
177                y[offset + 2],
178                y[offset + 3],
179                y[offset + 4],
180                y[offset + 5],
181                y[offset + 6],
182                y[offset + 7],
183            ]);
184
185            let result = x_vec - y_vec;
186            let arr: [f32; 8] = result.into();
187            z[offset..offset + 8].copy_from_slice(&arr);
188        }
189
190        let tail_start = chunks * 8;
191        for i in 0..remainder {
192            z[tail_start + i] = x[tail_start + i] - y[tail_start + i];
193        }
194    }
195
196    /// Element-wise multiplication: z = x * y
197    #[inline]
198    pub fn mul_f32(x: &[f32], y: &[f32], z: &mut [f32]) {
199        let n = x.len().min(y.len()).min(z.len());
200        let chunks = n / 8;
201        let remainder = n % 8;
202
203        for i in 0..chunks {
204            let offset = i * 8;
205            let x_vec = f32x8::new([
206                x[offset],
207                x[offset + 1],
208                x[offset + 2],
209                x[offset + 3],
210                x[offset + 4],
211                x[offset + 5],
212                x[offset + 6],
213                x[offset + 7],
214            ]);
215            let y_vec = f32x8::new([
216                y[offset],
217                y[offset + 1],
218                y[offset + 2],
219                y[offset + 3],
220                y[offset + 4],
221                y[offset + 5],
222                y[offset + 6],
223                y[offset + 7],
224            ]);
225
226            let result = x_vec * y_vec;
227            let arr: [f32; 8] = result.into();
228            z[offset..offset + 8].copy_from_slice(&arr);
229        }
230
231        let tail_start = chunks * 8;
232        for i in 0..remainder {
233            z[tail_start + i] = x[tail_start + i] * y[tail_start + i];
234        }
235    }
236
237    /// Dot product: sum(x * y)
238    #[inline]
239    pub fn dot_f32(x: &[f32], y: &[f32]) -> f32 {
240        let n = x.len().min(y.len());
241        let chunks = n / 8;
242        let remainder = n % 8;
243
244        let mut acc = f32x8::splat(0.0);
245
246        for i in 0..chunks {
247            let offset = i * 8;
248            let x_vec = f32x8::new([
249                x[offset],
250                x[offset + 1],
251                x[offset + 2],
252                x[offset + 3],
253                x[offset + 4],
254                x[offset + 5],
255                x[offset + 6],
256                x[offset + 7],
257            ]);
258            let y_vec = f32x8::new([
259                y[offset],
260                y[offset + 1],
261                y[offset + 2],
262                y[offset + 3],
263                y[offset + 4],
264                y[offset + 5],
265                y[offset + 6],
266                y[offset + 7],
267            ]);
268
269            acc += x_vec * y_vec;
270        }
271
272        // Horizontal sum
273        let arr: [f32; 8] = acc.into();
274        let mut sum: f32 = arr.iter().sum();
275
276        // Handle remainder
277        let tail_start = chunks * 8;
278        for i in 0..remainder {
279            sum += x[tail_start + i] * y[tail_start + i];
280        }
281
282        sum
283    }
284
285    /// Scale vector: x *= a
286    #[inline]
287    pub fn scale_f32(a: f32, x: &mut [f32]) {
288        let n = x.len();
289        let a_vec = f32x8::splat(a);
290        let chunks = n / 8;
291        let remainder = n % 8;
292
293        for i in 0..chunks {
294            let offset = i * 8;
295            let x_vec = f32x8::new([
296                x[offset],
297                x[offset + 1],
298                x[offset + 2],
299                x[offset + 3],
300                x[offset + 4],
301                x[offset + 5],
302                x[offset + 6],
303                x[offset + 7],
304            ]);
305
306            let result = a_vec * x_vec;
307            let arr: [f32; 8] = result.into();
308            x[offset..offset + 8].copy_from_slice(&arr);
309        }
310
311        let tail_start = chunks * 8;
312        for i in 0..remainder {
313            x[tail_start + i] *= a;
314        }
315    }
316}
317
318// ============================================================================
319// REDUCTION OPERATIONS
320// ============================================================================
321
322impl SimdOps {
323    /// Sum of f32 array using SIMD.
324    #[inline]
325    pub fn sum_f32(x: &[f32]) -> f32 {
326        let n = x.len();
327        let chunks = n / 8;
328        let remainder = n % 8;
329
330        let mut acc = f32x8::splat(0.0);
331
332        for i in 0..chunks {
333            let offset = i * 8;
334            let x_vec = f32x8::new([
335                x[offset],
336                x[offset + 1],
337                x[offset + 2],
338                x[offset + 3],
339                x[offset + 4],
340                x[offset + 5],
341                x[offset + 6],
342                x[offset + 7],
343            ]);
344            acc += x_vec;
345        }
346
347        let arr: [f32; 8] = acc.into();
348        let mut sum: f32 = arr.iter().sum();
349
350        let tail_start = chunks * 8;
351        for i in 0..remainder {
352            sum += x[tail_start + i];
353        }
354
355        sum
356    }
357
358    /// Sum of f64 array using SIMD.
359    #[inline]
360    pub fn sum_f64(x: &[f64]) -> f64 {
361        let n = x.len();
362        let chunks = n / 4;
363        let remainder = n % 4;
364
365        let mut acc = f64x4::splat(0.0);
366
367        for i in 0..chunks {
368            let offset = i * 4;
369            let x_vec = f64x4::new([x[offset], x[offset + 1], x[offset + 2], x[offset + 3]]);
370            acc += x_vec;
371        }
372
373        let arr: [f64; 4] = acc.into();
374        let mut sum: f64 = arr.iter().sum();
375
376        let tail_start = chunks * 4;
377        for i in 0..remainder {
378            sum += x[tail_start + i];
379        }
380
381        sum
382    }
383
384    /// Maximum of f32 array.
385    #[inline]
386    pub fn max_f32(x: &[f32]) -> f32 {
387        if x.is_empty() {
388            return f32::NEG_INFINITY;
389        }
390
391        let n = x.len();
392        let chunks = n / 8;
393        let remainder = n % 8;
394
395        let mut max_vec = f32x8::splat(f32::NEG_INFINITY);
396
397        for i in 0..chunks {
398            let offset = i * 8;
399            let x_vec = f32x8::new([
400                x[offset],
401                x[offset + 1],
402                x[offset + 2],
403                x[offset + 3],
404                x[offset + 4],
405                x[offset + 5],
406                x[offset + 6],
407                x[offset + 7],
408            ]);
409            max_vec = max_vec.max(x_vec);
410        }
411
412        let arr: [f32; 8] = max_vec.into();
413        let mut max_val = arr.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
414
415        let tail_start = chunks * 8;
416        for i in 0..remainder {
417            max_val = max_val.max(x[tail_start + i]);
418        }
419
420        max_val
421    }
422
423    /// Minimum of f32 array.
424    #[inline]
425    pub fn min_f32(x: &[f32]) -> f32 {
426        if x.is_empty() {
427            return f32::INFINITY;
428        }
429
430        let n = x.len();
431        let chunks = n / 8;
432        let remainder = n % 8;
433
434        let mut min_vec = f32x8::splat(f32::INFINITY);
435
436        for i in 0..chunks {
437            let offset = i * 8;
438            let x_vec = f32x8::new([
439                x[offset],
440                x[offset + 1],
441                x[offset + 2],
442                x[offset + 3],
443                x[offset + 4],
444                x[offset + 5],
445                x[offset + 6],
446                x[offset + 7],
447            ]);
448            min_vec = min_vec.min(x_vec);
449        }
450
451        let arr: [f32; 8] = min_vec.into();
452        let mut min_val = arr.iter().cloned().fold(f32::INFINITY, f32::min);
453
454        let tail_start = chunks * 8;
455        for i in 0..remainder {
456            min_val = min_val.min(x[tail_start + i]);
457        }
458
459        min_val
460    }
461
462    /// Mean of f32 array.
463    #[inline]
464    pub fn mean_f32(x: &[f32]) -> f32 {
465        if x.is_empty() {
466            return 0.0;
467        }
468        Self::sum_f32(x) / x.len() as f32
469    }
470}
471
472// ============================================================================
473// STENCIL OPERATIONS
474// ============================================================================
475
476impl SimdOps {
477    /// 2D Laplacian stencil (5-point).
478    ///
479    /// Computes: laplacian[i,j] = p[i-1,j] + p[i+1,j] + p[i,j-1] + p[i,j+1] - 4*p[i,j]
480    ///
481    /// This is the core operation for FDTD wave simulations.
482    #[inline]
483    pub fn laplacian_2d_f32(p: &[f32], laplacian: &mut [f32], width: usize, height: usize) {
484        let four = f32x8::splat(4.0);
485
486        // Skip boundary cells (halo of 1)
487        for y in 1..height - 1 {
488            let row_start = y * width;
489            let row_above = (y - 1) * width;
490            let row_below = (y + 1) * width;
491
492            // Process 8 cells at a time
493            let inner_width = width - 2;
494            let chunks = inner_width / 8;
495            let remainder = inner_width % 8;
496
497            for chunk in 0..chunks {
498                let x = 1 + chunk * 8;
499                let idx = row_start + x;
500
501                // Center
502                let center = f32x8::new([
503                    p[idx],
504                    p[idx + 1],
505                    p[idx + 2],
506                    p[idx + 3],
507                    p[idx + 4],
508                    p[idx + 5],
509                    p[idx + 6],
510                    p[idx + 7],
511                ]);
512
513                // North (y - 1)
514                let north_idx = row_above + x;
515                let north = f32x8::new([
516                    p[north_idx],
517                    p[north_idx + 1],
518                    p[north_idx + 2],
519                    p[north_idx + 3],
520                    p[north_idx + 4],
521                    p[north_idx + 5],
522                    p[north_idx + 6],
523                    p[north_idx + 7],
524                ]);
525
526                // South (y + 1)
527                let south_idx = row_below + x;
528                let south = f32x8::new([
529                    p[south_idx],
530                    p[south_idx + 1],
531                    p[south_idx + 2],
532                    p[south_idx + 3],
533                    p[south_idx + 4],
534                    p[south_idx + 5],
535                    p[south_idx + 6],
536                    p[south_idx + 7],
537                ]);
538
539                // West (x - 1)
540                let west = f32x8::new([
541                    p[idx - 1],
542                    p[idx],
543                    p[idx + 1],
544                    p[idx + 2],
545                    p[idx + 3],
546                    p[idx + 4],
547                    p[idx + 5],
548                    p[idx + 6],
549                ]);
550
551                // East (x + 1)
552                let east = f32x8::new([
553                    p[idx + 1],
554                    p[idx + 2],
555                    p[idx + 3],
556                    p[idx + 4],
557                    p[idx + 5],
558                    p[idx + 6],
559                    p[idx + 7],
560                    p[idx + 8],
561                ]);
562
563                // Laplacian = north + south + west + east - 4 * center
564                let result = north + south + west + east - four * center;
565                let arr: [f32; 8] = result.into();
566                laplacian[idx..idx + 8].copy_from_slice(&arr);
567            }
568
569            // Handle remainder
570            let tail_start = 1 + chunks * 8;
571            for i in 0..remainder {
572                let x = tail_start + i;
573                let idx = row_start + x;
574                laplacian[idx] =
575                    p[row_above + x] + p[row_below + x] + p[idx - 1] + p[idx + 1] - 4.0 * p[idx];
576            }
577        }
578    }
579
580    /// 2D FDTD wave equation step.
581    ///
582    /// Computes: p_next[i,j] = 2*p[i,j] - p_prev[i,j] + c2 * laplacian(p)[i,j]
583    ///
584    /// This is a complete wave simulation timestep.
585    #[inline]
586    pub fn fdtd_step_2d_f32(p: &[f32], p_prev: &mut [f32], c2: f32, width: usize, height: usize) {
587        let two = f32x8::splat(2.0);
588        let four = f32x8::splat(4.0);
589        let c2_vec = f32x8::splat(c2);
590
591        for y in 1..height - 1 {
592            let row_start = y * width;
593            let row_above = (y - 1) * width;
594            let row_below = (y + 1) * width;
595
596            let inner_width = width - 2;
597            let chunks = inner_width / 8;
598            let remainder = inner_width % 8;
599
600            for chunk in 0..chunks {
601                let x = 1 + chunk * 8;
602                let idx = row_start + x;
603
604                let center = f32x8::new([
605                    p[idx],
606                    p[idx + 1],
607                    p[idx + 2],
608                    p[idx + 3],
609                    p[idx + 4],
610                    p[idx + 5],
611                    p[idx + 6],
612                    p[idx + 7],
613                ]);
614
615                let prev = f32x8::new([
616                    p_prev[idx],
617                    p_prev[idx + 1],
618                    p_prev[idx + 2],
619                    p_prev[idx + 3],
620                    p_prev[idx + 4],
621                    p_prev[idx + 5],
622                    p_prev[idx + 6],
623                    p_prev[idx + 7],
624                ]);
625
626                let north_idx = row_above + x;
627                let north = f32x8::new([
628                    p[north_idx],
629                    p[north_idx + 1],
630                    p[north_idx + 2],
631                    p[north_idx + 3],
632                    p[north_idx + 4],
633                    p[north_idx + 5],
634                    p[north_idx + 6],
635                    p[north_idx + 7],
636                ]);
637
638                let south_idx = row_below + x;
639                let south = f32x8::new([
640                    p[south_idx],
641                    p[south_idx + 1],
642                    p[south_idx + 2],
643                    p[south_idx + 3],
644                    p[south_idx + 4],
645                    p[south_idx + 5],
646                    p[south_idx + 6],
647                    p[south_idx + 7],
648                ]);
649
650                let west = f32x8::new([
651                    p[idx - 1],
652                    p[idx],
653                    p[idx + 1],
654                    p[idx + 2],
655                    p[idx + 3],
656                    p[idx + 4],
657                    p[idx + 5],
658                    p[idx + 6],
659                ]);
660
661                let east = f32x8::new([
662                    p[idx + 1],
663                    p[idx + 2],
664                    p[idx + 3],
665                    p[idx + 4],
666                    p[idx + 5],
667                    p[idx + 6],
668                    p[idx + 7],
669                    p[idx + 8],
670                ]);
671
672                let laplacian = north + south + west + east - four * center;
673                let result = two * center - prev + c2_vec * laplacian;
674
675                let arr: [f32; 8] = result.into();
676                p_prev[idx..idx + 8].copy_from_slice(&arr);
677            }
678
679            let tail_start = 1 + chunks * 8;
680            for i in 0..remainder {
681                let x = tail_start + i;
682                let idx = row_start + x;
683                let laplacian =
684                    p[row_above + x] + p[row_below + x] + p[idx - 1] + p[idx + 1] - 4.0 * p[idx];
685                p_prev[idx] = 2.0 * p[idx] - p_prev[idx] + c2 * laplacian;
686            }
687        }
688    }
689
690    /// 3D Laplacian stencil (7-point).
691    ///
692    /// Computes the 3D discrete Laplacian for volumetric simulations.
693    #[inline]
694    pub fn laplacian_3d_f32(
695        p: &[f32],
696        laplacian: &mut [f32],
697        width: usize,
698        height: usize,
699        depth: usize,
700    ) {
701        let stride_y = width;
702        let stride_z = width * height;
703        let six = f32x8::splat(6.0);
704
705        for z in 1..depth - 1 {
706            for y in 1..height - 1 {
707                let row_start = z * stride_z + y * stride_y;
708                let inner_width = width - 2;
709                let chunks = inner_width / 8;
710                let remainder = inner_width % 8;
711
712                for chunk in 0..chunks {
713                    let x = 1 + chunk * 8;
714                    let idx = row_start + x;
715
716                    let center = f32x8::new([
717                        p[idx],
718                        p[idx + 1],
719                        p[idx + 2],
720                        p[idx + 3],
721                        p[idx + 4],
722                        p[idx + 5],
723                        p[idx + 6],
724                        p[idx + 7],
725                    ]);
726
727                    // X neighbors
728                    let west = f32x8::new([
729                        p[idx - 1],
730                        p[idx],
731                        p[idx + 1],
732                        p[idx + 2],
733                        p[idx + 3],
734                        p[idx + 4],
735                        p[idx + 5],
736                        p[idx + 6],
737                    ]);
738                    let east = f32x8::new([
739                        p[idx + 1],
740                        p[idx + 2],
741                        p[idx + 3],
742                        p[idx + 4],
743                        p[idx + 5],
744                        p[idx + 6],
745                        p[idx + 7],
746                        p[idx + 8],
747                    ]);
748
749                    // Y neighbors
750                    let north_idx = idx - stride_y;
751                    let south_idx = idx + stride_y;
752                    let north = f32x8::new([
753                        p[north_idx],
754                        p[north_idx + 1],
755                        p[north_idx + 2],
756                        p[north_idx + 3],
757                        p[north_idx + 4],
758                        p[north_idx + 5],
759                        p[north_idx + 6],
760                        p[north_idx + 7],
761                    ]);
762                    let south = f32x8::new([
763                        p[south_idx],
764                        p[south_idx + 1],
765                        p[south_idx + 2],
766                        p[south_idx + 3],
767                        p[south_idx + 4],
768                        p[south_idx + 5],
769                        p[south_idx + 6],
770                        p[south_idx + 7],
771                    ]);
772
773                    // Z neighbors
774                    let up_idx = idx - stride_z;
775                    let down_idx = idx + stride_z;
776                    let up = f32x8::new([
777                        p[up_idx],
778                        p[up_idx + 1],
779                        p[up_idx + 2],
780                        p[up_idx + 3],
781                        p[up_idx + 4],
782                        p[up_idx + 5],
783                        p[up_idx + 6],
784                        p[up_idx + 7],
785                    ]);
786                    let down = f32x8::new([
787                        p[down_idx],
788                        p[down_idx + 1],
789                        p[down_idx + 2],
790                        p[down_idx + 3],
791                        p[down_idx + 4],
792                        p[down_idx + 5],
793                        p[down_idx + 6],
794                        p[down_idx + 7],
795                    ]);
796
797                    let result = west + east + north + south + up + down - six * center;
798                    let arr: [f32; 8] = result.into();
799                    laplacian[idx..idx + 8].copy_from_slice(&arr);
800                }
801
802                let tail_start = 1 + chunks * 8;
803                for i in 0..remainder {
804                    let x = tail_start + i;
805                    let idx = row_start + x;
806                    laplacian[idx] = p[idx - 1]
807                        + p[idx + 1]
808                        + p[idx - stride_y]
809                        + p[idx + stride_y]
810                        + p[idx - stride_z]
811                        + p[idx + stride_z]
812                        - 6.0 * p[idx];
813                }
814            }
815        }
816    }
817}
818
819// ============================================================================
820// PARALLEL OPERATIONS (SIMD + Rayon)
821// ============================================================================
822
823impl SimdOps {
824    /// Parallel SAXPY using Rayon + SIMD.
825    ///
826    /// Best for large arrays (> 100K elements).
827    pub fn par_saxpy(a: f32, x: &[f32], y: &mut [f32]) {
828        const CHUNK_SIZE: usize = 4096;
829
830        y.par_chunks_mut(CHUNK_SIZE)
831            .zip(x.par_chunks(CHUNK_SIZE))
832            .for_each(|(y_chunk, x_chunk)| {
833                Self::saxpy(a, x_chunk, y_chunk);
834            });
835    }
836
837    /// Parallel sum using Rayon + SIMD.
838    pub fn par_sum_f32(x: &[f32]) -> f32 {
839        const CHUNK_SIZE: usize = 4096;
840
841        x.par_chunks(CHUNK_SIZE).map(Self::sum_f32).sum()
842    }
843
844    /// Parallel 2D FDTD step using Rayon + SIMD.
845    ///
846    /// Parallelizes over rows for better cache efficiency.
847    pub fn par_fdtd_step_2d_f32(
848        p: &[f32],
849        p_prev: &mut [f32],
850        c2: f32,
851        width: usize,
852        height: usize,
853    ) {
854        // Each row can be processed independently
855        p_prev
856            .par_chunks_mut(width)
857            .enumerate()
858            .skip(1)
859            .take(height - 2)
860            .for_each(|(y, row)| {
861                let row_above = (y - 1) * width;
862                let row_below = (y + 1) * width;
863                let row_start = y * width;
864
865                let two = f32x8::splat(2.0);
866                let four = f32x8::splat(4.0);
867                let c2_vec = f32x8::splat(c2);
868
869                let inner_width = width - 2;
870                let chunks = inner_width / 8;
871                let remainder = inner_width % 8;
872
873                for chunk in 0..chunks {
874                    let x = 1 + chunk * 8;
875                    let idx = row_start + x;
876                    let local_x = x;
877
878                    let center = f32x8::new([
879                        p[idx],
880                        p[idx + 1],
881                        p[idx + 2],
882                        p[idx + 3],
883                        p[idx + 4],
884                        p[idx + 5],
885                        p[idx + 6],
886                        p[idx + 7],
887                    ]);
888
889                    let prev = f32x8::new([
890                        row[local_x],
891                        row[local_x + 1],
892                        row[local_x + 2],
893                        row[local_x + 3],
894                        row[local_x + 4],
895                        row[local_x + 5],
896                        row[local_x + 6],
897                        row[local_x + 7],
898                    ]);
899
900                    let north_idx = row_above + x;
901                    let north = f32x8::new([
902                        p[north_idx],
903                        p[north_idx + 1],
904                        p[north_idx + 2],
905                        p[north_idx + 3],
906                        p[north_idx + 4],
907                        p[north_idx + 5],
908                        p[north_idx + 6],
909                        p[north_idx + 7],
910                    ]);
911
912                    let south_idx = row_below + x;
913                    let south = f32x8::new([
914                        p[south_idx],
915                        p[south_idx + 1],
916                        p[south_idx + 2],
917                        p[south_idx + 3],
918                        p[south_idx + 4],
919                        p[south_idx + 5],
920                        p[south_idx + 6],
921                        p[south_idx + 7],
922                    ]);
923
924                    let west = f32x8::new([
925                        p[idx - 1],
926                        p[idx],
927                        p[idx + 1],
928                        p[idx + 2],
929                        p[idx + 3],
930                        p[idx + 4],
931                        p[idx + 5],
932                        p[idx + 6],
933                    ]);
934
935                    let east = f32x8::new([
936                        p[idx + 1],
937                        p[idx + 2],
938                        p[idx + 3],
939                        p[idx + 4],
940                        p[idx + 5],
941                        p[idx + 6],
942                        p[idx + 7],
943                        p[idx + 8],
944                    ]);
945
946                    let laplacian = north + south + west + east - four * center;
947                    let result = two * center - prev + c2_vec * laplacian;
948
949                    let arr: [f32; 8] = result.into();
950                    row[local_x..local_x + 8].copy_from_slice(&arr);
951                }
952
953                let tail_start = 1 + chunks * 8;
954                for i in 0..remainder {
955                    let x = tail_start + i;
956                    let idx = row_start + x;
957                    let laplacian = p[row_above + x] + p[row_below + x] + p[idx - 1] + p[idx + 1]
958                        - 4.0 * p[idx];
959                    row[x] = 2.0 * p[idx] - row[x] + c2 * laplacian;
960                }
961            });
962    }
963}
964
965// ============================================================================
966// INTEGER OPERATIONS
967// ============================================================================
968
969impl SimdOps {
970    /// Sum of i32 array using SIMD.
971    #[inline]
972    pub fn sum_i32(x: &[i32]) -> i64 {
973        let n = x.len();
974        let chunks = n / 8;
975        let remainder = n % 8;
976
977        let mut acc = i32x8::splat(0);
978
979        for i in 0..chunks {
980            let offset = i * 8;
981            let x_vec = i32x8::new([
982                x[offset],
983                x[offset + 1],
984                x[offset + 2],
985                x[offset + 3],
986                x[offset + 4],
987                x[offset + 5],
988                x[offset + 6],
989                x[offset + 7],
990            ]);
991            acc += x_vec;
992        }
993
994        let arr: [i32; 8] = acc.into();
995        let mut sum: i64 = arr.iter().map(|&v| v as i64).sum();
996
997        let tail_start = chunks * 8;
998        for i in 0..remainder {
999            sum += x[tail_start + i] as i64;
1000        }
1001
1002        sum
1003    }
1004
1005    /// Element-wise i32 addition.
1006    #[inline]
1007    pub fn add_i32(x: &[i32], y: &[i32], z: &mut [i32]) {
1008        let n = x.len().min(y.len()).min(z.len());
1009        let chunks = n / 8;
1010        let remainder = n % 8;
1011
1012        for i in 0..chunks {
1013            let offset = i * 8;
1014            let x_vec = i32x8::new([
1015                x[offset],
1016                x[offset + 1],
1017                x[offset + 2],
1018                x[offset + 3],
1019                x[offset + 4],
1020                x[offset + 5],
1021                x[offset + 6],
1022                x[offset + 7],
1023            ]);
1024            let y_vec = i32x8::new([
1025                y[offset],
1026                y[offset + 1],
1027                y[offset + 2],
1028                y[offset + 3],
1029                y[offset + 4],
1030                y[offset + 5],
1031                y[offset + 6],
1032                y[offset + 7],
1033            ]);
1034
1035            let result = x_vec + y_vec;
1036            let arr: [i32; 8] = result.into();
1037            z[offset..offset + 8].copy_from_slice(&arr);
1038        }
1039
1040        let tail_start = chunks * 8;
1041        for i in 0..remainder {
1042            z[tail_start + i] = x[tail_start + i] + y[tail_start + i];
1043        }
1044    }
1045}
1046
1047// ============================================================================
1048// TESTS
1049// ============================================================================
1050
1051#[cfg(test)]
1052mod tests {
1053    use super::*;
1054
1055    #[test]
1056    fn test_saxpy() {
1057        let x = vec![1.0f32; 100];
1058        let mut y = vec![2.0f32; 100];
1059
1060        SimdOps::saxpy(3.0, &x, &mut y);
1061
1062        for v in y.iter() {
1063            assert!((v - 5.0).abs() < 1e-6, "Expected 5.0, got {}", v);
1064        }
1065    }
1066
1067    #[test]
1068    fn test_saxpy_unaligned() {
1069        let x = vec![1.0f32; 13]; // Not divisible by 8
1070        let mut y = vec![2.0f32; 13];
1071
1072        SimdOps::saxpy(2.0, &x, &mut y);
1073
1074        for v in y.iter() {
1075            assert!((v - 4.0).abs() < 1e-6);
1076        }
1077    }
1078
1079    #[test]
1080    fn test_daxpy() {
1081        let x = vec![1.0f64; 100];
1082        let mut y = vec![2.0f64; 100];
1083
1084        SimdOps::daxpy(3.0, &x, &mut y);
1085
1086        for v in y.iter() {
1087            assert!((v - 5.0).abs() < 1e-10);
1088        }
1089    }
1090
1091    #[test]
1092    fn test_dot_product() {
1093        let x = vec![1.0f32; 100];
1094        let y = vec![2.0f32; 100];
1095
1096        let dot = SimdOps::dot_f32(&x, &y);
1097        assert!((dot - 200.0).abs() < 1e-4);
1098    }
1099
1100    #[test]
1101    fn test_sum() {
1102        let x = vec![1.0f32; 1000];
1103        let sum = SimdOps::sum_f32(&x);
1104        assert!((sum - 1000.0).abs() < 1e-3);
1105    }
1106
1107    #[test]
1108    fn test_max_min() {
1109        let x = vec![1.0f32, -5.0, 3.0, 7.0, -2.0, 4.0, 6.0, 8.0, -1.0];
1110
1111        let max = SimdOps::max_f32(&x);
1112        let min = SimdOps::min_f32(&x);
1113
1114        assert!((max - 8.0).abs() < 1e-6);
1115        assert!((min - (-5.0)).abs() < 1e-6);
1116    }
1117
1118    #[test]
1119    fn test_laplacian_2d() {
1120        // 5x5 grid
1121        let width = 5;
1122        let height = 5;
1123        let mut p = vec![0.0f32; width * height];
1124
1125        // Set center to 1.0
1126        p[12] = 1.0; // (2, 2)
1127
1128        let mut laplacian = vec![0.0f32; width * height];
1129        SimdOps::laplacian_2d_f32(&p, &mut laplacian, width, height);
1130
1131        // Center should have laplacian of -4
1132        assert!((laplacian[12] - (-4.0)).abs() < 1e-6);
1133
1134        // Neighbors should have laplacian of 1
1135        assert!((laplacian[11] - 1.0).abs() < 1e-6); // (1, 2)
1136        assert!((laplacian[13] - 1.0).abs() < 1e-6); // (3, 2)
1137        assert!((laplacian[7] - 1.0).abs() < 1e-6); // (2, 1)
1138        assert!((laplacian[17] - 1.0).abs() < 1e-6); // (2, 3)
1139    }
1140
1141    #[test]
1142    fn test_fdtd_step_2d() {
1143        let width = 10;
1144        let height = 10;
1145        let mut p = vec![0.0f32; width * height];
1146        let mut p_prev = vec![0.0f32; width * height];
1147
1148        // Initial impulse at center
1149        p[55] = 1.0; // (5, 5)
1150
1151        let c2 = 0.1;
1152        SimdOps::fdtd_step_2d_f32(&p, &mut p_prev, c2, width, height);
1153
1154        // After one step, energy should spread from center
1155        // Center should now be: 2*1 - 0 + 0.1*(-4) = 1.6
1156        assert!((p_prev[55] - 1.6).abs() < 1e-6);
1157    }
1158
1159    #[test]
1160    fn test_par_saxpy() {
1161        let x = vec![1.0f32; 10000];
1162        let mut y = vec![2.0f32; 10000];
1163
1164        SimdOps::par_saxpy(3.0, &x, &mut y);
1165
1166        for v in y.iter() {
1167            assert!((v - 5.0).abs() < 1e-6);
1168        }
1169    }
1170
1171    #[test]
1172    fn test_par_sum() {
1173        let x = vec![1.0f32; 100000];
1174        let sum = SimdOps::par_sum_f32(&x);
1175        assert!((sum - 100000.0).abs() < 1.0); // Allow small floating point error
1176    }
1177
1178    #[test]
1179    fn test_sum_i32() {
1180        let x = vec![1i32; 1000];
1181        let sum = SimdOps::sum_i32(&x);
1182        assert_eq!(sum, 1000);
1183    }
1184
1185    #[test]
1186    fn test_add_vectors() {
1187        let x = vec![1.0f32; 100];
1188        let y = vec![2.0f32; 100];
1189        let mut z = vec![0.0f32; 100];
1190
1191        SimdOps::add_f32(&x, &y, &mut z);
1192
1193        for v in z.iter() {
1194            assert!((v - 3.0).abs() < 1e-6);
1195        }
1196    }
1197
1198    #[test]
1199    fn test_scale() {
1200        let mut x = vec![2.0f32; 100];
1201        SimdOps::scale_f32(3.0, &mut x);
1202
1203        for v in x.iter() {
1204            assert!((v - 6.0).abs() < 1e-6);
1205        }
1206    }
1207}