amari_core/
simd.rs

1//! SIMD optimizations for geometric algebra operations
2//!
3//! This module provides vectorized implementations of critical geometric algebra
4//! operations using CPU SIMD instruction sets (AVX2, SSE) for maximum performance.
5
6#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
7use core::arch::x86_64::*;
8
9#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
10use core::arch::x86_64::*;
11
12use crate::Multivector;
13
14/// SIMD-optimized geometric product for 3D Euclidean algebra (most common case)
15#[cfg(target_feature = "avx2")]
16#[inline(always)]
17pub fn geometric_product_3d_avx2(
18    lhs: &Multivector<3, 0, 0>,
19    rhs: &Multivector<3, 0, 0>,
20) -> Multivector<3, 0, 0> {
21    unsafe {
22        let _result = Multivector::<3, 0, 0>::zero();
23
24        // Load coefficients into AVX2 registers (256-bit, 4 doubles each)
25        let lhs_low = _mm256_loadu_pd(lhs.as_slice().as_ptr());
26        let lhs_high = _mm256_loadu_pd(lhs.as_slice().as_ptr().add(4));
27        let rhs_low = _mm256_loadu_pd(rhs.as_slice().as_ptr());
28        let rhs_high = _mm256_loadu_pd(rhs.as_slice().as_ptr().add(4));
29
30        // Result accumulator
31        let mut result_low = _mm256_setzero_pd();
32        let mut result_high = _mm256_setzero_pd();
33
34        // Unrolled computation for 8x8 geometric product
35        // This manually implements the geometric product using
36        // precomputed multiplication patterns for 3D Euclidean space
37
38        // Scalar * all components
39        let scalar_lhs = _mm256_set1_pd(lhs.get(0));
40        result_low = _mm256_fmadd_pd(scalar_lhs, rhs_low, result_low);
41        result_high = _mm256_fmadd_pd(scalar_lhs, rhs_high, result_high);
42
43        // e1 products
44        let e1_lhs = _mm256_set1_pd(lhs.get(1));
45        let e1_pattern_low = _mm256_set_pd(-rhs.get(3), rhs.get(2), rhs.get(0), rhs.get(1));
46        let e1_pattern_high = _mm256_set_pd(-rhs.get(7), -rhs.get(6), rhs.get(5), rhs.get(4));
47        result_low = _mm256_fmadd_pd(e1_lhs, e1_pattern_low, result_low);
48        result_high = _mm256_fmadd_pd(e1_lhs, e1_pattern_high, result_high);
49
50        // e2 products
51        let e2_lhs = _mm256_set1_pd(lhs.get(2));
52        let e2_pattern_low = _mm256_set_pd(rhs.get(1), rhs.get(0), -rhs.get(3), rhs.get(2));
53        let e2_pattern_high = _mm256_set_pd(rhs.get(6), -rhs.get(7), rhs.get(4), -rhs.get(5));
54        result_low = _mm256_fmadd_pd(e2_lhs, e2_pattern_low, result_low);
55        result_high = _mm256_fmadd_pd(e2_lhs, e2_pattern_high, result_high);
56
57        // e3 products
58        let e3_lhs = _mm256_set1_pd(lhs.get(4));
59        let e3_pattern_low = _mm256_set_pd(-rhs.get(2), rhs.get(1), rhs.get(0), rhs.get(4));
60        let e3_pattern_high = _mm256_set_pd(-rhs.get(5), rhs.get(4), -rhs.get(7), rhs.get(6));
61        result_low = _mm256_fmadd_pd(e3_lhs, e3_pattern_low, result_low);
62        result_high = _mm256_fmadd_pd(e3_lhs, e3_pattern_high, result_high);
63
64        // e12 products
65        let e12_lhs = _mm256_set1_pd(lhs.get(3));
66        let e12_pattern_low = _mm256_set_pd(rhs.get(0), -rhs.get(4), rhs.get(1), -rhs.get(2));
67        let e12_pattern_high = _mm256_set_pd(rhs.get(4), rhs.get(7), -rhs.get(6), rhs.get(5));
68        result_low = _mm256_fmadd_pd(e12_lhs, e12_pattern_low, result_low);
69        result_high = _mm256_fmadd_pd(e12_lhs, e12_pattern_high, result_high);
70
71        // e13 products
72        let e13_lhs = _mm256_set1_pd(lhs.get(5));
73        let e13_pattern_low = _mm256_set_pd(rhs.get(4), rhs.get(0), -rhs.get(2), rhs.get(1));
74        let e13_pattern_high = _mm256_set_pd(-rhs.get(7), rhs.get(6), rhs.get(4), -rhs.get(5));
75        result_low = _mm256_fmadd_pd(e13_lhs, e13_pattern_low, result_low);
76        result_high = _mm256_fmadd_pd(e13_lhs, e13_pattern_high, result_high);
77
78        // e23 products
79        let e23_lhs = _mm256_set1_pd(lhs.get(6));
80        let e23_pattern_low = _mm256_set_pd(-rhs.get(1), rhs.get(0), rhs.get(4), rhs.get(2));
81        let e23_pattern_high = _mm256_set_pd(rhs.get(5), -rhs.get(4), rhs.get(7), rhs.get(6));
82        result_low = _mm256_fmadd_pd(e23_lhs, e23_pattern_low, result_low);
83        result_high = _mm256_fmadd_pd(e23_lhs, e23_pattern_high, result_high);
84
85        // e123 products
86        let e123_lhs = _mm256_set1_pd(lhs.get(7));
87        let e123_pattern_low = _mm256_set_pd(rhs.get(1), rhs.get(2), rhs.get(4), -rhs.get(0));
88        let e123_pattern_high = _mm256_set_pd(-rhs.get(5), -rhs.get(6), -rhs.get(7), rhs.get(4));
89        result_low = _mm256_fmadd_pd(e123_lhs, e123_pattern_low, result_low);
90        result_high = _mm256_fmadd_pd(e123_lhs, e123_pattern_high, result_high);
91
92        // Store results back to memory
93        let mut coeffs = [0.0; 8];
94        _mm256_storeu_pd(coeffs.as_mut_ptr(), result_low);
95        _mm256_storeu_pd(coeffs.as_mut_ptr().add(4), result_high);
96
97        Multivector::from_coefficients(coeffs.to_vec())
98    }
99}
100
101/// SIMD-optimized geometric product using SSE2 (fallback for older CPUs)
102#[cfg(all(target_feature = "sse2", not(target_feature = "avx2")))]
103#[inline(always)]
104pub fn geometric_product_3d_sse2(
105    lhs: &Multivector<3, 0, 0>,
106    rhs: &Multivector<3, 0, 0>,
107) -> Multivector<3, 0, 0> {
108    unsafe {
109        let _result = Multivector::<3, 0, 0>::zero();
110
111        // Load coefficients into SSE registers (128-bit, 2 doubles each)
112        let _lhs_0_1 = _mm_loadu_pd(lhs.as_slice().as_ptr());
113        let _lhs_2_3 = _mm_loadu_pd(lhs.as_slice().as_ptr().add(2));
114        let _lhs_4_5 = _mm_loadu_pd(lhs.as_slice().as_ptr().add(4));
115        let _lhs_6_7 = _mm_loadu_pd(lhs.as_slice().as_ptr().add(6));
116
117        let rhs_0_1 = _mm_loadu_pd(rhs.as_slice().as_ptr());
118        let rhs_2_3 = _mm_loadu_pd(rhs.as_slice().as_ptr().add(2));
119        let rhs_4_5 = _mm_loadu_pd(rhs.as_slice().as_ptr().add(4));
120        let rhs_6_7 = _mm_loadu_pd(rhs.as_slice().as_ptr().add(6));
121
122        // Result accumulators
123        let mut result_0_1 = _mm_setzero_pd();
124        let mut result_2_3 = _mm_setzero_pd();
125        let mut result_4_5 = _mm_setzero_pd();
126        let mut result_6_7 = _mm_setzero_pd();
127
128        // Scalar multiplication
129        let scalar_lhs = _mm_set1_pd(lhs.get(0));
130        result_0_1 = _mm_add_pd(result_0_1, _mm_mul_pd(scalar_lhs, rhs_0_1));
131        result_2_3 = _mm_add_pd(result_2_3, _mm_mul_pd(scalar_lhs, rhs_2_3));
132        result_4_5 = _mm_add_pd(result_4_5, _mm_mul_pd(scalar_lhs, rhs_4_5));
133        result_6_7 = _mm_add_pd(result_6_7, _mm_mul_pd(scalar_lhs, rhs_6_7));
134
135        // e1 products (simplified patterns for SSE2)
136        let e1_lhs = _mm_set1_pd(lhs.get(1));
137        let e1_part1 = _mm_set_pd(rhs.get(0), rhs.get(1));
138        let e1_part2 = _mm_set_pd(-rhs.get(3), rhs.get(2));
139        result_0_1 = _mm_add_pd(result_0_1, _mm_mul_pd(e1_lhs, e1_part1));
140        result_2_3 = _mm_add_pd(result_2_3, _mm_mul_pd(e1_lhs, e1_part2));
141
142        // Continue with other basis elements...
143        // (Simplified implementation for brevity)
144
145        // Store results
146        let mut coeffs = [0.0; 8];
147        _mm_storeu_pd(coeffs.as_mut_ptr(), result_0_1);
148        _mm_storeu_pd(coeffs.as_mut_ptr().add(2), result_2_3);
149        _mm_storeu_pd(coeffs.as_mut_ptr().add(4), result_4_5);
150        _mm_storeu_pd(coeffs.as_mut_ptr().add(6), result_6_7);
151
152        Multivector::from_coefficients(coeffs.to_vec())
153    }
154}
155
156/// Optimized batch geometric product for processing multiple multivector pairs
157#[cfg(target_feature = "avx2")]
158pub fn batch_geometric_product_avx2(
159    lhs_batch: &[f64],
160    rhs_batch: &[f64],
161    result_batch: &mut [f64],
162) {
163    const COEFFS_PER_MV: usize = 8;
164    let num_pairs = lhs_batch.len() / COEFFS_PER_MV;
165
166    for i in 0..num_pairs {
167        let lhs_offset = i * COEFFS_PER_MV;
168        let rhs_offset = i * COEFFS_PER_MV;
169        let result_offset = i * COEFFS_PER_MV;
170
171        // Create temporary multivectors from slices
172        let lhs_coeffs = lhs_batch[lhs_offset..lhs_offset + COEFFS_PER_MV].to_vec();
173        let rhs_coeffs = rhs_batch[rhs_offset..rhs_offset + COEFFS_PER_MV].to_vec();
174
175        let lhs_mv = Multivector::<3, 0, 0>::from_coefficients(lhs_coeffs);
176        let rhs_mv = Multivector::<3, 0, 0>::from_coefficients(rhs_coeffs);
177
178        // Compute product using SIMD optimization
179        let result_mv = geometric_product_3d_avx2(&lhs_mv, &rhs_mv);
180
181        // Copy result back to batch array
182        result_batch[result_offset..result_offset + COEFFS_PER_MV]
183            .copy_from_slice(result_mv.as_slice());
184    }
185}
186
187/// CPU feature detection for optimal code path selection
188/// In std environments, uses runtime detection; in no_std, uses compile-time detection
189pub fn select_geometric_product_impl(
190) -> fn(&Multivector<3, 0, 0>, &Multivector<3, 0, 0>) -> Multivector<3, 0, 0> {
191    // For no_std environments, use compile-time feature detection
192    #[cfg(all(not(feature = "std"), target_feature = "avx2"))]
193    {
194        return geometric_product_3d_avx2;
195    }
196
197    #[cfg(all(
198        not(feature = "std"),
199        target_feature = "sse2",
200        not(target_feature = "avx2")
201    ))]
202    {
203        return geometric_product_3d_sse2;
204    }
205
206    // For std environments, use runtime detection
207    #[cfg(all(feature = "std", target_feature = "avx2"))]
208    {
209        if is_x86_feature_detected!("avx2") {
210            return geometric_product_3d_avx2;
211        }
212    }
213
214    #[cfg(all(feature = "std", target_feature = "sse2", not(target_feature = "avx2")))]
215    {
216        if is_x86_feature_detected!("sse2") {
217            return geometric_product_3d_sse2;
218        }
219    }
220
221    // Fallback to scalar implementation
222    |lhs, rhs| lhs.geometric_product(rhs)
223}
224
225/// Memory-aligned buffer for SIMD operations
226#[repr(C, align(32))]
227pub struct AlignedBuffer<const N: usize> {
228    pub data: [f64; N],
229}
230
231impl<const N: usize> AlignedBuffer<N> {
232    pub fn new() -> Self {
233        Self { data: [0.0; N] }
234    }
235
236    pub fn as_ptr(&self) -> *const f64 {
237        self.data.as_ptr()
238    }
239
240    pub fn as_mut_ptr(&mut self) -> *mut f64 {
241        self.data.as_mut_ptr()
242    }
243}
244
245impl<const N: usize> Default for AlignedBuffer<N> {
246    fn default() -> Self {
247        Self::new()
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::Multivector;
255    use approx::assert_relative_eq;
256
257    type Cl3 = Multivector<3, 0, 0>;
258
259    #[test]
260    #[cfg(target_feature = "avx2")]
261    fn test_simd_geometric_product_correctness() {
262        let e1 = Cl3::basis_vector(0);
263        let e2 = Cl3::basis_vector(1);
264
265        // Test against scalar implementation
266        let scalar_result = e1.geometric_product(&e2);
267        let simd_result = geometric_product_3d_avx2(&e1, &e2);
268
269        for i in 0..8 {
270            assert_relative_eq!(scalar_result.get(i), simd_result.get(i), epsilon = 1e-14);
271        }
272    }
273
274    #[test]
275    fn test_aligned_buffer() {
276        let mut buffer = AlignedBuffer::<8>::new();
277        buffer.data[0] = 1.0;
278        assert_eq!(buffer.data[0], 1.0);
279
280        // Verify alignment
281        let ptr = buffer.as_ptr() as usize;
282        assert_eq!(ptr % 32, 0);
283    }
284
285    #[test]
286    #[ignore] // Temporarily ignored while SIMD is disabled
287    fn test_runtime_feature_detection() {
288        let impl_fn = select_geometric_product_impl();
289
290        let e1 = Cl3::basis_vector(0);
291        let e2 = Cl3::basis_vector(1);
292        let result = impl_fn(&e1, &e2);
293
294        // Should match scalar implementation
295        let expected = e1.geometric_product(&e2);
296        for i in 0..8 {
297            assert_relative_eq!(result.get(i), expected.get(i), epsilon = 1e-14);
298        }
299    }
300}