axhash-core 1.0.0

Platform-agnostic AxHash core for Rust with no_std compatibility.
Documentation
// Portable AVX2 backend.
//
// Bit-identical to `scalar::hash_bytes_long`. Operations used (XOR, ADD,
// 32x32->64 multiply via _mm256_mul_epu32, lane shuffle) match scalar Rust
// exactly. No AES instructions: x86 `aesenc` and ARM `aese`+`aesmc` are not
// bit-equivalent, so using them would re-introduce cross-device divergence.

use crate::backend::scalar::{
    BLOCK_STRIPES, FINAL_STRIPE_SECRET_OFFSET, SCRAMBLE_PRIME, SCRAMBLE_SECRET_OFFSET, STRIPE_BYTES,
    STRIPE_SECRET_BASE, init_acc, merge_acc,
};
use crate::constants::SECRET_STREAM;

use core::arch::x86_64::*;

// Shuffle imm for `_mm256_shuffle_epi32` that swaps adjacent u64 lanes within
// each 128-bit half. Source 32-bit lanes: (2,3,0,1) per half.
const SWAP_U64_PAIRS: i32 = 0b01_00_11_10;

#[target_feature(enable = "avx2")]
#[inline]
unsafe fn mix_quad(acc: __m256i, data: __m256i, secret: __m256i) -> __m256i {
    let k = _mm256_xor_si256(data, secret);
    // hi32(k) per lane: upper 32 bits of each u64, in low half of the lane
    let hi = _mm256_srli_epi64(k, 32);
    // 32x32 -> 64 multiply (uses low 32 bits of each lane); per-lane
    let prod = _mm256_mul_epu32(k, hi);
    // swap adjacent u64 lanes within each 128-bit half: scalar `acc[i ^ 1] += data[i]`
    let swapped = _mm256_shuffle_epi32(data, SWAP_U64_PAIRS);
    _mm256_add_epi64(_mm256_add_epi64(acc, swapped), prod)
}

// 64-bit wrapping multiply by a 32-bit constant, AVX2 emulation via two
// 32x32->64 multiplies. Bit-identical to scalar `v.wrapping_mul(C)` where
// C fits in 32 bits.
#[target_feature(enable = "avx2")]
#[inline]
unsafe fn mul_m256i_by_u32(v: __m256i, c: u32) -> __m256i {
    let c_vec = _mm256_set1_epi64x(c as i64);
    let v_hi = _mm256_srli_epi64(v, 32);
    let prod_lo = _mm256_mul_epu32(v, c_vec);
    let prod_hi = _mm256_mul_epu32(v_hi, c_vec);
    _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32))
}

#[target_feature(enable = "avx2")]
#[inline]
unsafe fn scramble_quad(v: __m256i, s: __m256i) -> __m256i {
    let shifted = _mm256_srli_epi64(v, 47);
    let mixed = _mm256_xor_si256(_mm256_xor_si256(v, shifted), s);
    unsafe { mul_m256i_by_u32(mixed, SCRAMBLE_PRIME as u32) }
}

#[target_feature(enable = "avx2")]
#[inline]
unsafe fn scramble_avx2(a0: &mut __m256i, a1: &mut __m256i, secret_ptr: *const u64) {
    let s0 = unsafe { _mm256_loadu_si256(secret_ptr.cast::<__m256i>()) };
    let s1 = unsafe { _mm256_loadu_si256(secret_ptr.add(4).cast::<__m256i>()) };
    *a0 = unsafe { scramble_quad(*a0, s0) };
    *a1 = unsafe { scramble_quad(*a1, s1) };
}

#[target_feature(enable = "avx2")]
#[inline]
unsafe fn mix_stripe_avx2(
    a0: &mut __m256i,
    a1: &mut __m256i,
    data_ptr: *const u8,
    secret_ptr: *const u64,
) {
    let d0 = unsafe { _mm256_loadu_si256(data_ptr.cast::<__m256i>()) };
    let d1 = unsafe { _mm256_loadu_si256(data_ptr.add(32).cast::<__m256i>()) };
    let s0 = unsafe { _mm256_loadu_si256(secret_ptr.cast::<__m256i>()) };
    let s1 = unsafe { _mm256_loadu_si256(secret_ptr.add(4).cast::<__m256i>()) };

    *a0 = unsafe { mix_quad(*a0, d0, s0) };
    *a1 = unsafe { mix_quad(*a1, d1, s1) };
}

#[target_feature(enable = "avx2")]
pub(crate) unsafe fn hash_bytes_long(ptr: *const u8, len: usize, acc_in: u64) -> u64 {
    let init = init_acc(acc_in);

    // Load scalar init into AVX2 vectors. Each __m256i holds 4 u64 lanes in
    // the order [0,1,2,3] / [4,5,6,7]. The 128-bit-half pairing matches the
    // scalar `i ^ 1` cross-add pattern (pairs (0,1), (2,3), (4,5), (6,7)).
    let mut a0 = unsafe { _mm256_loadu_si256(init.as_ptr().cast::<__m256i>()) };
    let mut a1 = unsafe { _mm256_loadu_si256(init.as_ptr().add(4).cast::<__m256i>()) };

    let secret_base = SECRET_STREAM.as_ptr();
    let scramble_secret = unsafe { secret_base.add(SCRAMBLE_SECRET_OFFSET) };
    let mut secret_off = STRIPE_SECRET_BASE;
    let mut offset = 0usize;
    let mut stripe_in_block = 0usize;

    while offset + STRIPE_BYTES <= len {
        unsafe {
            mix_stripe_avx2(
                &mut a0,
                &mut a1,
                ptr.add(offset),
                secret_base.add(secret_off),
            );
        }
        offset += STRIPE_BYTES;
        secret_off += 1;
        stripe_in_block += 1;

        if stripe_in_block == BLOCK_STRIPES {
            unsafe { scramble_avx2(&mut a0, &mut a1, scramble_secret) };
            stripe_in_block = 0;
            secret_off = STRIPE_SECRET_BASE;
        }
    }

    unsafe {
        mix_stripe_avx2(
            &mut a0,
            &mut a1,
            ptr.add(len - STRIPE_BYTES),
            secret_base.add(FINAL_STRIPE_SECRET_OFFSET),
        );
    }

    let mut acc = [0u64; 8];
    unsafe {
        _mm256_storeu_si256(acc.as_mut_ptr().cast::<__m256i>(), a0);
        _mm256_storeu_si256(acc.as_mut_ptr().add(4).cast::<__m256i>(), a1);
    }

    merge_acc(&acc, len, acc_in)
}