Skip to main content

ligerito_binary_fields/
simd.rs

1// src/simd.rs
2use crate::elem::BinaryElem32;
3use crate::poly::{BinaryPoly128, BinaryPoly256, BinaryPoly64};
4
5// 64x64 -> 128 bit carryless multiplication
6pub fn carryless_mul_64(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
7    // x86_64 with PCLMULQDQ
8    #[cfg(all(
9        feature = "hardware-accel",
10        target_arch = "x86_64",
11        target_feature = "pclmulqdq"
12    ))]
13    {
14        use core::arch::x86_64::*;
15
16        unsafe {
17            let a_vec = _mm_set_epi64x(0, a.value() as i64);
18            let b_vec = _mm_set_epi64x(0, b.value() as i64);
19
20            let result = _mm_clmulepi64_si128(a_vec, b_vec, 0x00);
21
22            let lo = _mm_extract_epi64(result, 0) as u64;
23            let hi = _mm_extract_epi64(result, 1) as u64;
24
25            return BinaryPoly128::new(((hi as u128) << 64) | (lo as u128));
26        }
27    }
28
29    // WASM with SIMD128
30    #[cfg(all(
31        feature = "hardware-accel",
32        target_arch = "wasm32",
33        target_feature = "simd128"
34    ))]
35    {
36        return carryless_mul_64_wasm_simd(a, b);
37    }
38
39    // Software fallback for other platforms
40    #[cfg(not(any(
41        all(
42            feature = "hardware-accel",
43            target_arch = "x86_64",
44            target_feature = "pclmulqdq"
45        ),
46        all(
47            feature = "hardware-accel",
48            target_arch = "wasm32",
49            target_feature = "simd128"
50        )
51    )))]
52    {
53        // software fallback
54        carryless_mul_64_soft(a, b)
55    }
56}
57
58// software implementation for 64x64 using lookup tables
59#[allow(dead_code)]
60fn carryless_mul_64_soft(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
61    let a_val = a.value();
62    let b_val = b.value();
63
64    // Split into 32-bit halves for Karatsuba
65    let a_lo = (a_val & 0xFFFFFFFF) as u32;
66    let a_hi = (a_val >> 32) as u32;
67    let b_lo = (b_val & 0xFFFFFFFF) as u32;
68    let b_hi = (b_val >> 32) as u32;
69
70    // Karatsuba multiplication with lookup tables
71    let z0 = mul_32x32_to_64_lut(a_lo, b_lo);
72    let z2 = mul_32x32_to_64_lut(a_hi, b_hi);
73    let z1 = mul_32x32_to_64_lut(a_lo ^ a_hi, b_lo ^ b_hi);
74
75    // Karatsuba combination
76    let middle = z1 ^ z0 ^ z2;
77    let result_lo = z0 ^ (middle << 32);
78    let result_hi = (middle >> 32) ^ z2;
79
80    BinaryPoly128::new(((result_hi as u128) << 64) | (result_lo as u128))
81}
82
83// WASM SIMD128 optimized implementation using Karatsuba decomposition
84#[cfg(all(
85    feature = "hardware-accel",
86    target_arch = "wasm32",
87    target_feature = "simd128"
88))]
89fn carryless_mul_64_wasm_simd(a: BinaryPoly64, b: BinaryPoly64) -> BinaryPoly128 {
90    unsafe {
91        let a_val = a.value();
92        let b_val = b.value();
93
94        // Split into 32-bit halves for Karatsuba
95        let a_lo = (a_val & 0xFFFFFFFF) as u32;
96        let a_hi = (a_val >> 32) as u32;
97        let b_lo = (b_val & 0xFFFFFFFF) as u32;
98        let b_hi = (b_val >> 32) as u32;
99
100        // Use SIMD for parallel 32x32 multiplications
101        let z0 = mul_32x32_to_64_simd(a_lo, b_lo);
102        let z2 = mul_32x32_to_64_simd(a_hi, b_hi);
103        let z1 = mul_32x32_to_64_simd(a_lo ^ a_hi, b_lo ^ b_hi);
104
105        // Karatsuba combination: z1 = z1 ^ z0 ^ z2
106        let middle = z1 ^ z0 ^ z2;
107
108        // Combine: result = z0 + (middle << 32) + (z2 << 64)
109        let result_lo = z0 ^ (middle << 32);
110        let result_hi = (middle >> 32) ^ z2;
111
112        BinaryPoly128::new(((result_hi as u128) << 64) | (result_lo as u128))
113    }
114}
115
116// 4-bit x 4-bit carryless multiplication lookup table
117// Entry [a][b] = a * b in GF(2) polynomial multiplication (no reduction)
118// This is the core building block for branchless carryless multiply
119static CLMUL_4X4: [[u8; 16]; 16] = {
120    let mut table = [[0u8; 16]; 16];
121    let mut a = 0usize;
122    while a < 16 {
123        let mut b = 0usize;
124        while b < 16 {
125            // Compute a * b carryless (no branches)
126            let mut result = 0u8;
127            let mut i = 0;
128            while i < 4 {
129                // If bit i of b is set, XOR a << i into result
130                let mask = ((b >> i) & 1) as u8;
131                result ^= ((a as u8) << i) * mask;
132                i += 1;
133            }
134            table[a][b] = result;
135            b += 1;
136        }
137        a += 1;
138    }
139    table
140};
141
142// 32x32 -> 64 carryless multiplication using WASM SIMD128 i8x16_swizzle
143// Uses swizzle for parallel 16-way table lookups - 4x faster than scalar
144#[cfg(all(
145    feature = "hardware-accel",
146    target_arch = "wasm32",
147    target_feature = "simd128"
148))]
149#[inline(always)]
150unsafe fn mul_32x32_to_64_simd(a: u32, b: u32) -> u64 {
151    use core::arch::wasm32::*;
152
153    // Extract 4-bit nibbles from both operands
154    let a_nibbles: [usize; 8] = [
155        (a & 0xF) as usize,
156        ((a >> 4) & 0xF) as usize,
157        ((a >> 8) & 0xF) as usize,
158        ((a >> 12) & 0xF) as usize,
159        ((a >> 16) & 0xF) as usize,
160        ((a >> 20) & 0xF) as usize,
161        ((a >> 24) & 0xF) as usize,
162        ((a >> 28) & 0xF) as usize,
163    ];
164
165    // Create b_nibbles index vector for swizzle (replicate to fill 16 lanes)
166    let b0 = (b & 0xF) as u8;
167    let b1 = ((b >> 4) & 0xF) as u8;
168    let b2 = ((b >> 8) & 0xF) as u8;
169    let b3 = ((b >> 12) & 0xF) as u8;
170    let b4 = ((b >> 16) & 0xF) as u8;
171    let b5 = ((b >> 20) & 0xF) as u8;
172    let b6 = ((b >> 24) & 0xF) as u8;
173    let b7 = ((b >> 28) & 0xF) as u8;
174
175    // Index vector: b_nibbles[0..7] in first 8 lanes, zeros in upper 8 (unused)
176    let b_indices = u8x16(b0, b1, b2, b3, b4, b5, b6, b7, 0, 0, 0, 0, 0, 0, 0, 0);
177
178    let mut result = 0u64;
179
180    // For each a_nibble, load its CLMUL_4X4 row and swizzle with b_indices
181    // This computes all 8 products a_nibble[i] * b_nibble[j] in parallel
182    for i in 0..8 {
183        // Load the 16-byte lookup table row for this a_nibble
184        let table_row = v128_load(CLMUL_4X4[a_nibbles[i]].as_ptr() as *const v128);
185
186        // Swizzle: products[j] = CLMUL_4X4[a_nibbles[i]][b_nibbles[j]]
187        let products = i8x16_swizzle(table_row, b_indices);
188
189        // Extract the 8 products and accumulate with proper shifts
190        // Each product at position j contributes at bit position (i+j)*4
191        let p0 = u8x16_extract_lane::<0>(products) as u64;
192        let p1 = u8x16_extract_lane::<1>(products) as u64;
193        let p2 = u8x16_extract_lane::<2>(products) as u64;
194        let p3 = u8x16_extract_lane::<3>(products) as u64;
195        let p4 = u8x16_extract_lane::<4>(products) as u64;
196        let p5 = u8x16_extract_lane::<5>(products) as u64;
197        let p6 = u8x16_extract_lane::<6>(products) as u64;
198        let p7 = u8x16_extract_lane::<7>(products) as u64;
199
200        let base_shift = i * 4;
201        result ^= p0 << base_shift;
202        result ^= p1 << (base_shift + 4);
203        result ^= p2 << (base_shift + 8);
204        result ^= p3 << (base_shift + 12);
205        result ^= p4 << (base_shift + 16);
206        result ^= p5 << (base_shift + 20);
207        result ^= p6 << (base_shift + 24);
208        result ^= p7 << (base_shift + 28);
209    }
210
211    result
212}
213
214// Lookup table based 32x32 carryless multiply (also used as software fallback)
215#[inline(always)]
216fn mul_32x32_to_64_lut(a: u32, b: u32) -> u64 {
217    // Split a and b into 4-bit nibbles
218    let a_nibbles: [usize; 8] = [
219        (a & 0xF) as usize,
220        ((a >> 4) & 0xF) as usize,
221        ((a >> 8) & 0xF) as usize,
222        ((a >> 12) & 0xF) as usize,
223        ((a >> 16) & 0xF) as usize,
224        ((a >> 20) & 0xF) as usize,
225        ((a >> 24) & 0xF) as usize,
226        ((a >> 28) & 0xF) as usize,
227    ];
228
229    let b_nibbles: [usize; 8] = [
230        (b & 0xF) as usize,
231        ((b >> 4) & 0xF) as usize,
232        ((b >> 8) & 0xF) as usize,
233        ((b >> 12) & 0xF) as usize,
234        ((b >> 16) & 0xF) as usize,
235        ((b >> 20) & 0xF) as usize,
236        ((b >> 24) & 0xF) as usize,
237        ((b >> 28) & 0xF) as usize,
238    ];
239
240    // Schoolbook multiplication with 4-bit chunks
241    // Each a_nibble[i] * b_nibble[j] contributes at bit position (i+j)*4
242    let mut result = 0u64;
243
244    // Unrolled for performance
245    for i in 0..8 {
246        for j in 0..8 {
247            let prod = CLMUL_4X4[a_nibbles[i]][b_nibbles[j]] as u64;
248            result ^= prod << ((i + j) * 4);
249        }
250    }
251
252    result
253}
254
255// 128x128 -> 128 bit carryless multiplication (truncated)
256pub fn carryless_mul_128(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly128 {
257    #[cfg(all(
258        feature = "hardware-accel",
259        target_arch = "x86_64",
260        target_feature = "pclmulqdq"
261    ))]
262    {
263        use core::arch::x86_64::*;
264
265        unsafe {
266            // split inputs into 64-bit halves
267            let a_lo = a.value() as u64;
268            let a_hi = (a.value() >> 64) as u64;
269            let b_lo = b.value() as u64;
270            let b_hi = (b.value() >> 64) as u64;
271
272            // perform 3 64x64->128 bit multiplications (skip hi*hi for truncated result)
273            let lo_lo = _mm_clmulepi64_si128(
274                _mm_set_epi64x(0, a_lo as i64),
275                _mm_set_epi64x(0, b_lo as i64),
276                0x00,
277            );
278
279            let lo_hi = _mm_clmulepi64_si128(
280                _mm_set_epi64x(0, a_lo as i64),
281                _mm_set_epi64x(0, b_hi as i64),
282                0x00,
283            );
284
285            let hi_lo = _mm_clmulepi64_si128(
286                _mm_set_epi64x(0, a_hi as i64),
287                _mm_set_epi64x(0, b_lo as i64),
288                0x00,
289            );
290
291            // extract 128-bit results - fix the overflow by casting to u128 first
292            let r0 = (_mm_extract_epi64(lo_lo, 0) as u64) as u128
293                | ((_mm_extract_epi64(lo_lo, 1) as u64) as u128) << 64;
294            let r1 = (_mm_extract_epi64(lo_hi, 0) as u64) as u128
295                | ((_mm_extract_epi64(lo_hi, 1) as u64) as u128) << 64;
296            let r2 = (_mm_extract_epi64(hi_lo, 0) as u64) as u128
297                | ((_mm_extract_epi64(hi_lo, 1) as u64) as u128) << 64;
298
299            // combine: result = r0 + (r1 << 64) + (r2 << 64)
300            let result = r0 ^ (r1 << 64) ^ (r2 << 64);
301
302            return BinaryPoly128::new(result);
303        }
304    }
305
306    #[cfg(not(all(
307        feature = "hardware-accel",
308        target_arch = "x86_64",
309        target_feature = "pclmulqdq"
310    )))]
311    {
312        // software fallback
313        carryless_mul_128_soft(a, b)
314    }
315}
316
317// software implementation for 128x128 truncated
318#[allow(dead_code)]
319fn carryless_mul_128_soft(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly128 {
320    let a_lo = a.value() as u64;
321    let a_hi = (a.value() >> 64) as u64;
322    let b_lo = b.value() as u64;
323    let b_hi = (b.value() >> 64) as u64;
324
325    let z0 = mul_64x64_to_128(a_lo, b_lo);
326    let z1 = mul_64x64_to_128(a_lo ^ a_hi, b_lo ^ b_hi);
327    let z2 = mul_64x64_to_128(a_hi, b_hi);
328
329    // karatsuba combination (truncated)
330    let result = z0 ^ (z1 << 64) ^ (z0 << 64) ^ (z2 << 64);
331    BinaryPoly128::new(result)
332}
333
334// 128x128 -> 256 bit full multiplication
335pub fn carryless_mul_128_full(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly256 {
336    #[cfg(all(
337        feature = "hardware-accel",
338        target_arch = "x86_64",
339        target_feature = "pclmulqdq"
340    ))]
341    {
342        use core::arch::x86_64::*;
343
344        unsafe {
345            let a_lo = a.value() as u64;
346            let a_hi = (a.value() >> 64) as u64;
347            let b_lo = b.value() as u64;
348            let b_hi = (b.value() >> 64) as u64;
349
350            // 4 multiplications
351            let lo_lo = _mm_clmulepi64_si128(
352                _mm_set_epi64x(0, a_lo as i64),
353                _mm_set_epi64x(0, b_lo as i64),
354                0x00,
355            );
356
357            let lo_hi = _mm_clmulepi64_si128(
358                _mm_set_epi64x(0, a_lo as i64),
359                _mm_set_epi64x(0, b_hi as i64),
360                0x00,
361            );
362
363            let hi_lo = _mm_clmulepi64_si128(
364                _mm_set_epi64x(0, a_hi as i64),
365                _mm_set_epi64x(0, b_lo as i64),
366                0x00,
367            );
368
369            let hi_hi = _mm_clmulepi64_si128(
370                _mm_set_epi64x(0, a_hi as i64),
371                _mm_set_epi64x(0, b_hi as i64),
372                0x00,
373            );
374
375            // extract and combine
376            let r0_lo = _mm_extract_epi64(lo_lo, 0) as u64;
377            let r0_hi = _mm_extract_epi64(lo_lo, 1) as u64;
378            let r1_lo = _mm_extract_epi64(lo_hi, 0) as u64;
379            let r1_hi = _mm_extract_epi64(lo_hi, 1) as u64;
380            let r2_lo = _mm_extract_epi64(hi_lo, 0) as u64;
381            let r2_hi = _mm_extract_epi64(hi_lo, 1) as u64;
382            let r3_lo = _mm_extract_epi64(hi_hi, 0) as u64;
383            let r3_hi = _mm_extract_epi64(hi_hi, 1) as u64;
384
385            // build 256-bit result
386            let mut lo = r0_lo as u128 | ((r0_hi as u128) << 64);
387            let mut hi = 0u128;
388
389            // add r1 << 64
390            lo ^= (r1_lo as u128) << 64;
391            hi ^= (r1_lo as u128) >> 64;
392            hi ^= r1_hi as u128;
393
394            // add r2 << 64
395            lo ^= (r2_lo as u128) << 64;
396            hi ^= (r2_lo as u128) >> 64;
397            hi ^= r2_hi as u128;
398
399            // add r3 << 128
400            hi ^= r3_lo as u128 | ((r3_hi as u128) << 64);
401
402            return BinaryPoly256::from_parts(hi, lo);
403        }
404    }
405
406    #[cfg(not(all(
407        feature = "hardware-accel",
408        target_arch = "x86_64",
409        target_feature = "pclmulqdq"
410    )))]
411    {
412        // software fallback
413        carryless_mul_128_full_soft(a, b)
414    }
415}
416
417// software implementation for 128x128 full
418#[allow(dead_code)]
419fn carryless_mul_128_full_soft(a: BinaryPoly128, b: BinaryPoly128) -> BinaryPoly256 {
420    let a_lo = a.value() as u64;
421    let a_hi = (a.value() >> 64) as u64;
422    let b_lo = b.value() as u64;
423    let b_hi = (b.value() >> 64) as u64;
424
425    let z0 = mul_64x64_to_128(a_lo, b_lo);
426    let z2 = mul_64x64_to_128(a_hi, b_hi);
427    let z1 = mul_64x64_to_128(a_lo ^ a_hi, b_lo ^ b_hi) ^ z0 ^ z2;
428
429    // combine: result = z0 + (z1 << 64) + (z2 << 128)
430    let mut lo = z0;
431    let mut hi = 0u128;
432
433    // add z1 << 64
434    lo ^= z1 << 64;
435    hi ^= z1 >> 64;
436
437    // add z2 << 128
438    hi ^= z2;
439
440    BinaryPoly256::from_parts(hi, lo)
441}
442
443// helper: constant-time 64x64 -> 128
444#[inline(always)]
445#[allow(dead_code)]
446fn mul_64x64_to_128(a: u64, b: u64) -> u128 {
447    let mut result = 0u128;
448    let mut a_shifted = a as u128;
449
450    for i in 0..64 {
451        let mask = 0u128.wrapping_sub(((b >> i) & 1) as u128);
452        result ^= a_shifted & mask;
453        a_shifted <<= 1;
454    }
455
456    result
457}
458
459// batch field operations
460
461use crate::{BinaryElem128, BinaryFieldElement};
462
463/// batch multiply gf(2^128) elements with two-tier dispatch:
464/// hardware-accel → pclmulqdq, else → scalar
465pub fn batch_mul_gf128(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
466    assert_eq!(a.len(), b.len());
467    assert_eq!(a.len(), out.len());
468
469    #[cfg(all(
470        feature = "hardware-accel",
471        target_arch = "x86_64",
472        target_feature = "pclmulqdq"
473    ))]
474    {
475        return batch_mul_gf128_hw(a, b, out);
476    }
477
478    #[cfg(not(all(
479        feature = "hardware-accel",
480        target_arch = "x86_64",
481        target_feature = "pclmulqdq"
482    )))]
483    {
484        // scalar fallback
485        for i in 0..a.len() {
486            out[i] = a[i].mul(&b[i]);
487        }
488    }
489}
490
491/// batch add gf(2^128) elements (xor in gf(2^n))
492pub fn batch_add_gf128(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
493    assert_eq!(a.len(), b.len());
494    assert_eq!(a.len(), out.len());
495
496    // scalar fallback (XOR is already very fast)
497    for i in 0..a.len() {
498        out[i] = a[i].add(&b[i]);
499    }
500}
501
502// pclmulqdq-based batch multiply for x86_64
503#[cfg(all(
504    feature = "hardware-accel",
505    target_arch = "x86_64",
506    target_feature = "pclmulqdq"
507))]
508fn batch_mul_gf128_hw(a: &[BinaryElem128], b: &[BinaryElem128], out: &mut [BinaryElem128]) {
509    for i in 0..a.len() {
510        let a_poly = a[i].poly();
511        let b_poly = b[i].poly();
512        let product = carryless_mul_128_full(a_poly, b_poly);
513        let reduced = reduce_gf128(product);
514        out[i] = BinaryElem128::from_value(reduced.value());
515    }
516}
517
518/// reduce 256-bit product modulo GF(2^128) irreducible polynomial
519/// irreducible: x^128 + x^7 + x^2 + x + 1 (0x87 = 0b10000111)
520/// matches julia's @generated mod_irreducible (binaryfield.jl:73-114)
521#[inline(always)]
522pub fn reduce_gf128(product: BinaryPoly256) -> BinaryPoly128 {
523    let (hi, lo) = product.split();
524    let high = hi.value();
525    let low = lo.value();
526
527    // julia's compute_tmp for irreducible 0b10000111 (bits 0,1,2,7):
528    // for each set bit i in irreducible: tmp ^= hi >> (128 - i)
529    // bits set: 0, 1, 2, 7 -> shifts: 128, 127, 126, 121
530    let tmp = high ^ (high >> 127) ^ (high >> 126) ^ (high >> 121);
531
532    // julia's compute_res:
533    // for each set bit i in irreducible: res ^= tmp << i
534    // bits set: 0, 1, 2, 7 -> shifts: 0, 1, 2, 7
535    let res = low ^ tmp ^ (tmp << 1) ^ (tmp << 2) ^ (tmp << 7);
536
537    BinaryPoly128::new(res)
538}
539
540// =========================================================================
541// BinaryElem32 batch operations - FFT optimization
542// =========================================================================
543
544/// Vectorized FFT butterfly for GF(2^32) with tiered fallback
545/// Tries: AVX-512 (8 elements) -> AVX2 (4 elements) -> SSE (2 elements) -> scalar
546#[cfg(all(
547    feature = "hardware-accel",
548    target_arch = "x86_64",
549    target_feature = "pclmulqdq"
550))]
551pub fn fft_butterfly_gf32_avx512(
552    u: &mut [BinaryElem32],
553    w: &mut [BinaryElem32],
554    lambda: BinaryElem32,
555) {
556    #[cfg(target_arch = "x86_64")]
557    {
558        // Tier 1: Try AVX-512 with 512-bit VPCLMULQDQ (8 elements at once)
559        if is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx512f") {
560            // SAFETY: we just checked the CPU supports these features
561            unsafe { fft_butterfly_gf32_avx512_impl(u, w, lambda) };
562            return;
563        }
564
565        // Tier 2: Try AVX2 with 256-bit VPCLMULQDQ (4 elements at once)
566        if is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx2") {
567            // SAFETY: we just checked the CPU supports these features
568            unsafe { fft_butterfly_gf32_avx2_impl(u, w, lambda) };
569            return;
570        }
571
572        // Tier 3: Fall back to SSE with 128-bit PCLMULQDQ (2 elements at once)
573        return fft_butterfly_gf32_sse(u, w, lambda);
574    }
575
576    #[cfg(not(target_arch = "x86_64"))]
577    {
578        fft_butterfly_gf32_scalar(u, w, lambda)
579    }
580}
581
582/// AVX-512 implementation using 512-bit VPCLMULQDQ
583/// Processes 8 elements at once using full 512-bit vectors
584/// Requires Rust 1.89+ for _mm512_extracti64x4_epi64
585#[cfg(all(
586    feature = "hardware-accel",
587    target_arch = "x86_64",
588    target_feature = "pclmulqdq"
589))]
590#[target_feature(enable = "avx512f,vpclmulqdq")]
591unsafe fn fft_butterfly_gf32_avx512_impl(
592    u: &mut [BinaryElem32],
593    w: &mut [BinaryElem32],
594    lambda: BinaryElem32,
595) {
596    use core::arch::x86_64::*;
597
598    assert_eq!(u.len(), w.len());
599    let len = u.len();
600
601    let lambda_val = lambda.poly().value() as u64;
602    // Broadcast lambda to all 8 lanes of 512-bit vector
603    let lambda_512 = _mm512_set1_epi64(lambda_val as i64);
604
605    let mut i = 0;
606
607    // Process 8 elements at once using 512-bit vectors
608    // VPCLMULQDQ on 512-bit does 4 clmuls per instruction (one per 128-bit lane)
609    // We pack elements: [w0, w1] [w2, w3] [w4, w5] [w6, w7] in 4 x 128-bit lanes
610    while i + 8 <= len {
611        // Load 8 w elements into 512-bit vector
612        // Each 128-bit lane holds 2 elements for clmul
613        let w_512 = _mm512_set_epi64(
614            w[i + 7].poly().value() as i64,
615            w[i + 6].poly().value() as i64,
616            w[i + 5].poly().value() as i64,
617            w[i + 4].poly().value() as i64,
618            w[i + 3].poly().value() as i64,
619            w[i + 2].poly().value() as i64,
620            w[i + 1].poly().value() as i64,
621            w[i].poly().value() as i64,
622        );
623
624        // VPCLMULQDQ selector 0x00: multiply low 64-bits of each 128-bit lane
625        // This gives us: lambda*w[0], lambda*w[2], lambda*w[4], lambda*w[6]
626        let prod_even = _mm512_clmulepi64_epi128(lambda_512, w_512, 0x00);
627
628        // VPCLMULQDQ selector 0x01: multiply low of first operand with high of second
629        // This gives us: lambda*w[1], lambda*w[3], lambda*w[5], lambda*w[7]
630        let prod_odd = _mm512_clmulepi64_epi128(lambda_512, w_512, 0x01);
631
632        // Extract 256-bit halves using _mm512_extracti64x4_epi64 (Rust 1.89+)
633        let prod_even_lo: __m256i = _mm512_extracti64x4_epi64::<0>(prod_even);
634        let prod_even_hi: __m256i = _mm512_extracti64x4_epi64::<1>(prod_even);
635        let prod_odd_lo: __m256i = _mm512_extracti64x4_epi64::<0>(prod_odd);
636        let prod_odd_hi: __m256i = _mm512_extracti64x4_epi64::<1>(prod_odd);
637
638        // Extract individual 64-bit products
639        let p0 = _mm256_extract_epi64::<0>(prod_even_lo) as u64; // lambda * w[0]
640        let p2 = _mm256_extract_epi64::<2>(prod_even_lo) as u64; // lambda * w[2]
641        let p4 = _mm256_extract_epi64::<0>(prod_even_hi) as u64; // lambda * w[4]
642        let p6 = _mm256_extract_epi64::<2>(prod_even_hi) as u64; // lambda * w[6]
643
644        let p1 = _mm256_extract_epi64::<0>(prod_odd_lo) as u64; // lambda * w[1]
645        let p3 = _mm256_extract_epi64::<2>(prod_odd_lo) as u64; // lambda * w[3]
646        let p5 = _mm256_extract_epi64::<0>(prod_odd_hi) as u64; // lambda * w[5]
647        let p7 = _mm256_extract_epi64::<2>(prod_odd_hi) as u64; // lambda * w[7]
648
649        // Reduce all 8 products
650        let lw = [
651            reduce_gf32_inline(p0) as u32,
652            reduce_gf32_inline(p1) as u32,
653            reduce_gf32_inline(p2) as u32,
654            reduce_gf32_inline(p3) as u32,
655            reduce_gf32_inline(p4) as u32,
656            reduce_gf32_inline(p5) as u32,
657            reduce_gf32_inline(p6) as u32,
658            reduce_gf32_inline(p7) as u32,
659        ];
660
661        // u[i] = u[i] XOR lambda_w[i], then w[i] = w[i] XOR u[i]
662        for j in 0..8 {
663            let u_val = u[i + j].poly().value() ^ lw[j];
664            let w_val = w[i + j].poly().value() ^ u_val;
665            u[i + j] = BinaryElem32::from(u_val);
666            w[i + j] = BinaryElem32::from(w_val);
667        }
668
669        i += 8;
670    }
671
672    // Handle remaining elements with scalar
673    while i < len {
674        let lambda_w = lambda.mul(&w[i]);
675        u[i] = u[i].add(&lambda_w);
676        w[i] = w[i].add(&u[i]);
677        i += 1;
678    }
679}
680
681/// Inline reduction for GF(2^32) - branchless
682#[inline(always)]
683fn reduce_gf32_inline(p: u64) -> u64 {
684    let hi = p >> 32;
685    let lo = p & 0xFFFFFFFF;
686
687    // Irreducible: x^32 + x^15 + x^9 + x^7 + x^4 + x^3 + 1
688    // Compute reduction using shift pattern
689    let tmp = hi ^ (hi >> 17) ^ (hi >> 23) ^ (hi >> 25) ^ (hi >> 28) ^ (hi >> 29);
690    lo ^ tmp ^ (tmp << 3) ^ (tmp << 4) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
691}
692
693/// AVX2 vectorized FFT butterfly operation for GF(2^32)
694/// Processes 4 elements at once using 256-bit VPCLMULQDQ
695/// For CPUs with AVX2 but not AVX-512
696#[cfg(all(
697    feature = "hardware-accel",
698    target_arch = "x86_64",
699    target_feature = "pclmulqdq"
700))]
701#[target_feature(enable = "avx2,vpclmulqdq")]
702unsafe fn fft_butterfly_gf32_avx2_impl(
703    u: &mut [BinaryElem32],
704    w: &mut [BinaryElem32],
705    lambda: BinaryElem32,
706) {
707    use core::arch::x86_64::*;
708
709    assert_eq!(u.len(), w.len());
710    let len = u.len();
711
712    let lambda_val = lambda.poly().value() as u64;
713    // Broadcast lambda to both 128-bit lanes
714    let lambda_256 = _mm256_set1_epi64x(lambda_val as i64);
715
716    let mut i = 0;
717
718    // Process 4 elements at once using 256-bit vectors
719    // VPCLMULQDQ on 256-bit does 2 clmuls per instruction (one per 128-bit lane)
720    while i + 4 <= len {
721        // Load 4 w elements: [w0, w1] in low lane, [w2, w3] in high lane
722        let w_256 = _mm256_set_epi64x(
723            w[i + 3].poly().value() as i64,
724            w[i + 2].poly().value() as i64,
725            w[i + 1].poly().value() as i64,
726            w[i].poly().value() as i64,
727        );
728
729        // Selector 0x00: multiply low 64-bits of each 128-bit lane
730        // Gives: lambda*w[0], lambda*w[2]
731        let prod_even = _mm256_clmulepi64_epi128(lambda_256, w_256, 0x00);
732
733        // Selector 0x01: multiply low of first with high of second
734        // Gives: lambda*w[1], lambda*w[3]
735        let prod_odd = _mm256_clmulepi64_epi128(lambda_256, w_256, 0x01);
736
737        // Extract products
738        let p0 = _mm256_extract_epi64::<0>(prod_even) as u64;
739        let p2 = _mm256_extract_epi64::<2>(prod_even) as u64;
740        let p1 = _mm256_extract_epi64::<0>(prod_odd) as u64;
741        let p3 = _mm256_extract_epi64::<2>(prod_odd) as u64;
742
743        // Reduce all 4 products
744        let lw = [
745            reduce_gf32_inline(p0) as u32,
746            reduce_gf32_inline(p1) as u32,
747            reduce_gf32_inline(p2) as u32,
748            reduce_gf32_inline(p3) as u32,
749        ];
750
751        // u[i] = u[i] XOR lambda_w[i], then w[i] = w[i] XOR u[i]
752        for j in 0..4 {
753            let u_val = u[i + j].poly().value() ^ lw[j];
754            let w_val = w[i + j].poly().value() ^ u_val;
755            u[i + j] = BinaryElem32::from(u_val);
756            w[i + j] = BinaryElem32::from(w_val);
757        }
758
759        i += 4;
760    }
761
762    // Handle remaining elements with scalar
763    while i < len {
764        let lambda_w = lambda.mul(&w[i]);
765        u[i] = u[i].add(&lambda_w);
766        w[i] = w[i].add(&u[i]);
767        i += 1;
768    }
769}
770
771/// Force AVX-512 path (for benchmarking)
772/// Panics if AVX-512 + VPCLMULQDQ not available
773#[cfg(all(
774    feature = "hardware-accel",
775    target_arch = "x86_64",
776    target_feature = "pclmulqdq"
777))]
778pub fn fft_butterfly_gf32_avx512_only(
779    u: &mut [BinaryElem32],
780    w: &mut [BinaryElem32],
781    lambda: BinaryElem32,
782) {
783    assert!(
784        is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx512f"),
785        "AVX-512 + VPCLMULQDQ required"
786    );
787    unsafe { fft_butterfly_gf32_avx512_impl(u, w, lambda) };
788}
789
790/// Force AVX2 path (for benchmarking)
791/// Panics if AVX2 + VPCLMULQDQ not available
792#[cfg(all(
793    feature = "hardware-accel",
794    target_arch = "x86_64",
795    target_feature = "pclmulqdq"
796))]
797pub fn fft_butterfly_gf32_avx2_only(
798    u: &mut [BinaryElem32],
799    w: &mut [BinaryElem32],
800    lambda: BinaryElem32,
801) {
802    assert!(
803        is_x86_feature_detected!("vpclmulqdq") && is_x86_feature_detected!("avx2"),
804        "AVX2 + VPCLMULQDQ required"
805    );
806    unsafe { fft_butterfly_gf32_avx2_impl(u, w, lambda) };
807}
808
809/// SSE vectorized FFT butterfly operation for GF(2^32)
810/// computes: u[i] = u[i] + lambda*w[i]; w[i] = w[i] + u[i]
811/// processes 2 elements at a time using SSE/AVX
812#[cfg(all(
813    feature = "hardware-accel",
814    target_arch = "x86_64",
815    target_feature = "pclmulqdq"
816))]
817pub fn fft_butterfly_gf32_sse(
818    u: &mut [BinaryElem32],
819    w: &mut [BinaryElem32],
820    lambda: BinaryElem32,
821) {
822    use core::arch::x86_64::*;
823
824    assert_eq!(u.len(), w.len());
825    let len = u.len();
826
827    // irreducible polynomial for GF(2^32):
828    // x^32 + x^7 + x^9 + x^15 + x^3 + 1
829    const IRREDUCIBLE_32: u64 = (1u64 << 32) | 0b11001 | (1 << 7) | (1 << 9) | (1 << 15);
830
831    unsafe {
832        let lambda_val = lambda.poly().value() as u64;
833        let lambda_vec = _mm_set1_epi64x(lambda_val as i64);
834
835        let mut i = 0;
836
837        // process 2 elements at once (2x32-bit = 64 bits, fits in one lane)
838        while i + 2 <= len {
839            // load w[i] and w[i+1] into 64-bit lanes
840            let w0 = w[i].poly().value() as u64;
841            let w1 = w[i + 1].poly().value() as u64;
842            let w_vec = _mm_set_epi64x(w1 as i64, w0 as i64);
843
844            // carryless multiply: lambda * w[i]
845            let prod_lo = _mm_clmulepi64_si128(lambda_vec, w_vec, 0x00); // lambda * w0
846            let prod_hi = _mm_clmulepi64_si128(lambda_vec, w_vec, 0x11); // lambda * w1
847
848            // reduce modulo irreducible
849            let p0 = _mm_extract_epi64(prod_lo, 0) as u64;
850            let p1 = _mm_extract_epi64(prod_hi, 0) as u64;
851
852            let lambda_w0 = reduce_gf32(p0, IRREDUCIBLE_32);
853            let lambda_w1 = reduce_gf32(p1, IRREDUCIBLE_32);
854
855            // u[i] = u[i] XOR lambda_w[i]
856            let u0 = u[i].poly().value() ^ (lambda_w0 as u32);
857            let u1 = u[i + 1].poly().value() ^ (lambda_w1 as u32);
858
859            // w[i] = w[i] XOR u[i] (using updated u)
860            let w0_new = w[i].poly().value() ^ u0;
861            let w1_new = w[i + 1].poly().value() ^ u1;
862
863            u[i] = BinaryElem32::from(u0);
864            u[i + 1] = BinaryElem32::from(u1);
865            w[i] = BinaryElem32::from(w0_new);
866            w[i + 1] = BinaryElem32::from(w1_new);
867
868            i += 2;
869        }
870
871        // handle remaining element
872        if i < len {
873            let lambda_w = lambda.mul(&w[i]);
874            u[i] = u[i].add(&lambda_w);
875            w[i] = w[i].add(&u[i]);
876        }
877    }
878}
879
880/// reduce 64-bit product modulo GF(2^32) irreducible
881/// optimized branchless reduction for GF(2^32)
882#[inline(always)]
883fn reduce_gf32(p: u64, _irr: u64) -> u64 {
884    // for 32x32 -> 64 multiplication, we need to reduce bits [63:32]
885    // unrolled reduction: process high 32 bits in chunks
886
887    let hi = (p >> 32);
888    let lo = (p & 0xFFFFFFFF);
889
890    // compute tmp by shifting high bits down
891    // for irreducible 0b1_0000_1000_1001_1000_1001 (x^32 + x^15 + x^9 + x^7 + x^3 + 1)
892    // bits set at positions: 0,3,7,9,15 -> shifts needed: 32,29,25,23,17
893    let tmp = hi
894        ^ (hi >> 29)  // bit 15: shift by (32-3)
895        ^ (hi >> 25)  // bit 9: shift by (32-7)
896        ^ (hi >> 23)  // bit 7: shift by (32-9)
897        ^ (hi >> 17); // bit 3: shift by (32-15)
898
899    // XOR with low bits and shifted tmp
900    lo ^ tmp ^ (tmp << 3) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
901}
902
903/// scalar fallback for FFT butterfly
904pub fn fft_butterfly_gf32_scalar(
905    u: &mut [BinaryElem32],
906    w: &mut [BinaryElem32],
907    lambda: BinaryElem32,
908) {
909    assert_eq!(u.len(), w.len());
910
911    for i in 0..u.len() {
912        let lambda_w = lambda.mul(&w[i]);
913        u[i] = u[i].add(&lambda_w);
914        w[i] = w[i].add(&u[i]);
915    }
916}
917
918/// dispatch FFT butterfly to best available SIMD version
919pub fn fft_butterfly_gf32(u: &mut [BinaryElem32], w: &mut [BinaryElem32], lambda: BinaryElem32) {
920    #[cfg(all(
921        feature = "hardware-accel",
922        target_arch = "x86_64",
923        target_feature = "pclmulqdq"
924    ))]
925    {
926        // Try AVX-512 first (runtime detection), fallback to SSE
927        return fft_butterfly_gf32_avx512(u, w, lambda);
928    }
929
930    #[cfg(all(
931        feature = "hardware-accel",
932        target_arch = "wasm32",
933        target_feature = "simd128"
934    ))]
935    {
936        return fft_butterfly_gf32_wasm_simd(u, w, lambda);
937    }
938
939    #[cfg(not(any(
940        all(
941            feature = "hardware-accel",
942            target_arch = "x86_64",
943            target_feature = "pclmulqdq"
944        ),
945        all(
946            feature = "hardware-accel",
947            target_arch = "wasm32",
948            target_feature = "simd128"
949        )
950    )))]
951    {
952        fft_butterfly_gf32_scalar(u, w, lambda)
953    }
954}
955
956/// WASM SIMD128 optimized FFT butterfly
957/// Uses v128_xor for additions and swizzle-based table lookups for multiplication
958#[cfg(all(
959    feature = "hardware-accel",
960    target_arch = "wasm32",
961    target_feature = "simd128"
962))]
963pub fn fft_butterfly_gf32_wasm_simd(
964    u: &mut [BinaryElem32],
965    w: &mut [BinaryElem32],
966    lambda: BinaryElem32,
967) {
968    use core::arch::wasm32::*;
969
970    assert_eq!(u.len(), w.len());
971    let len = u.len();
972
973    // GF(2^32) irreducible: x^32 + x^7 + x^9 + x^15 + x^3 + 1
974    const IRR: u64 = 0x100008299;
975
976    // Process 4 elements at once using v128
977    let mut i = 0;
978    while i + 4 <= len {
979        // Load 4 elements (each is 32-bit)
980        unsafe {
981            let u_vec = v128_load(u.as_ptr().add(i) as *const v128);
982            let w_vec = v128_load(w.as_ptr().add(i) as *const v128);
983
984            // Compute lambda * w[i..i+4] using swizzle-based multiply
985            let w0 = w[i].poly().value();
986            let w1 = w[i + 1].poly().value();
987            let w2 = w[i + 2].poly().value();
988            let w3 = w[i + 3].poly().value();
989            let lambda_val = lambda.poly().value();
990
991            // Multiply and reduce each element
992            let p0 = mul_32x32_to_64_simd(lambda_val, w0);
993            let p1 = mul_32x32_to_64_simd(lambda_val, w1);
994            let p2 = mul_32x32_to_64_simd(lambda_val, w2);
995            let p3 = mul_32x32_to_64_simd(lambda_val, w3);
996
997            // Reduce mod irreducible polynomial
998            let r0 = reduce_gf32_wasm(p0, IRR) as u32;
999            let r1 = reduce_gf32_wasm(p1, IRR) as u32;
1000            let r2 = reduce_gf32_wasm(p2, IRR) as u32;
1001            let r3 = reduce_gf32_wasm(p3, IRR) as u32;
1002
1003            // Create lambda*w vector
1004            let lambda_w = u32x4(r0, r1, r2, r3);
1005
1006            // u[i] = u[i] ^ lambda*w[i] (GF addition is XOR)
1007            let new_u = v128_xor(u_vec, lambda_w);
1008
1009            // w[i] = w[i] ^ new_u[i]
1010            let new_w = v128_xor(w_vec, new_u);
1011
1012            // Store results
1013            v128_store(u.as_mut_ptr().add(i) as *mut v128, new_u);
1014            v128_store(w.as_mut_ptr().add(i) as *mut v128, new_w);
1015        }
1016
1017        i += 4;
1018    }
1019
1020    // Handle remaining elements
1021    while i < len {
1022        let lambda_w = lambda.mul(&w[i]);
1023        u[i] = u[i].add(&lambda_w);
1024        w[i] = w[i].add(&u[i]);
1025        i += 1;
1026    }
1027}
1028
1029/// GF(2^32) reduction for WASM (branchless)
1030#[cfg(all(
1031    feature = "hardware-accel",
1032    target_arch = "wasm32",
1033    target_feature = "simd128"
1034))]
1035#[inline(always)]
1036fn reduce_gf32_wasm(p: u64, _irr: u64) -> u64 {
1037    // Reduction for x^32 + x^7 + x^9 + x^15 + x^3 + 1
1038    let hi = p >> 32;
1039    let lo = p & 0xFFFFFFFF;
1040    let tmp = hi ^ (hi >> 17) ^ (hi >> 23) ^ (hi >> 25) ^ (hi >> 28) ^ (hi >> 29);
1041    lo ^ tmp ^ (tmp << 3) ^ (tmp << 4) ^ (tmp << 7) ^ (tmp << 9) ^ (tmp << 15)
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047
1048    #[test]
1049    fn test_fft_butterfly_gf32() {
1050        // test SIMD vs scalar butterfly give same results
1051        let mut u_simd = vec![
1052            BinaryElem32::from(1),
1053            BinaryElem32::from(2),
1054            BinaryElem32::from(3),
1055            BinaryElem32::from(4),
1056        ];
1057        let mut w_simd = vec![
1058            BinaryElem32::from(5),
1059            BinaryElem32::from(6),
1060            BinaryElem32::from(7),
1061            BinaryElem32::from(8),
1062        ];
1063        let lambda = BinaryElem32::from(3);
1064
1065        let mut u_scalar = u_simd.clone();
1066        let mut w_scalar = w_simd.clone();
1067
1068        fft_butterfly_gf32(&mut u_simd, &mut w_simd, lambda);
1069        fft_butterfly_gf32_scalar(&mut u_scalar, &mut w_scalar, lambda);
1070
1071        for i in 0..u_simd.len() {
1072            assert_eq!(u_simd[i], u_scalar[i], "u mismatch at index {}", i);
1073            assert_eq!(w_simd[i], w_scalar[i], "w mismatch at index {}", i);
1074        }
1075    }
1076
1077    #[test]
1078    fn test_batch_add() {
1079        let a = vec![
1080            BinaryElem128::from(1),
1081            BinaryElem128::from(2),
1082            BinaryElem128::from(3),
1083        ];
1084        let b = vec![
1085            BinaryElem128::from(4),
1086            BinaryElem128::from(5),
1087            BinaryElem128::from(6),
1088        ];
1089        let mut out = vec![BinaryElem128::zero(); 3];
1090
1091        batch_add_gf128(&a, &b, &mut out);
1092
1093        for i in 0..3 {
1094            assert_eq!(out[i], a[i].add(&b[i]));
1095        }
1096    }
1097
1098    #[test]
1099    fn test_batch_mul() {
1100        let a = vec![
1101            BinaryElem128::from(7),
1102            BinaryElem128::from(11),
1103            BinaryElem128::from(13),
1104        ];
1105        let b = vec![
1106            BinaryElem128::from(3),
1107            BinaryElem128::from(5),
1108            BinaryElem128::from(7),
1109        ];
1110        let mut out = vec![BinaryElem128::zero(); 3];
1111
1112        batch_mul_gf128(&a, &b, &mut out);
1113
1114        for i in 0..3 {
1115            assert_eq!(out[i], a[i].mul(&b[i]));
1116        }
1117    }
1118
1119    #[test]
1120    fn test_batch_mul_large() {
1121        // test with larger field elements
1122        let a = vec![
1123            BinaryElem128::from(0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF0),
1124            BinaryElem128::from(u128::MAX),
1125        ];
1126        let b = vec![
1127            BinaryElem128::from(0x123456789ABCDEF0123456789ABCDEF0),
1128            BinaryElem128::from(0x8000000000000000_0000000000000000),
1129        ];
1130        let mut out = vec![BinaryElem128::zero(); 2];
1131
1132        batch_mul_gf128(&a, &b, &mut out);
1133
1134        for i in 0..2 {
1135            assert_eq!(out[i], a[i].mul(&b[i]));
1136        }
1137    }
1138}