axhash-core 1.0.0

Platform-agnostic AxHash core for Rust with no_std compatibility.
Documentation
// Portable NEON backend.
//
// Bit-identical to `scalar::hash_bytes_long`. Operations used (XOR, ADD,
// 32x32->64 multiply, lane swap) all have well-defined wraparound semantics
// that match scalar Rust on every CPU. No AES instructions: ARM AES rounds
// are not equivalent to x86 AES rounds, so using them would re-introduce
// cross-device hash divergence.

use crate::backend::scalar::{
    BLOCK_STRIPES, FINAL_STRIPE_SECRET_OFFSET, SCRAMBLE_PRIME, SCRAMBLE_SECRET_OFFSET, STRIPE_BYTES,
    STRIPE_SECRET_BASE, init_acc, merge_acc,
};
use crate::constants::SECRET_STREAM;
use crate::memory::{r_u64x2, r_u64x2_aligned};
use core::arch::aarch64::*;

#[target_feature(enable = "neon")]
#[inline]
unsafe fn mix_pair(
    acc: uint64x2_t,
    data: uint64x2_t,
    secret: uint64x2_t,
) -> uint64x2_t {
    let k = veorq_u64(data, secret);
    // lo32(k) per lane (low 32 bits of each u64), placed in low half of u32x2
    let lo32 = vmovn_u64(k);
    // hi32(k) per lane: shift right by 32, narrow to u32x2
    let hi32 = vshrn_n_u64(k, 32);
    // 32x32 -> 64 multiply, per lane
    let prod = vmull_u32(lo32, hi32);
    // swap pair: [data[1], data[0]] — implements scalar `acc[i ^ 1] += data[i]`
    let swapped = vextq_u64(data, data, 1);
    vaddq_u64(vaddq_u64(acc, swapped), prod)
}

// 64-bit wrapping multiply by a 32-bit constant, emulated via two
// 32x32->64 multiplies. Bit-identical to scalar `v.wrapping_mul(C)` where C
// fits in 32 bits.
#[target_feature(enable = "neon")]
#[inline]
unsafe fn mul_u64x2_by_u32(v: uint64x2_t, c: u32) -> uint64x2_t {
    let c_vec = vdup_n_u32(c);
    let v_lo = vmovn_u64(v);
    let v_hi = vshrn_n_u64(v, 32);
    let prod_lo = vmull_u32(v_lo, c_vec);
    let prod_hi = vmull_u32(v_hi, c_vec);
    vaddq_u64(prod_lo, vshlq_n_u64(prod_hi, 32))
}

#[target_feature(enable = "neon")]
#[inline]
unsafe fn scramble_pair(v: uint64x2_t, s: uint64x2_t) -> uint64x2_t {
    let shifted = vshrq_n_u64(v, 47);
    let mixed = veorq_u64(veorq_u64(v, shifted), s);
    unsafe { mul_u64x2_by_u32(mixed, SCRAMBLE_PRIME as u32) }
}

#[target_feature(enable = "neon")]
#[inline]
unsafe fn scramble_neon(
    a0: &mut uint64x2_t,
    a1: &mut uint64x2_t,
    a2: &mut uint64x2_t,
    a3: &mut uint64x2_t,
    secret_ptr: *const u64,
) {
    let s0 = unsafe { r_u64x2_aligned(secret_ptr) };
    let s1 = unsafe { r_u64x2_aligned(secret_ptr.add(2)) };
    let s2 = unsafe { r_u64x2_aligned(secret_ptr.add(4)) };
    let s3 = unsafe { r_u64x2_aligned(secret_ptr.add(6)) };

    *a0 = unsafe { scramble_pair(*a0, s0) };
    *a1 = unsafe { scramble_pair(*a1, s1) };
    *a2 = unsafe { scramble_pair(*a2, s2) };
    *a3 = unsafe { scramble_pair(*a3, s3) };
}

#[target_feature(enable = "neon")]
#[inline]
unsafe fn mix_stripe_neon(
    a0: &mut uint64x2_t,
    a1: &mut uint64x2_t,
    a2: &mut uint64x2_t,
    a3: &mut uint64x2_t,
    data_ptr: *const u8,
    secret_ptr: *const u64,
) {
    let d0 = unsafe { r_u64x2(data_ptr) };
    let d1 = unsafe { r_u64x2(data_ptr.add(16)) };
    let d2 = unsafe { r_u64x2(data_ptr.add(32)) };
    let d3 = unsafe { r_u64x2(data_ptr.add(48)) };

    let s0 = unsafe { r_u64x2_aligned(secret_ptr) };
    let s1 = unsafe { r_u64x2_aligned(secret_ptr.add(2)) };
    let s2 = unsafe { r_u64x2_aligned(secret_ptr.add(4)) };
    let s3 = unsafe { r_u64x2_aligned(secret_ptr.add(6)) };

    *a0 = unsafe { mix_pair(*a0, d0, s0) };
    *a1 = unsafe { mix_pair(*a1, d1, s1) };
    *a2 = unsafe { mix_pair(*a2, d2, s2) };
    *a3 = unsafe { mix_pair(*a3, d3, s3) };
}

#[target_feature(enable = "neon")]
pub(crate) unsafe fn hash_bytes_long(ptr: *const u8, len: usize, acc_in: u64) -> u64 {
    let init = init_acc(acc_in);

    // Load scalar init into NEON vectors. uint64x2_t pairs match lane pairs
    // (0,1), (2,3), (4,5), (6,7) — the same pairing used by the scalar
    // `i ^ 1` cross-add inside mix_pair.
    let mut a0 = unsafe { vld1q_u64(init.as_ptr()) };
    let mut a1 = unsafe { vld1q_u64(init.as_ptr().add(2)) };
    let mut a2 = unsafe { vld1q_u64(init.as_ptr().add(4)) };
    let mut a3 = unsafe { vld1q_u64(init.as_ptr().add(6)) };

    let secret_base = SECRET_STREAM.as_ptr();
    let scramble_secret = unsafe { secret_base.add(SCRAMBLE_SECRET_OFFSET) };
    let mut secret_off = STRIPE_SECRET_BASE;
    let mut offset = 0usize;
    let mut stripe_in_block = 0usize;

    while offset + STRIPE_BYTES <= len {
        unsafe {
            mix_stripe_neon(
                &mut a0,
                &mut a1,
                &mut a2,
                &mut a3,
                ptr.add(offset),
                secret_base.add(secret_off),
            );
        }
        offset += STRIPE_BYTES;
        secret_off += 1;
        stripe_in_block += 1;

        if stripe_in_block == BLOCK_STRIPES {
            unsafe {
                scramble_neon(&mut a0, &mut a1, &mut a2, &mut a3, scramble_secret);
            }
            stripe_in_block = 0;
            secret_off = STRIPE_SECRET_BASE;
        }
    }

    unsafe {
        mix_stripe_neon(
            &mut a0,
            &mut a1,
            &mut a2,
            &mut a3,
            ptr.add(len - STRIPE_BYTES),
            secret_base.add(FINAL_STRIPE_SECRET_OFFSET),
        );
    }

    // Extract back to scalar lanes; merge uses the same folded_multiply chain
    // as scalar so output is bit-identical.
    let mut acc = [0u64; 8];
    unsafe {
        vst1q_u64(acc.as_mut_ptr(), a0);
        vst1q_u64(acc.as_mut_ptr().add(2), a1);
        vst1q_u64(acc.as_mut_ptr().add(4), a2);
        vst1q_u64(acc.as_mut_ptr().add(6), a3);
    }

    merge_acc(&acc, len, acc_in)
}