Skip to main content

oxiphysics_core/
simd_math.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! SIMD-accelerated math kernels for batch vector and particle operations.
6//!
7//! This module provides Structure-of-Arrays (SoA) layouts and batch operations
8//! optimized for CPU cache locality and auto-vectorization. The batch operations
9//! process multiple elements at once, enabling the compiler to emit SIMD
10//! instructions on supported platforms.
11//!
12//! # Layout
13//!
14//! Instead of the traditional Array-of-Structures (AoS) layout:
15//! ```text
16//! [x0,y0,z0, x1,y1,z1, x2,y2,z2, ...]
17//! ```
18//! We use Structure-of-Arrays (SoA):
19//! ```text
20//! xs: [x0, x1, x2, ...]
21//! ys: [y0, y1, y2, ...]
22//! zs: [z0, z1, z2, ...]
23//! ```
24//! This layout allows the compiler to vectorize operations across contiguous
25//! memory, improving throughput for large batches.
26
27use std::f64;
28
29/// Structure-of-Arrays layout for batch Vec3 operations.
30///
31/// Each component (x, y, z) is stored in a separate contiguous vector,
32/// enabling efficient SIMD-style batch processing.
33#[derive(Debug, Clone, PartialEq)]
34pub struct Vec3Batch {
35    /// X components of all vectors in the batch.
36    pub x: Vec<f64>,
37    /// Y components of all vectors in the batch.
38    pub y: Vec<f64>,
39    /// Z components of all vectors in the batch.
40    pub z: Vec<f64>,
41}
42
43/// Errors that can occur in SIMD batch operations.
44#[derive(Debug, Clone, thiserror::Error)]
45pub enum SimdMathError {
46    /// Batch size mismatch between operands.
47    #[error("batch size mismatch: left has {left} elements, right has {right} elements")]
48    SizeMismatch {
49        /// Size of the left operand.
50        left: usize,
51        /// Size of the right operand.
52        right: usize,
53    },
54    /// Inconsistent internal dimensions in a Vec3Batch.
55    #[error("inconsistent Vec3Batch dimensions: x={x_len}, y={y_len}, z={z_len}")]
56    InconsistentDimensions {
57        /// Length of x component.
58        x_len: usize,
59        /// Length of y component.
60        y_len: usize,
61        /// Length of z component.
62        z_len: usize,
63    },
64    /// Zero-length vector encountered where normalization is needed.
65    #[error("cannot normalize zero-length vector at index {index}")]
66    ZeroLengthVector {
67        /// Index of the zero-length vector.
68        index: usize,
69    },
70}
71
72impl Vec3Batch {
73    /// Allocate a new batch with `n` zero-initialized vectors.
74    #[must_use]
75    pub fn new(n: usize) -> Self {
76        Self {
77            x: vec![0.0; n],
78            y: vec![0.0; n],
79            z: vec![0.0; n],
80        }
81    }
82
83    /// Convert an Array-of-Structures slice to Structure-of-Arrays layout.
84    ///
85    /// Each element of `positions` is `[x, y, z]`.
86    #[must_use]
87    pub fn from_aos(positions: &[[f64; 3]]) -> Self {
88        let n = positions.len();
89        let mut x = Vec::with_capacity(n);
90        let mut y = Vec::with_capacity(n);
91        let mut z = Vec::with_capacity(n);
92        for p in positions {
93            x.push(p[0]);
94            y.push(p[1]);
95            z.push(p[2]);
96        }
97        Self { x, y, z }
98    }
99
100    /// Convert back from SoA to AoS layout.
101    ///
102    /// Returns an error if internal dimensions are inconsistent.
103    pub fn to_aos(&self) -> Result<Vec<[f64; 3]>, SimdMathError> {
104        self.validate()?;
105        let n = self.x.len();
106        let mut result = Vec::with_capacity(n);
107        for i in 0..n {
108            result.push([self.x[i], self.y[i], self.z[i]]);
109        }
110        Ok(result)
111    }
112
113    /// Returns the number of vectors in this batch.
114    #[must_use]
115    pub fn len(&self) -> usize {
116        self.x.len()
117    }
118
119    /// Returns true if the batch contains no vectors.
120    #[must_use]
121    pub fn is_empty(&self) -> bool {
122        self.x.is_empty()
123    }
124
125    /// Validate that all component vectors have equal length.
126    fn validate(&self) -> Result<(), SimdMathError> {
127        let (xl, yl, zl) = (self.x.len(), self.y.len(), self.z.len());
128        if xl != yl || yl != zl {
129            return Err(SimdMathError::InconsistentDimensions {
130                x_len: xl,
131                y_len: yl,
132                z_len: zl,
133            });
134        }
135        Ok(())
136    }
137
138    /// Check that two batches have equal size.
139    fn check_size(&self, other: &Self) -> Result<(), SimdMathError> {
140        if self.len() != other.len() {
141            return Err(SimdMathError::SizeMismatch {
142                left: self.len(),
143                right: other.len(),
144            });
145        }
146        Ok(())
147    }
148
149    /// Element-wise addition of two batches.
150    ///
151    /// # Errors
152    /// Returns `SimdMathError::SizeMismatch` if batch sizes differ.
153    pub fn add(&self, other: &Vec3Batch) -> Result<Vec3Batch, SimdMathError> {
154        self.check_size(other)?;
155        let n = self.len();
156        let mut rx = vec![0.0_f64; n];
157        let mut ry = vec![0.0_f64; n];
158        let mut rz = vec![0.0_f64; n];
159
160        // Written as simple loops to encourage auto-vectorization
161        for i in 0..n {
162            rx[i] = self.x[i] + other.x[i];
163        }
164        for i in 0..n {
165            ry[i] = self.y[i] + other.y[i];
166        }
167        for i in 0..n {
168            rz[i] = self.z[i] + other.z[i];
169        }
170
171        Ok(Vec3Batch {
172            x: rx,
173            y: ry,
174            z: rz,
175        })
176    }
177
178    /// Element-wise subtraction of two batches.
179    ///
180    /// # Errors
181    /// Returns `SimdMathError::SizeMismatch` if batch sizes differ.
182    pub fn sub(&self, other: &Vec3Batch) -> Result<Vec3Batch, SimdMathError> {
183        self.check_size(other)?;
184        let n = self.len();
185        let mut rx = vec![0.0_f64; n];
186        let mut ry = vec![0.0_f64; n];
187        let mut rz = vec![0.0_f64; n];
188
189        for i in 0..n {
190            rx[i] = self.x[i] - other.x[i];
191        }
192        for i in 0..n {
193            ry[i] = self.y[i] - other.y[i];
194        }
195        for i in 0..n {
196            rz[i] = self.z[i] - other.z[i];
197        }
198
199        Ok(Vec3Batch {
200            x: rx,
201            y: ry,
202            z: rz,
203        })
204    }
205
206    /// Scale all vectors by a uniform scalar.
207    #[must_use]
208    pub fn scale(&self, s: f64) -> Vec3Batch {
209        let n = self.len();
210        let mut rx = vec![0.0_f64; n];
211        let mut ry = vec![0.0_f64; n];
212        let mut rz = vec![0.0_f64; n];
213
214        for i in 0..n {
215            rx[i] = self.x[i] * s;
216        }
217        for i in 0..n {
218            ry[i] = self.y[i] * s;
219        }
220        for i in 0..n {
221            rz[i] = self.z[i] * s;
222        }
223
224        Vec3Batch {
225            x: rx,
226            y: ry,
227            z: rz,
228        }
229    }
230
231    /// Batch dot product: returns `x[i]*other.x[i] + y[i]*other.y[i] + z[i]*other.z[i]`
232    /// for each `i`.
233    ///
234    /// # Errors
235    /// Returns `SimdMathError::SizeMismatch` if batch sizes differ.
236    pub fn dot(&self, other: &Vec3Batch) -> Result<Vec<f64>, SimdMathError> {
237        self.check_size(other)?;
238        let n = self.len();
239        let mut result = vec![0.0_f64; n];
240
241        // Accumulate component-wise to allow vectorization of each loop
242        for i in 0..n {
243            result[i] = self.x[i] * other.x[i];
244        }
245        for i in 0..n {
246            result[i] += self.y[i] * other.y[i];
247        }
248        for i in 0..n {
249            result[i] += self.z[i] * other.z[i];
250        }
251
252        Ok(result)
253    }
254
255    /// Batch cross product.
256    ///
257    /// For each index `i`, computes `self[i] × other[i]`.
258    ///
259    /// # Errors
260    /// Returns `SimdMathError::SizeMismatch` if batch sizes differ.
261    pub fn cross(&self, other: &Vec3Batch) -> Result<Vec3Batch, SimdMathError> {
262        self.check_size(other)?;
263        let n = self.len();
264        let mut rx = vec![0.0_f64; n];
265        let mut ry = vec![0.0_f64; n];
266        let mut rz = vec![0.0_f64; n];
267
268        // cross.x = self.y * other.z - self.z * other.y
269        for i in 0..n {
270            rx[i] = self.y[i] * other.z[i] - self.z[i] * other.y[i];
271        }
272        // cross.y = self.z * other.x - self.x * other.z
273        for i in 0..n {
274            ry[i] = self.z[i] * other.x[i] - self.x[i] * other.z[i];
275        }
276        // cross.z = self.x * other.y - self.y * other.x
277        for i in 0..n {
278            rz[i] = self.x[i] * other.y[i] - self.y[i] * other.x[i];
279        }
280
281        Ok(Vec3Batch {
282            x: rx,
283            y: ry,
284            z: rz,
285        })
286    }
287
288    /// Batch squared length: `x[i]^2 + y[i]^2 + z[i]^2` for each `i`.
289    #[must_use]
290    pub fn length_sq(&self) -> Vec<f64> {
291        let n = self.len();
292        let mut result = vec![0.0_f64; n];
293
294        for i in 0..n {
295            result[i] = self.x[i] * self.x[i];
296        }
297        for i in 0..n {
298            result[i] += self.y[i] * self.y[i];
299        }
300        for i in 0..n {
301            result[i] += self.z[i] * self.z[i];
302        }
303
304        result
305    }
306
307    /// Batch vector length (Euclidean norm).
308    #[must_use]
309    pub fn length(&self) -> Vec<f64> {
310        let sq = self.length_sq();
311        sq.into_iter().map(f64::sqrt).collect()
312    }
313
314    /// Normalize all vectors in-place to unit length.
315    ///
316    /// Vectors with length below `f64::EPSILON` are left unchanged and their
317    /// indices are collected in the returned error. If all vectors are valid,
318    /// returns `Ok(())`.
319    ///
320    /// # Errors
321    /// Returns `SimdMathError::ZeroLengthVector` for the first zero-length vector found.
322    pub fn normalize(&mut self) -> Result<(), SimdMathError> {
323        let lengths = self.length();
324        let n = self.len();
325
326        // First pass: check for zero-length vectors
327        for (i, &len) in lengths.iter().enumerate() {
328            if len < f64::EPSILON {
329                return Err(SimdMathError::ZeroLengthVector { index: i });
330            }
331        }
332
333        // Second pass: compute reciprocals and scale (vectorization-friendly)
334        let mut inv_lengths = vec![0.0_f64; n];
335        for i in 0..n {
336            inv_lengths[i] = 1.0 / lengths[i];
337        }
338
339        for i in 0..n {
340            self.x[i] *= inv_lengths[i];
341        }
342        for i in 0..n {
343            self.y[i] *= inv_lengths[i];
344        }
345        for i in 0..n {
346            self.z[i] *= inv_lengths[i];
347        }
348
349        Ok(())
350    }
351
352    /// Batch pairwise squared distance: `|a[i] - b[i]|^2` for each `i`.
353    ///
354    /// # Errors
355    /// Returns `SimdMathError::SizeMismatch` if batch sizes differ.
356    pub fn distance_sq_pairwise(a: &Vec3Batch, b: &Vec3Batch) -> Result<Vec<f64>, SimdMathError> {
357        a.check_size(b)?;
358        let n = a.len();
359        let mut result = vec![0.0_f64; n];
360
361        // dx^2
362        for i in 0..n {
363            let dx = a.x[i] - b.x[i];
364            result[i] = dx * dx;
365        }
366        // dy^2
367        for i in 0..n {
368            let dy = a.y[i] - b.y[i];
369            result[i] += dy * dy;
370        }
371        // dz^2
372        for i in 0..n {
373            let dz = a.z[i] - b.z[i];
374            result[i] += dz * dz;
375        }
376
377        Ok(result)
378    }
379}
380
381// ---------------------------------------------------------------------------
382// Batch Particle Operations
383// ---------------------------------------------------------------------------
384
385/// Compute distances from a single reference point to many positions.
386///
387/// Returns the Euclidean distance from `ref_pos` to each position in the batch.
388#[must_use]
389pub fn compute_distances_batch(positions: &Vec3Batch, ref_pos: [f64; 3]) -> Vec<f64> {
390    let n = positions.len();
391    let mut result = vec![0.0_f64; n];
392
393    // dx^2
394    for i in 0..n {
395        let dx = positions.x[i] - ref_pos[0];
396        result[i] = dx * dx;
397    }
398    // dy^2
399    for i in 0..n {
400        let dy = positions.y[i] - ref_pos[1];
401        result[i] += dy * dy;
402    }
403    // dz^2
404    for i in 0..n {
405        let dz = positions.z[i] - ref_pos[2];
406        result[i] += dz * dz;
407    }
408
409    // sqrt
410    for val in &mut result {
411        *val = val.sqrt();
412    }
413
414    result
415}
416
417/// Accumulate forces onto a mutable force batch.
418///
419/// For each index `i`, computes:
420/// ```text
421/// forces[i] += directions[i] * magnitudes[i]
422/// ```
423///
424/// # Errors
425/// Returns `SimdMathError::SizeMismatch` if the sizes of `forces`, `directions`,
426/// or `magnitudes` do not match.
427pub fn accumulate_forces_batch(
428    forces: &mut Vec3Batch,
429    directions: &Vec3Batch,
430    magnitudes: &[f64],
431) -> Result<(), SimdMathError> {
432    let n = forces.len();
433    if n != directions.len() {
434        return Err(SimdMathError::SizeMismatch {
435            left: n,
436            right: directions.len(),
437        });
438    }
439    if n != magnitudes.len() {
440        return Err(SimdMathError::SizeMismatch {
441            left: n,
442            right: magnitudes.len(),
443        });
444    }
445
446    for i in 0..n {
447        forces.x[i] += directions.x[i] * magnitudes[i];
448    }
449    for i in 0..n {
450        forces.y[i] += directions.y[i] * magnitudes[i];
451    }
452    for i in 0..n {
453        forces.z[i] += directions.z[i] * magnitudes[i];
454    }
455
456    Ok(())
457}
458
459/// Evaluate the cubic spline kernel (M4 kernel) for a batch of distances.
460///
461/// The cubic spline kernel is widely used in Smoothed Particle Hydrodynamics (SPH).
462/// It is defined in 3D as:
463///
464/// ```text
465/// W(r, h) = σ * { (2-q)^3 - 4*(1-q)^3   if 0 ≤ q < 1
466///               { (2-q)^3                  if 1 ≤ q < 2
467///               { 0                        if q ≥ 2
468/// ```
469///
470/// where `q = r/h` and `σ = 1/(4π h³)` is the 3D normalization constant.
471///
472/// # Arguments
473/// * `r` - batch of distances
474/// * `h` - smoothing length (must be positive)
475///
476/// # Returns
477/// Vector of kernel values. Returns empty vec if `h` is not positive.
478#[must_use]
479pub fn cubic_spline_kernel_batch(r: &[f64], h: f64) -> Vec<f64> {
480    if h <= 0.0 {
481        return vec![0.0; r.len()];
482    }
483
484    let n = r.len();
485    let inv_h = 1.0 / h;
486    let sigma = 1.0 / (4.0 * f64::consts::PI * h * h * h);
487
488    let mut result = vec![0.0_f64; n];
489    let mut q = vec![0.0_f64; n];
490
491    // Compute q = r/h (vectorizable)
492    for i in 0..n {
493        q[i] = r[i] * inv_h;
494    }
495
496    // Evaluate kernel piecewise
497    for i in 0..n {
498        let qi = q[i];
499        if qi >= 2.0 {
500            result[i] = 0.0;
501        } else if qi >= 1.0 {
502            let t = 2.0 - qi;
503            result[i] = sigma * t * t * t;
504        } else if qi >= 0.0 {
505            let t2 = 2.0 - qi;
506            let t1 = 1.0 - qi;
507            result[i] = sigma * (t2 * t2 * t2 - 4.0 * t1 * t1 * t1);
508        } else {
509            // Negative distance: treat as absolute value
510            let qi_abs = qi.abs();
511            if qi_abs >= 2.0 {
512                result[i] = 0.0;
513            } else if qi_abs >= 1.0 {
514                let t = 2.0 - qi_abs;
515                result[i] = sigma * t * t * t;
516            } else {
517                let t2 = 2.0 - qi_abs;
518                let t1 = 1.0 - qi_abs;
519                result[i] = sigma * (t2 * t2 * t2 - 4.0 * t1 * t1 * t1);
520            }
521        }
522    }
523
524    result
525}
526
527// ---------------------------------------------------------------------------
528// Tests
529// ---------------------------------------------------------------------------
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    const EPSILON: f64 = 1e-12;
536
537    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
538        (a - b).abs() < tol
539    }
540
541    // -----------------------------------------------------------------------
542    // AoS <-> SoA round-trip
543    // -----------------------------------------------------------------------
544
545    #[test]
546    fn test_aos_soa_round_trip() {
547        let positions = vec![
548            [1.0, 2.0, 3.0],
549            [4.0, 5.0, 6.0],
550            [7.0, 8.0, 9.0],
551            [-1.5, 0.0, 3.125],
552        ];
553        let batch = Vec3Batch::from_aos(&positions);
554        assert_eq!(batch.len(), 4);
555        assert_eq!(batch.x, vec![1.0, 4.0, 7.0, -1.5]);
556        assert_eq!(batch.y, vec![2.0, 5.0, 8.0, 0.0]);
557        assert_eq!(batch.z, vec![3.0, 6.0, 9.0, 3.125]);
558
559        let back = batch.to_aos().expect("to_aos should succeed");
560        assert_eq!(back, positions);
561    }
562
563    #[test]
564    fn test_aos_soa_round_trip_empty() {
565        let positions: Vec<[f64; 3]> = vec![];
566        let batch = Vec3Batch::from_aos(&positions);
567        assert!(batch.is_empty());
568        assert_eq!(batch.len(), 0);
569        let back = batch.to_aos().expect("to_aos should succeed for empty");
570        assert!(back.is_empty());
571    }
572
573    // -----------------------------------------------------------------------
574    // Batch dot product matches scalar
575    // -----------------------------------------------------------------------
576
577    #[test]
578    fn test_batch_dot_matches_scalar() {
579        let a_aos = vec![[1.0, 2.0, 3.0], [4.0, -1.0, 2.0], [0.0, 0.0, 1.0]];
580        let b_aos = vec![[3.0, -2.0, 1.0], [1.0, 1.0, 1.0], [0.0, 0.0, 0.0]];
581
582        let a = Vec3Batch::from_aos(&a_aos);
583        let b = Vec3Batch::from_aos(&b_aos);
584
585        let batch_dots = a.dot(&b).expect("dot should succeed");
586
587        // Scalar reference
588        for (i, (&da, &db)) in a_aos.iter().zip(b_aos.iter()).enumerate() {
589            let scalar_dot = da[0] * db[0] + da[1] * db[1] + da[2] * db[2];
590            assert!(
591                approx_eq(batch_dots[i], scalar_dot, EPSILON),
592                "dot mismatch at index {i}: batch={}, scalar={scalar_dot}",
593                batch_dots[i]
594            );
595        }
596    }
597
598    // -----------------------------------------------------------------------
599    // Batch cross product matches scalar
600    // -----------------------------------------------------------------------
601
602    #[test]
603    fn test_batch_cross_matches_scalar() {
604        let a_aos = vec![[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 2.0, 3.0]];
605        let b_aos = vec![[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [4.0, 5.0, 6.0]];
606
607        let a = Vec3Batch::from_aos(&a_aos);
608        let b = Vec3Batch::from_aos(&b_aos);
609
610        let cross_batch = a.cross(&b).expect("cross should succeed");
611        let cross_aos = cross_batch.to_aos().expect("to_aos should succeed");
612
613        for (i, (&va, &vb)) in a_aos.iter().zip(b_aos.iter()).enumerate() {
614            let cx = va[1] * vb[2] - va[2] * vb[1];
615            let cy = va[2] * vb[0] - va[0] * vb[2];
616            let cz = va[0] * vb[1] - va[1] * vb[0];
617            assert!(
618                approx_eq(cross_aos[i][0], cx, EPSILON),
619                "cross.x mismatch at {i}"
620            );
621            assert!(
622                approx_eq(cross_aos[i][1], cy, EPSILON),
623                "cross.y mismatch at {i}"
624            );
625            assert!(
626                approx_eq(cross_aos[i][2], cz, EPSILON),
627                "cross.z mismatch at {i}"
628            );
629        }
630    }
631
632    // -----------------------------------------------------------------------
633    // Batch normalize produces unit vectors
634    // -----------------------------------------------------------------------
635
636    #[test]
637    fn test_batch_normalize_unit_vectors() {
638        let positions = vec![
639            [3.0, 4.0, 0.0],  // length = 5
640            [0.0, 0.0, 7.0],  // length = 7
641            [1.0, 1.0, 1.0],  // length = sqrt(3)
642            [10.0, 0.0, 0.0], // length = 10
643        ];
644        let mut batch = Vec3Batch::from_aos(&positions);
645        batch.normalize().expect("normalize should succeed");
646
647        let lengths = batch.length();
648        for (i, &len) in lengths.iter().enumerate() {
649            assert!(
650                approx_eq(len, 1.0, 1e-10),
651                "expected unit length at index {i}, got {len}"
652            );
653        }
654
655        // Check specific known result: [3,4,0] -> [0.6, 0.8, 0.0]
656        assert!(approx_eq(batch.x[0], 0.6, EPSILON));
657        assert!(approx_eq(batch.y[0], 0.8, EPSILON));
658        assert!(approx_eq(batch.z[0], 0.0, EPSILON));
659    }
660
661    #[test]
662    fn test_normalize_zero_vector_error() {
663        let mut batch = Vec3Batch::from_aos(&[[0.0, 0.0, 0.0]]);
664        let result = batch.normalize();
665        assert!(result.is_err());
666        match result {
667            Err(SimdMathError::ZeroLengthVector { index }) => assert_eq!(index, 0),
668            other => panic!("expected ZeroLengthVector error, got {other:?}"),
669        }
670    }
671
672    // -----------------------------------------------------------------------
673    // Distance computation matches naive loop
674    // -----------------------------------------------------------------------
675
676    #[test]
677    fn test_compute_distances_batch_matches_naive() {
678        let positions_aos = vec![
679            [1.0, 0.0, 0.0],
680            [0.0, 3.0, 4.0],
681            [1.0, 1.0, 1.0],
682            [10.0, 20.0, 30.0],
683        ];
684        let ref_pos = [0.0, 0.0, 0.0];
685
686        let batch = Vec3Batch::from_aos(&positions_aos);
687        let distances = compute_distances_batch(&batch, ref_pos);
688
689        for (i, pos) in positions_aos.iter().enumerate() {
690            let naive = ((pos[0] - ref_pos[0]).powi(2)
691                + (pos[1] - ref_pos[1]).powi(2)
692                + (pos[2] - ref_pos[2]).powi(2))
693            .sqrt();
694            assert!(
695                approx_eq(distances[i], naive, EPSILON),
696                "distance mismatch at {i}: batch={}, naive={naive}",
697                distances[i]
698            );
699        }
700    }
701
702    #[test]
703    fn test_distance_sq_pairwise() {
704        let a = Vec3Batch::from_aos(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
705        let b = Vec3Batch::from_aos(&[[4.0, 6.0, 3.0], [4.0, 5.0, 6.0]]);
706
707        let dsq = Vec3Batch::distance_sq_pairwise(&a, &b).expect("should succeed");
708        // [1]: (4-4)^2 + (5-5)^2 + (6-6)^2 = 0
709        assert!(approx_eq(dsq[1], 0.0, EPSILON));
710        // [0]: (4-1)^2 + (6-2)^2 + (3-3)^2 = 9+16+0 = 25
711        assert!(approx_eq(dsq[0], 25.0, EPSILON));
712    }
713
714    // -----------------------------------------------------------------------
715    // Kernel evaluation matches scalar version
716    // -----------------------------------------------------------------------
717
718    #[test]
719    fn test_cubic_spline_kernel_matches_scalar() {
720        let h = 1.0;
721        let sigma = 1.0 / (4.0 * f64::consts::PI * h * h * h);
722
723        // Test representative q values
724        let r_values = vec![0.0, 0.5, 0.99, 1.0, 1.5, 1.99, 2.0, 3.0];
725        let kernel_vals = cubic_spline_kernel_batch(&r_values, h);
726
727        for (i, &r) in r_values.iter().enumerate() {
728            let q = r / h;
729            let expected = if q >= 2.0 {
730                0.0
731            } else if q >= 1.0 {
732                let t = 2.0 - q;
733                sigma * t * t * t
734            } else {
735                let t2 = 2.0 - q;
736                let t1 = 1.0 - q;
737                sigma * (t2 * t2 * t2 - 4.0 * t1 * t1 * t1)
738            };
739            assert!(
740                approx_eq(kernel_vals[i], expected, EPSILON),
741                "kernel mismatch at r={r}: batch={}, expected={expected}",
742                kernel_vals[i]
743            );
744        }
745    }
746
747    #[test]
748    fn test_cubic_spline_kernel_zero_at_boundary() {
749        let h = 2.0;
750        let vals = cubic_spline_kernel_batch(&[4.0, 5.0, 100.0], h);
751        for (i, &v) in vals.iter().enumerate() {
752            assert!(
753                approx_eq(v, 0.0, EPSILON),
754                "expected zero at index {i}, got {v}"
755            );
756        }
757    }
758
759    #[test]
760    fn test_cubic_spline_kernel_non_positive_h() {
761        let vals = cubic_spline_kernel_batch(&[1.0, 2.0], 0.0);
762        assert_eq!(vals, vec![0.0, 0.0]);
763
764        let vals_neg = cubic_spline_kernel_batch(&[1.0], -1.0);
765        assert_eq!(vals_neg, vec![0.0]);
766    }
767
768    // -----------------------------------------------------------------------
769    // Empty batch operations
770    // -----------------------------------------------------------------------
771
772    #[test]
773    fn test_empty_batch_operations() {
774        let a = Vec3Batch::new(0);
775        let b = Vec3Batch::new(0);
776
777        assert!(a.is_empty());
778
779        let sum = a.add(&b).expect("add empty should succeed");
780        assert!(sum.is_empty());
781
782        let diff = a.sub(&b).expect("sub empty should succeed");
783        assert!(diff.is_empty());
784
785        let dots = a.dot(&b).expect("dot empty should succeed");
786        assert!(dots.is_empty());
787
788        let cross = a.cross(&b).expect("cross empty should succeed");
789        assert!(cross.is_empty());
790
791        let scaled = a.scale(5.0);
792        assert!(scaled.is_empty());
793
794        let lsq = a.length_sq();
795        assert!(lsq.is_empty());
796
797        let lens = a.length();
798        assert!(lens.is_empty());
799
800        let dsq =
801            Vec3Batch::distance_sq_pairwise(&a, &b).expect("distance_sq empty should succeed");
802        assert!(dsq.is_empty());
803
804        let dists = compute_distances_batch(&a, [0.0, 0.0, 0.0]);
805        assert!(dists.is_empty());
806
807        let kernel = cubic_spline_kernel_batch(&[], 1.0);
808        assert!(kernel.is_empty());
809    }
810
811    // -----------------------------------------------------------------------
812    // Size mismatch errors
813    // -----------------------------------------------------------------------
814
815    #[test]
816    fn test_size_mismatch_errors() {
817        let a = Vec3Batch::new(3);
818        let b = Vec3Batch::new(5);
819
820        assert!(a.add(&b).is_err());
821        assert!(a.sub(&b).is_err());
822        assert!(a.dot(&b).is_err());
823        assert!(a.cross(&b).is_err());
824        assert!(Vec3Batch::distance_sq_pairwise(&a, &b).is_err());
825    }
826
827    // -----------------------------------------------------------------------
828    // Force accumulation
829    // -----------------------------------------------------------------------
830
831    #[test]
832    fn test_accumulate_forces_batch() {
833        let mut forces = Vec3Batch::new(3);
834        let directions = Vec3Batch::from_aos(&[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]);
835        let magnitudes = vec![10.0, 20.0, 30.0];
836
837        accumulate_forces_batch(&mut forces, &directions, &magnitudes)
838            .expect("accumulate should succeed");
839
840        assert!(approx_eq(forces.x[0], 10.0, EPSILON));
841        assert!(approx_eq(forces.y[1], 20.0, EPSILON));
842        assert!(approx_eq(forces.z[2], 30.0, EPSILON));
843
844        // Accumulate again (forces should add up)
845        accumulate_forces_batch(&mut forces, &directions, &magnitudes)
846            .expect("second accumulate should succeed");
847
848        assert!(approx_eq(forces.x[0], 20.0, EPSILON));
849        assert!(approx_eq(forces.y[1], 40.0, EPSILON));
850        assert!(approx_eq(forces.z[2], 60.0, EPSILON));
851    }
852
853    #[test]
854    fn test_accumulate_forces_size_mismatch() {
855        let mut forces = Vec3Batch::new(3);
856        let directions = Vec3Batch::new(2);
857        let magnitudes = vec![1.0, 2.0, 3.0];
858
859        assert!(accumulate_forces_batch(&mut forces, &directions, &magnitudes).is_err());
860
861        let directions2 = Vec3Batch::new(3);
862        let magnitudes2 = vec![1.0, 2.0];
863        assert!(accumulate_forces_batch(&mut forces, &directions2, &magnitudes2).is_err());
864    }
865
866    // -----------------------------------------------------------------------
867    // Add / Sub correctness
868    // -----------------------------------------------------------------------
869
870    #[test]
871    fn test_add_sub_inverse() {
872        let a = Vec3Batch::from_aos(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
873        let b = Vec3Batch::from_aos(&[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]);
874
875        let sum = a.add(&b).expect("add should succeed");
876        let back = sum.sub(&b).expect("sub should succeed");
877
878        for i in 0..a.len() {
879            assert!(approx_eq(back.x[i], a.x[i], EPSILON));
880            assert!(approx_eq(back.y[i], a.y[i], EPSILON));
881            assert!(approx_eq(back.z[i], a.z[i], EPSILON));
882        }
883    }
884
885    // -----------------------------------------------------------------------
886    // Scale correctness
887    // -----------------------------------------------------------------------
888
889    #[test]
890    fn test_scale_by_zero() {
891        let a = Vec3Batch::from_aos(&[[1.0, 2.0, 3.0]]);
892        let scaled = a.scale(0.0);
893        assert!(approx_eq(scaled.x[0], 0.0, EPSILON));
894        assert!(approx_eq(scaled.y[0], 0.0, EPSILON));
895        assert!(approx_eq(scaled.z[0], 0.0, EPSILON));
896    }
897
898    #[test]
899    fn test_scale_negative() {
900        let a = Vec3Batch::from_aos(&[[1.0, 2.0, 3.0]]);
901        let scaled = a.scale(-2.0);
902        assert!(approx_eq(scaled.x[0], -2.0, EPSILON));
903        assert!(approx_eq(scaled.y[0], -4.0, EPSILON));
904        assert!(approx_eq(scaled.z[0], -6.0, EPSILON));
905    }
906
907    // -----------------------------------------------------------------------
908    // Length / length_sq
909    // -----------------------------------------------------------------------
910
911    #[test]
912    fn test_length_sq_and_length() {
913        let batch = Vec3Batch::from_aos(&[[3.0, 4.0, 0.0], [0.0, 0.0, 5.0]]);
914        let lsq = batch.length_sq();
915        assert!(approx_eq(lsq[0], 25.0, EPSILON));
916        assert!(approx_eq(lsq[1], 25.0, EPSILON));
917
918        let lens = batch.length();
919        assert!(approx_eq(lens[0], 5.0, EPSILON));
920        assert!(approx_eq(lens[1], 5.0, EPSILON));
921    }
922
923    // -----------------------------------------------------------------------
924    // Cross product specific identities
925    // -----------------------------------------------------------------------
926
927    #[test]
928    fn test_cross_product_anticommutative() {
929        let a = Vec3Batch::from_aos(&[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
930        let b = Vec3Batch::from_aos(&[[7.0, 8.0, 9.0], [10.0, 11.0, 12.0]]);
931
932        let axb = a.cross(&b).expect("cross a x b");
933        let bxa = b.cross(&a).expect("cross b x a");
934
935        // a x b = -(b x a)
936        for i in 0..a.len() {
937            assert!(approx_eq(axb.x[i], -bxa.x[i], EPSILON));
938            assert!(approx_eq(axb.y[i], -bxa.y[i], EPSILON));
939            assert!(approx_eq(axb.z[i], -bxa.z[i], EPSILON));
940        }
941    }
942
943    #[test]
944    fn test_cross_product_perpendicular() {
945        // a x b should be perpendicular to both a and b
946        let a = Vec3Batch::from_aos(&[[1.0, 2.0, 3.0]]);
947        let b = Vec3Batch::from_aos(&[[4.0, 5.0, 6.0]]);
948
949        let c = a.cross(&b).expect("cross");
950        let dot_ac = a.dot(&c).expect("dot a.c");
951        let dot_bc = b.dot(&c).expect("dot b.c");
952
953        assert!(approx_eq(dot_ac[0], 0.0, 1e-10));
954        assert!(approx_eq(dot_bc[0], 0.0, 1e-10));
955    }
956
957    // -----------------------------------------------------------------------
958    // Kernel monotonicity
959    // -----------------------------------------------------------------------
960
961    #[test]
962    fn test_cubic_spline_kernel_monotonic_decrease() {
963        let h = 1.0;
964        let r: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
965        let vals = cubic_spline_kernel_batch(&r, h);
966
967        // Kernel should be non-negative
968        for (i, &v) in vals.iter().enumerate() {
969            assert!(v >= 0.0, "kernel negative at r={}: {v}", r[i]);
970        }
971
972        // Maximum should be at r=0
973        let max_val = vals.iter().copied().fold(f64::NEG_INFINITY, f64::max);
974        assert!(approx_eq(max_val, vals[0], EPSILON));
975    }
976
977    // -----------------------------------------------------------------------
978    // Compute distances with non-origin reference
979    // -----------------------------------------------------------------------
980
981    #[test]
982    fn test_compute_distances_nonzero_ref() {
983        let positions = Vec3Batch::from_aos(&[[4.0, 0.0, 0.0], [1.0, 1.0, 1.0]]);
984        let ref_pos = [1.0, 0.0, 0.0];
985        let dists = compute_distances_batch(&positions, ref_pos);
986
987        assert!(approx_eq(dists[0], 3.0, EPSILON));
988        let expected_1 = (0.0_f64 + 1.0 + 1.0_f64).sqrt();
989        assert!(approx_eq(dists[1], expected_1, EPSILON));
990    }
991}