ligerito_binary_fields/
simd.rs

1// src/simd.rs
2use crate::poly::{BinaryPoly64, BinaryPoly128, BinaryPoly256};
3use crate::elem::BinaryElem32;
4
5// 64x64 -> 128 bit carryless multiplication
6pub fn carryless_mul_64(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
7    // x86_64 with PCLMULQDQ
8    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
9    {
10        use core::arch::x86_64::*;
11
12        unsafe {
13            let a_vec = _mm_set_epi64x(0, a.value() as i64);
14            let b_vec = _mm_set_epi64x(0, b.value() as i64);
15
16            let result = _mm_clmulepi64_si128(a_vec, b_vec, 0x00);
17
18            let lo = _mm_extract_epi64(result, 0) as u64;
19            let hi = _mm_extract_epi64(result, 1) as u64;
20
21            return BinaryPoly128::new(((hi as u128) << 64) | (lo as u128));
22        }
23    }
24
25    // WASM with SIMD128
26    #[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
27    {
28        return carryless_mul_64_wasm_simd(a, b);
29    }
30
31    // Software fallback for other platforms
32    #[cfg(not(any(
33        all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"),
34        all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128")
35    )))]
36    {
37        // software fallback
38        carryless_mul_64_soft(a, b)
39    }
40}
41
42// software implementation for 64x64 using lookup tables
43#[allow(dead_code)]
44fn carryless_mul_64_soft(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
45    let a_val = a.value();
46    let b_val = b.value();
47
48    // Split into 32-bit halves for Karatsuba
49    let a_lo = (a_val & 0xFFFFFFFF) as u32;
50    let a_hi = (a_val >> 32) as u32;
51    let b_lo = (b_val & 0xFFFFFFFF) as u32;
52    let b_hi = (b_val >> 32) as u32;
53
54    // Karatsuba multiplication with lookup tables
55    let z0 = mul_32x32_to_64_lut(a_lo, b_lo);
56    let z2 = mul_32x32_to_64_lut(a_hi, b_hi);
57    let z1 = mul_32x32_to_64_lut(a_lo ^ a_hi, b_lo ^ b_hi);
58
59    // Karatsuba combination
60    let middle = z1 ^ z0 ^ z2;
61    let result_lo = z0 ^ (middle << 32);
62    let result_hi = (middle >> 32) ^ z2;
63
64    BinaryPoly128::new(((result_hi as u128) << 64) | (result_lo as u128))
65}
66
67// WASM SIMD128 optimized implementation using Karatsuba decomposition
68#[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
69fn carryless_mul_64_wasm_simd(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
70    unsafe {
71        let a_val = a.value();
72        let b_val = b.value();
73
74        // Split into 32-bit halves for Karatsuba
75        let a_lo = (a_val & 0xFFFFFFFF) as u32;
76        let a_hi = (a_val >> 32) as u32;
77        let b_lo = (b_val & 0xFFFFFFFF) as u32;
78        let b_hi = (b_val >> 32) as u32;
79
80        // Use SIMD for parallel 32x32 multiplications
81        let z0 = mul_32x32_to_64_simd(a_lo, b_lo);
82        let z2 = mul_32x32_to_64_simd(a_hi, b_hi);
83        let z1 = mul_32x32_to_64_simd(a_lo ^ a_hi, b_lo ^ b_hi);
84
85        // Karatsuba combination: z1 = z1 ^ z0 ^ z2
86        let middle = z1 ^ z0 ^ z2;
87
88        // Combine: result = z0 + (middle << 32) + (z2 << 64)
89        let result_lo = z0 ^ (middle << 32);
90        let result_hi = (middle >> 32) ^ z2;
91
92        BinaryPoly128::new(((result_hi as u128) << 64) | (result_lo as u128))
93    }
94}
95
96// 4-bit x 4-bit carryless multiplication lookup table
97// Entry [a][b] = a * b in GF(2) polynomial multiplication (no reduction)
98// This is the core building block for branchless carryless multiply
99static CLMUL_4X4: [[u8; 16]; 16] = {
100    let mut table = [[0u8; 16]; 16];
101    let mut a = 0usize;
102    while a < 16 {
103        let mut b = 0usize;
104        while b < 16 {
105            // Compute a * b carryless (no branches)
106            let mut result = 0u8;
107            let mut i = 0;
108            while i < 4 {
109                // If bit i of b is set, XOR a << i into result
110                let mask = ((b >> i) & 1) as u8;
111                result ^= ((a as u8) << i) * mask;
112                i += 1;
113            }
114            table[a][b] = result;
115            b += 1;
116        }
117        a += 1;
118    }
119    table
120};
121
122// 32x32 -> 64 carryless multiplication using WASM SIMD128 i8x16_swizzle
123// Uses swizzle for parallel 16-way table lookups - 4x faster than scalar
124#[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
125#[inline(always)]
126unsafe fn mul_32x32_to_64_simd(a: u32, b: u32) -> u64 {
127    use core::arch::wasm32::*;
128
129    // Extract 4-bit nibbles from both operands
130    let a_nibbles: [usize; 8] = [
131        (a & 0xF) as usize,
132        ((a >> 4) & 0xF) as usize,
133        ((a >> 8) & 0xF) as usize,
134        ((a >> 12) & 0xF) as usize,
135        ((a >> 16) & 0xF) as usize,
136        ((a >> 20) & 0xF) as usize,
137        ((a >> 24) & 0xF) as usize,
138        ((a >> 28) & 0xF) as usize,
139    ];
140
141    // Create b_nibbles index vector for swizzle (replicate to fill 16 lanes)
142    let b0 = (b & 0xF) as u8;
143    let b1 = ((b >> 4) & 0xF) as u8;
144    let b2 = ((b >> 8) & 0xF) as u8;
145    let b3 = ((b >> 12) & 0xF) as u8;
146    let b4 = ((b >> 16) & 0xF) as u8;
147    let b5 = ((b >> 20) & 0xF) as u8;
148    let b6 = ((b >> 24) & 0xF) as u8;
149    let b7 = ((b >> 28) & 0xF) as u8;
150
151    // Index vector: b_nibbles[0..7] in first 8 lanes, zeros in upper 8 (unused)
152    let b_indices = u8x16(b0, b1, b2, b3, b4, b5, b6, b7, 0, 0, 0, 0, 0, 0, 0, 0);
153
154    let mut result = 0u64;
155
156    // For each a_nibble, load its CLMUL_4X4 row and swizzle with b_indices
157    // This computes all 8 products a_nibble[i] * b_nibble[j] in parallel
158    for i in 0..8 {
159        // Load the 16-byte lookup table row for this a_nibble
160        let table_row = v128_load(CLMUL_4X4[a_nibbles[i]].as_ptr() as *const v128);
161
162        // Swizzle: products[j] = CLMUL_4X4[a_nibbles[i]][b_nibbles[j]]
163        let products = i8x16_swizzle(table_row, b_indices);
164
165        // Extract the 8 products and accumulate with proper shifts
166        // Each product at position j contributes at bit position (i+j)*4
167        let p0 = u8x16_extract_lane::<0>(products) as u64;
168        let p1 = u8x16_extract_lane::<1>(products) as u64;
169        let p2 = u8x16_extract_lane::<2>(products) as u64;
170        let p3 = u8x16_extract_lane::<3>(products) as u64;
171        let p4 = u8x16_extract_lane::<4>(products) as u64;
172        let p5 = u8x16_extract_lane::<5>(products) as u64;
173        let p6 = u8x16_extract_lane::<6>(products) as u64;
174        let p7 = u8x16_extract_lane::<7>(products) as u64;
175
176        let base_shift = i * 4;
177        result ^= p0 << base_shift;
178        result ^= p1 << (base_shift + 4);
179        result ^= p2 << (base_shift + 8);
180        result ^= p3 << (base_shift + 12);
181        result ^= p4 << (base_shift + 16);
182        result ^= p5 << (base_shift + 20);
183        result ^= p6 << (base_shift + 24);
184        result ^= p7 << (base_shift + 28);
185    }
186
187    result
188}
189
190// Lookup table based 32x32 carryless multiply (also used as software fallback)
191#[inline(always)]
192fn mul_32x32_to_64_lut(a: u32, b: u32) -> u64 {
193    // Split a and b into 4-bit nibbles
194    let a_nibbles: [usize; 8] = [
195        (a & 0xF) as usize,
196        ((a >> 4) & 0xF) as usize,
197        ((a >> 8) & 0xF) as usize,
198        ((a >> 12) & 0xF) as usize,
199        ((a >> 16) & 0xF) as usize,
200        ((a >> 20) & 0xF) as usize,
201        ((a >> 24) & 0xF) as usize,
202        ((a >> 28) & 0xF) as usize,
203    ];
204
205    let b_nibbles: [usize; 8] = [
206        (b & 0xF) as usize,
207        ((b >> 4) & 0xF) as usize,
208        ((b >> 8) & 0xF) as usize,
209        ((b >> 12) & 0xF) as usize,
210        ((b >> 16) & 0xF) as usize,
211        ((b >> 20) & 0xF) as usize,
212        ((b >> 24) & 0xF) as usize,
213        ((b >> 28) & 0xF) as usize,
214    ];
215
216    // Schoolbook multiplication with 4-bit chunks
217    // Each a_nibble[i] * b_nibble[j] contributes at bit position (i+j)*4
218    let mut result = 0u64;
219
220    // Unrolled for performance
221    for i in 0..8 {
222        for j in 0..8 {
223            let prod = CLMUL_4X4[a_nibbles[i]][b_nibbles[j]] as u64;
224            result ^= prod << ((i + j) * 4);
225        }
226    }
227
228    result
229}
230
231// 128x128 -> 128 bit carryless multiplication (truncated)
232pub fn carryless_mul_128(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly128 {
233    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
234    {
235        use core::arch::x86_64::*;
236
237        unsafe {
238            // split inputs into 64-bit halves
239            let a_lo = a.value() as u64;
240            let a_hi = (a.value() >> 64) as u64;
241            let b_lo = b.value() as u64;
242            let b_hi = (b.value() >> 64) as u64;
243
244            // perform 3 64x64->128 bit multiplications (skip hi*hi for truncated result)
245            let lo_lo = _mm_clmulepi64_si128(
246                _mm_set_epi64x(0, a_lo as i64),
247                _mm_set_epi64x(0, b_lo as i64),
248                0x00
249            );
250
251            let lo_hi = _mm_clmulepi64_si128(
252                _mm_set_epi64x(0, a_lo as i64),
253                _mm_set_epi64x(0, b_hi as i64),
254                0x00
255            );
256
257            let hi_lo = _mm_clmulepi64_si128(
258                _mm_set_epi64x(0, a_hi as i64),
259                _mm_set_epi64x(0, b_lo as i64),
260                0x00
261            );
262
263            // extract 128-bit results - fix the overflow by casting to u128 first
264            let r0 = (_mm_extract_epi64(lo_lo, 0) as u64) as u128
265                   | ((_mm_extract_epi64(lo_lo, 1) as u64) as u128) << 64;
266            let r1 = (_mm_extract_epi64(lo_hi, 0) as u64) as u128
267                   | ((_mm_extract_epi64(lo_hi, 1) as u64) as u128) << 64;
268            let r2 = (_mm_extract_epi64(hi_lo, 0) as u64) as u128
269                   | ((_mm_extract_epi64(hi_lo, 1) as u64) as u128) << 64;
270
271            // combine: result = r0 + (r1 << 64) + (r2 << 64)
272            let result = r0 ^ (r1 << 64) ^ (r2 << 64);
273
274            return BinaryPoly128::new(result);
275        }
276    }
277
278    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
279    {
280        // software fallback
281        carryless_mul_128_soft(a, b)
282    }
283}
284
285// software implementation for 128x128 truncated
286#[allow(dead_code)]
287fn carryless_mul_128_soft(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly128 {
288    let a_lo = a.value() as u64;
289    let a_hi = (a.value() >> 64) as u64;
290    let b_lo = b.value() as u64;
291    let b_hi = (b.value() >> 64) as u64;
292
293    let z0 = mul_64x64_to_128(a_lo, b_lo);
294    let z1 = mul_64x64_to_128(a_lo ^ a_hi, b_lo ^ b_hi);
295    let z2 = mul_64x64_to_128(a_hi, b_hi);
296
297    // karatsuba combination (truncated)
298    let result = z0 ^ (z1 << 64) ^ (z0 << 64) ^ (z2 << 64);
299    BinaryPoly128::new(result)
300}
301
302// 128x128 -> 256 bit full multiplication
303pub fn carryless_mul_128_full(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly256 {
304    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
305    {
306        use core::arch::x86_64::*;
307
308        unsafe {
309            let a_lo = a.value() as u64;
310            let a_hi = (a.value() >> 64) as u64;
311            let b_lo = b.value() as u64;
312            let b_hi = (b.value() >> 64) as u64;
313
314            // 4 multiplications
315            let lo_lo = _mm_clmulepi64_si128(
316                _mm_set_epi64x(0, a_lo as i64),
317                _mm_set_epi64x(0, b_lo as i64),
318                0x00
319            );
320
321            let lo_hi = _mm_clmulepi64_si128(
322                _mm_set_epi64x(0, a_lo as i64),
323                _mm_set_epi64x(0, b_hi as i64),
324                0x00
325            );
326
327            let hi_lo = _mm_clmulepi64_si128(
328                _mm_set_epi64x(0, a_hi as i64),
329                _mm_set_epi64x(0, b_lo as i64),
330                0x00
331            );
332
333            let hi_hi = _mm_clmulepi64_si128(
334                _mm_set_epi64x(0, a_hi as i64),
335                _mm_set_epi64x(0, b_hi as i64),
336                0x00
337            );
338
339            // extract and combine
340            let r0_lo = _mm_extract_epi64(lo_lo, 0) as u64;
341            let r0_hi = _mm_extract_epi64(lo_lo, 1) as u64;
342            let r1_lo = _mm_extract_epi64(lo_hi, 0) as u64;
343            let r1_hi = _mm_extract_epi64(lo_hi, 1) as u64;
344            let r2_lo = _mm_extract_epi64(hi_lo, 0) as u64;
345            let r2_hi = _mm_extract_epi64(hi_lo, 1) as u64;
346            let r3_lo = _mm_extract_epi64(hi_hi, 0) as u64;
347            let r3_hi = _mm_extract_epi64(hi_hi, 1) as u64;
348
349            // build 256-bit result
350            let mut lo = r0_lo as u128 | ((r0_hi as u128) << 64);
351            let mut hi = 0u128;
352
353            // add r1 << 64
354            lo ^= (r1_lo as u128) << 64;
355            hi ^= (r1_lo as u128) >> 64;
356            hi ^= r1_hi as u128;
357
358            // add r2 << 64
359            lo ^= (r2_lo as u128) << 64;
360            hi ^= (r2_lo as u128) >> 64;
361            hi ^= r2_hi as u128;
362
363            // add r3 << 128
364            hi ^= r3_lo as u128 | ((r3_hi as u128) << 64);
365
366            return BinaryPoly256::from_parts(hi, lo);
367        }
368    }
369
370    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
371    {
372        // software fallback
373        carryless_mul_128_full_soft(a, b)
374    }
375}
376
377// software implementation for 128x128 full
378#[allow(dead_code)]
379fn carryless_mul_128_full_soft(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly256 {
380    let a_lo = a.value() as u64;
381    let a_hi = (a.value() >> 64) as u64;
382    let b_lo = b.value() as u64;
383    let b_hi = (b.value() >> 64) as u64;
384
385    let z0 = mul_64x64_to_128(a_lo, b_lo);
386    let z2 = mul_64x64_to_128(a_hi, b_hi);
387    let z1 = mul_64x64_to_128(a_lo ^ a_hi, b_lo ^ b_hi) ^ z0 ^ z2;
388
389    // combine: result = z0 + (z1 << 64) + (z2 << 128)
390    let mut lo = z0;
391    let mut hi = 0u128;
392
393    // add z1 << 64
394    lo ^= z1 << 64;
395    hi ^= z1 >> 64;
396
397    // add z2 << 128
398    hi ^= z2;
399
400    BinaryPoly256::from_parts(hi, lo)
401}
402
403// helper: constant-time 64x64 -> 128
404#[inline(always)]
405#[allow(dead_code)]
406fn mul_64x64_to_128(a: u64, b: u64) -> u128 {
407    let mut result = 0u128;
408    let mut a_shifted = a as u128;
409
410    for i in 0..64 {
411        let mask = 0u128.wrapping_sub(((b >> i) & 1) as u128);
412        result ^= a_shifted & mask;
413        a_shifted <<= 1;
414    }
415
416    result
417}
418
419// batch field operations
420
421use crate::{BinaryElem128, BinaryFieldElement};
422
423/// batch multiply gf(2^128) elements with two-tier dispatch:
424/// hardware-accel → pclmulqdq, else → scalar
425pub fn batch_mul_gf128(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
426    assert_eq!(a.len(), b.len());
427    assert_eq!(a.len(), out.len());
428
429    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
430    {
431        return batch_mul_gf128_hw(a, b, out);
432    }
433
434    #[cfg(not(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq")))]
435    {
436        // scalar fallback
437        for i in 0..a.len() {
438            out[i] = a[i].mul(&b[i]);
439        }
440    }
441}
442
443/// batch add gf(2^128) elements (xor in gf(2^n))
444pub fn batch_add_gf128(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
445    assert_eq!(a.len(), b.len());
446    assert_eq!(a.len(), out.len());
447
448    // scalar fallback (XOR is already very fast)
449    for i in 0..a.len() {
450        out[i] = a[i].add(&b[i]);
451    }
452}
453
454// pclmulqdq-based batch multiply for x86_64
455#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
456fn batch_mul_gf128_hw(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
457    for i in 0..a.len() {
458        let a_poly = a[i].poly();
459        let b_poly = b[i].poly();
460        let product = carryless_mul_128_full(a_poly, b_poly);
461        let reduced = reduce_gf128(product);
462        out[i] = BinaryElem128::from_value(reduced.value());
463    }
464}
465
466/// reduce 256-bit product modulo GF(2^128) irreducible polynomial
467/// irreducible: x^128 + x^7 + x^2 + x + 1 (0x87 = 0b10000111)
468/// matches julia's @generated mod_irreducible (binaryfield.jl:73-114)
469#[inline(always)]
470pub fn reduce_gf128(product: BinaryPoly256) -> BinaryPoly128 {
471    let (hi, lo) = product.split();
472    let high = hi.value();
473    let low = lo.value();
474
475    // julia's compute_tmp for irreducible 0b10000111 (bits 0,1,2,7):
476    // for each set bit i in irreducible: tmp ^= hi >> (128 - i)
477    // bits set: 0, 1, 2, 7 -> shifts: 128, 127, 126, 121
478    let tmp = high ^ (high >> 127) ^ (high >> 126) ^ (high >> 121);
479
480    // julia's compute_res:
481    // for each set bit i in irreducible: res ^= tmp << i
482    // bits set: 0, 1, 2, 7 -> shifts: 0, 1, 2, 7
483    let res = low ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7);
484
485    BinaryPoly128::new(res)
486}
487
488// =========================================================================
489// BinaryElem32 batch operations - FFT optimization
490// =========================================================================
491
492/// Vectorized FFT butterfly for GF(2^32) with tiered fallback
493/// Tries: AVX-512 (8 elements) -> AVX2 (4 elements) -> SSE (2 elements) -> scalar
494#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
495pub fn fft_butterfly_gf32_avx512(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
496    #[cfg(target_arch = "x86_64")]
497    {
498        // Tier 1: Try AVX-512 with 512-bit VPCLMULQDQ (8 elements at once)
499        if is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx512f") {
500            // SAFETY: we just checked the CPU supports these features
501            unsafe { fft_butterfly_gf32_avx512_impl(u, w, lambda) };
502            return;
503        }
504
505        // Tier 2: Try AVX2 with 256-bit VPCLMULQDQ (4 elements at once)
506        if is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx2") {
507            // SAFETY: we just checked the CPU supports these features
508            unsafe { fft_butterfly_gf32_avx2_impl(u, w, lambda) };
509            return;
510        }
511
512        // Tier 3: Fall back to SSE with 128-bit PCLMULQDQ (2 elements at once)
513        return fft_butterfly_gf32_sse(u, w, lambda);
514    }
515
516    #[cfg(not(target_arch = "x86_64"))]
517    {
518        fft_butterfly_gf32_scalar(u, w, lambda)
519    }
520}
521
522/// AVX-512 implementation using 512-bit VPCLMULQDQ
523/// Processes 8 elements at once using full 512-bit vectors
524/// Requires Rust 1.89+ for _mm512_extracti64x4_epi64
525#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
526#[target_feature(enable = "avx512f,vpclmulqdq")]
527unsafe fn fft_butterfly_gf32_avx512_impl(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
528    use core::arch::x86_64::*;
529
530    assert_eq!(u.len(), w.len());
531    let len = u.len();
532
533    let lambda_val = lambda.poly().value() as u64;
534    // Broadcast lambda to all 8 lanes of 512-bit vector
535    let lambda_512 = _mm512_set1_epi64(lambda_val as i64);
536
537    let mut i = 0;
538
539    // Process 8 elements at once using 512-bit vectors
540    // VPCLMULQDQ on 512-bit does 4 clmuls per instruction (one per 128-bit lane)
541    // We pack elements: [w0, w1] [w2, w3] [w4, w5] [w6, w7] in 4 x 128-bit lanes
542    while i + 8 <= len {
543        // Load 8 w elements into 512-bit vector
544        // Each 128-bit lane holds 2 elements for clmul
545        let w_512 = _mm512_set_epi64(
546            w[i+7].poly().value() as i64, w[i+6].poly().value() as i64,
547            w[i+5].poly().value() as i64, w[i+4].poly().value() as i64,
548            w[i+3].poly().value() as i64, w[i+2].poly().value() as i64,
549            w[i+1].poly().value() as i64, w[i].poly().value() as i64,
550        );
551
552        // VPCLMULQDQ selector 0x00: multiply low 64-bits of each 128-bit lane
553        // This gives us: lambda*w[0], lambda*w[2], lambda*w[4], lambda*w[6]
554        let prod_even = _mm512_clmulepi64_epi128(lambda_512, w_512, 0x00);
555
556        // VPCLMULQDQ selector 0x01: multiply low of first operand with high of second
557        // This gives us: lambda*w[1], lambda*w[3], lambda*w[5], lambda*w[7]
558        let prod_odd = _mm512_clmulepi64_epi128(lambda_512, w_512, 0x01);
559
560        // Extract 256-bit halves using _mm512_extracti64x4_epi64 (Rust 1.89+)
561        let prod_even_lo: __m256i = _mm512_extracti64x4_epi64::<0>(prod_even);
562        let prod_even_hi: __m256i = _mm512_extracti64x4_epi64::<1>(prod_even);
563        let prod_odd_lo: __m256i = _mm512_extracti64x4_epi64::<0>(prod_odd);
564        let prod_odd_hi: __m256i = _mm512_extracti64x4_epi64::<1>(prod_odd);
565
566        // Extract individual 64-bit products
567        let p0 = _mm256_extract_epi64::<0>(prod_even_lo) as u64; // lambda * w[0]
568        let p2 = _mm256_extract_epi64::<2>(prod_even_lo) as u64; // lambda * w[2]
569        let p4 = _mm256_extract_epi64::<0>(prod_even_hi) as u64; // lambda * w[4]
570        let p6 = _mm256_extract_epi64::<2>(prod_even_hi) as u64; // lambda * w[6]
571
572        let p1 = _mm256_extract_epi64::<0>(prod_odd_lo) as u64;  // lambda * w[1]
573        let p3 = _mm256_extract_epi64::<2>(prod_odd_lo) as u64;  // lambda * w[3]
574        let p5 = _mm256_extract_epi64::<0>(prod_odd_hi) as u64;  // lambda * w[5]
575        let p7 = _mm256_extract_epi64::<2>(prod_odd_hi) as u64;  // lambda * w[7]
576
577        // Reduce all 8 products
578        let lw = [
579            reduce_gf32_inline(p0) as u32, reduce_gf32_inline(p1) as u32,
580            reduce_gf32_inline(p2) as u32, reduce_gf32_inline(p3) as u32,
581            reduce_gf32_inline(p4) as u32, reduce_gf32_inline(p5) as u32,
582            reduce_gf32_inline(p6) as u32, reduce_gf32_inline(p7) as u32,
583        ];
584
585        // u[i] = u[i] XOR lambda_w[i], then w[i] = w[i] XOR u[i]
586        for j in 0..8 {
587            let u_val = u[i+j].poly().value() ^ lw[j];
588            let w_val = w[i+j].poly().value() ^ u_val;
589            u[i+j] = BinaryElem32::from(u_val);
590            w[i+j] = BinaryElem32::from(w_val);
591        }
592
593        i += 8;
594    }
595
596    // Handle remaining elements with scalar
597    while i < len {
598        let lambda_w = lambda.mul(&w[i]);
599        u[i] = u[i].add(&lambda_w);
600        w[i] = w[i].add(&u[i]);
601        i += 1;
602    }
603}
604
605/// Inline reduction for GF(2^32) - branchless
606#[inline(always)]
607fn reduce_gf32_inline(p: u64) -> u64 {
608    let hi = p >> 32;
609    let lo = p & 0xFFFFFFFF;
610
611    // Irreducible: x^32 + x^15 + x^9 + x^7 + x^4 + x^3 + 1
612    // Compute reduction using shift pattern
613    let tmp = hi ^ (hi >> 17) ^ (hi >> 23) ^ (hi >> 25) ^ (hi >> 28) ^ (hi >> 29);
614    lo ^ tmp ^ (tmp << 3) ^ (tmp << 4) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
615}
616
617/// AVX2 vectorized FFT butterfly operation for GF(2^32)
618/// Processes 4 elements at once using 256-bit VPCLMULQDQ
619/// For CPUs with AVX2 but not AVX-512
620#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
621#[target_feature(enable = "avx2,vpclmulqdq")]
622unsafe fn fft_butterfly_gf32_avx2_impl(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
623    use core::arch::x86_64::*;
624
625    assert_eq!(u.len(), w.len());
626    let len = u.len();
627
628    let lambda_val = lambda.poly().value() as u64;
629    // Broadcast lambda to both 128-bit lanes
630    let lambda_256 = _mm256_set1_epi64x(lambda_val as i64);
631
632    let mut i = 0;
633
634    // Process 4 elements at once using 256-bit vectors
635    // VPCLMULQDQ on 256-bit does 2 clmuls per instruction (one per 128-bit lane)
636    while i + 4 <= len {
637        // Load 4 w elements: [w0, w1] in low lane, [w2, w3] in high lane
638        let w_256 = _mm256_set_epi64x(
639            w[i+3].poly().value() as i64, w[i+2].poly().value() as i64,
640            w[i+1].poly().value() as i64, w[i].poly().value() as i64,
641        );
642
643        // Selector 0x00: multiply low 64-bits of each 128-bit lane
644        // Gives: lambda*w[0], lambda*w[2]
645        let prod_even = _mm256_clmulepi64_epi128(lambda_256, w_256, 0x00);
646
647        // Selector 0x01: multiply low of first with high of second
648        // Gives: lambda*w[1], lambda*w[3]
649        let prod_odd = _mm256_clmulepi64_epi128(lambda_256, w_256, 0x01);
650
651        // Extract products
652        let p0 = _mm256_extract_epi64::<0>(prod_even) as u64;
653        let p2 = _mm256_extract_epi64::<2>(prod_even) as u64;
654        let p1 = _mm256_extract_epi64::<0>(prod_odd) as u64;
655        let p3 = _mm256_extract_epi64::<2>(prod_odd) as u64;
656
657        // Reduce all 4 products
658        let lw = [
659            reduce_gf32_inline(p0) as u32,
660            reduce_gf32_inline(p1) as u32,
661            reduce_gf32_inline(p2) as u32,
662            reduce_gf32_inline(p3) as u32,
663        ];
664
665        // u[i] = u[i] XOR lambda_w[i], then w[i] = w[i] XOR u[i]
666        for j in 0..4 {
667            let u_val = u[i+j].poly().value() ^ lw[j];
668            let w_val = w[i+j].poly().value() ^ u_val;
669            u[i+j] = BinaryElem32::from(u_val);
670            w[i+j] = BinaryElem32::from(w_val);
671        }
672
673        i += 4;
674    }
675
676    // Handle remaining elements with scalar
677    while i < len {
678        let lambda_w = lambda.mul(&w[i]);
679        u[i] = u[i].add(&lambda_w);
680        w[i] = w[i].add(&u[i]);
681        i += 1;
682    }
683}
684
685/// Force AVX-512 path (for benchmarking)
686/// Panics if AVX-512 + VPCLMULQDQ not available
687#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
688pub fn fft_butterfly_gf32_avx512_only(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
689    assert!(is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx512f"),
690            "AVX-512 + VPCLMULQDQ required");
691    unsafe { fft_butterfly_gf32_avx512_impl(u, w, lambda) };
692}
693
694/// Force AVX2 path (for benchmarking)
695/// Panics if AVX2 + VPCLMULQDQ not available
696#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
697pub fn fft_butterfly_gf32_avx2_only(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
698    assert!(is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx2"),
699            "AVX2 + VPCLMULQDQ required");
700    unsafe { fft_butterfly_gf32_avx2_impl(u, w, lambda) };
701}
702
703/// SSE vectorized FFT butterfly operation for GF(2^32)
704/// computes: u[i] = u[i] + lambda*w[i]; w[i] = w[i] + u[i]
705/// processes 2 elements at a time using SSE/AVX
706#[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
707pub fn fft_butterfly_gf32_sse(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
708    use core::arch::x86_64::*;
709
710    assert_eq!(u.len(), w.len());
711    let len = u.len();
712
713    // irreducible polynomial for GF(2^32):
714    // x^32 + x^7 + x^9 + x^15 + x^3 + 1
715    const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);
716
717    unsafe {
718        let lambda_val = lambda.poly().value() as u64;
719        let lambda_vec = _mm_set1_epi64x(lambda_val as i64);
720
721        let mut i = 0;
722
723        // process 2 elements at once (2x32-bit = 64 bits, fits in one lane)
724        while i + 2 <= len {
725            // load w[i] and w[i+1] into 64-bit lanes
726            let w0 = w[i].poly().value() as u64;
727            let w1 = w[i+1].poly().value() as u64;
728            let w_vec = _mm_set_epi64x(w1 as i64, w0 as i64);
729
730            // carryless multiply: lambda * w[i]
731            let prod_lo = _mm_clmulepi64_si128(lambda_vec, w_vec, 0x00); // lambda * w0
732            let prod_hi = _mm_clmulepi64_si128(lambda_vec, w_vec, 0x11); // lambda * w1
733
734            // reduce modulo irreducible
735            let p0 = _mm_extract_epi64(prod_lo, 0) as u64;
736            let p1 = _mm_extract_epi64(prod_hi, 0) as u64;
737
738            let lambda_w0 = reduce_gf32(p0, IRREDUCIBLE_32);
739            let lambda_w1 = reduce_gf32(p1, IRREDUCIBLE_32);
740
741            // u[i] = u[i] XOR lambda_w[i]
742            let u0 = u[i].poly().value() ^ (lambda_w0 as u32);
743            let u1 = u[i+1].poly().value() ^ (lambda_w1 as u32);
744
745            // w[i] = w[i] XOR u[i] (using updated u)
746            let w0_new = w[i].poly().value() ^ u0;
747            let w1_new = w[i+1].poly().value() ^ u1;
748
749            u[i] = BinaryElem32::from(u0);
750            u[i+1] = BinaryElem32::from(u1);
751            w[i] = BinaryElem32::from(w0_new);
752            w[i+1] = BinaryElem32::from(w1_new);
753
754            i += 2;
755        }
756
757        // handle remaining element
758        if i < len {
759            let lambda_w = lambda.mul(&w[i]);
760            u[i] = u[i].add(&lambda_w);
761            w[i] = w[i].add(&u[i]);
762        }
763    }
764}
765
766/// reduce 64-bit product modulo GF(2^32) irreducible
767/// optimized branchless reduction for GF(2^32)
768#[inline(always)]
769fn reduce_gf32(p: u64, _irr: u64) -> u64 {
770    // for 32x32 -> 64 multiplication, we need to reduce bits [63:32]
771    // unrolled reduction: process high 32 bits in chunks
772
773    let hi = (p >> 32) as u64;
774    let lo = (p & 0xFFFFFFFF) as u64;
775
776    // compute tmp by shifting high bits down
777    // for irreducible 0b1_0000_1000_1001_1000_1001 (x^32 + x^15 + x^9 + x^7 + x^3 + 1)
778    // bits set at positions: 0,3,7,9,15 -> shifts needed: 32,29,25,23,17
779    let tmp = hi
780        ^ (hi >> 29)  // bit 15: shift by (32-3)
781        ^ (hi >> 25)  // bit 9: shift by (32-7)
782        ^ (hi >> 23)  // bit 7: shift by (32-9)
783        ^ (hi >> 17); // bit 3: shift by (32-15)
784
785    // XOR with low bits and shifted tmp
786    lo ^ tmp ^ (tmp << 3) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
787}
788
789/// scalar fallback for FFT butterfly
790pub fn fft_butterfly_gf32_scalar(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
791    assert_eq!(u.len(), w.len());
792
793    for i in 0..u.len() {
794        let lambda_w = lambda.mul(&w[i]);
795        u[i] = u[i].add(&lambda_w);
796        w[i] = w[i].add(&u[i]);
797    }
798}
799
800/// dispatch FFT butterfly to best available SIMD version
801pub fn fft_butterfly_gf32(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
802    #[cfg(all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"))]
803    {
804        // Try AVX-512 first (runtime detection), fallback to SSE
805        return fft_butterfly_gf32_avx512(u, w, lambda);
806    }
807
808    #[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
809    {
810        return fft_butterfly_gf32_wasm_simd(u, w, lambda);
811    }
812
813    #[cfg(not(any(
814        all(feature = "hardware-accel", target_arch = "x86_64", target_feature = "pclmulqdq"),
815        all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128")
816    )))]
817    {
818        fft_butterfly_gf32_scalar(u, w, lambda)
819    }
820}
821
822/// WASM SIMD128 optimized FFT butterfly
823/// Uses v128_xor for additions and swizzle-based table lookups for multiplication
824#[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
825pub fn fft_butterfly_gf32_wasm_simd(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
826    use core::arch::wasm32::*;
827
828    assert_eq!(u.len(), w.len());
829    let len = u.len();
830
831    // GF(2^32) irreducible: x^32 + x^7 + x^9 + x^15 + x^3 + 1
832    const IRR: u64 = 0x100008299;
833
834    // Process 4 elements at once using v128
835    let mut i = 0;
836    while i + 4 <= len {
837        // Load 4 elements (each is 32-bit)
838        unsafe {
839            let u_vec = v128_load(u.as_ptr().add(i) as *const v128);
840            let w_vec = v128_load(w.as_ptr().add(i) as *const v128);
841
842            // Compute lambda * w[i..i+4] using swizzle-based multiply
843            let w0 = w[i].poly().value();
844            let w1 = w[i+1].poly().value();
845            let w2 = w[i+2].poly().value();
846            let w3 = w[i+3].poly().value();
847            let lambda_val = lambda.poly().value();
848
849            // Multiply and reduce each element
850            let p0 = mul_32x32_to_64_simd(lambda_val, w0);
851            let p1 = mul_32x32_to_64_simd(lambda_val, w1);
852            let p2 = mul_32x32_to_64_simd(lambda_val, w2);
853            let p3 = mul_32x32_to_64_simd(lambda_val, w3);
854
855            // Reduce mod irreducible polynomial
856            let r0 = reduce_gf32_wasm(p0, IRR) as u32;
857            let r1 = reduce_gf32_wasm(p1, IRR) as u32;
858            let r2 = reduce_gf32_wasm(p2, IRR) as u32;
859            let r3 = reduce_gf32_wasm(p3, IRR) as u32;
860
861            // Create lambda*w vector
862            let lambda_w = u32x4(r0, r1, r2, r3);
863
864            // u[i] = u[i] ^ lambda*w[i] (GF addition is XOR)
865            let new_u = v128_xor(u_vec, lambda_w);
866
867            // w[i] = w[i] ^ new_u[i]
868            let new_w = v128_xor(w_vec, new_u);
869
870            // Store results
871            v128_store(u.as_mut_ptr().add(i) as *mut v128, new_u);
872            v128_store(w.as_mut_ptr().add(i) as *mut v128, new_w);
873        }
874
875        i += 4;
876    }
877
878    // Handle remaining elements
879    while i < len {
880        let lambda_w = lambda.mul(&w[i]);
881        u[i] = u[i].add(&lambda_w);
882        w[i] = w[i].add(&u[i]);
883        i += 1;
884    }
885}
886
887/// GF(2^32) reduction for WASM (branchless)
888#[cfg(all(feature = "hardware-accel", target_arch = "wasm32", target_feature = "simd128"))]
889#[inline(always)]
890fn reduce_gf32_wasm(p: u64, _irr: u64) -> u64 {
891    // Reduction for x^32 + x^7 + x^9 + x^15 + x^3 + 1
892    let hi = p >> 32;
893    let lo = p & 0xFFFFFFFF;
894    let tmp = hi ^ (hi >> 17) ^ (hi >> 23) ^ (hi >> 25) ^ (hi >> 28) ^ (hi >> 29);
895    lo ^ tmp ^ (tmp << 3) ^ (tmp << 4) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
896}
897
898#[cfg(test)]
899mod tests {
900    use super::*;
901
902    #[test]
903    fn test_fft_butterfly_gf32() {
904        // test SIMD vs scalar butterfly give same results
905        let mut u_simd = vec![
906            BinaryElem32::from(1),
907            BinaryElem32::from(2),
908            BinaryElem32::from(3),
909            BinaryElem32::from(4),
910        ];
911        let mut w_simd = vec![
912            BinaryElem32::from(5),
913            BinaryElem32::from(6),
914            BinaryElem32::from(7),
915            BinaryElem32::from(8),
916        ];
917        let lambda = BinaryElem32::from(3);
918
919        let mut u_scalar = u_simd.clone();
920        let mut w_scalar = w_simd.clone();
921
922        fft_butterfly_gf32(&mut u_simd, &mut w_simd, lambda);
923        fft_butterfly_gf32_scalar(&mut u_scalar, &mut w_scalar, lambda);
924
925        for i in 0..u_simd.len() {
926            assert_eq!(u_simd[i], u_scalar[i], "u mismatch at index {}", i);
927            assert_eq!(w_simd[i], w_scalar[i], "w mismatch at index {}", i);
928        }
929    }
930
931    #[test]
932    fn test_batch_add() {
933        let a = vec![
934            BinaryElem128::from(1),
935            BinaryElem128::from(2),
936            BinaryElem128::from(3),
937        ];
938        let b = vec![
939            BinaryElem128::from(4),
940            BinaryElem128::from(5),
941            BinaryElem128::from(6),
942        ];
943        let mut out = vec![BinaryElem128::zero(); 3];
944
945        batch_add_gf128(&a, &b, &mut out);
946
947        for i in 0..3 {
948            assert_eq!(out[i], a[i].add(&b[i]));
949        }
950    }
951
952    #[test]
953    fn test_batch_mul() {
954        let a = vec![
955            BinaryElem128::from(7),
956            BinaryElem128::from(11),
957            BinaryElem128::from(13),
958        ];
959        let b = vec![
960            BinaryElem128::from(3),
961            BinaryElem128::from(5),
962            BinaryElem128::from(7),
963        ];
964        let mut out = vec![BinaryElem128::zero(); 3];
965
966        batch_mul_gf128(&a, &b, &mut out);
967
968        for i in 0..3 {
969            assert_eq!(out[i], a[i].mul(&b[i]));
970        }
971    }
972
973    #[test]
974    fn test_batch_mul_large() {
975        // test with larger field elements
976        let a = vec![
977            BinaryElem128::from(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0),
978            BinaryElem128::from(u128::MAX),
979        ];
980        let b = vec![
981            BinaryElem128::from(0x123456789ABCDEF0123456789ABCDEF0),
982            BinaryElem128::from(0x8000000000000000_0000000000000000),
983        ];
984        let mut out = vec![BinaryElem128::zero(); 2];
985
986        batch_mul_gf128(&a, &b, &mut out);
987
988        for i in 0..2 {
989            assert_eq!(out[i], a[i].mul(&b[i]));
990        }
991    }
992}