1#[cfg(all(target_arch = "x86_64", target_feature = "avx2"))]
7use core::arch::x86_64::*;
8
9#[cfg(all(target_arch = "x86_64", target_feature = "sse2"))]
10use core::arch::x86_64::*;
11
12use crate::Multivector;
13
14#[cfg(target_feature = "avx2")]
16#[inline(always)]
17pub fn geometric_product_3d_avx2(
18 lhs: &Multivector<3, 0, 0>,
19 rhs: &Multivector<3, 0, 0>,
20) -> Multivector<3, 0, 0> {
21 unsafe {
22 let _result = Multivector::<3, 0, 0>::zero();
23
24 let lhs_low = _mm256_loadu_pd(lhs.as_slice().as_ptr());
26 let lhs_high = _mm256_loadu_pd(lhs.as_slice().as_ptr().add(4));
27 let rhs_low = _mm256_loadu_pd(rhs.as_slice().as_ptr());
28 let rhs_high = _mm256_loadu_pd(rhs.as_slice().as_ptr().add(4));
29
30 let mut result_low = _mm256_setzero_pd();
32 let mut result_high = _mm256_setzero_pd();
33
34 let scalar_lhs = _mm256_set1_pd(lhs.get(0));
40 result_low = _mm256_fmadd_pd(scalar_lhs, rhs_low, result_low);
41 result_high = _mm256_fmadd_pd(scalar_lhs, rhs_high, result_high);
42
43 let e1_lhs = _mm256_set1_pd(lhs.get(1));
45 let e1_pattern_low = _mm256_set_pd(-rhs.get(3), rhs.get(2), rhs.get(0), rhs.get(1));
46 let e1_pattern_high = _mm256_set_pd(-rhs.get(7), -rhs.get(6), rhs.get(5), rhs.get(4));
47 result_low = _mm256_fmadd_pd(e1_lhs, e1_pattern_low, result_low);
48 result_high = _mm256_fmadd_pd(e1_lhs, e1_pattern_high, result_high);
49
50 let e2_lhs = _mm256_set1_pd(lhs.get(2));
52 let e2_pattern_low = _mm256_set_pd(rhs.get(1), rhs.get(0), -rhs.get(3), rhs.get(2));
53 let e2_pattern_high = _mm256_set_pd(rhs.get(6), -rhs.get(7), rhs.get(4), -rhs.get(5));
54 result_low = _mm256_fmadd_pd(e2_lhs, e2_pattern_low, result_low);
55 result_high = _mm256_fmadd_pd(e2_lhs, e2_pattern_high, result_high);
56
57 let e3_lhs = _mm256_set1_pd(lhs.get(4));
59 let e3_pattern_low = _mm256_set_pd(-rhs.get(2), rhs.get(1), rhs.get(0), rhs.get(4));
60 let e3_pattern_high = _mm256_set_pd(-rhs.get(5), rhs.get(4), -rhs.get(7), rhs.get(6));
61 result_low = _mm256_fmadd_pd(e3_lhs, e3_pattern_low, result_low);
62 result_high = _mm256_fmadd_pd(e3_lhs, e3_pattern_high, result_high);
63
64 let e12_lhs = _mm256_set1_pd(lhs.get(3));
66 let e12_pattern_low = _mm256_set_pd(rhs.get(0), -rhs.get(4), rhs.get(1), -rhs.get(2));
67 let e12_pattern_high = _mm256_set_pd(rhs.get(4), rhs.get(7), -rhs.get(6), rhs.get(5));
68 result_low = _mm256_fmadd_pd(e12_lhs, e12_pattern_low, result_low);
69 result_high = _mm256_fmadd_pd(e12_lhs, e12_pattern_high, result_high);
70
71 let e13_lhs = _mm256_set1_pd(lhs.get(5));
73 let e13_pattern_low = _mm256_set_pd(rhs.get(4), rhs.get(0), -rhs.get(2), rhs.get(1));
74 let e13_pattern_high = _mm256_set_pd(-rhs.get(7), rhs.get(6), rhs.get(4), -rhs.get(5));
75 result_low = _mm256_fmadd_pd(e13_lhs, e13_pattern_low, result_low);
76 result_high = _mm256_fmadd_pd(e13_lhs, e13_pattern_high, result_high);
77
78 let e23_lhs = _mm256_set1_pd(lhs.get(6));
80 let e23_pattern_low = _mm256_set_pd(-rhs.get(1), rhs.get(0), rhs.get(4), rhs.get(2));
81 let e23_pattern_high = _mm256_set_pd(rhs.get(5), -rhs.get(4), rhs.get(7), rhs.get(6));
82 result_low = _mm256_fmadd_pd(e23_lhs, e23_pattern_low, result_low);
83 result_high = _mm256_fmadd_pd(e23_lhs, e23_pattern_high, result_high);
84
85 let e123_lhs = _mm256_set1_pd(lhs.get(7));
87 let e123_pattern_low = _mm256_set_pd(rhs.get(1), rhs.get(2), rhs.get(4), -rhs.get(0));
88 let e123_pattern_high = _mm256_set_pd(-rhs.get(5), -rhs.get(6), -rhs.get(7), rhs.get(4));
89 result_low = _mm256_fmadd_pd(e123_lhs, e123_pattern_low, result_low);
90 result_high = _mm256_fmadd_pd(e123_lhs, e123_pattern_high, result_high);
91
92 let mut coeffs = [0.0; 8];
94 _mm256_storeu_pd(coeffs.as_mut_ptr(), result_low);
95 _mm256_storeu_pd(coeffs.as_mut_ptr().add(4), result_high);
96
97 Multivector::from_coefficients(coeffs.to_vec())
98 }
99}
100
101#[cfg(all(target_feature = "sse2", not(target_feature = "avx2")))]
103#[inline(always)]
104pub fn geometric_product_3d_sse2(
105 lhs: &Multivector<3, 0, 0>,
106 rhs: &Multivector<3, 0, 0>,
107) -> Multivector<3, 0, 0> {
108 unsafe {
109 let _result = Multivector::<3, 0, 0>::zero();
110
111 let _lhs_0_1 = _mm_loadu_pd(lhs.as_slice().as_ptr());
113 let _lhs_2_3 = _mm_loadu_pd(lhs.as_slice().as_ptr().add(2));
114 let _lhs_4_5 = _mm_loadu_pd(lhs.as_slice().as_ptr().add(4));
115 let _lhs_6_7 = _mm_loadu_pd(lhs.as_slice().as_ptr().add(6));
116
117 let rhs_0_1 = _mm_loadu_pd(rhs.as_slice().as_ptr());
118 let rhs_2_3 = _mm_loadu_pd(rhs.as_slice().as_ptr().add(2));
119 let rhs_4_5 = _mm_loadu_pd(rhs.as_slice().as_ptr().add(4));
120 let rhs_6_7 = _mm_loadu_pd(rhs.as_slice().as_ptr().add(6));
121
122 let mut result_0_1 = _mm_setzero_pd();
124 let mut result_2_3 = _mm_setzero_pd();
125 let mut result_4_5 = _mm_setzero_pd();
126 let mut result_6_7 = _mm_setzero_pd();
127
128 let scalar_lhs = _mm_set1_pd(lhs.get(0));
130 result_0_1 = _mm_add_pd(result_0_1, _mm_mul_pd(scalar_lhs, rhs_0_1));
131 result_2_3 = _mm_add_pd(result_2_3, _mm_mul_pd(scalar_lhs, rhs_2_3));
132 result_4_5 = _mm_add_pd(result_4_5, _mm_mul_pd(scalar_lhs, rhs_4_5));
133 result_6_7 = _mm_add_pd(result_6_7, _mm_mul_pd(scalar_lhs, rhs_6_7));
134
135 let e1_lhs = _mm_set1_pd(lhs.get(1));
137 let e1_part1 = _mm_set_pd(rhs.get(0), rhs.get(1));
138 let e1_part2 = _mm_set_pd(-rhs.get(3), rhs.get(2));
139 result_0_1 = _mm_add_pd(result_0_1, _mm_mul_pd(e1_lhs, e1_part1));
140 result_2_3 = _mm_add_pd(result_2_3, _mm_mul_pd(e1_lhs, e1_part2));
141
142 let mut coeffs = [0.0; 8];
147 _mm_storeu_pd(coeffs.as_mut_ptr(), result_0_1);
148 _mm_storeu_pd(coeffs.as_mut_ptr().add(2), result_2_3);
149 _mm_storeu_pd(coeffs.as_mut_ptr().add(4), result_4_5);
150 _mm_storeu_pd(coeffs.as_mut_ptr().add(6), result_6_7);
151
152 Multivector::from_coefficients(coeffs.to_vec())
153 }
154}
155
156#[cfg(target_feature = "avx2")]
158pub fn batch_geometric_product_avx2(
159 lhs_batch: &[f64],
160 rhs_batch: &[f64],
161 result_batch: &mut [f64],
162) {
163 const COEFFS_PER_MV: usize = 8;
164 let num_pairs = lhs_batch.len() / COEFFS_PER_MV;
165
166 for i in 0..num_pairs {
167 let lhs_offset = i * COEFFS_PER_MV;
168 let rhs_offset = i * COEFFS_PER_MV;
169 let result_offset = i * COEFFS_PER_MV;
170
171 let lhs_coeffs = lhs_batch[lhs_offset..lhs_offset + COEFFS_PER_MV].to_vec();
173 let rhs_coeffs = rhs_batch[rhs_offset..rhs_offset + COEFFS_PER_MV].to_vec();
174
175 let lhs_mv = Multivector::<3, 0, 0>::from_coefficients(lhs_coeffs);
176 let rhs_mv = Multivector::<3, 0, 0>::from_coefficients(rhs_coeffs);
177
178 let result_mv = geometric_product_3d_avx2(&lhs_mv, &rhs_mv);
180
181 result_batch[result_offset..result_offset + COEFFS_PER_MV]
183 .copy_from_slice(result_mv.as_slice());
184 }
185}
186
187pub fn select_geometric_product_impl(
189) -> fn(&Multivector<3, 0, 0>, &Multivector<3, 0, 0>) -> Multivector<3, 0, 0> {
190 #[cfg(target_feature = "avx2")]
191 {
192 if is_x86_feature_detected!("avx2") {
193 return geometric_product_3d_avx2;
194 }
195 }
196
197 #[cfg(target_feature = "sse2")]
198 {
199 if is_x86_feature_detected!("sse2") {
200 return geometric_product_3d_sse2;
201 }
202 }
203
204 |lhs, rhs| lhs.geometric_product(rhs)
206}
207
208#[repr(C, align(32))]
210pub struct AlignedBuffer<const N: usize> {
211 pub data: [f64; N],
212}
213
214impl<const N: usize> AlignedBuffer<N> {
215 pub fn new() -> Self {
216 Self { data: [0.0; N] }
217 }
218
219 pub fn as_ptr(&self) -> *const f64 {
220 self.data.as_ptr()
221 }
222
223 pub fn as_mut_ptr(&mut self) -> *mut f64 {
224 self.data.as_mut_ptr()
225 }
226}
227
228impl<const N: usize> Default for AlignedBuffer<N> {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use crate::Multivector;
238 use approx::assert_relative_eq;
239
240 type Cl3 = Multivector<3, 0, 0>;
241
242 #[test]
243 #[cfg(target_feature = "avx2")]
244 fn test_simd_geometric_product_correctness() {
245 let e1 = Cl3::basis_vector(0);
246 let e2 = Cl3::basis_vector(1);
247
248 let scalar_result = e1.geometric_product(&e2);
250 let simd_result = geometric_product_3d_avx2(&e1, &e2);
251
252 for i in 0..8 {
253 assert_relative_eq!(scalar_result.get(i), simd_result.get(i), epsilon = 1e-14);
254 }
255 }
256
257 #[test]
258 fn test_aligned_buffer() {
259 let mut buffer = AlignedBuffer::<8>::new();
260 buffer.data[0] = 1.0;
261 assert_eq!(buffer.data[0], 1.0);
262
263 let ptr = buffer.as_ptr() as usize;
265 assert_eq!(ptr % 32, 0);
266 }
267
268 #[test]
269 #[ignore] fn test_runtime_feature_detection() {
271 let impl_fn = select_geometric_product_impl();
272
273 let e1 = Cl3::basis_vector(0);
274 let e2 = Cl3::basis_vector(1);
275 let result = impl_fn(&e1, &e2);
276
277 let expected = e1.geometric_product(&e2);
279 for i in 0..8 {
280 assert_relative_eq!(result.get(i), expected.get(i), epsilon = 1e-14);
281 }
282 }
283}