ligerito_binary_fields/
simd.rs

1// src/simd.rs
2use crate::poly::{BinaryPoly64, BinaryPoly128, BinaryPoly256};
3use crate::elem::BinaryElem32;
4
5// 64x64 -> 128 bit carryless multiplication
6pub fn carryless_mul_64(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
7    // x86_64 with PCLMULQDQ
8    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
9    {
10        use core::arch::x86_64::*;
11
12        unsafe {
13            let a_vec = _mm_set_epi64x(0, a.value() as i64);
14            let b_vec = _mm_set_epi64x(0, b.value() as i64);
15
16            let result = _mm_clmulepi64_si128(a_vec, b_vec, 0x00);
17
18            let lo = _mm_extract_epi64(result, 0) as u64;
19            let hi = _mm_extract_epi64(result, 1) as u64;
20
21            return BinaryPoly128::new(((hi as u128) << 64) | (lo as u128));
22        }
23    }
24
25    // WASM with SIMD128
26    #[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
27    {
28        return carryless_mul_64_wasm_simd(a, b);
29    }
30
31    // Software fallback for other platforms
32    #[cfg(not(any(
33        all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"),
34        all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128")
35    )))]
36    {
37        // software fallback
38        carryless_mul_64_soft(a, b)
39    }
40}
41
42// software implementation for 64x64
43#[allow(dead_code)]
44fn carryless_mul_64_soft(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
45    let mut result = 0u128;
46    let a_val = a.value();
47    let b_val = b.value();
48
49    for i in 0..64 {
50        let mask = 0u128.wrapping_sub(((b_val >> i) & 1) as u128);
51        result ^= ((a_val as u128) << i) & mask;
52    }
53
54    BinaryPoly128::new(result)
55}
56
57// WASM SIMD128 optimized implementation using Karatsuba decomposition
58#[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
59fn carryless_mul_64_wasm_simd(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
60    unsafe {
61        let a_val = a.value();
62        let b_val = b.value();
63
64        // Split into 32-bit halves for Karatsuba
65        let a_lo = (a_val & 0xFFFFFFFF) as u32;
66        let a_hi = (a_val >> 32) as u32;
67        let b_lo = (b_val & 0xFFFFFFFF) as u32;
68        let b_hi = (b_val >> 32) as u32;
69
70        // Use SIMD for parallel 32x32 multiplications
71        let z0 = mul_32x32_to_64_simd(a_lo, b_lo);
72        let z2 = mul_32x32_to_64_simd(a_hi, b_hi);
73        let z1 = mul_32x32_to_64_simd(a_lo ^ a_hi, b_lo ^ b_hi);
74
75        // Karatsuba combination: z1 = z1 ^ z0 ^ z2
76        let middle = z1 ^ z0 ^ z2;
77
78        // Combine: result = z0 + (middle << 32) + (z2 << 64)
79        let result_lo = z0 ^ (middle << 32);
80        let result_hi = (middle >> 32) ^ z2;
81
82        BinaryPoly128::new(((result_hi as u128) << 64) | (result_lo as u128))
83    }
84}
85
86// 32x32 -> 64 carryless multiplication optimized for WASM
87// Uses bit-slicing technique - process 4 bits at a time
88#[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
89#[inline(always)]
90unsafe fn mul_32x32_to_64_simd(a: u32, b: u32) -> u64 {
91    // Use 4-bit nibble table approach for better branch prediction
92    let mut result = 0u64;
93    let a64 = a as u64;
94
95    // Process b in 4-bit nibbles (8 nibbles for 32 bits)
96    // Manually unroll for better codegen
97    let nibble0 = b & 0xF;
98    if nibble0 & 1 != 0 { result ^= a64 << 0; }
99    if nibble0 & 2 != 0 { result ^= a64 << 1; }
100    if nibble0 & 4 != 0 { result ^= a64 << 2; }
101    if nibble0 & 8 != 0 { result ^= a64 << 3; }
102
103    let nibble1 = (b >> 4) & 0xF;
104    if nibble1 & 1 != 0 { result ^= a64 << 4; }
105    if nibble1 & 2 != 0 { result ^= a64 << 5; }
106    if nibble1 & 4 != 0 { result ^= a64 << 6; }
107    if nibble1 & 8 != 0 { result ^= a64 << 7; }
108
109    let nibble2 = (b >> 8) & 0xF;
110    if nibble2 & 1 != 0 { result ^= a64 << 8; }
111    if nibble2 & 2 != 0 { result ^= a64 << 9; }
112    if nibble2 & 4 != 0 { result ^= a64 << 10; }
113    if nibble2 & 8 != 0 { result ^= a64 << 11; }
114
115    let nibble3 = (b >> 12) & 0xF;
116    if nibble3 & 1 != 0 { result ^= a64 << 12; }
117    if nibble3 & 2 != 0 { result ^= a64 << 13; }
118    if nibble3 & 4 != 0 { result ^= a64 << 14; }
119    if nibble3 & 8 != 0 { result ^= a64 << 15; }
120
121    let nibble4 = (b >> 16) & 0xF;
122    if nibble4 & 1 != 0 { result ^= a64 << 16; }
123    if nibble4 & 2 != 0 { result ^= a64 << 17; }
124    if nibble4 & 4 != 0 { result ^= a64 << 18; }
125    if nibble4 & 8 != 0 { result ^= a64 << 19; }
126
127    let nibble5 = (b >> 20) & 0xF;
128    if nibble5 & 1 != 0 { result ^= a64 << 20; }
129    if nibble5 & 2 != 0 { result ^= a64 << 21; }
130    if nibble5 & 4 != 0 { result ^= a64 << 22; }
131    if nibble5 & 8 != 0 { result ^= a64 << 23; }
132
133    let nibble6 = (b >> 24) & 0xF;
134    if nibble6 & 1 != 0 { result ^= a64 << 24; }
135    if nibble6 & 2 != 0 { result ^= a64 << 25; }
136    if nibble6 & 4 != 0 { result ^= a64 << 26; }
137    if nibble6 & 8 != 0 { result ^= a64 << 27; }
138
139    let nibble7 = (b >> 28) & 0xF;
140    if nibble7 & 1 != 0 { result ^= a64 << 28; }
141    if nibble7 & 2 != 0 { result ^= a64 << 29; }
142    if nibble7 & 4 != 0 { result ^= a64 << 30; }
143    if nibble7 & 8 != 0 { result ^= a64 << 31; }
144
145    result
146}
147
148// 128x128 -> 128 bit carryless multiplication (truncated)
149pub fn carryless_mul_128(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly128 {
150    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
151    {
152        use core::arch::x86_64::*;
153
154        unsafe {
155            // split inputs into 64-bit halves
156            let a_lo = a.value() as u64;
157            let a_hi = (a.value() >> 64) as u64;
158            let b_lo = b.value() as u64;
159            let b_hi = (b.value() >> 64) as u64;
160
161            // perform 3 64x64->128 bit multiplications (skip hi*hi for truncated result)
162            let lo_lo = _mm_clmulepi64_si128(
163                _mm_set_epi64x(0, a_lo as i64),
164                _mm_set_epi64x(0, b_lo as i64),
165                0x00
166            );
167
168            let lo_hi = _mm_clmulepi64_si128(
169                _mm_set_epi64x(0, a_lo as i64),
170                _mm_set_epi64x(0, b_hi as i64),
171                0x00
172            );
173
174            let hi_lo = _mm_clmulepi64_si128(
175                _mm_set_epi64x(0, a_hi as i64),
176                _mm_set_epi64x(0, b_lo as i64),
177                0x00
178            );
179
180            // extract 128-bit results - fix the overflow by casting to u128 first
181            let r0 = (_mm_extract_epi64(lo_lo, 0) as u64) as u128
182                   | ((_mm_extract_epi64(lo_lo, 1) as u64) as u128) << 64;
183            let r1 = (_mm_extract_epi64(lo_hi, 0) as u64) as u128
184                   | ((_mm_extract_epi64(lo_hi, 1) as u64) as u128) << 64;
185            let r2 = (_mm_extract_epi64(hi_lo, 0) as u64) as u128
186                   | ((_mm_extract_epi64(hi_lo, 1) as u64) as u128) << 64;
187
188            // combine: result = r0 + (r1 << 64) + (r2 << 64)
189            let result = r0 ^ (r1 << 64) ^ (r2 << 64);
190
191            return BinaryPoly128::new(result);
192        }
193    }
194
195    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
196    {
197        // software fallback
198        carryless_mul_128_soft(a, b)
199    }
200}
201
202// software implementation for 128x128 truncated
203#[allow(dead_code)]
204fn carryless_mul_128_soft(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly128 {
205    let a_lo = a.value() as u64;
206    let a_hi = (a.value() >> 64) as u64;
207    let b_lo = b.value() as u64;
208    let b_hi = (b.value() >> 64) as u64;
209
210    let z0 = mul_64x64_to_128(a_lo, b_lo);
211    let z1 = mul_64x64_to_128(a_lo ^ a_hi, b_lo ^ b_hi);
212    let z2 = mul_64x64_to_128(a_hi, b_hi);
213
214    // karatsuba combination (truncated)
215    let result = z0 ^ (z1 << 64) ^ (z0 << 64) ^ (z2 << 64);
216    BinaryPoly128::new(result)
217}
218
219// 128x128 -> 256 bit full multiplication
220pub fn carryless_mul_128_full(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly256 {
221    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
222    {
223        use core::arch::x86_64::*;
224
225        unsafe {
226            let a_lo = a.value() as u64;
227            let a_hi = (a.value() >> 64) as u64;
228            let b_lo = b.value() as u64;
229            let b_hi = (b.value() >> 64) as u64;
230
231            // 4 multiplications
232            let lo_lo = _mm_clmulepi64_si128(
233                _mm_set_epi64x(0, a_lo as i64),
234                _mm_set_epi64x(0, b_lo as i64),
235                0x00
236            );
237
238            let lo_hi = _mm_clmulepi64_si128(
239                _mm_set_epi64x(0, a_lo as i64),
240                _mm_set_epi64x(0, b_hi as i64),
241                0x00
242            );
243
244            let hi_lo = _mm_clmulepi64_si128(
245                _mm_set_epi64x(0, a_hi as i64),
246                _mm_set_epi64x(0, b_lo as i64),
247                0x00
248            );
249
250            let hi_hi = _mm_clmulepi64_si128(
251                _mm_set_epi64x(0, a_hi as i64),
252                _mm_set_epi64x(0, b_hi as i64),
253                0x00
254            );
255
256            // extract and combine
257            let r0_lo = _mm_extract_epi64(lo_lo, 0) as u64;
258            let r0_hi = _mm_extract_epi64(lo_lo, 1) as u64;
259            let r1_lo = _mm_extract_epi64(lo_hi, 0) as u64;
260            let r1_hi = _mm_extract_epi64(lo_hi, 1) as u64;
261            let r2_lo = _mm_extract_epi64(hi_lo, 0) as u64;
262            let r2_hi = _mm_extract_epi64(hi_lo, 1) as u64;
263            let r3_lo = _mm_extract_epi64(hi_hi, 0) as u64;
264            let r3_hi = _mm_extract_epi64(hi_hi, 1) as u64;
265
266            // build 256-bit result
267            let mut lo = r0_lo as u128 | ((r0_hi as u128) << 64);
268            let mut hi = 0u128;
269
270            // add r1 << 64
271            lo ^= (r1_lo as u128) << 64;
272            hi ^= (r1_lo as u128) >> 64;
273            hi ^= r1_hi as u128;
274
275            // add r2 << 64
276            lo ^= (r2_lo as u128) << 64;
277            hi ^= (r2_lo as u128) >> 64;
278            hi ^= r2_hi as u128;
279
280            // add r3 << 128
281            hi ^= r3_lo as u128 | ((r3_hi as u128) << 64);
282
283            return BinaryPoly256::from_parts(hi, lo);
284        }
285    }
286
287    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
288    {
289        // software fallback
290        carryless_mul_128_full_soft(a, b)
291    }
292}
293
294// software implementation for 128x128 full
295#[allow(dead_code)]
296fn carryless_mul_128_full_soft(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly256 {
297    let a_lo = a.value() as u64;
298    let a_hi = (a.value() >> 64) as u64;
299    let b_lo = b.value() as u64;
300    let b_hi = (b.value() >> 64) as u64;
301
302    let z0 = mul_64x64_to_128(a_lo, b_lo);
303    let z2 = mul_64x64_to_128(a_hi, b_hi);
304    let z1 = mul_64x64_to_128(a_lo ^ a_hi, b_lo ^ b_hi) ^ z0 ^ z2;
305
306    // combine: result = z0 + (z1 << 64) + (z2 << 128)
307    let mut lo = z0;
308    let mut hi = 0u128;
309
310    // add z1 << 64
311    lo ^= z1 << 64;
312    hi ^= z1 >> 64;
313
314    // add z2 << 128
315    hi ^= z2;
316
317    BinaryPoly256::from_parts(hi, lo)
318}
319
320// helper: constant-time 64x64 -> 128
321#[inline(always)]
322#[allow(dead_code)]
323fn mul_64x64_to_128(a: u64, b: u64) -> u128 {
324    let mut result = 0u128;
325    let mut a_shifted = a as u128;
326
327    for i in 0..64 {
328        let mask = 0u128.wrapping_sub(((b >> i) & 1) as u128);
329        result ^= a_shifted & mask;
330        a_shifted <<= 1;
331    }
332
333    result
334}
335
336// batch field operations
337
338use crate::{BinaryElem128, BinaryFieldElement};
339
340/// batch multiply gf(2^128) elements with two-tier dispatch:
341/// hardware-accel → pclmulqdq, else → scalar
342pub fn batch_mul_gf128(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
343    assert_eq!(a.len(), b.len());
344    assert_eq!(a.len(), out.len());
345
346    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
347    {
348        return batch_mul_gf128_hw(a, b, out);
349    }
350
351    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
352    {
353        // scalar fallback
354        for i in 0..a.len() {
355            out[i] = a[i].mul(&b[i]);
356        }
357    }
358}
359
360/// batch add gf(2^128) elements (xor in gf(2^n))
361pub fn batch_add_gf128(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
362    assert_eq!(a.len(), b.len());
363    assert_eq!(a.len(), out.len());
364
365    // scalar fallback (XOR is already very fast)
366    for i in 0..a.len() {
367        out[i] = a[i].add(&b[i]);
368    }
369}
370
371// pclmulqdq-based batch multiply for x86_64
372#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
373fn batch_mul_gf128_hw(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
374    for i in 0..a.len() {
375        let a_poly = a[i].poly();
376        let b_poly = b[i].poly();
377        let product = carryless_mul_128_full(a_poly, b_poly);
378        let reduced = reduce_gf128(product);
379        out[i] = BinaryElem128::from_value(reduced.value());
380    }
381}
382
383/// reduce 256-bit product modulo GF(2^128) irreducible polynomial
384/// irreducible: x^128 + x^7 + x^2 + x + 1 (0x87 = 0b10000111)
385/// matches julia's @generated mod_irreducible (binaryfield.jl:73-114)
386#[inline(always)]
387pub fn reduce_gf128(product: BinaryPoly256) -> BinaryPoly128 {
388    let (hi, lo) = product.split();
389    let high = hi.value();
390    let low = lo.value();
391
392    // julia's compute_tmp for irreducible 0b10000111 (bits 0,1,2,7):
393    // for each set bit i in irreducible: tmp ^= hi >> (128 - i)
394    // bits set: 0, 1, 2, 7 -> shifts: 128, 127, 126, 121
395    let tmp = high ^ (high >> 127) ^ (high >> 126) ^ (high >> 121);
396
397    // julia's compute_res:
398    // for each set bit i in irreducible: res ^= tmp << i
399    // bits set: 0, 1, 2, 7 -> shifts: 0, 1, 2, 7
400    let res = low ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7);
401
402    BinaryPoly128::new(res)
403}
404
405// =========================================================================
406// BinaryElem32 batch operations - FFT optimization
407// =========================================================================
408
409/// AVX-512 vectorized FFT butterfly for GF(2^32)
410/// processes 4 elements at once using 256-bit vectors with VPCLMULQDQ
411#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
412pub fn fft_butterfly_gf32_avx512(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
413    // Runtime detection for vpclmulqdq (AVX-512 carryless multiply)
414    #[cfg(target_arch = "x86_64")]
415    {
416        if is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx512f") {
417            // SAFETY: we just checked the CPU supports these features
418            unsafe { fft_butterfly_gf32_avx512_impl(u, w, lambda) };
419            return;
420        }
421        // Fallback to SSE if AVX-512 not available
422        return fft_butterfly_gf32_sse(u, w, lambda);
423    }
424
425    #[cfg(not(target_arch = "x86_64"))]
426    {
427        fft_butterfly_gf32_scalar(u, w, lambda)
428    }
429}
430
431/// AVX-512 implementation (unsafe, requires feature detection)
432#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
433#[target_feature(enable = "avx512f,vpclmulqdq")]
434unsafe fn fft_butterfly_gf32_avx512_impl(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
435    use core::arch::x86_64::*;
436
437    assert_eq!(u.len(), w.len());
438    let len = u.len();
439
440    const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);
441
442    let lambda_val = lambda.poly().value() as u64;
443    let lambda_vec = _mm256_set1_epi64x(lambda_val as i64);
444
445    let mut i = 0;
446
447    // Process 4 elements at once (4x32-bit = 128 bits, fit in 256-bit lanes)
448    while i + 4 <= len {
449        // Load w[i..i+4] into 256-bit vector
450        let w0 = w[i].poly().value() as u64;
451        let w1 = w[i+1].poly().value() as u64;
452        let w2 = w[i+2].poly().value() as u64;
453        let w3 = w[i+3].poly().value() as u64;
454
455        let w_vec = _mm256_set_epi64x(w3 as i64, w2 as i64, w1 as i64, w0 as i64);
456
457        // Carryless multiply: lambda * w (4 parallel multiplications)
458        let prod = _mm256_clmulepi64_epi128(lambda_vec, w_vec, 0x00);
459
460        // Extract and reduce each 64-bit product
461        let p0 = _mm256_extract_epi64(prod, 0) as u64;
462        let p1 = _mm256_extract_epi64(prod, 1) as u64;
463        let p2 = _mm256_extract_epi64(prod, 2) as u64;
464        let p3 = _mm256_extract_epi64(prod, 3) as u64;
465
466        let lambda_w0 = reduce_gf32(p0, IRREDUCIBLE_32);
467        let lambda_w1 = reduce_gf32(p1, IRREDUCIBLE_32);
468        let lambda_w2 = reduce_gf32(p2, IRREDUCIBLE_32);
469        let lambda_w3 = reduce_gf32(p3, IRREDUCIBLE_32);
470
471        // u[i] = u[i] XOR lambda_w[i]
472        let u0 = u[i].poly().value() ^ (lambda_w0 as u32);
473        let u1 = u[i+1].poly().value() ^ (lambda_w1 as u32);
474        let u2 = u[i+2].poly().value() ^ (lambda_w2 as u32);
475        let u3 = u[i+3].poly().value() ^ (lambda_w3 as u32);
476
477        // w[i] = w[i] XOR u[i]
478        let w0_new = w[i].poly().value() ^ u0;
479        let w1_new = w[i+1].poly().value() ^ u1;
480        let w2_new = w[i+2].poly().value() ^ u2;
481        let w3_new = w[i+3].poly().value() ^ u3;
482
483        u[i] = BinaryElem32::from(u0);
484        u[i+1] = BinaryElem32::from(u1);
485        u[i+2] = BinaryElem32::from(u2);
486        u[i+3] = BinaryElem32::from(u3);
487        w[i] = BinaryElem32::from(w0_new);
488        w[i+1] = BinaryElem32::from(w1_new);
489        w[i+2] = BinaryElem32::from(w2_new);
490        w[i+3] = BinaryElem32::from(w3_new);
491
492        i += 4;
493    }
494
495    // Handle remaining elements with SSE
496    while i < len {
497        let lambda_w = lambda.mul(&w[i]);
498        u[i] = u[i].add(&lambda_w);
499        w[i] = w[i].add(&u[i]);
500        i += 1;
501    }
502}
503
504/// SSE vectorized FFT butterfly operation for GF(2^32)
505/// computes: u[i] = u[i] + lambda*w[i]; w[i] = w[i] + u[i]
506/// processes 2 elements at a time using SSE/AVX
507#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
508pub fn fft_butterfly_gf32_sse(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
509    use core::arch::x86_64::*;
510
511    assert_eq!(u.len(), w.len());
512    let len = u.len();
513
514    // irreducible polynomial for GF(2^32):
515    // x^32 + x^7 + x^9 + x^15 + x^3 + 1
516    const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);
517
518    unsafe {
519        let lambda_val = lambda.poly().value() as u64;
520        let lambda_vec = _mm_set1_epi64x(lambda_val as i64);
521
522        let mut i = 0;
523
524        // process 2 elements at once (2x32-bit = 64 bits, fits in one lane)
525        while i + 2 <= len {
526            // load w[i] and w[i+1] into 64-bit lanes
527            let w0 = w[i].poly().value() as u64;
528            let w1 = w[i+1].poly().value() as u64;
529            let w_vec = _mm_set_epi64x(w1 as i64, w0 as i64);
530
531            // carryless multiply: lambda * w[i]
532            let prod_lo = _mm_clmulepi64_si128(lambda_vec, w_vec, 0x00); // lambda * w0
533            let prod_hi = _mm_clmulepi64_si128(lambda_vec, w_vec, 0x11); // lambda * w1
534
535            // reduce modulo irreducible
536            let p0 = _mm_extract_epi64(prod_lo, 0) as u64;
537            let p1 = _mm_extract_epi64(prod_hi, 0) as u64;
538
539            let lambda_w0 = reduce_gf32(p0, IRREDUCIBLE_32);
540            let lambda_w1 = reduce_gf32(p1, IRREDUCIBLE_32);
541
542            // u[i] = u[i] XOR lambda_w[i]
543            let u0 = u[i].poly().value() ^ (lambda_w0 as u32);
544            let u1 = u[i+1].poly().value() ^ (lambda_w1 as u32);
545
546            // w[i] = w[i] XOR u[i] (using updated u)
547            let w0_new = w[i].poly().value() ^ u0;
548            let w1_new = w[i+1].poly().value() ^ u1;
549
550            u[i] = BinaryElem32::from(u0);
551            u[i+1] = BinaryElem32::from(u1);
552            w[i] = BinaryElem32::from(w0_new);
553            w[i+1] = BinaryElem32::from(w1_new);
554
555            i += 2;
556        }
557
558        // handle remaining element
559        if i < len {
560            let lambda_w = lambda.mul(&w[i]);
561            u[i] = u[i].add(&lambda_w);
562            w[i] = w[i].add(&u[i]);
563        }
564    }
565}
566
567/// reduce 64-bit product modulo GF(2^32) irreducible
568/// optimized branchless reduction for GF(2^32)
569#[inline(always)]
570fn reduce_gf32(p: u64, _irr: u64) -> u64 {
571    // for 32x32 -> 64 multiplication, we need to reduce bits [63:32]
572    // unrolled reduction: process high 32 bits in chunks
573
574    let hi = (p >> 32) as u64;
575    let lo = (p & 0xFFFFFFFF) as u64;
576
577    // compute tmp by shifting high bits down
578    // for irreducible 0b1_0000_1000_1001_1000_1001 (x^32 + x^15 + x^9 + x^7 + x^3 + 1)
579    // bits set at positions: 0,3,7,9,15 -> shifts needed: 32,29,25,23,17
580    let tmp = hi
581        ^ (hi >> 29)  // bit 15: shift by (32-3)
582        ^ (hi >> 25)  // bit 9: shift by (32-7)
583        ^ (hi >> 23)  // bit 7: shift by (32-9)
584        ^ (hi >> 17); // bit 3: shift by (32-15)
585
586    // XOR with low bits and shifted tmp
587    lo ^ tmp ^ (tmp << 3) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
588}
589
590/// scalar fallback for FFT butterfly
591pub fn fft_butterfly_gf32_scalar(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
592    assert_eq!(u.len(), w.len());
593
594    for i in 0..u.len() {
595        let lambda_w = lambda.mul(&w[i]);
596        u[i] = u[i].add(&lambda_w);
597        w[i] = w[i].add(&u[i]);
598    }
599}
600
601/// dispatch FFT butterfly to best available SIMD version
602pub fn fft_butterfly_gf32(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
603    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
604    {
605        // Try AVX-512 first (runtime detection), fallback to SSE
606        return fft_butterfly_gf32_avx512(u, w, lambda);
607    }
608
609    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
610    {
611        fft_butterfly_gf32_scalar(u, w, lambda)
612    }
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618
619    #[test]
620    fn test_fft_butterfly_gf32() {
621        // test SIMD vs scalar butterfly give same results
622        let mut u_simd = vec![
623            BinaryElem32::from(1),
624            BinaryElem32::from(2),
625            BinaryElem32::from(3),
626            BinaryElem32::from(4),
627        ];
628        let mut w_simd = vec![
629            BinaryElem32::from(5),
630            BinaryElem32::from(6),
631            BinaryElem32::from(7),
632            BinaryElem32::from(8),
633        ];
634        let lambda = BinaryElem32::from(3);
635
636        let mut u_scalar = u_simd.clone();
637        let mut w_scalar = w_simd.clone();
638
639        fft_butterfly_gf32(&mut u_simd, &mut w_simd, lambda);
640        fft_butterfly_gf32_scalar(&mut u_scalar, &mut w_scalar, lambda);
641
642        for i in 0..u_simd.len() {
643            assert_eq!(u_simd[i], u_scalar[i], "u mismatch at index {}", i);
644            assert_eq!(w_simd[i], w_scalar[i], "w mismatch at index {}", i);
645        }
646    }
647
648    #[test]
649    fn test_batch_add() {
650        let a = vec![
651            BinaryElem128::from(1),
652            BinaryElem128::from(2),
653            BinaryElem128::from(3),
654        ];
655        let b = vec![
656            BinaryElem128::from(4),
657            BinaryElem128::from(5),
658            BinaryElem128::from(6),
659        ];
660        let mut out = vec![BinaryElem128::zero(); 3];
661
662        batch_add_gf128(&a, &b, &mut out);
663
664        for i in 0..3 {
665            assert_eq!(out[i], a[i].add(&b[i]));
666        }
667    }
668
669    #[test]
670    fn test_batch_mul() {
671        let a = vec![
672            BinaryElem128::from(7),
673            BinaryElem128::from(11),
674            BinaryElem128::from(13),
675        ];
676        let b = vec![
677            BinaryElem128::from(3),
678            BinaryElem128::from(5),
679            BinaryElem128::from(7),
680        ];
681        let mut out = vec![BinaryElem128::zero(); 3];
682
683        batch_mul_gf128(&a, &b, &mut out);
684
685        for i in 0..3 {
686            assert_eq!(out[i], a[i].mul(&b[i]));
687        }
688    }
689
690    #[test]
691    fn test_batch_mul_large() {
692        // test with larger field elements
693        let a = vec![
694            BinaryElem128::from(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0),
695            BinaryElem128::from(u128::MAX),
696        ];
697        let b = vec![
698            BinaryElem128::from(0x123456789ABCDEF0123456789ABCDEF0),
699            BinaryElem128::from(0x8000000000000000_0000000000000000),
700        ];
701        let mut out = vec![BinaryElem128::zero(); 2];
702
703        batch_mul_gf128(&a, &b, &mut out);
704
705        for i in 0..2 {
706            assert_eq!(out[i], a[i].mul(&b[i]));
707        }
708    }
709}