arcium-primitives 0.4.2

Arcium primitives
Documentation
use std::num::Wrapping;

use itertools::Itertools;

// Carry-less multiplication of two n-limb numbers.
// Output has 2.n limbs.
pub fn carry_less_mul<const LIMBS: usize>(
    a: [u64; LIMBS],
    b: [u64; LIMBS],
) -> ([u64; LIMBS], [u64; LIMBS]) {
    let mut res = [[0u64; LIMBS]; 2];
    let slice = res.as_flattened_mut();

    (0..a.len())
        .cartesian_product(0..b.len())
        .for_each(|(i, j)| clfma_u64(&mut slice[i + j..i + j + 2], a[i], b[j]));

    (res[0], res[1])
}

pub fn carry_less_mul_1limb(a: [u64; 1], b: [u64; 1]) -> ([u64; 1], [u64; 1]) {
    #[cfg(all(
        any(target_arch = "x86", target_arch = "x86_64"),
        target_feature = "pclmulqdq",
        target_feature = "sse2",
    ))]
    return clmul_u64_sse2(a, b);

    #[cfg(all(
        target_arch = "aarch64",
        target_feature = "neon",
        target_feature = "aes",
    ))]
    return clmul_u64_neon(a, b);

    #[allow(unreachable_code)]
    carry_less_mul(a, b)
}

pub fn carry_less_mul_2limbs(a: [u64; 2], b: [u64; 2]) -> ([u64; 2], [u64; 2]) {
    #[cfg(all(
        any(target_arch = "x86", target_arch = "x86_64"),
        target_feature = "pclmulqdq",
        target_feature = "sse2",
    ))]
    return clmul_2_u64_sse2(a, b);

    #[cfg(all(
        target_arch = "aarch64",
        target_feature = "neon",
        target_feature = "aes",
    ))]
    return clmul_2_u64_neon(a, b);

    #[allow(unreachable_code)]
    carry_less_mul(a, b)
}

fn clfma_u64(res: &mut [u64], a: u64, b: u64) {
    let bit_k: u64 = (-Wrapping((a) & 0x01)).0;
    res[0] ^= b & bit_k;

    for k in 1..64 {
        let bit_k: u64 = (-Wrapping((a >> k) & 0x01)).0;
        let tmp = b & bit_k;
        res[0] ^= tmp << k;
        res[1] ^= tmp >> (64 - k);
    }
}

#[cfg(all(
    any(target_arch = "x86", target_arch = "x86_64"),
    target_feature = "pclmulqdq",
    target_feature = "sse2",
))]
fn clmul_2_u64_sse2(x: [u64; 2], y: [u64; 2]) -> ([u64; 2], [u64; 2]) {
    #[cfg(target_arch = "x86")]
    use core::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use core::arch::x86_64::*;
    use std::mem::MaybeUninit;

    unsafe {
        let a_b = _mm_loadu_si128(x.as_ptr() as *const _);
        let c_d = _mm_loadu_si128(y.as_ptr() as *const _);

        let prod_a_c = _mm_clmulepi64_si128::<0>(a_b, c_d);
        let prod_b_d = _mm_clmulepi64_si128::<17>(a_b, c_d);

        let s_ab = _mm_xor_si128(_mm_broadcastq_epi64(a_b), a_b);
        let s_cd = _mm_xor_si128(_mm_broadcastq_epi64(c_d), c_d);
        let prod_ab_cd = _mm_clmulepi64_si128::<17>(s_ab, s_cd);

        let prod_ab_cd = _mm_xor_si128(prod_ab_cd, _mm_xor_si128(prod_a_c, prod_b_d));

        let res_l = _mm_xor_si128(prod_a_c, _mm_slli_si128::<8>(prod_ab_cd));
        let res_h = _mm_xor_si128(prod_b_d, _mm_srli_si128::<8>(prod_ab_cd));

        let mut low = MaybeUninit::<[u64; 2]>::uninit();
        _mm_storeu_si128(low.as_mut_ptr() as *mut _, res_l);
        let mut high = MaybeUninit::<[u64; 2]>::uninit();
        _mm_storeu_si128(high.as_mut_ptr() as *mut _, res_h);

        (low.assume_init(), high.assume_init())
    }
}

#[cfg(all(
    any(target_arch = "x86", target_arch = "x86_64"),
    target_feature = "pclmulqdq",
    target_feature = "sse2",
))]
fn clmul_u64_sse2(x: [u64; 1], y: [u64; 1]) -> ([u64; 1], [u64; 1]) {
    #[cfg(target_arch = "x86")]
    use core::arch::x86::*;
    #[cfg(target_arch = "x86_64")]
    use core::arch::x86_64::*;
    use std::mem::MaybeUninit;

    let t = unsafe {
        let x = _mm_loadu_si64(x.as_ptr() as *const _);
        let y = _mm_loadu_si64(y.as_ptr() as *const _);

        let prod = _mm_clmulepi64_si128::<0>(x, y);

        let mut res = MaybeUninit::<[u64; 2]>::uninit();
        _mm_storeu_si128(res.as_mut_ptr() as *mut _, prod);

        res.assume_init()
    };

    ([t[0]], [t[1]])
}

// ARM NEON PMULL carry-less multiply (aarch64 with crypto extensions)

#[cfg(all(
    target_arch = "aarch64",
    target_feature = "neon",
    target_feature = "aes",
))]
fn clmul_2_u64_neon(x: [u64; 2], y: [u64; 2]) -> ([u64; 2], [u64; 2]) {
    use core::arch::aarch64::{
        vgetq_lane_p64,
        vld1q_u64,
        vmull_high_p64,
        vmull_p64,
        vreinterpretq_p64_u64,
    };

    unsafe {
        let a = vreinterpretq_p64_u64(vld1q_u64(x.as_ptr()));
        let b = vreinterpretq_p64_u64(vld1q_u64(y.as_ptr()));

        // Karatsuba: lo = x[0]*y[0], hi = x[1]*y[1],
        // mid = (x[0]^x[1])*(y[0]^y[1])
        let lo = vmull_p64(vgetq_lane_p64::<0>(a), vgetq_lane_p64::<0>(b));
        let hi = vmull_high_p64(a, b);

        let s_x = x[0] ^ vgetq_lane_p64::<1>(a);
        let s_y = y[0] ^ vgetq_lane_p64::<1>(b);
        let mid = vmull_p64(s_x, s_y);

        let middle = mid ^ lo ^ hi;

        (
            [lo as u64, ((lo >> 64) ^ middle) as u64],
            [(hi ^ (middle >> 64)) as u64, (hi >> 64) as u64],
        )
    }
}

#[cfg(all(
    target_arch = "aarch64",
    target_feature = "neon",
    target_feature = "aes",
))]
fn clmul_u64_neon(x: [u64; 1], y: [u64; 1]) -> ([u64; 1], [u64; 1]) {
    use core::arch::aarch64::vmull_p64;

    unsafe {
        let prod = vmull_p64(x[0], y[0]);
        ([prod as u64], [(prod >> 64) as u64])
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_carry_less_mul_2limbs() {
        macro_rules! prod_test_case {
            ($aval:expr, $bval:expr, $prod:expr) => {{
                let prod_comp = carry_less_mul($aval, $bval);
                assert_eq!(prod_comp, $prod);
            }};
        }

        prod_test_case!(
            [0x9f418f3bffd84bba, 0x4a7c605645afdfb1],
            [0x80b7bd91cddc5be5, 0x3a97291035e41e1f],
            (
                [0xfffa315f244b1f92, 0x288de208c77fb2f9],
                [0x94aced70da538690, 0x0f0f341c05b65c5e,]
            )
        );

        prod_test_case!(
            [0x74ef862bc1b6d333, 0x3a88103b80d97b73],
            [0x753f4846eb020b5a, 0x8f108359ea25fa8f],
            (
                [0x419f60d2de880d0e, 0xfd4d74204161d27d],
                [0xe96db1bf781f351f, 0x1c36456adc21ac7a]
            )
        );

        prod_test_case!(
            [0x6447b3dcaed62649, 0x6e4af40b2ee1b4c1],
            [0xbd7a4e12fdb29840, 0x8950f56742015f25],
            (
                [0x010b1b56e559ca40, 0xf1ff4cfe33d20957],
                [0xed77c432e4701779, 0x342437199bebda57]
            )
        );
    }

    #[test]
    fn test_dispatched_matches_portable_2limbs() {
        let cases: [([u64; 2], [u64; 2]); 3] = [
            (
                [0x9f418f3bffd84bba, 0x4a7c605645afdfb1],
                [0x80b7bd91cddc5be5, 0x3a97291035e41e1f],
            ),
            (
                [0x74ef862bc1b6d333, 0x3a88103b80d97b73],
                [0x753f4846eb020b5a, 0x8f108359ea25fa8f],
            ),
            (
                [0x6447b3dcaed62649, 0x6e4af40b2ee1b4c1],
                [0xbd7a4e12fdb29840, 0x8950f56742015f25],
            ),
        ];
        for (a, b) in cases {
            let portable = carry_less_mul(a, b);
            let dispatched = carry_less_mul_2limbs(a, b);
            assert_eq!(portable, dispatched);
        }
    }

    #[test]
    fn test_dispatched_matches_portable_1limb() {
        let cases: [([u64; 1], [u64; 1]); 3] = [
            ([0x9f418f3bffd84bba], [0x80b7bd91cddc5be5]),
            ([0x74ef862bc1b6d333], [0x753f4846eb020b5a]),
            ([0xffffffffffffffff], [0xffffffffffffffff]),
        ];
        for (a, b) in cases {
            let portable = carry_less_mul(a, b);
            let dispatched = carry_less_mul_1limb(a, b);
            assert_eq!(portable, dispatched);
        }
    }
}