#[cfg(target_arch = "aarch64")]
#[allow(unsafe_op_in_unsafe_fn)]
mod neon {
use std::arch::aarch64::*;
#[target_feature(enable = "neon")]
pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 16;
let remainder = count % 16;
for chunk in 0..chunks {
let base = chunk * 16;
let in_ptr = input.as_ptr().add(base);
let bytes = vld1q_u8(in_ptr);
let low8 = vget_low_u8(bytes);
let high8 = vget_high_u8(bytes);
let low16 = vmovl_u8(low8);
let high16 = vmovl_u8(high8);
let v0 = vmovl_u16(vget_low_u16(low16));
let v1 = vmovl_u16(vget_high_u16(low16));
let v2 = vmovl_u16(vget_low_u16(high16));
let v3 = vmovl_u16(vget_high_u16(high16));
let out_ptr = output.as_mut_ptr().add(base);
vst1q_u32(out_ptr, v0);
vst1q_u32(out_ptr.add(4), v1);
vst1q_u32(out_ptr.add(8), v2);
vst1q_u32(out_ptr.add(12), v3);
}
let base = chunks * 16;
for i in 0..remainder {
output[base + i] = input[base + i] as u32;
}
}
#[target_feature(enable = "neon")]
pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 8;
let remainder = count % 8;
for chunk in 0..chunks {
let base = chunk * 8;
let in_ptr = input.as_ptr().add(base * 2) as *const u16;
let vals = vld1q_u16(in_ptr);
let low = vmovl_u16(vget_low_u16(vals));
let high = vmovl_u16(vget_high_u16(vals));
let out_ptr = output.as_mut_ptr().add(base);
vst1q_u32(out_ptr, low);
vst1q_u32(out_ptr.add(4), high);
}
let base = chunks * 8;
for i in 0..remainder {
let idx = (base + i) * 2;
output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
}
}
#[target_feature(enable = "neon")]
pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 4;
let remainder = count % 4;
let in_ptr = input.as_ptr() as *const u32;
let out_ptr = output.as_mut_ptr();
for chunk in 0..chunks {
let vals = vld1q_u32(in_ptr.add(chunk * 4));
vst1q_u32(out_ptr.add(chunk * 4), vals);
}
let base = chunks * 4;
for i in 0..remainder {
let idx = (base + i) * 4;
output[base + i] =
u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
}
}
#[inline]
#[target_feature(enable = "neon")]
unsafe fn prefix_sum_4(v: uint32x4_t) -> uint32x4_t {
let shifted1 = vextq_u32(vdupq_n_u32(0), v, 3);
let sum1 = vaddq_u32(v, shifted1);
let shifted2 = vextq_u32(vdupq_n_u32(0), sum1, 2);
vaddq_u32(sum1, shifted2)
}
#[target_feature(enable = "neon")]
pub unsafe fn delta_decode(
output: &mut [u32],
deltas: &[u32],
first_doc_id: u32,
count: usize,
) {
if count == 0 {
return;
}
output[0] = first_doc_id;
if count == 1 {
return;
}
let ones = vdupq_n_u32(1);
let mut carry = vdupq_n_u32(first_doc_id);
let full_groups = (count - 1) / 4;
let remainder = (count - 1) % 4;
for group in 0..full_groups {
let base = group * 4;
let d = vld1q_u32(deltas[base..].as_ptr());
let gaps = vaddq_u32(d, ones);
let prefix = prefix_sum_4(gaps);
let result = vaddq_u32(prefix, carry);
vst1q_u32(output[base + 1..].as_mut_ptr(), result);
carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
}
let base = full_groups * 4;
let mut scalar_carry = vgetq_lane_u32(carry, 0);
for j in 0..remainder {
scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "neon")]
pub unsafe fn add_one(values: &mut [u32], count: usize) {
let ones = vdupq_n_u32(1);
let chunks = count / 4;
let remainder = count % 4;
for chunk in 0..chunks {
let base = chunk * 4;
let ptr = values.as_mut_ptr().add(base);
let v = vld1q_u32(ptr);
let result = vaddq_u32(v, ones);
vst1q_u32(ptr, result);
}
let base = chunks * 4;
for i in 0..remainder {
values[base + i] += 1;
}
}
#[target_feature(enable = "neon")]
pub unsafe fn unpack_8bit_delta_decode(
input: &[u8],
output: &mut [u32],
first_value: u32,
count: usize,
) {
output[0] = first_value;
if count <= 1 {
return;
}
let ones = vdupq_n_u32(1);
let mut carry = vdupq_n_u32(first_value);
let full_groups = (count - 1) / 4;
let remainder = (count - 1) % 4;
for group in 0..full_groups {
let base = group * 4;
let raw = std::ptr::read_unaligned(input.as_ptr().add(base) as *const u32);
let bytes = vreinterpret_u8_u32(vdup_n_u32(raw));
let u16s = vmovl_u8(bytes); let d = vmovl_u16(vget_low_u16(u16s));
let gaps = vaddq_u32(d, ones);
let prefix = prefix_sum_4(gaps);
let result = vaddq_u32(prefix, carry);
vst1q_u32(output[base + 1..].as_mut_ptr(), result);
carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
}
let base = full_groups * 4;
let mut scalar_carry = vgetq_lane_u32(carry, 0);
for j in 0..remainder {
scalar_carry = scalar_carry
.wrapping_add(input[base + j] as u32)
.wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "neon")]
pub unsafe fn unpack_16bit_delta_decode(
input: &[u8],
output: &mut [u32],
first_value: u32,
count: usize,
) {
output[0] = first_value;
if count <= 1 {
return;
}
let ones = vdupq_n_u32(1);
let mut carry = vdupq_n_u32(first_value);
let full_groups = (count - 1) / 4;
let remainder = (count - 1) % 4;
for group in 0..full_groups {
let base = group * 4;
let in_ptr = input.as_ptr().add(base * 2) as *const u16;
let vals = vld1_u16(in_ptr);
let d = vmovl_u16(vals);
let gaps = vaddq_u32(d, ones);
let prefix = prefix_sum_4(gaps);
let result = vaddq_u32(prefix, carry);
vst1q_u32(output[base + 1..].as_mut_ptr(), result);
carry = vdupq_n_u32(vgetq_lane_u32(result, 3));
}
let base = full_groups * 4;
let mut scalar_carry = vgetq_lane_u32(carry, 0);
for j in 0..remainder {
let idx = (base + j) * 2;
let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "neon")]
pub unsafe fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let chunks16 = len / 16;
let mut total = 0u32;
let mut i = 0;
while i < chunks16 {
let batch_end = (i + 31).min(chunks16);
let mut acc = vdupq_n_u8(0);
for j in i..batch_end {
let off = j * 16;
let va = vld1q_u8(a.as_ptr().add(off));
let vb = vld1q_u8(b.as_ptr().add(off));
let popcnt = vcntq_u8(veorq_u8(va, vb));
acc = vaddq_u8(acc, popcnt);
}
let sum64 = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(acc)));
total += vgetq_lane_u64(sum64, 0) as u32 + vgetq_lane_u64(sum64, 1) as u32;
i = batch_end;
}
let base = chunks16 * 16;
for k in base..len {
total += (a[k] ^ b[k]).count_ones();
}
total
}
#[inline]
pub fn is_available() -> bool {
true
}
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_op_in_unsafe_fn)]
mod sse {
use std::arch::x86_64::*;
#[target_feature(enable = "sse2", enable = "sse4.1")]
pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 16;
let remainder = count % 16;
for chunk in 0..chunks {
let base = chunk * 16;
let in_ptr = input.as_ptr().add(base);
let bytes = _mm_loadu_si128(in_ptr as *const __m128i);
let v0 = _mm_cvtepu8_epi32(bytes);
let v1 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 4));
let v2 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 8));
let v3 = _mm_cvtepu8_epi32(_mm_srli_si128(bytes, 12));
let out_ptr = output.as_mut_ptr().add(base);
_mm_storeu_si128(out_ptr as *mut __m128i, v0);
_mm_storeu_si128(out_ptr.add(4) as *mut __m128i, v1);
_mm_storeu_si128(out_ptr.add(8) as *mut __m128i, v2);
_mm_storeu_si128(out_ptr.add(12) as *mut __m128i, v3);
}
let base = chunks * 16;
for i in 0..remainder {
output[base + i] = input[base + i] as u32;
}
}
#[target_feature(enable = "sse2", enable = "sse4.1")]
pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 8;
let remainder = count % 8;
for chunk in 0..chunks {
let base = chunk * 8;
let in_ptr = input.as_ptr().add(base * 2);
let vals = _mm_loadu_si128(in_ptr as *const __m128i);
let low = _mm_cvtepu16_epi32(vals);
let high = _mm_cvtepu16_epi32(_mm_srli_si128(vals, 8));
let out_ptr = output.as_mut_ptr().add(base);
_mm_storeu_si128(out_ptr as *mut __m128i, low);
_mm_storeu_si128(out_ptr.add(4) as *mut __m128i, high);
}
let base = chunks * 8;
for i in 0..remainder {
let idx = (base + i) * 2;
output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
}
}
#[target_feature(enable = "sse2")]
pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 4;
let remainder = count % 4;
let in_ptr = input.as_ptr() as *const __m128i;
let out_ptr = output.as_mut_ptr() as *mut __m128i;
for chunk in 0..chunks {
let vals = _mm_loadu_si128(in_ptr.add(chunk));
_mm_storeu_si128(out_ptr.add(chunk), vals);
}
let base = chunks * 4;
for i in 0..remainder {
let idx = (base + i) * 4;
output[base + i] =
u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
}
}
#[inline]
#[target_feature(enable = "sse2")]
unsafe fn prefix_sum_4(v: __m128i) -> __m128i {
let shifted1 = _mm_slli_si128(v, 4);
let sum1 = _mm_add_epi32(v, shifted1);
let shifted2 = _mm_slli_si128(sum1, 8);
_mm_add_epi32(sum1, shifted2)
}
#[target_feature(enable = "sse2", enable = "sse4.1")]
pub unsafe fn delta_decode(
output: &mut [u32],
deltas: &[u32],
first_doc_id: u32,
count: usize,
) {
if count == 0 {
return;
}
output[0] = first_doc_id;
if count == 1 {
return;
}
let ones = _mm_set1_epi32(1);
let mut carry = _mm_set1_epi32(first_doc_id as i32);
let full_groups = (count - 1) / 4;
let remainder = (count - 1) % 4;
for group in 0..full_groups {
let base = group * 4;
let d = _mm_loadu_si128(deltas[base..].as_ptr() as *const __m128i);
let gaps = _mm_add_epi32(d, ones);
let prefix = prefix_sum_4(gaps);
let result = _mm_add_epi32(prefix, carry);
_mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
carry = _mm_shuffle_epi32(result, 0xFF); }
let base = full_groups * 4;
let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
for j in 0..remainder {
scalar_carry = scalar_carry.wrapping_add(deltas[base + j]).wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "sse2")]
pub unsafe fn add_one(values: &mut [u32], count: usize) {
let ones = _mm_set1_epi32(1);
let chunks = count / 4;
let remainder = count % 4;
for chunk in 0..chunks {
let base = chunk * 4;
let ptr = values.as_mut_ptr().add(base) as *mut __m128i;
let v = _mm_loadu_si128(ptr);
let result = _mm_add_epi32(v, ones);
_mm_storeu_si128(ptr, result);
}
let base = chunks * 4;
for i in 0..remainder {
values[base + i] += 1;
}
}
#[target_feature(enable = "sse2", enable = "sse4.1")]
pub unsafe fn unpack_8bit_delta_decode(
input: &[u8],
output: &mut [u32],
first_value: u32,
count: usize,
) {
output[0] = first_value;
if count <= 1 {
return;
}
let ones = _mm_set1_epi32(1);
let mut carry = _mm_set1_epi32(first_value as i32);
let full_groups = (count - 1) / 4;
let remainder = (count - 1) % 4;
for group in 0..full_groups {
let base = group * 4;
let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
input.as_ptr().add(base) as *const i32
));
let d = _mm_cvtepu8_epi32(bytes);
let gaps = _mm_add_epi32(d, ones);
let prefix = prefix_sum_4(gaps);
let result = _mm_add_epi32(prefix, carry);
_mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
carry = _mm_shuffle_epi32(result, 0xFF);
}
let base = full_groups * 4;
let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
for j in 0..remainder {
scalar_carry = scalar_carry
.wrapping_add(input[base + j] as u32)
.wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "sse2", enable = "sse4.1")]
pub unsafe fn unpack_16bit_delta_decode(
input: &[u8],
output: &mut [u32],
first_value: u32,
count: usize,
) {
output[0] = first_value;
if count <= 1 {
return;
}
let ones = _mm_set1_epi32(1);
let mut carry = _mm_set1_epi32(first_value as i32);
let full_groups = (count - 1) / 4;
let remainder = (count - 1) % 4;
for group in 0..full_groups {
let base = group * 4;
let in_ptr = input.as_ptr().add(base * 2);
let vals = _mm_loadl_epi64(in_ptr as *const __m128i); let d = _mm_cvtepu16_epi32(vals);
let gaps = _mm_add_epi32(d, ones);
let prefix = prefix_sum_4(gaps);
let result = _mm_add_epi32(prefix, carry);
_mm_storeu_si128(output[base + 1..].as_mut_ptr() as *mut __m128i, result);
carry = _mm_shuffle_epi32(result, 0xFF);
}
let base = full_groups * 4;
let mut scalar_carry = _mm_extract_epi32(carry, 0) as u32;
for j in 0..remainder {
let idx = (base + j) * 2;
let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[inline]
pub fn is_available() -> bool {
is_x86_feature_detected!("sse4.1")
}
}
#[cfg(target_arch = "x86_64")]
#[allow(unsafe_op_in_unsafe_fn)]
mod avx2 {
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 32;
let remainder = count % 32;
for chunk in 0..chunks {
let base = chunk * 32;
let in_ptr = input.as_ptr().add(base);
let bytes_lo = _mm_loadu_si128(in_ptr as *const __m128i);
let bytes_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
let v0 = _mm256_cvtepu8_epi32(bytes_lo);
let v1 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_lo, 8));
let v2 = _mm256_cvtepu8_epi32(bytes_hi);
let v3 = _mm256_cvtepu8_epi32(_mm_srli_si128(bytes_hi, 8));
let out_ptr = output.as_mut_ptr().add(base);
_mm256_storeu_si256(out_ptr as *mut __m256i, v0);
_mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
_mm256_storeu_si256(out_ptr.add(16) as *mut __m256i, v2);
_mm256_storeu_si256(out_ptr.add(24) as *mut __m256i, v3);
}
let base = chunks * 32;
for i in 0..remainder {
output[base + i] = input[base + i] as u32;
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 16;
let remainder = count % 16;
for chunk in 0..chunks {
let base = chunk * 16;
let in_ptr = input.as_ptr().add(base * 2);
let vals_lo = _mm_loadu_si128(in_ptr as *const __m128i);
let vals_hi = _mm_loadu_si128(in_ptr.add(16) as *const __m128i);
let v0 = _mm256_cvtepu16_epi32(vals_lo);
let v1 = _mm256_cvtepu16_epi32(vals_hi);
let out_ptr = output.as_mut_ptr().add(base);
_mm256_storeu_si256(out_ptr as *mut __m256i, v0);
_mm256_storeu_si256(out_ptr.add(8) as *mut __m256i, v1);
}
let base = chunks * 16;
for i in 0..remainder {
let idx = (base + i) * 2;
output[base + i] = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
let chunks = count / 8;
let remainder = count % 8;
let in_ptr = input.as_ptr() as *const __m256i;
let out_ptr = output.as_mut_ptr() as *mut __m256i;
for chunk in 0..chunks {
let vals = _mm256_loadu_si256(in_ptr.add(chunk));
_mm256_storeu_si256(out_ptr.add(chunk), vals);
}
let base = chunks * 8;
for i in 0..remainder {
let idx = (base + i) * 4;
output[base + i] =
u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn add_one(values: &mut [u32], count: usize) {
let ones = _mm256_set1_epi32(1);
let chunks = count / 8;
let remainder = count % 8;
for chunk in 0..chunks {
let base = chunk * 8;
let ptr = values.as_mut_ptr().add(base) as *mut __m256i;
let v = _mm256_loadu_si256(ptr);
let result = _mm256_add_epi32(v, ones);
_mm256_storeu_si256(ptr, result);
}
let base = chunks * 8;
for i in 0..remainder {
values[base + i] += 1;
}
}
#[inline]
#[target_feature(enable = "avx2")]
unsafe fn prefix_sum_8(v: __m256i) -> __m256i {
let s1 = _mm256_slli_si256(v, 4);
let r1 = _mm256_add_epi32(v, s1);
let s2 = _mm256_slli_si256(r1, 8);
let r2 = _mm256_add_epi32(r1, s2);
let lo_sum = _mm256_shuffle_epi32(r2, 0xFF);
let carry = _mm256_permute2x128_si256(lo_sum, lo_sum, 0x00);
let carry_hi = _mm256_blend_epi32::<0xF0>(_mm256_setzero_si256(), carry);
_mm256_add_epi32(r2, carry_hi)
}
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_8bit_delta_decode(
input: &[u8],
output: &mut [u32],
first_value: u32,
count: usize,
) {
output[0] = first_value;
if count <= 1 {
return;
}
let ones = _mm256_set1_epi32(1);
let mut carry = _mm256_set1_epi32(first_value as i32);
let broadcast_idx = _mm256_set1_epi32(7);
let full_groups = (count - 1) / 8;
let remainder = (count - 1) % 8;
for group in 0..full_groups {
let base = group * 8;
let bytes = _mm_loadl_epi64(input.as_ptr().add(base) as *const __m128i);
let d = _mm256_cvtepu8_epi32(bytes);
let gaps = _mm256_add_epi32(d, ones);
let prefix = prefix_sum_8(gaps);
let result = _mm256_add_epi32(prefix, carry);
_mm256_storeu_si256(output[base + 1..].as_mut_ptr() as *mut __m256i, result);
carry = _mm256_permutevar8x32_epi32(result, broadcast_idx);
}
let base = full_groups * 8;
let mut scalar_carry = _mm256_extract_epi32::<0>(carry) as u32;
for j in 0..remainder {
scalar_carry = scalar_carry
.wrapping_add(input[base + j] as u32)
.wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn unpack_16bit_delta_decode(
input: &[u8],
output: &mut [u32],
first_value: u32,
count: usize,
) {
output[0] = first_value;
if count <= 1 {
return;
}
let ones = _mm256_set1_epi32(1);
let mut carry = _mm256_set1_epi32(first_value as i32);
let broadcast_idx = _mm256_set1_epi32(7);
let full_groups = (count - 1) / 8;
let remainder = (count - 1) % 8;
for group in 0..full_groups {
let base = group * 8;
let in_ptr = input.as_ptr().add(base * 2);
let vals = _mm_loadu_si128(in_ptr as *const __m128i);
let d = _mm256_cvtepu16_epi32(vals);
let gaps = _mm256_add_epi32(d, ones);
let prefix = prefix_sum_8(gaps);
let result = _mm256_add_epi32(prefix, carry);
_mm256_storeu_si256(output[base + 1..].as_mut_ptr() as *mut __m256i, result);
carry = _mm256_permutevar8x32_epi32(result, broadcast_idx);
}
let base = full_groups * 8;
let mut scalar_carry = _mm256_extract_epi32::<0>(carry) as u32;
for j in 0..remainder {
let idx = (base + j) * 2;
let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
scalar_carry = scalar_carry.wrapping_add(delta).wrapping_add(1);
output[base + j + 1] = scalar_carry;
}
}
#[target_feature(enable = "avx2")]
pub unsafe fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let chunks32 = len / 32;
let low_mask = _mm256_set1_epi8(0x0f);
let lookup = _mm256_setr_epi8(
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2,
3, 3, 4,
);
let mut total = 0u64;
let mut i = 0;
while i < chunks32 {
let batch_end = (i + 31).min(chunks32);
let mut acc = _mm256_setzero_si256();
for j in i..batch_end {
let off = j * 32;
let va = _mm256_loadu_si256(a.as_ptr().add(off) as *const __m256i);
let vb = _mm256_loadu_si256(b.as_ptr().add(off) as *const __m256i);
let xored = _mm256_xor_si256(va, vb);
let lo = _mm256_and_si256(xored, low_mask);
let hi = _mm256_and_si256(_mm256_srli_epi16(xored, 4), low_mask);
let popcnt = _mm256_add_epi8(
_mm256_shuffle_epi8(lookup, lo),
_mm256_shuffle_epi8(lookup, hi),
);
acc = _mm256_add_epi8(acc, popcnt);
}
let sad = _mm256_sad_epu8(acc, _mm256_setzero_si256());
total += _mm256_extract_epi64(sad, 0) as u64
+ _mm256_extract_epi64(sad, 1) as u64
+ _mm256_extract_epi64(sad, 2) as u64
+ _mm256_extract_epi64(sad, 3) as u64;
i = batch_end;
}
let base = chunks32 * 32;
for k in base..len {
total += (a[k] ^ b[k]).count_ones() as u64;
}
total as u32
}
#[inline]
pub fn is_available() -> bool {
is_x86_feature_detected!("avx2")
}
}
#[allow(dead_code)]
mod scalar {
#[inline]
pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
for i in 0..count {
output[i] = input[i] as u32;
}
}
#[inline]
pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
for (i, out) in output.iter_mut().enumerate().take(count) {
let idx = i * 2;
*out = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
}
}
#[inline]
pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
for (i, out) in output.iter_mut().enumerate().take(count) {
let idx = i * 4;
*out = u32::from_le_bytes([input[idx], input[idx + 1], input[idx + 2], input[idx + 3]]);
}
}
#[inline]
pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_doc_id: u32, count: usize) {
if count == 0 {
return;
}
output[0] = first_doc_id;
let mut carry = first_doc_id;
for i in 0..count - 1 {
carry = carry.wrapping_add(deltas[i]).wrapping_add(1);
output[i + 1] = carry;
}
}
#[inline]
pub fn add_one(values: &mut [u32], count: usize) {
for val in values.iter_mut().take(count) {
*val += 1;
}
}
}
#[inline]
pub fn unpack_8bit(input: &[u8], output: &mut [u32], count: usize) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::unpack_8bit(input, output, count);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
unsafe {
avx2::unpack_8bit(input, output, count);
}
return;
}
if sse::is_available() {
unsafe {
sse::unpack_8bit(input, output, count);
}
return;
}
}
scalar::unpack_8bit(input, output, count);
}
#[inline]
pub fn unpack_16bit(input: &[u8], output: &mut [u32], count: usize) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::unpack_16bit(input, output, count);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
unsafe {
avx2::unpack_16bit(input, output, count);
}
return;
}
if sse::is_available() {
unsafe {
sse::unpack_16bit(input, output, count);
}
return;
}
}
scalar::unpack_16bit(input, output, count);
}
#[inline]
pub fn unpack_32bit(input: &[u8], output: &mut [u32], count: usize) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::unpack_32bit(input, output, count);
}
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
unsafe {
avx2::unpack_32bit(input, output, count);
}
} else {
unsafe {
sse::unpack_32bit(input, output, count);
}
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
scalar::unpack_32bit(input, output, count);
}
}
#[inline]
pub fn delta_decode(output: &mut [u32], deltas: &[u32], first_value: u32, count: usize) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::delta_decode(output, deltas, first_value, count);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if sse::is_available() {
unsafe {
sse::delta_decode(output, deltas, first_value, count);
}
return;
}
}
scalar::delta_decode(output, deltas, first_value, count);
}
#[inline]
pub fn add_one(values: &mut [u32], count: usize) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::add_one(values, count);
}
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
unsafe {
avx2::add_one(values, count);
}
} else {
unsafe {
sse::add_one(values, count);
}
}
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
{
scalar::add_one(values, count);
}
}
#[inline]
pub fn bits_needed(val: u32) -> u8 {
if val == 0 {
0
} else {
32 - val.leading_zeros() as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum RoundedBitWidth {
Zero = 0,
Bits8 = 8,
Bits16 = 16,
Bits32 = 32,
}
impl RoundedBitWidth {
#[inline]
pub fn from_exact(bits: u8) -> Self {
match bits {
0 => RoundedBitWidth::Zero,
1..=8 => RoundedBitWidth::Bits8,
9..=16 => RoundedBitWidth::Bits16,
_ => RoundedBitWidth::Bits32,
}
}
#[inline]
pub fn from_u8(bits: u8) -> Self {
match bits {
0 => RoundedBitWidth::Zero,
8 => RoundedBitWidth::Bits8,
16 => RoundedBitWidth::Bits16,
32 => RoundedBitWidth::Bits32,
_ => RoundedBitWidth::Bits32, }
}
#[inline]
pub fn bytes_per_value(self) -> usize {
match self {
RoundedBitWidth::Zero => 0,
RoundedBitWidth::Bits8 => 1,
RoundedBitWidth::Bits16 => 2,
RoundedBitWidth::Bits32 => 4,
}
}
#[inline]
pub fn as_u8(self) -> u8 {
self as u8
}
}
#[inline]
pub fn round_bit_width(bits: u8) -> u8 {
RoundedBitWidth::from_exact(bits).as_u8()
}
#[inline]
pub fn pack_rounded(values: &[u32], bit_width: RoundedBitWidth, output: &mut [u8]) -> usize {
let count = values.len();
match bit_width {
RoundedBitWidth::Zero => 0,
RoundedBitWidth::Bits8 => {
for (i, &v) in values.iter().enumerate() {
output[i] = v as u8;
}
count
}
RoundedBitWidth::Bits16 => {
for (i, &v) in values.iter().enumerate() {
let bytes = (v as u16).to_le_bytes();
output[i * 2] = bytes[0];
output[i * 2 + 1] = bytes[1];
}
count * 2
}
RoundedBitWidth::Bits32 => {
for (i, &v) in values.iter().enumerate() {
let bytes = v.to_le_bytes();
output[i * 4] = bytes[0];
output[i * 4 + 1] = bytes[1];
output[i * 4 + 2] = bytes[2];
output[i * 4 + 3] = bytes[3];
}
count * 4
}
}
}
#[inline]
pub fn unpack_rounded(input: &[u8], bit_width: RoundedBitWidth, output: &mut [u32], count: usize) {
match bit_width {
RoundedBitWidth::Zero => {
for out in output.iter_mut().take(count) {
*out = 0;
}
}
RoundedBitWidth::Bits8 => unpack_8bit(input, output, count),
RoundedBitWidth::Bits16 => unpack_16bit(input, output, count),
RoundedBitWidth::Bits32 => unpack_32bit(input, output, count),
}
}
#[inline]
pub fn unpack_rounded_delta_decode(
input: &[u8],
bit_width: RoundedBitWidth,
output: &mut [u32],
first_value: u32,
count: usize,
) {
match bit_width {
RoundedBitWidth::Zero => {
let mut val = first_value;
for out in output.iter_mut().take(count) {
*out = val;
val = val.wrapping_add(1);
}
}
RoundedBitWidth::Bits8 => unpack_8bit_delta_decode(input, output, first_value, count),
RoundedBitWidth::Bits16 => unpack_16bit_delta_decode(input, output, first_value, count),
RoundedBitWidth::Bits32 => {
if count > 0 {
output[0] = first_value;
let mut carry = first_value;
for i in 0..count - 1 {
let idx = i * 4;
let delta = u32::from_le_bytes([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
]);
carry = carry.wrapping_add(delta).wrapping_add(1);
output[i + 1] = carry;
}
}
}
}
}
#[inline]
pub fn unpack_8bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
if count == 0 {
return;
}
output[0] = first_value;
if count == 1 {
return;
}
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::unpack_8bit_delta_decode(input, output, first_value, count);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
unsafe {
avx2::unpack_8bit_delta_decode(input, output, first_value, count);
}
return;
}
if sse::is_available() {
unsafe {
sse::unpack_8bit_delta_decode(input, output, first_value, count);
}
return;
}
}
let mut carry = first_value;
for i in 0..count - 1 {
carry = carry.wrapping_add(input[i] as u32).wrapping_add(1);
output[i + 1] = carry;
}
}
#[inline]
pub fn unpack_16bit_delta_decode(input: &[u8], output: &mut [u32], first_value: u32, count: usize) {
if count == 0 {
return;
}
output[0] = first_value;
if count == 1 {
return;
}
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
neon::unpack_16bit_delta_decode(input, output, first_value, count);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
unsafe {
avx2::unpack_16bit_delta_decode(input, output, first_value, count);
}
return;
}
if sse::is_available() {
unsafe {
sse::unpack_16bit_delta_decode(input, output, first_value, count);
}
return;
}
}
let mut carry = first_value;
for i in 0..count - 1 {
let idx = i * 2;
let delta = u16::from_le_bytes([input[idx], input[idx + 1]]) as u32;
carry = carry.wrapping_add(delta).wrapping_add(1);
output[i + 1] = carry;
}
}
#[inline]
pub fn unpack_delta_decode(
input: &[u8],
bit_width: u8,
output: &mut [u32],
first_value: u32,
count: usize,
) {
if count == 0 {
return;
}
output[0] = first_value;
if count == 1 {
return;
}
match bit_width {
0 => {
let mut val = first_value;
for item in output.iter_mut().take(count).skip(1) {
val = val.wrapping_add(1);
*item = val;
}
}
8 => unpack_8bit_delta_decode(input, output, first_value, count),
16 => unpack_16bit_delta_decode(input, output, first_value, count),
32 => {
let mut carry = first_value;
for i in 0..count - 1 {
let idx = i * 4;
let delta = u32::from_le_bytes([
input[idx],
input[idx + 1],
input[idx + 2],
input[idx + 3],
]);
carry = carry.wrapping_add(delta).wrapping_add(1);
output[i + 1] = carry;
}
}
_ => {
let mask = (1u64 << bit_width) - 1;
let bit_width_usize = bit_width as usize;
let mut bit_pos = 0usize;
let input_ptr = input.as_ptr();
let mut carry = first_value;
for i in 0..count - 1 {
let byte_idx = bit_pos >> 3;
let bit_offset = bit_pos & 7;
let word = unsafe { (input_ptr.add(byte_idx) as *const u64).read_unaligned() };
let delta = ((word >> bit_offset) & mask) as u32;
carry = carry.wrapping_add(delta).wrapping_add(1);
output[i + 1] = carry;
bit_pos += bit_width_usize;
}
}
}
}
#[inline]
pub fn dequantize_uint8(input: &[u8], output: &mut [f32], scale: f32, min_val: f32, count: usize) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
unsafe {
dequantize_uint8_neon(input, output, scale, min_val, count);
}
return;
}
}
#[cfg(target_arch = "x86_64")]
{
if sse::is_available() {
unsafe {
dequantize_uint8_sse(input, output, scale, min_val, count);
}
return;
}
}
for i in 0..count {
output[i] = input[i] as f32 * scale + min_val;
}
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dequantize_uint8_neon(
input: &[u8],
output: &mut [f32],
scale: f32,
min_val: f32,
count: usize,
) {
use std::arch::aarch64::*;
let scale_v = vdupq_n_f32(scale);
let min_v = vdupq_n_f32(min_val);
let chunks = count / 16;
let remainder = count % 16;
for chunk in 0..chunks {
let base = chunk * 16;
let in_ptr = input.as_ptr().add(base);
let bytes = vld1q_u8(in_ptr);
let low8 = vget_low_u8(bytes);
let high8 = vget_high_u8(bytes);
let low16 = vmovl_u8(low8);
let high16 = vmovl_u8(high8);
let u32_0 = vmovl_u16(vget_low_u16(low16));
let u32_1 = vmovl_u16(vget_high_u16(low16));
let u32_2 = vmovl_u16(vget_low_u16(high16));
let u32_3 = vmovl_u16(vget_high_u16(high16));
let f32_0 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_0), scale_v);
let f32_1 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_1), scale_v);
let f32_2 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_2), scale_v);
let f32_3 = vfmaq_f32(min_v, vcvtq_f32_u32(u32_3), scale_v);
let out_ptr = output.as_mut_ptr().add(base);
vst1q_f32(out_ptr, f32_0);
vst1q_f32(out_ptr.add(4), f32_1);
vst1q_f32(out_ptr.add(8), f32_2);
vst1q_f32(out_ptr.add(12), f32_3);
}
let base = chunks * 16;
for i in 0..remainder {
output[base + i] = input[base + i] as f32 * scale + min_val;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dequantize_uint8_sse(
input: &[u8],
output: &mut [f32],
scale: f32,
min_val: f32,
count: usize,
) {
use std::arch::x86_64::*;
let scale_v = _mm_set1_ps(scale);
let min_v = _mm_set1_ps(min_val);
let chunks = count / 4;
let remainder = count % 4;
for chunk in 0..chunks {
let base = chunk * 4;
let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
input.as_ptr().add(base) as *const i32
));
let ints = _mm_cvtepu8_epi32(bytes);
let floats = _mm_cvtepi32_ps(ints);
let scaled = _mm_add_ps(_mm_mul_ps(floats, scale_v), min_v);
_mm_storeu_ps(output.as_mut_ptr().add(base), scaled);
}
let base = chunks * 4;
for i in 0..remainder {
output[base + i] = input[base + i] as f32 * scale + min_val;
}
}
#[inline]
pub fn dot_product_f32(a: &[f32], b: &[f32], count: usize) -> f32 {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
return unsafe { dot_product_f32_neon(a, b, count) };
}
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { dot_product_f32_avx512(a, b, count) };
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { dot_product_f32_avx2(a, b, count) };
}
if sse::is_available() {
return unsafe { dot_product_f32_sse(a, b, count) };
}
}
let mut sum = 0.0f32;
for i in 0..count {
sum += a[i] * b[i];
}
sum
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dot_product_f32_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::aarch64::*;
let chunks16 = count / 16;
let remainder = count % 16;
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
let mut acc2 = vdupq_n_f32(0.0);
let mut acc3 = vdupq_n_f32(0.0);
for c in 0..chunks16 {
let base = c * 16;
acc0 = vfmaq_f32(
acc0,
vld1q_f32(a.as_ptr().add(base)),
vld1q_f32(b.as_ptr().add(base)),
);
acc1 = vfmaq_f32(
acc1,
vld1q_f32(a.as_ptr().add(base + 4)),
vld1q_f32(b.as_ptr().add(base + 4)),
);
acc2 = vfmaq_f32(
acc2,
vld1q_f32(a.as_ptr().add(base + 8)),
vld1q_f32(b.as_ptr().add(base + 8)),
);
acc3 = vfmaq_f32(
acc3,
vld1q_f32(a.as_ptr().add(base + 12)),
vld1q_f32(b.as_ptr().add(base + 12)),
);
}
let acc = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
let mut sum = vaddvq_f32(acc);
let base = chunks16 * 16;
for i in 0..remainder {
sum += a[base + i] * b[base + i];
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dot_product_f32_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::x86_64::*;
let chunks32 = count / 32;
let remainder = count % 32;
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for c in 0..chunks32 {
let base = c * 32;
acc0 = _mm256_fmadd_ps(
_mm256_loadu_ps(a.as_ptr().add(base)),
_mm256_loadu_ps(b.as_ptr().add(base)),
acc0,
);
acc1 = _mm256_fmadd_ps(
_mm256_loadu_ps(a.as_ptr().add(base + 8)),
_mm256_loadu_ps(b.as_ptr().add(base + 8)),
acc1,
);
acc2 = _mm256_fmadd_ps(
_mm256_loadu_ps(a.as_ptr().add(base + 16)),
_mm256_loadu_ps(b.as_ptr().add(base + 16)),
acc2,
);
acc3 = _mm256_fmadd_ps(
_mm256_loadu_ps(a.as_ptr().add(base + 24)),
_mm256_loadu_ps(b.as_ptr().add(base + 24)),
acc3,
);
}
let acc = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
let hi = _mm256_extractf128_ps(acc, 1);
let lo = _mm256_castps256_ps128(acc);
let sum128 = _mm_add_ps(lo, hi);
let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let final_sum = _mm_add_ss(sums, shuf2);
let mut sum = _mm_cvtss_f32(final_sum);
let base = chunks32 * 32;
for i in 0..remainder {
sum += a[base + i] * b[base + i];
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dot_product_f32_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::x86_64::*;
let chunks = count / 4;
let remainder = count % 4;
let mut acc = _mm_setzero_ps();
for chunk in 0..chunks {
let base = chunk * 4;
let va = _mm_loadu_ps(a.as_ptr().add(base));
let vb = _mm_loadu_ps(b.as_ptr().add(base));
acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
}
let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01); let sums = _mm_add_ps(acc, shuf); let shuf2 = _mm_movehl_ps(sums, sums); let final_sum = _mm_add_ss(sums, shuf2);
let mut sum = _mm_cvtss_f32(final_sum);
let base = chunks * 4;
for i in 0..remainder {
sum += a[base + i] * b[base + i];
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dot_product_f32_avx512(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::x86_64::*;
let chunks64 = count / 64;
let remainder = count % 64;
let mut acc0 = _mm512_setzero_ps();
let mut acc1 = _mm512_setzero_ps();
let mut acc2 = _mm512_setzero_ps();
let mut acc3 = _mm512_setzero_ps();
for c in 0..chunks64 {
let base = c * 64;
acc0 = _mm512_fmadd_ps(
_mm512_loadu_ps(a.as_ptr().add(base)),
_mm512_loadu_ps(b.as_ptr().add(base)),
acc0,
);
acc1 = _mm512_fmadd_ps(
_mm512_loadu_ps(a.as_ptr().add(base + 16)),
_mm512_loadu_ps(b.as_ptr().add(base + 16)),
acc1,
);
acc2 = _mm512_fmadd_ps(
_mm512_loadu_ps(a.as_ptr().add(base + 32)),
_mm512_loadu_ps(b.as_ptr().add(base + 32)),
acc2,
);
acc3 = _mm512_fmadd_ps(
_mm512_loadu_ps(a.as_ptr().add(base + 48)),
_mm512_loadu_ps(b.as_ptr().add(base + 48)),
acc3,
);
}
let acc = _mm512_add_ps(_mm512_add_ps(acc0, acc1), _mm512_add_ps(acc2, acc3));
let mut sum = _mm512_reduce_add_ps(acc);
let base = chunks64 * 64;
for i in 0..remainder {
sum += a[base + i] * b[base + i];
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx512f")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_avx512(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
use std::arch::x86_64::*;
let chunks64 = count / 64;
let remainder = count % 64;
let mut d0 = _mm512_setzero_ps();
let mut d1 = _mm512_setzero_ps();
let mut d2 = _mm512_setzero_ps();
let mut d3 = _mm512_setzero_ps();
let mut n0 = _mm512_setzero_ps();
let mut n1 = _mm512_setzero_ps();
let mut n2 = _mm512_setzero_ps();
let mut n3 = _mm512_setzero_ps();
for c in 0..chunks64 {
let base = c * 64;
let vb0 = _mm512_loadu_ps(b.as_ptr().add(base));
d0 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base)), vb0, d0);
n0 = _mm512_fmadd_ps(vb0, vb0, n0);
let vb1 = _mm512_loadu_ps(b.as_ptr().add(base + 16));
d1 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 16)), vb1, d1);
n1 = _mm512_fmadd_ps(vb1, vb1, n1);
let vb2 = _mm512_loadu_ps(b.as_ptr().add(base + 32));
d2 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 32)), vb2, d2);
n2 = _mm512_fmadd_ps(vb2, vb2, n2);
let vb3 = _mm512_loadu_ps(b.as_ptr().add(base + 48));
d3 = _mm512_fmadd_ps(_mm512_loadu_ps(a.as_ptr().add(base + 48)), vb3, d3);
n3 = _mm512_fmadd_ps(vb3, vb3, n3);
}
let acc_dot = _mm512_add_ps(_mm512_add_ps(d0, d1), _mm512_add_ps(d2, d3));
let acc_norm = _mm512_add_ps(_mm512_add_ps(n0, n1), _mm512_add_ps(n2, n3));
let mut dot = _mm512_reduce_add_ps(acc_dot);
let mut norm = _mm512_reduce_add_ps(acc_norm);
let base = chunks64 * 64;
for i in 0..remainder {
dot += a[base + i] * b[base + i];
norm += b[base + i] * b[base + i];
}
(dot, norm)
}
#[inline]
pub fn max_f32(values: &[f32], count: usize) -> f32 {
if count == 0 {
return f32::NEG_INFINITY;
}
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
return unsafe { max_f32_neon(values, count) };
}
}
#[cfg(target_arch = "x86_64")]
{
if sse::is_available() {
return unsafe { max_f32_sse(values, count) };
}
}
values[..count]
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn max_f32_neon(values: &[f32], count: usize) -> f32 {
use std::arch::aarch64::*;
let chunks = count / 4;
let remainder = count % 4;
let mut max_v = vdupq_n_f32(f32::NEG_INFINITY);
for chunk in 0..chunks {
let base = chunk * 4;
let v = vld1q_f32(values.as_ptr().add(base));
max_v = vmaxq_f32(max_v, v);
}
let mut max_val = vmaxvq_f32(max_v);
let base = chunks * 4;
for i in 0..remainder {
max_val = max_val.max(values[base + i]);
}
max_val
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn max_f32_sse(values: &[f32], count: usize) -> f32 {
use std::arch::x86_64::*;
let chunks = count / 4;
let remainder = count % 4;
let mut max_v = _mm_set1_ps(f32::NEG_INFINITY);
for chunk in 0..chunks {
let base = chunk * 4;
let v = _mm_loadu_ps(values.as_ptr().add(base));
max_v = _mm_max_ps(max_v, v);
}
let shuf = _mm_shuffle_ps(max_v, max_v, 0b10_11_00_01); let max1 = _mm_max_ps(max_v, shuf); let shuf2 = _mm_movehl_ps(max1, max1); let final_max = _mm_max_ss(max1, shuf2);
let mut max_val = _mm_cvtss_f32(final_max);
let base = chunks * 4;
for i in 0..remainder {
max_val = max_val.max(values[base + i]);
}
max_val
}
#[inline]
fn fused_dot_norm(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
return unsafe { fused_dot_norm_neon(a, b, count) };
}
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx512f") {
return unsafe { fused_dot_norm_avx512(a, b, count) };
}
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe { fused_dot_norm_avx2(a, b, count) };
}
if sse::is_available() {
return unsafe { fused_dot_norm_sse(a, b, count) };
}
}
let mut dot = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..count {
dot += a[i] * b[i];
norm_b += b[i] * b[i];
}
(dot, norm_b)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_neon(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
use std::arch::aarch64::*;
let chunks16 = count / 16;
let remainder = count % 16;
let mut d0 = vdupq_n_f32(0.0);
let mut d1 = vdupq_n_f32(0.0);
let mut d2 = vdupq_n_f32(0.0);
let mut d3 = vdupq_n_f32(0.0);
let mut n0 = vdupq_n_f32(0.0);
let mut n1 = vdupq_n_f32(0.0);
let mut n2 = vdupq_n_f32(0.0);
let mut n3 = vdupq_n_f32(0.0);
for c in 0..chunks16 {
let base = c * 16;
let va0 = vld1q_f32(a.as_ptr().add(base));
let vb0 = vld1q_f32(b.as_ptr().add(base));
d0 = vfmaq_f32(d0, va0, vb0);
n0 = vfmaq_f32(n0, vb0, vb0);
let va1 = vld1q_f32(a.as_ptr().add(base + 4));
let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
d1 = vfmaq_f32(d1, va1, vb1);
n1 = vfmaq_f32(n1, vb1, vb1);
let va2 = vld1q_f32(a.as_ptr().add(base + 8));
let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
d2 = vfmaq_f32(d2, va2, vb2);
n2 = vfmaq_f32(n2, vb2, vb2);
let va3 = vld1q_f32(a.as_ptr().add(base + 12));
let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
d3 = vfmaq_f32(d3, va3, vb3);
n3 = vfmaq_f32(n3, vb3, vb3);
}
let acc_dot = vaddq_f32(vaddq_f32(d0, d1), vaddq_f32(d2, d3));
let acc_norm = vaddq_f32(vaddq_f32(n0, n1), vaddq_f32(n2, n3));
let mut dot = vaddvq_f32(acc_dot);
let mut norm = vaddvq_f32(acc_norm);
let base = chunks16 * 16;
for i in 0..remainder {
dot += a[base + i] * b[base + i];
norm += b[base + i] * b[base + i];
}
(dot, norm)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_avx2(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
use std::arch::x86_64::*;
let chunks32 = count / 32;
let remainder = count % 32;
let mut d0 = _mm256_setzero_ps();
let mut d1 = _mm256_setzero_ps();
let mut d2 = _mm256_setzero_ps();
let mut d3 = _mm256_setzero_ps();
let mut n0 = _mm256_setzero_ps();
let mut n1 = _mm256_setzero_ps();
let mut n2 = _mm256_setzero_ps();
let mut n3 = _mm256_setzero_ps();
for c in 0..chunks32 {
let base = c * 32;
let vb0 = _mm256_loadu_ps(b.as_ptr().add(base));
d0 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base)), vb0, d0);
n0 = _mm256_fmadd_ps(vb0, vb0, n0);
let vb1 = _mm256_loadu_ps(b.as_ptr().add(base + 8));
d1 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 8)), vb1, d1);
n1 = _mm256_fmadd_ps(vb1, vb1, n1);
let vb2 = _mm256_loadu_ps(b.as_ptr().add(base + 16));
d2 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 16)), vb2, d2);
n2 = _mm256_fmadd_ps(vb2, vb2, n2);
let vb3 = _mm256_loadu_ps(b.as_ptr().add(base + 24));
d3 = _mm256_fmadd_ps(_mm256_loadu_ps(a.as_ptr().add(base + 24)), vb3, d3);
n3 = _mm256_fmadd_ps(vb3, vb3, n3);
}
let acc_dot = _mm256_add_ps(_mm256_add_ps(d0, d1), _mm256_add_ps(d2, d3));
let acc_norm = _mm256_add_ps(_mm256_add_ps(n0, n1), _mm256_add_ps(n2, n3));
let hi_d = _mm256_extractf128_ps(acc_dot, 1);
let lo_d = _mm256_castps256_ps128(acc_dot);
let sum_d = _mm_add_ps(lo_d, hi_d);
let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
let sums_d = _mm_add_ps(sum_d, shuf_d);
let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
let hi_n = _mm256_extractf128_ps(acc_norm, 1);
let lo_n = _mm256_castps256_ps128(acc_norm);
let sum_n = _mm_add_ps(lo_n, hi_n);
let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
let sums_n = _mm_add_ps(sum_n, shuf_n);
let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
let base = chunks32 * 32;
for i in 0..remainder {
dot += a[base + i] * b[base + i];
norm += b[base + i] * b[base + i];
}
(dot, norm)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_sse(a: &[f32], b: &[f32], count: usize) -> (f32, f32) {
use std::arch::x86_64::*;
let chunks = count / 4;
let remainder = count % 4;
let mut acc_dot = _mm_setzero_ps();
let mut acc_norm = _mm_setzero_ps();
for chunk in 0..chunks {
let base = chunk * 4;
let va = _mm_loadu_ps(a.as_ptr().add(base));
let vb = _mm_loadu_ps(b.as_ptr().add(base));
acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
}
let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
let sums_d = _mm_add_ps(acc_dot, shuf_d);
let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
let final_d = _mm_add_ss(sums_d, shuf2_d);
let mut dot = _mm_cvtss_f32(final_d);
let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
let sums_n = _mm_add_ps(acc_norm, shuf_n);
let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
let final_n = _mm_add_ss(sums_n, shuf2_n);
let mut norm = _mm_cvtss_f32(final_n);
let base = chunks * 4;
for i in 0..remainder {
dot += a[base + i] * b[base + i];
norm += b[base + i] * b[base + i];
}
(dot, norm)
}
#[inline]
pub fn fast_inv_sqrt(x: f32) -> f32 {
let half = 0.5 * x;
let i = 0x5F37_5A86_u32.wrapping_sub(x.to_bits() >> 1);
let y = f32::from_bits(i);
let y = y * (1.5 - half * y * y); y * (1.5 - half * y * y) }
#[inline]
pub fn batch_cosine_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
let n = scores.len();
debug_assert!(vectors.len() >= n * dim);
debug_assert_eq!(query.len(), dim);
if dim == 0 || n == 0 {
return;
}
let norm_q_sq = dot_product_f32(query, query, dim);
if norm_q_sq < f32::EPSILON {
for s in scores.iter_mut() {
*s = 0.0;
}
return;
}
let inv_norm_q = fast_inv_sqrt(norm_q_sq);
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
if norm_v_sq < f32::EPSILON {
scores[i] = 0.0;
} else {
scores[i] = dot * inv_norm_q * fast_inv_sqrt(norm_v_sq);
}
}
}
#[inline]
pub fn f32_to_f16(value: f32) -> u16 {
let bits = value.to_bits();
let sign = (bits >> 16) & 0x8000;
let exp = ((bits >> 23) & 0xFF) as i32;
let mantissa = bits & 0x7F_FFFF;
if exp == 255 {
return (sign | 0x7C00 | ((mantissa >> 13) & 0x3FF)) as u16;
}
let exp16 = exp - 127 + 15;
if exp16 >= 31 {
return (sign | 0x7C00) as u16; }
if exp16 <= 0 {
if exp16 < -10 {
return sign as u16; }
let shift = (1 - exp16) as u32;
let m = (mantissa | 0x80_0000) >> shift;
let round_bit = (m >> 12) & 1;
let sticky = m & 0xFFF;
let m13 = m >> 13;
let rounded = m13 + (round_bit & (m13 | if sticky != 0 { 1 } else { 0 }));
return (sign | rounded) as u16;
}
let round_bit = (mantissa >> 12) & 1;
let sticky = mantissa & 0xFFF;
let m13 = mantissa >> 13;
let rounded = m13 + (round_bit & (m13 | if sticky != 0 { 1 } else { 0 }));
if rounded > 0x3FF {
let exp16_inc = exp16 as u32 + 1;
if exp16_inc >= 31 {
return (sign | 0x7C00) as u16; }
(sign | (exp16_inc << 10)) as u16
} else {
(sign | ((exp16 as u32) << 10) | rounded) as u16
}
}
#[inline]
pub fn f16_to_f32(half: u16) -> f32 {
let sign = ((half & 0x8000) as u32) << 16;
let exp = ((half >> 10) & 0x1F) as u32;
let mantissa = (half & 0x3FF) as u32;
if exp == 0 {
if mantissa == 0 {
return f32::from_bits(sign);
}
let mut e = 0u32;
let mut m = mantissa;
while (m & 0x400) == 0 {
m <<= 1;
e += 1;
}
return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | ((m & 0x3FF) << 13));
}
if exp == 31 {
return f32::from_bits(sign | 0x7F80_0000 | (mantissa << 13));
}
f32::from_bits(sign | ((exp + 127 - 15) << 23) | (mantissa << 13))
}
const U8_SCALE: f32 = 127.5;
const U8_INV_SCALE: f32 = 1.0 / 127.5;
#[inline]
pub fn f32_to_u8_saturating(value: f32) -> u8 {
((value.clamp(-1.0, 1.0) + 1.0) * U8_SCALE) as u8
}
#[inline]
pub fn u8_to_f32(byte: u8) -> f32 {
byte as f32 * U8_INV_SCALE - 1.0
}
pub fn batch_f32_to_f16(src: &[f32], dst: &mut [u16]) {
debug_assert_eq!(src.len(), dst.len());
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = f32_to_f16(*s);
}
}
pub fn batch_f32_to_u8(src: &[f32], dst: &mut [u8]) {
debug_assert_eq!(src.len(), dst.len());
for (s, d) in src.iter().zip(dst.iter_mut()) {
*d = f32_to_u8_saturating(*s);
}
}
#[cfg(target_arch = "aarch64")]
#[allow(unsafe_op_in_unsafe_fn)]
mod neon_quant {
use std::arch::aarch64::*;
#[allow(clippy::incompatible_msrv)]
#[target_feature(enable = "neon")]
pub unsafe fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
let chunks16 = dim / 16;
let remainder = dim % 16;
let mut acc_dot0 = vdupq_n_f32(0.0);
let mut acc_dot1 = vdupq_n_f32(0.0);
let mut acc_norm0 = vdupq_n_f32(0.0);
let mut acc_norm1 = vdupq_n_f32(0.0);
for c in 0..chunks16 {
let base = c * 16;
let v_raw0 = vld1q_u16(vec_f16.as_ptr().add(base));
let v_lo0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw0)));
let v_hi0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw0)));
let q_raw0 = vld1q_u16(query_f16.as_ptr().add(base));
let q_lo0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw0)));
let q_hi0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw0)));
acc_dot0 = vfmaq_f32(acc_dot0, q_lo0, v_lo0);
acc_dot0 = vfmaq_f32(acc_dot0, q_hi0, v_hi0);
acc_norm0 = vfmaq_f32(acc_norm0, v_lo0, v_lo0);
acc_norm0 = vfmaq_f32(acc_norm0, v_hi0, v_hi0);
let v_raw1 = vld1q_u16(vec_f16.as_ptr().add(base + 8));
let v_lo1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw1)));
let v_hi1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw1)));
let q_raw1 = vld1q_u16(query_f16.as_ptr().add(base + 8));
let q_lo1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw1)));
let q_hi1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw1)));
acc_dot1 = vfmaq_f32(acc_dot1, q_lo1, v_lo1);
acc_dot1 = vfmaq_f32(acc_dot1, q_hi1, v_hi1);
acc_norm1 = vfmaq_f32(acc_norm1, v_lo1, v_lo1);
acc_norm1 = vfmaq_f32(acc_norm1, v_hi1, v_hi1);
}
let mut dot = vaddvq_f32(vaddq_f32(acc_dot0, acc_dot1));
let mut norm = vaddvq_f32(vaddq_f32(acc_norm0, acc_norm1));
let base = chunks16 * 16;
for i in 0..remainder {
let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
dot += q * v;
norm += v * v;
}
(dot, norm)
}
#[target_feature(enable = "neon")]
pub unsafe fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
let scale = vdupq_n_f32(super::U8_INV_SCALE);
let offset = vdupq_n_f32(-1.0);
let chunks16 = dim / 16;
let remainder = dim % 16;
let mut acc_dot = vdupq_n_f32(0.0);
let mut acc_norm = vdupq_n_f32(0.0);
for c in 0..chunks16 {
let base = c * 16;
let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
let lo8 = vget_low_u8(bytes);
let hi8 = vget_high_u8(bytes);
let lo16 = vmovl_u8(lo8);
let hi16 = vmovl_u8(hi8);
let f0 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
offset,
);
let f1 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
offset,
);
let f2 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
offset,
);
let f3 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
offset,
);
let q0 = vld1q_f32(query.as_ptr().add(base));
let q1 = vld1q_f32(query.as_ptr().add(base + 4));
let q2 = vld1q_f32(query.as_ptr().add(base + 8));
let q3 = vld1q_f32(query.as_ptr().add(base + 12));
acc_dot = vfmaq_f32(acc_dot, q0, f0);
acc_dot = vfmaq_f32(acc_dot, q1, f1);
acc_dot = vfmaq_f32(acc_dot, q2, f2);
acc_dot = vfmaq_f32(acc_dot, q3, f3);
acc_norm = vfmaq_f32(acc_norm, f0, f0);
acc_norm = vfmaq_f32(acc_norm, f1, f1);
acc_norm = vfmaq_f32(acc_norm, f2, f2);
acc_norm = vfmaq_f32(acc_norm, f3, f3);
}
let mut dot = vaddvq_f32(acc_dot);
let mut norm = vaddvq_f32(acc_norm);
let base = chunks16 * 16;
for i in 0..remainder {
let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
dot += *query.get_unchecked(base + i) * v;
norm += v * v;
}
(dot, norm)
}
#[allow(clippy::incompatible_msrv)]
#[target_feature(enable = "neon")]
pub unsafe fn dot_product_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
let chunks8 = dim / 8;
let remainder = dim % 8;
let mut acc = vdupq_n_f32(0.0);
for c in 0..chunks8 {
let base = c * 8;
let v_raw = vld1q_u16(vec_f16.as_ptr().add(base));
let v_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(v_raw)));
let v_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(v_raw)));
let q_raw = vld1q_u16(query_f16.as_ptr().add(base));
let q_lo = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(q_raw)));
let q_hi = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(q_raw)));
acc = vfmaq_f32(acc, q_lo, v_lo);
acc = vfmaq_f32(acc, q_hi, v_hi);
}
let mut dot = vaddvq_f32(acc);
let base = chunks8 * 8;
for i in 0..remainder {
let v = super::f16_to_f32(*vec_f16.get_unchecked(base + i));
let q = super::f16_to_f32(*query_f16.get_unchecked(base + i));
dot += q * v;
}
dot
}
#[target_feature(enable = "neon")]
pub unsafe fn dot_product_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
let scale = vdupq_n_f32(super::U8_INV_SCALE);
let offset = vdupq_n_f32(-1.0);
let chunks16 = dim / 16;
let remainder = dim % 16;
let mut acc = vdupq_n_f32(0.0);
for c in 0..chunks16 {
let base = c * 16;
let bytes = vld1q_u8(vec_u8.as_ptr().add(base));
let lo8 = vget_low_u8(bytes);
let hi8 = vget_high_u8(bytes);
let lo16 = vmovl_u8(lo8);
let hi16 = vmovl_u8(hi8);
let f0 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(lo16))), scale),
offset,
);
let f1 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(lo16))), scale),
offset,
);
let f2 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_low_u16(hi16))), scale),
offset,
);
let f3 = vaddq_f32(
vmulq_f32(vcvtq_f32_u32(vmovl_u16(vget_high_u16(hi16))), scale),
offset,
);
let q0 = vld1q_f32(query.as_ptr().add(base));
let q1 = vld1q_f32(query.as_ptr().add(base + 4));
let q2 = vld1q_f32(query.as_ptr().add(base + 8));
let q3 = vld1q_f32(query.as_ptr().add(base + 12));
acc = vfmaq_f32(acc, q0, f0);
acc = vfmaq_f32(acc, q1, f1);
acc = vfmaq_f32(acc, q2, f2);
acc = vfmaq_f32(acc, q3, f3);
}
let mut dot = vaddvq_f32(acc);
let base = chunks16 * 16;
for i in 0..remainder {
let v = super::u8_to_f32(*vec_u8.get_unchecked(base + i));
dot += *query.get_unchecked(base + i) * v;
}
dot
}
}
#[allow(dead_code)]
fn fused_dot_norm_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
let mut dot = 0.0f32;
let mut norm = 0.0f32;
for i in 0..dim {
let v = f16_to_f32(vec_f16[i]);
let q = f16_to_f32(query_f16[i]);
dot += q * v;
norm += v * v;
}
(dot, norm)
}
#[allow(dead_code)]
fn fused_dot_norm_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
let mut dot = 0.0f32;
let mut norm = 0.0f32;
for i in 0..dim {
let v = u8_to_f32(vec_u8[i]);
dot += query[i] * v;
norm += v * v;
}
(dot, norm)
}
#[allow(dead_code)]
fn dot_product_f16_scalar(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
let mut dot = 0.0f32;
for i in 0..dim {
dot += f16_to_f32(query_f16[i]) * f16_to_f32(vec_f16[i]);
}
dot
}
#[allow(dead_code)]
fn dot_product_u8_scalar(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
let mut dot = 0.0f32;
for i in 0..dim {
dot += query[i] * u8_to_f32(vec_u8[i]);
}
dot
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_f16_sse(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
use std::arch::x86_64::*;
let chunks = dim / 4;
let remainder = dim % 4;
let mut acc_dot = _mm_setzero_ps();
let mut acc_norm = _mm_setzero_ps();
for chunk in 0..chunks {
let base = chunk * 4;
let v0 = f16_to_f32(*vec_f16.get_unchecked(base));
let v1 = f16_to_f32(*vec_f16.get_unchecked(base + 1));
let v2 = f16_to_f32(*vec_f16.get_unchecked(base + 2));
let v3 = f16_to_f32(*vec_f16.get_unchecked(base + 3));
let vb = _mm_set_ps(v3, v2, v1, v0);
let q0 = f16_to_f32(*query_f16.get_unchecked(base));
let q1 = f16_to_f32(*query_f16.get_unchecked(base + 1));
let q2 = f16_to_f32(*query_f16.get_unchecked(base + 2));
let q3 = f16_to_f32(*query_f16.get_unchecked(base + 3));
let va = _mm_set_ps(q3, q2, q1, q0);
acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
}
let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
let sums_d = _mm_add_ps(acc_dot, shuf_d);
let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
let sums_n = _mm_add_ps(acc_norm, shuf_n);
let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
let base = chunks * 4;
for i in 0..remainder {
let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
let q = f16_to_f32(*query_f16.get_unchecked(base + i));
dot += q * v;
norm += v * v;
}
(dot, norm)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
use std::arch::x86_64::*;
let scale = _mm_set1_ps(U8_INV_SCALE);
let offset = _mm_set1_ps(-1.0);
let chunks = dim / 4;
let remainder = dim % 4;
let mut acc_dot = _mm_setzero_ps();
let mut acc_norm = _mm_setzero_ps();
for chunk in 0..chunks {
let base = chunk * 4;
let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
vec_u8.as_ptr().add(base) as *const i32
));
let ints = _mm_cvtepu8_epi32(bytes);
let floats = _mm_cvtepi32_ps(ints);
let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
let va = _mm_loadu_ps(query.as_ptr().add(base));
acc_dot = _mm_add_ps(acc_dot, _mm_mul_ps(va, vb));
acc_norm = _mm_add_ps(acc_norm, _mm_mul_ps(vb, vb));
}
let shuf_d = _mm_shuffle_ps(acc_dot, acc_dot, 0b10_11_00_01);
let sums_d = _mm_add_ps(acc_dot, shuf_d);
let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
let shuf_n = _mm_shuffle_ps(acc_norm, acc_norm, 0b10_11_00_01);
let sums_n = _mm_add_ps(acc_norm, shuf_n);
let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
let base = chunks * 4;
for i in 0..remainder {
let v = u8_to_f32(*vec_u8.get_unchecked(base + i));
dot += *query.get_unchecked(base + i) * v;
norm += v * v;
}
(dot, norm)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn fused_dot_norm_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
use std::arch::x86_64::*;
let chunks16 = dim / 16;
let remainder = dim % 16;
let mut acc_dot0 = _mm256_setzero_ps();
let mut acc_dot1 = _mm256_setzero_ps();
let mut acc_norm0 = _mm256_setzero_ps();
let mut acc_norm1 = _mm256_setzero_ps();
for c in 0..chunks16 {
let base = c * 16;
let v_raw0 = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
let vb0 = _mm256_cvtph_ps(v_raw0);
let q_raw0 = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
let qa0 = _mm256_cvtph_ps(q_raw0);
acc_dot0 = _mm256_fmadd_ps(qa0, vb0, acc_dot0);
acc_norm0 = _mm256_fmadd_ps(vb0, vb0, acc_norm0);
let v_raw1 = _mm_loadu_si128(vec_f16.as_ptr().add(base + 8) as *const __m128i);
let vb1 = _mm256_cvtph_ps(v_raw1);
let q_raw1 = _mm_loadu_si128(query_f16.as_ptr().add(base + 8) as *const __m128i);
let qa1 = _mm256_cvtph_ps(q_raw1);
acc_dot1 = _mm256_fmadd_ps(qa1, vb1, acc_dot1);
acc_norm1 = _mm256_fmadd_ps(vb1, vb1, acc_norm1);
}
let acc_dot = _mm256_add_ps(acc_dot0, acc_dot1);
let acc_norm = _mm256_add_ps(acc_norm0, acc_norm1);
let hi_d = _mm256_extractf128_ps(acc_dot, 1);
let lo_d = _mm256_castps256_ps128(acc_dot);
let sum_d = _mm_add_ps(lo_d, hi_d);
let shuf_d = _mm_shuffle_ps(sum_d, sum_d, 0b10_11_00_01);
let sums_d = _mm_add_ps(sum_d, shuf_d);
let shuf2_d = _mm_movehl_ps(sums_d, sums_d);
let mut dot = _mm_cvtss_f32(_mm_add_ss(sums_d, shuf2_d));
let hi_n = _mm256_extractf128_ps(acc_norm, 1);
let lo_n = _mm256_castps256_ps128(acc_norm);
let sum_n = _mm_add_ps(lo_n, hi_n);
let shuf_n = _mm_shuffle_ps(sum_n, sum_n, 0b10_11_00_01);
let sums_n = _mm_add_ps(sum_n, shuf_n);
let shuf2_n = _mm_movehl_ps(sums_n, sums_n);
let mut norm = _mm_cvtss_f32(_mm_add_ss(sums_n, shuf2_n));
let base = chunks16 * 16;
for i in 0..remainder {
let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
let q = f16_to_f32(*query_f16.get_unchecked(base + i));
dot += q * v;
norm += v * v;
}
(dot, norm)
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx", enable = "f16c", enable = "fma")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dot_product_f16_f16c(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
use std::arch::x86_64::*;
let chunks = dim / 8;
let remainder = dim % 8;
let mut acc = _mm256_setzero_ps();
for chunk in 0..chunks {
let base = chunk * 8;
let v_raw = _mm_loadu_si128(vec_f16.as_ptr().add(base) as *const __m128i);
let vb = _mm256_cvtph_ps(v_raw);
let q_raw = _mm_loadu_si128(query_f16.as_ptr().add(base) as *const __m128i);
let qa = _mm256_cvtph_ps(q_raw);
acc = _mm256_fmadd_ps(qa, vb, acc);
}
let hi = _mm256_extractf128_ps(acc, 1);
let lo = _mm256_castps256_ps128(acc);
let sum = _mm_add_ps(lo, hi);
let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
let sums = _mm_add_ps(sum, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
let base = chunks * 8;
for i in 0..remainder {
let v = f16_to_f32(*vec_f16.get_unchecked(base + i));
let q = f16_to_f32(*query_f16.get_unchecked(base + i));
dot += q * v;
}
dot
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2", enable = "sse4.1")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn dot_product_u8_sse(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
use std::arch::x86_64::*;
let scale = _mm_set1_ps(U8_INV_SCALE);
let offset = _mm_set1_ps(-1.0);
let chunks = dim / 4;
let remainder = dim % 4;
let mut acc = _mm_setzero_ps();
for chunk in 0..chunks {
let base = chunk * 4;
let bytes = _mm_cvtsi32_si128(std::ptr::read_unaligned(
vec_u8.as_ptr().add(base) as *const i32
));
let ints = _mm_cvtepu8_epi32(bytes);
let floats = _mm_cvtepi32_ps(ints);
let vb = _mm_add_ps(_mm_mul_ps(floats, scale), offset);
let va = _mm_loadu_ps(query.as_ptr().add(base));
acc = _mm_add_ps(acc, _mm_mul_ps(va, vb));
}
let shuf = _mm_shuffle_ps(acc, acc, 0b10_11_00_01);
let sums = _mm_add_ps(acc, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let mut dot = _mm_cvtss_f32(_mm_add_ss(sums, shuf2));
let base = chunks * 4;
for i in 0..remainder {
dot += *query.get_unchecked(base + i) * u8_to_f32(*vec_u8.get_unchecked(base + i));
}
dot
}
#[inline]
fn fused_dot_norm_f16(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> (f32, f32) {
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon_quant::fused_dot_norm_f16(query_f16, vec_f16, dim) };
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
return unsafe { fused_dot_norm_f16_f16c(query_f16, vec_f16, dim) };
}
if sse::is_available() {
return unsafe { fused_dot_norm_f16_sse(query_f16, vec_f16, dim) };
}
}
#[allow(unreachable_code)]
fused_dot_norm_f16_scalar(query_f16, vec_f16, dim)
}
#[inline]
fn fused_dot_norm_u8(query: &[f32], vec_u8: &[u8], dim: usize) -> (f32, f32) {
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon_quant::fused_dot_norm_u8(query, vec_u8, dim) };
}
#[cfg(target_arch = "x86_64")]
{
if sse::is_available() {
return unsafe { fused_dot_norm_u8_sse(query, vec_u8, dim) };
}
}
#[allow(unreachable_code)]
fused_dot_norm_u8_scalar(query, vec_u8, dim)
}
#[inline]
fn dot_product_f16_quant(query_f16: &[u16], vec_f16: &[u16], dim: usize) -> f32 {
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon_quant::dot_product_f16(query_f16, vec_f16, dim) };
}
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("f16c") && is_x86_feature_detected!("fma") {
return unsafe { dot_product_f16_f16c(query_f16, vec_f16, dim) };
}
}
#[allow(unreachable_code)]
dot_product_f16_scalar(query_f16, vec_f16, dim)
}
#[inline]
fn dot_product_u8_quant(query: &[f32], vec_u8: &[u8], dim: usize) -> f32 {
#[cfg(target_arch = "aarch64")]
{
return unsafe { neon_quant::dot_product_u8(query, vec_u8, dim) };
}
#[cfg(target_arch = "x86_64")]
{
if sse::is_available() {
return unsafe { dot_product_u8_sse(query, vec_u8, dim) };
}
}
#[allow(unreachable_code)]
dot_product_u8_scalar(query, vec_u8, dim)
}
#[inline]
pub fn batch_cosine_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
let n = scores.len();
if dim == 0 || n == 0 {
return;
}
let norm_q_sq = dot_product_f32(query, query, dim);
if norm_q_sq < f32::EPSILON {
for s in scores.iter_mut() {
*s = 0.0;
}
return;
}
let inv_norm_q = fast_inv_sqrt(norm_q_sq);
let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
let vec_bytes = dim * 2;
debug_assert!(vectors_raw.len() >= n * vec_bytes);
debug_assert!(
(vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
"f16 vector data not 2-byte aligned"
);
for i in 0..n {
let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
let (dot, norm_v_sq) = fused_dot_norm_f16(&query_f16, f16_slice, dim);
scores[i] = if norm_v_sq < f32::EPSILON {
0.0
} else {
dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
};
}
}
#[inline]
pub fn batch_cosine_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
let n = scores.len();
if dim == 0 || n == 0 {
return;
}
let norm_q_sq = dot_product_f32(query, query, dim);
if norm_q_sq < f32::EPSILON {
for s in scores.iter_mut() {
*s = 0.0;
}
return;
}
let inv_norm_q = fast_inv_sqrt(norm_q_sq);
debug_assert!(vectors_raw.len() >= n * dim);
for i in 0..n {
let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
scores[i] = if norm_v_sq < f32::EPSILON {
0.0
} else {
dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
};
}
}
#[inline]
pub fn batch_dot_scores(query: &[f32], vectors: &[f32], dim: usize, scores: &mut [f32]) {
let n = scores.len();
debug_assert!(vectors.len() >= n * dim);
debug_assert_eq!(query.len(), dim);
if dim == 0 || n == 0 {
return;
}
let norm_q_sq = dot_product_f32(query, query, dim);
if norm_q_sq < f32::EPSILON {
for s in scores.iter_mut() {
*s = 0.0;
}
return;
}
let inv_norm_q = fast_inv_sqrt(norm_q_sq);
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
let dot = dot_product_f32(query, vec, dim);
scores[i] = dot * inv_norm_q;
}
}
#[inline]
pub fn batch_dot_scores_f16(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
let n = scores.len();
if dim == 0 || n == 0 {
return;
}
let norm_q_sq = dot_product_f32(query, query, dim);
if norm_q_sq < f32::EPSILON {
for s in scores.iter_mut() {
*s = 0.0;
}
return;
}
let inv_norm_q = fast_inv_sqrt(norm_q_sq);
let query_f16: Vec<u16> = query.iter().map(|&v| f32_to_f16(v)).collect();
let vec_bytes = dim * 2;
debug_assert!(vectors_raw.len() >= n * vec_bytes);
debug_assert!(
(vectors_raw.as_ptr() as usize).is_multiple_of(std::mem::align_of::<u16>()),
"f16 vector data not 2-byte aligned"
);
for i in 0..n {
let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
let dot = dot_product_f16_quant(&query_f16, f16_slice, dim);
scores[i] = dot * inv_norm_q;
}
}
#[inline]
pub fn batch_dot_scores_u8(query: &[f32], vectors_raw: &[u8], dim: usize, scores: &mut [f32]) {
let n = scores.len();
if dim == 0 || n == 0 {
return;
}
let norm_q_sq = dot_product_f32(query, query, dim);
if norm_q_sq < f32::EPSILON {
for s in scores.iter_mut() {
*s = 0.0;
}
return;
}
let inv_norm_q = fast_inv_sqrt(norm_q_sq);
debug_assert!(vectors_raw.len() >= n * dim);
for i in 0..n {
let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
let dot = dot_product_u8_quant(query, u8_slice, dim);
scores[i] = dot * inv_norm_q;
}
}
#[inline]
pub fn batch_cosine_scores_precomp(
query: &[f32],
vectors: &[f32],
dim: usize,
scores: &mut [f32],
inv_norm_q: f32,
) {
let n = scores.len();
debug_assert!(vectors.len() >= n * dim);
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
let (dot, norm_v_sq) = fused_dot_norm(query, vec, dim);
scores[i] = if norm_v_sq < f32::EPSILON {
0.0
} else {
dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
};
}
}
#[inline]
pub fn batch_cosine_scores_f16_precomp(
query_f16: &[u16],
vectors_raw: &[u8],
dim: usize,
scores: &mut [f32],
inv_norm_q: f32,
) {
let n = scores.len();
let vec_bytes = dim * 2;
debug_assert!(vectors_raw.len() >= n * vec_bytes);
for i in 0..n {
let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
let (dot, norm_v_sq) = fused_dot_norm_f16(query_f16, f16_slice, dim);
scores[i] = if norm_v_sq < f32::EPSILON {
0.0
} else {
dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
};
}
}
#[inline]
pub fn batch_cosine_scores_u8_precomp(
query: &[f32],
vectors_raw: &[u8],
dim: usize,
scores: &mut [f32],
inv_norm_q: f32,
) {
let n = scores.len();
debug_assert!(vectors_raw.len() >= n * dim);
for i in 0..n {
let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
let (dot, norm_v_sq) = fused_dot_norm_u8(query, u8_slice, dim);
scores[i] = if norm_v_sq < f32::EPSILON {
0.0
} else {
dot * inv_norm_q * fast_inv_sqrt(norm_v_sq)
};
}
}
#[inline]
pub fn batch_dot_scores_precomp(
query: &[f32],
vectors: &[f32],
dim: usize,
scores: &mut [f32],
inv_norm_q: f32,
) {
let n = scores.len();
debug_assert!(vectors.len() >= n * dim);
for i in 0..n {
let vec = &vectors[i * dim..(i + 1) * dim];
scores[i] = dot_product_f32(query, vec, dim) * inv_norm_q;
}
}
#[inline]
pub fn batch_dot_scores_f16_precomp(
query_f16: &[u16],
vectors_raw: &[u8],
dim: usize,
scores: &mut [f32],
inv_norm_q: f32,
) {
let n = scores.len();
let vec_bytes = dim * 2;
debug_assert!(vectors_raw.len() >= n * vec_bytes);
for i in 0..n {
let raw = &vectors_raw[i * vec_bytes..(i + 1) * vec_bytes];
let f16_slice = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const u16, dim) };
scores[i] = dot_product_f16_quant(query_f16, f16_slice, dim) * inv_norm_q;
}
}
#[inline]
pub fn batch_dot_scores_u8_precomp(
query: &[f32],
vectors_raw: &[u8],
dim: usize,
scores: &mut [f32],
inv_norm_q: f32,
) {
let n = scores.len();
debug_assert!(vectors_raw.len() >= n * dim);
for i in 0..n {
let u8_slice = &vectors_raw[i * dim..(i + 1) * dim];
scores[i] = dot_product_u8_quant(query, u8_slice, dim) * inv_norm_q;
}
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let count = a.len();
if count == 0 {
return 0.0;
}
let dot = dot_product_f32(a, b, count);
let norm_a = dot_product_f32(a, a, count);
let norm_b = dot_product_f32(b, b, count);
let denom = (norm_a * norm_b).sqrt();
if denom < f32::EPSILON {
return 0.0;
}
dot / denom
}
#[inline]
pub fn squared_euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let count = a.len();
if count == 0 {
return 0.0;
}
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
return unsafe { squared_euclidean_neon(a, b, count) };
}
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() && is_x86_feature_detected!("fma") {
return unsafe { squared_euclidean_avx2(a, b, count) };
}
if sse::is_available() {
return unsafe { squared_euclidean_sse(a, b, count) };
}
}
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let d = x - y;
d * d
})
.sum()
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn squared_euclidean_neon(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::aarch64::*;
let chunks16 = count / 16;
let remainder = count % 16;
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
let mut acc2 = vdupq_n_f32(0.0);
let mut acc3 = vdupq_n_f32(0.0);
for c in 0..chunks16 {
let base = c * 16;
let va0 = vld1q_f32(a.as_ptr().add(base));
let vb0 = vld1q_f32(b.as_ptr().add(base));
let d0 = vsubq_f32(va0, vb0);
acc0 = vfmaq_f32(acc0, d0, d0);
let va1 = vld1q_f32(a.as_ptr().add(base + 4));
let vb1 = vld1q_f32(b.as_ptr().add(base + 4));
let d1 = vsubq_f32(va1, vb1);
acc1 = vfmaq_f32(acc1, d1, d1);
let va2 = vld1q_f32(a.as_ptr().add(base + 8));
let vb2 = vld1q_f32(b.as_ptr().add(base + 8));
let d2 = vsubq_f32(va2, vb2);
acc2 = vfmaq_f32(acc2, d2, d2);
let va3 = vld1q_f32(a.as_ptr().add(base + 12));
let vb3 = vld1q_f32(b.as_ptr().add(base + 12));
let d3 = vsubq_f32(va3, vb3);
acc3 = vfmaq_f32(acc3, d3, d3);
}
let combined = vaddq_f32(vaddq_f32(acc0, acc1), vaddq_f32(acc2, acc3));
let mut sum = vaddvq_f32(combined);
let base = chunks16 * 16;
for i in 0..remainder {
let d = a[base + i] - b[base + i];
sum += d * d;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn squared_euclidean_sse(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::x86_64::*;
let chunks16 = count / 16;
let remainder = count % 16;
let mut acc0 = _mm_setzero_ps();
let mut acc1 = _mm_setzero_ps();
let mut acc2 = _mm_setzero_ps();
let mut acc3 = _mm_setzero_ps();
for c in 0..chunks16 {
let base = c * 16;
let d0 = _mm_sub_ps(
_mm_loadu_ps(a.as_ptr().add(base)),
_mm_loadu_ps(b.as_ptr().add(base)),
);
acc0 = _mm_add_ps(acc0, _mm_mul_ps(d0, d0));
let d1 = _mm_sub_ps(
_mm_loadu_ps(a.as_ptr().add(base + 4)),
_mm_loadu_ps(b.as_ptr().add(base + 4)),
);
acc1 = _mm_add_ps(acc1, _mm_mul_ps(d1, d1));
let d2 = _mm_sub_ps(
_mm_loadu_ps(a.as_ptr().add(base + 8)),
_mm_loadu_ps(b.as_ptr().add(base + 8)),
);
acc2 = _mm_add_ps(acc2, _mm_mul_ps(d2, d2));
let d3 = _mm_sub_ps(
_mm_loadu_ps(a.as_ptr().add(base + 12)),
_mm_loadu_ps(b.as_ptr().add(base + 12)),
);
acc3 = _mm_add_ps(acc3, _mm_mul_ps(d3, d3));
}
let combined = _mm_add_ps(_mm_add_ps(acc0, acc1), _mm_add_ps(acc2, acc3));
let shuf = _mm_shuffle_ps(combined, combined, 0b10_11_00_01);
let sums = _mm_add_ps(combined, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let final_sum = _mm_add_ss(sums, shuf2);
let mut sum = _mm_cvtss_f32(final_sum);
let base = chunks16 * 16;
for i in 0..remainder {
let d = a[base + i] - b[base + i];
sum += d * d;
}
sum
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn squared_euclidean_avx2(a: &[f32], b: &[f32], count: usize) -> f32 {
use std::arch::x86_64::*;
let chunks32 = count / 32;
let remainder = count % 32;
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
let mut acc2 = _mm256_setzero_ps();
let mut acc3 = _mm256_setzero_ps();
for c in 0..chunks32 {
let base = c * 32;
let d0 = _mm256_sub_ps(
_mm256_loadu_ps(a.as_ptr().add(base)),
_mm256_loadu_ps(b.as_ptr().add(base)),
);
acc0 = _mm256_fmadd_ps(d0, d0, acc0);
let d1 = _mm256_sub_ps(
_mm256_loadu_ps(a.as_ptr().add(base + 8)),
_mm256_loadu_ps(b.as_ptr().add(base + 8)),
);
acc1 = _mm256_fmadd_ps(d1, d1, acc1);
let d2 = _mm256_sub_ps(
_mm256_loadu_ps(a.as_ptr().add(base + 16)),
_mm256_loadu_ps(b.as_ptr().add(base + 16)),
);
acc2 = _mm256_fmadd_ps(d2, d2, acc2);
let d3 = _mm256_sub_ps(
_mm256_loadu_ps(a.as_ptr().add(base + 24)),
_mm256_loadu_ps(b.as_ptr().add(base + 24)),
);
acc3 = _mm256_fmadd_ps(d3, d3, acc3);
}
let combined = _mm256_add_ps(_mm256_add_ps(acc0, acc1), _mm256_add_ps(acc2, acc3));
let high = _mm256_extractf128_ps(combined, 1);
let low = _mm256_castps256_ps128(combined);
let sum128 = _mm_add_ps(low, high);
let shuf = _mm_shuffle_ps(sum128, sum128, 0b10_11_00_01);
let sums = _mm_add_ps(sum128, shuf);
let shuf2 = _mm_movehl_ps(sums, sums);
let final_sum = _mm_add_ss(sums, shuf2);
let mut sum = _mm_cvtss_f32(final_sum);
let base = chunks32 * 32;
for i in 0..remainder {
let d = a[base + i] - b[base + i];
sum += d * d;
}
sum
}
#[inline]
pub fn batch_squared_euclidean_distances(
query: &[f32],
vectors: &[Vec<f32>],
distances: &mut [f32],
) {
debug_assert_eq!(vectors.len(), distances.len());
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() && is_x86_feature_detected!("fma") {
for (i, vec) in vectors.iter().enumerate() {
distances[i] = unsafe { squared_euclidean_avx2(query, vec, query.len()) };
}
return;
}
}
for (i, vec) in vectors.iter().enumerate() {
distances[i] = squared_euclidean_distance(query, vec);
}
}
#[inline]
pub fn hamming_distance(a: &[u8], b: &[u8]) -> u32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "aarch64")]
unsafe {
neon::hamming_distance(a, b)
}
#[cfg(target_arch = "x86_64")]
{
if avx2::is_available() {
return unsafe { avx2::hamming_distance(a, b) };
}
hamming_distance_scalar(a, b)
}
#[cfg(not(any(target_arch = "aarch64", target_arch = "x86_64")))]
hamming_distance_scalar(a, b)
}
#[inline]
#[allow(dead_code)]
fn hamming_distance_scalar(a: &[u8], b: &[u8]) -> u32 {
let len = a.len();
let chunks = len / 8;
let remainder = len % 8;
let mut total = 0u32;
for i in 0..chunks {
let off = i * 8;
let va = unsafe { std::ptr::read_unaligned(a.as_ptr().add(off) as *const u64) };
let vb = unsafe { std::ptr::read_unaligned(b.as_ptr().add(off) as *const u64) };
total += (va ^ vb).count_ones();
}
let base = chunks * 8;
for i in 0..remainder {
total += (a[base + i] ^ b[base + i]).count_ones();
}
total
}
pub fn batch_hamming_scores(
query: &[u8],
db: &[u8],
byte_len: usize,
dim_bits: usize,
scores: &mut [f32],
) {
let n = scores.len();
debug_assert_eq!(query.len(), byte_len);
debug_assert!(db.len() >= n * byte_len);
if byte_len == 0 || n == 0 || dim_bits == 0 {
return;
}
let inv_dim = 1.0 / dim_bits as f32;
for i in 0..n {
let vec = &db[i * byte_len..(i + 1) * byte_len];
let dist = hamming_distance(query, vec);
scores[i] = 1.0 - dist as f32 * inv_dim;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unpack_8bit() {
let input: Vec<u8> = (0..128).collect();
let mut output = vec![0u32; 128];
unpack_8bit(&input, &mut output, 128);
for (i, &v) in output.iter().enumerate() {
assert_eq!(v, i as u32);
}
}
#[test]
fn test_unpack_16bit() {
let mut input = vec![0u8; 256];
for i in 0..128 {
let val = (i * 100) as u16;
input[i * 2] = val as u8;
input[i * 2 + 1] = (val >> 8) as u8;
}
let mut output = vec![0u32; 128];
unpack_16bit(&input, &mut output, 128);
for (i, &v) in output.iter().enumerate() {
assert_eq!(v, (i * 100) as u32);
}
}
#[test]
fn test_unpack_32bit() {
let mut input = vec![0u8; 512];
for i in 0..128 {
let val = (i * 1000) as u32;
let bytes = val.to_le_bytes();
input[i * 4..i * 4 + 4].copy_from_slice(&bytes);
}
let mut output = vec![0u32; 128];
unpack_32bit(&input, &mut output, 128);
for (i, &v) in output.iter().enumerate() {
assert_eq!(v, (i * 1000) as u32);
}
}
#[test]
fn test_delta_decode() {
let deltas = vec![4u32, 4, 9, 19];
let mut output = vec![0u32; 5];
delta_decode(&mut output, &deltas, 10, 5);
assert_eq!(output, vec![10, 15, 20, 30, 50]);
}
#[test]
fn test_add_one() {
let mut values = vec![0u32, 1, 2, 3, 4, 5, 6, 7];
add_one(&mut values, 8);
assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8]);
}
#[test]
fn test_bits_needed() {
assert_eq!(bits_needed(0), 0);
assert_eq!(bits_needed(1), 1);
assert_eq!(bits_needed(2), 2);
assert_eq!(bits_needed(3), 2);
assert_eq!(bits_needed(4), 3);
assert_eq!(bits_needed(255), 8);
assert_eq!(bits_needed(256), 9);
assert_eq!(bits_needed(u32::MAX), 32);
}
#[test]
fn test_unpack_8bit_delta_decode() {
let input: Vec<u8> = vec![4, 4, 9, 19];
let mut output = vec![0u32; 5];
unpack_8bit_delta_decode(&input, &mut output, 10, 5);
assert_eq!(output, vec![10, 15, 20, 30, 50]);
}
#[test]
fn test_unpack_16bit_delta_decode() {
let mut input = vec![0u8; 8];
for (i, &delta) in [499u16, 499, 999, 1999].iter().enumerate() {
input[i * 2] = delta as u8;
input[i * 2 + 1] = (delta >> 8) as u8;
}
let mut output = vec![0u32; 5];
unpack_16bit_delta_decode(&input, &mut output, 100, 5);
assert_eq!(output, vec![100, 600, 1100, 2100, 4100]);
}
#[test]
fn test_fused_vs_separate_8bit() {
let input: Vec<u8> = (0..127).collect();
let first_value = 1000u32;
let count = 128;
let mut unpacked = vec![0u32; 128];
unpack_8bit(&input, &mut unpacked, 127);
let mut separate_output = vec![0u32; 128];
delta_decode(&mut separate_output, &unpacked, first_value, count);
let mut fused_output = vec![0u32; 128];
unpack_8bit_delta_decode(&input, &mut fused_output, first_value, count);
assert_eq!(separate_output, fused_output);
}
#[test]
fn test_round_bit_width() {
assert_eq!(round_bit_width(0), 0);
assert_eq!(round_bit_width(1), 8);
assert_eq!(round_bit_width(5), 8);
assert_eq!(round_bit_width(8), 8);
assert_eq!(round_bit_width(9), 16);
assert_eq!(round_bit_width(12), 16);
assert_eq!(round_bit_width(16), 16);
assert_eq!(round_bit_width(17), 32);
assert_eq!(round_bit_width(24), 32);
assert_eq!(round_bit_width(32), 32);
}
#[test]
fn test_rounded_bitwidth_from_exact() {
assert_eq!(RoundedBitWidth::from_exact(0), RoundedBitWidth::Zero);
assert_eq!(RoundedBitWidth::from_exact(1), RoundedBitWidth::Bits8);
assert_eq!(RoundedBitWidth::from_exact(8), RoundedBitWidth::Bits8);
assert_eq!(RoundedBitWidth::from_exact(9), RoundedBitWidth::Bits16);
assert_eq!(RoundedBitWidth::from_exact(16), RoundedBitWidth::Bits16);
assert_eq!(RoundedBitWidth::from_exact(17), RoundedBitWidth::Bits32);
assert_eq!(RoundedBitWidth::from_exact(32), RoundedBitWidth::Bits32);
}
#[test]
fn test_pack_unpack_rounded_8bit() {
let values: Vec<u32> = (0..128).map(|i| i % 256).collect();
let mut packed = vec![0u8; 128];
let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits8, &mut packed);
assert_eq!(bytes_written, 128);
let mut unpacked = vec![0u32; 128];
unpack_rounded(&packed, RoundedBitWidth::Bits8, &mut unpacked, 128);
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_unpack_rounded_16bit() {
let values: Vec<u32> = (0..128).map(|i| i * 100).collect();
let mut packed = vec![0u8; 256];
let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits16, &mut packed);
assert_eq!(bytes_written, 256);
let mut unpacked = vec![0u32; 128];
unpack_rounded(&packed, RoundedBitWidth::Bits16, &mut unpacked, 128);
assert_eq!(values, unpacked);
}
#[test]
fn test_pack_unpack_rounded_32bit() {
let values: Vec<u32> = (0..128).map(|i| i * 100000).collect();
let mut packed = vec![0u8; 512];
let bytes_written = pack_rounded(&values, RoundedBitWidth::Bits32, &mut packed);
assert_eq!(bytes_written, 512);
let mut unpacked = vec![0u32; 128];
unpack_rounded(&packed, RoundedBitWidth::Bits32, &mut unpacked, 128);
assert_eq!(values, unpacked);
}
#[test]
fn test_unpack_rounded_delta_decode() {
let input: Vec<u8> = vec![4, 4, 9, 19];
let mut output = vec![0u32; 5];
unpack_rounded_delta_decode(&input, RoundedBitWidth::Bits8, &mut output, 10, 5);
assert_eq!(output, vec![10, 15, 20, 30, 50]);
}
#[test]
fn test_unpack_rounded_delta_decode_zero() {
let input: Vec<u8> = vec![];
let mut output = vec![0u32; 5];
unpack_rounded_delta_decode(&input, RoundedBitWidth::Zero, &mut output, 100, 5);
assert_eq!(output, vec![100, 101, 102, 103, 104]);
}
#[test]
fn test_dequantize_uint8() {
let input: Vec<u8> = vec![0, 128, 255, 64, 192];
let mut output = vec![0.0f32; 5];
let scale = 0.1;
let min_val = 1.0;
dequantize_uint8(&input, &mut output, scale, min_val, 5);
assert!((output[0] - 1.0).abs() < 1e-6); assert!((output[1] - 13.8).abs() < 1e-6); assert!((output[2] - 26.5).abs() < 1e-6); assert!((output[3] - 7.4).abs() < 1e-6); assert!((output[4] - 20.2).abs() < 1e-6); }
#[test]
fn test_dequantize_uint8_large() {
let input: Vec<u8> = (0..128).collect();
let mut output = vec![0.0f32; 128];
let scale = 2.0;
let min_val = -10.0;
dequantize_uint8(&input, &mut output, scale, min_val, 128);
for (i, &out) in output.iter().enumerate().take(128) {
let expected = i as f32 * scale + min_val;
assert!(
(out - expected).abs() < 1e-5,
"Mismatch at {}: expected {}, got {}",
i,
expected,
out
);
}
}
#[test]
fn test_dot_product_f32() {
let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0];
let result = dot_product_f32(&a, &b, 5);
assert!((result - 70.0).abs() < 1e-5);
}
#[test]
fn test_dot_product_f32_large() {
let a: Vec<f32> = (0..128).map(|i| i as f32).collect();
let b: Vec<f32> = (0..128).map(|i| (i + 1) as f32).collect();
let result = dot_product_f32(&a, &b, 128);
let expected: f32 = (0..128).map(|i| (i as f32) * ((i + 1) as f32)).sum();
assert!(
(result - expected).abs() < 1e-3,
"Expected {}, got {}",
expected,
result
);
}
#[test]
fn test_max_f32() {
let values = vec![1.0f32, 5.0, 3.0, 9.0, 2.0, 7.0];
let result = max_f32(&values, 6);
assert!((result - 9.0).abs() < 1e-6);
}
#[test]
fn test_max_f32_large() {
let mut values: Vec<f32> = (0..128).map(|i| i as f32).collect();
values[77] = 1000.0;
let result = max_f32(&values, 128);
assert!((result - 1000.0).abs() < 1e-5);
}
#[test]
fn test_max_f32_negative() {
let values = vec![-5.0f32, -2.0, -10.0, -1.0, -3.0];
let result = max_f32(&values, 5);
assert!((result - (-1.0)).abs() < 1e-6);
}
#[test]
fn test_max_f32_empty() {
let values: Vec<f32> = vec![];
let result = max_f32(&values, 0);
assert_eq!(result, f32::NEG_INFINITY);
}
#[test]
fn test_fused_dot_norm() {
let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let expected_norm: f32 = b.iter().map(|x| x * x).sum();
assert!(
(dot - expected_dot).abs() < 1e-5,
"dot: expected {}, got {}",
expected_dot,
dot
);
assert!(
(norm_b - expected_norm).abs() < 1e-5,
"norm: expected {}, got {}",
expected_norm,
norm_b
);
}
#[test]
fn test_fused_dot_norm_large() {
let a: Vec<f32> = (0..768).map(|i| (i as f32) * 0.01).collect();
let b: Vec<f32> = (0..768).map(|i| (i as f32) * 0.02 + 0.5).collect();
let (dot, norm_b) = fused_dot_norm(&a, &b, a.len());
let expected_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let expected_norm: f32 = b.iter().map(|x| x * x).sum();
assert!(
(dot - expected_dot).abs() < 1.0,
"dot: expected {}, got {}",
expected_dot,
dot
);
assert!(
(norm_b - expected_norm).abs() < 1.0,
"norm: expected {}, got {}",
expected_norm,
norm_b
);
}
#[test]
fn test_batch_cosine_scores() {
let query = vec![1.0f32, 0.0, 0.0];
let vectors = vec![
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.5, 0.5, 0.0, ];
let mut scores = vec![0f32; 4];
batch_cosine_scores(&query, &vectors, 3, &mut scores);
assert!((scores[0] - 1.0).abs() < 1e-5, "identical: {}", scores[0]);
assert!(scores[1].abs() < 1e-5, "orthogonal: {}", scores[1]);
assert!((scores[2] - (-1.0)).abs() < 1e-5, "opposite: {}", scores[2]);
let expected_45 = 0.5f32 / (0.5f32.powi(2) + 0.5f32.powi(2)).sqrt();
assert!(
(scores[3] - expected_45).abs() < 1e-5,
"45deg: expected {}, got {}",
expected_45,
scores[3]
);
}
#[test]
fn test_batch_cosine_scores_matches_individual() {
let query: Vec<f32> = (0..128).map(|i| (i as f32) * 0.1).collect();
let n = 50;
let dim = 128;
let vectors: Vec<f32> = (0..n * dim).map(|i| ((i * 7 + 3) as f32) * 0.01).collect();
let mut batch_scores = vec![0f32; n];
batch_cosine_scores(&query, &vectors, dim, &mut batch_scores);
for i in 0..n {
let vec_i = &vectors[i * dim..(i + 1) * dim];
let individual = cosine_similarity(&query, vec_i);
assert!(
(batch_scores[i] - individual).abs() < 1e-5,
"vec {}: batch={}, individual={}",
i,
batch_scores[i],
individual
);
}
}
#[test]
fn test_batch_cosine_scores_empty() {
let query = vec![1.0f32, 2.0, 3.0];
let vectors: Vec<f32> = vec![];
let mut scores: Vec<f32> = vec![];
batch_cosine_scores(&query, &vectors, 3, &mut scores);
assert!(scores.is_empty());
}
#[test]
fn test_batch_cosine_scores_zero_query() {
let query = vec![0.0f32, 0.0, 0.0];
let vectors = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut scores = vec![0f32; 2];
batch_cosine_scores(&query, &vectors, 3, &mut scores);
assert_eq!(scores[0], 0.0);
assert_eq!(scores[1], 0.0);
}
#[test]
fn test_squared_euclidean_distance() {
let a = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let b = vec![2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
let result = squared_euclidean_distance(&a, &b);
assert!(
(result - expected).abs() < 1e-5,
"expected {}, got {}",
expected,
result
);
}
#[test]
fn test_squared_euclidean_distance_large() {
let a: Vec<f32> = (0..128).map(|i| i as f32 * 0.1).collect();
let b: Vec<f32> = (0..128).map(|i| (i as f32 * 0.1) + 0.5).collect();
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum();
let result = squared_euclidean_distance(&a, &b);
assert!(
(result - expected).abs() < 1e-3,
"expected {}, got {}",
expected,
result
);
}
#[test]
fn test_f16_roundtrip_normal() {
for &v in &[0.0f32, 1.0, -1.0, 0.5, -0.5, 0.333, 65504.0] {
let h = f32_to_f16(v);
let back = f16_to_f32(h);
let err = (back - v).abs() / v.abs().max(1e-6);
assert!(
err < 0.002,
"f16 roundtrip {v} → {h:#06x} → {back}, rel err {err}"
);
}
}
#[test]
fn test_f16_special() {
assert_eq!(f16_to_f32(f32_to_f16(0.0)), 0.0);
assert_eq!(f32_to_f16(-0.0), 0x8000);
assert!(f16_to_f32(f32_to_f16(f32::INFINITY)).is_infinite());
assert!(f16_to_f32(f32_to_f16(f32::NAN)).is_nan());
}
#[test]
fn test_f16_embedding_range() {
let values: Vec<f32> = (-100..=100).map(|i| i as f32 / 100.0).collect();
for &v in &values {
let back = f16_to_f32(f32_to_f16(v));
assert!((back - v).abs() < 0.001, "f16 error for {v}: got {back}");
}
}
#[test]
fn test_u8_roundtrip() {
assert_eq!(f32_to_u8_saturating(-1.0), 0);
assert_eq!(f32_to_u8_saturating(1.0), 255);
assert_eq!(f32_to_u8_saturating(0.0), 127);
assert_eq!(f32_to_u8_saturating(-2.0), 0);
assert_eq!(f32_to_u8_saturating(2.0), 255);
}
#[test]
fn test_u8_dequantize() {
assert!((u8_to_f32(0) - (-1.0)).abs() < 0.01);
assert!((u8_to_f32(255) - 1.0).abs() < 0.01);
assert!((u8_to_f32(127) - 0.0).abs() < 0.01);
}
#[test]
fn test_batch_cosine_scores_f16() {
let query = vec![0.6f32, 0.8, 0.0, 0.0];
let dim = 4;
let vecs_f32 = vec![
0.6f32, 0.8, 0.0, 0.0, 0.0, 0.0, 0.6, 0.8, ];
let mut f16_buf = vec![0u16; 8];
batch_f32_to_f16(&vecs_f32, &mut f16_buf);
let raw: &[u8] =
unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
let mut scores = vec![0f32; 2];
batch_cosine_scores_f16(&query, raw, dim, &mut scores);
assert!(
(scores[0] - 1.0).abs() < 0.01,
"identical vectors: {}",
scores[0]
);
assert!(scores[1].abs() < 0.01, "orthogonal vectors: {}", scores[1]);
}
#[test]
fn test_batch_cosine_scores_u8() {
let query = vec![0.6f32, 0.8, 0.0, 0.0];
let dim = 4;
let vecs_f32 = vec![
0.6f32, 0.8, 0.0, 0.0, -0.6, -0.8, 0.0, 0.0, ];
let mut u8_buf = vec![0u8; 8];
batch_f32_to_u8(&vecs_f32, &mut u8_buf);
let mut scores = vec![0f32; 2];
batch_cosine_scores_u8(&query, &u8_buf, dim, &mut scores);
assert!(scores[0] > 0.95, "similar vectors: {}", scores[0]);
assert!(scores[1] < -0.95, "opposite vectors: {}", scores[1]);
}
#[test]
fn test_batch_cosine_scores_f16_large_dim() {
let dim = 768;
let query: Vec<f32> = (0..dim).map(|i| (i as f32 / dim as f32) - 0.5).collect();
let vec2: Vec<f32> = query.iter().map(|x| x * 0.9 + 0.01).collect();
let mut all_vecs = query.clone();
all_vecs.extend_from_slice(&vec2);
let mut f16_buf = vec![0u16; all_vecs.len()];
batch_f32_to_f16(&all_vecs, &mut f16_buf);
let raw: &[u8] =
unsafe { std::slice::from_raw_parts(f16_buf.as_ptr() as *const u8, f16_buf.len() * 2) };
let mut scores = vec![0f32; 2];
batch_cosine_scores_f16(&query, raw, dim, &mut scores);
assert!((scores[0] - 1.0).abs() < 0.01, "self-sim: {}", scores[0]);
assert!(scores[1] > 0.99, "scaled-sim: {}", scores[1]);
}
#[test]
fn test_hamming_distance_identical() {
let a = vec![0xAA; 64];
assert_eq!(hamming_distance(&a, &a), 0);
}
#[test]
fn test_hamming_distance_opposite() {
let a = vec![0xFF; 32];
let b = vec![0x00; 32];
assert_eq!(hamming_distance(&a, &b), 256);
}
#[test]
fn test_hamming_distance_known() {
let a = vec![0xAA];
let b = vec![0x55];
assert_eq!(hamming_distance(&a, &b), 8);
let a = vec![0xFF, 0x00];
let b = vec![0x00, 0x00];
assert_eq!(hamming_distance(&a, &b), 8);
}
#[test]
fn test_hamming_distance_single_bit() {
let a = vec![0x00; 16];
let mut b = vec![0x00; 16];
b[7] = 0x01; assert_eq!(hamming_distance(&a, &b), 1);
}
#[test]
fn test_hamming_distance_empty() {
let a: Vec<u8> = vec![];
assert_eq!(hamming_distance(&a, &a), 0);
}
#[test]
fn test_hamming_distance_remainder_path() {
let a = vec![0xFF; 17];
let b = vec![0x00; 17];
assert_eq!(hamming_distance(&a, &b), 136);
let a = vec![0xFF; 33];
let b = vec![0x00; 33];
assert_eq!(hamming_distance(&a, &b), 264); }
#[test]
fn test_hamming_distance_large() {
let a = vec![0xFF; 4096];
let b = vec![0x00; 4096];
assert_eq!(hamming_distance(&a, &b), 32768);
}
#[test]
fn test_hamming_distance_scalar_matches() {
for size in [1, 7, 8, 15, 16, 31, 32, 63, 64, 100, 128, 255, 256] {
let a: Vec<u8> = (0..size).map(|i| (i * 37 + 13) as u8).collect();
let b: Vec<u8> = (0..size).map(|i| (i * 53 + 7) as u8).collect();
let expected = hamming_distance_scalar(&a, &b);
let got = hamming_distance(&a, &b);
assert_eq!(got, expected, "mismatch at size {size}");
}
}
#[test]
fn test_batch_hamming_scores_identical() {
let query = vec![0xAA; 16];
let db = vec![0xAA; 16]; let mut scores = vec![0f32; 1];
batch_hamming_scores(&query, &db, 16, 128, &mut scores);
assert!((scores[0] - 1.0).abs() < 1e-6, "identical: {}", scores[0]);
}
#[test]
fn test_batch_hamming_scores_opposite() {
let query = vec![0xFF; 16];
let db = vec![0x00; 16];
let mut scores = vec![0f32; 1];
batch_hamming_scores(&query, &db, 16, 128, &mut scores);
assert!((scores[0] - 0.0).abs() < 1e-6, "opposite: {}", scores[0]);
}
#[test]
fn test_batch_hamming_scores_multiple() {
let byte_len = 8;
let dim_bits = 64;
let query = vec![0xFF; byte_len];
let mut db = Vec::new();
db.extend_from_slice(&vec![0xFF; byte_len]); db.extend_from_slice(&vec![0x00; byte_len]); db.extend_from_slice(&vec![0x0F; byte_len]);
let mut scores = vec![0f32; 3];
batch_hamming_scores(&query, &db, byte_len, dim_bits, &mut scores);
assert!((scores[0] - 1.0).abs() < 1e-6, "identical: {}", scores[0]);
assert!((scores[1] - 0.0).abs() < 1e-6, "opposite: {}", scores[1]);
assert!((scores[2] - 0.5).abs() < 1e-6, "half: {}", scores[2]);
}
#[test]
fn test_batch_hamming_scores_empty() {
let query = vec![0xFF; 8];
let db: Vec<u8> = vec![];
let mut scores: Vec<f32> = vec![];
batch_hamming_scores(&query, &db, 8, 64, &mut scores);
assert!(scores.is_empty());
}
#[test]
fn test_batch_hamming_scores_zero_byte_len() {
let query: Vec<u8> = vec![];
let db: Vec<u8> = vec![];
let mut scores = vec![0f32; 1];
batch_hamming_scores(&query, &db, 0, 0, &mut scores);
assert_eq!(scores[0], 0.0);
}
}
#[inline]
pub fn find_first_ge_u32(slice: &[u32], target: u32) -> usize {
#[cfg(target_arch = "aarch64")]
{
if neon::is_available() {
return unsafe { find_first_ge_u32_neon(slice, target) };
}
}
#[cfg(target_arch = "x86_64")]
{
if sse::is_available() {
return unsafe { find_first_ge_u32_sse(slice, target) };
}
}
slice.partition_point(|&d| d < target)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn find_first_ge_u32_neon(slice: &[u32], target: u32) -> usize {
use std::arch::aarch64::*;
let n = slice.len();
let ptr = slice.as_ptr();
let target_vec = vdupq_n_u32(target);
let bit_mask: uint32x4_t = core::mem::transmute([1u32, 2u32, 4u32, 8u32]);
let chunks = n / 16;
let mut base = 0usize;
for _ in 0..chunks {
let v0 = vld1q_u32(ptr.add(base));
let v1 = vld1q_u32(ptr.add(base + 4));
let v2 = vld1q_u32(ptr.add(base + 8));
let v3 = vld1q_u32(ptr.add(base + 12));
let c0 = vcgeq_u32(v0, target_vec);
let c1 = vcgeq_u32(v1, target_vec);
let c2 = vcgeq_u32(v2, target_vec);
let c3 = vcgeq_u32(v3, target_vec);
let m0 = vaddvq_u32(vandq_u32(c0, bit_mask));
if m0 != 0 {
return base + m0.trailing_zeros() as usize;
}
let m1 = vaddvq_u32(vandq_u32(c1, bit_mask));
if m1 != 0 {
return base + 4 + m1.trailing_zeros() as usize;
}
let m2 = vaddvq_u32(vandq_u32(c2, bit_mask));
if m2 != 0 {
return base + 8 + m2.trailing_zeros() as usize;
}
let m3 = vaddvq_u32(vandq_u32(c3, bit_mask));
if m3 != 0 {
return base + 12 + m3.trailing_zeros() as usize;
}
base += 16;
}
while base + 4 <= n {
let vals = vld1q_u32(ptr.add(base));
let cmp = vcgeq_u32(vals, target_vec);
let mask = vaddvq_u32(vandq_u32(cmp, bit_mask));
if mask != 0 {
return base + mask.trailing_zeros() as usize;
}
base += 4;
}
while base < n {
if *slice.get_unchecked(base) >= target {
return base;
}
base += 1;
}
n
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn find_first_ge_u32_sse(slice: &[u32], target: u32) -> usize {
use std::arch::x86_64::*;
let n = slice.len();
let ptr = slice.as_ptr();
let sign_flip = _mm_set1_epi32(i32::MIN);
let target_xor = _mm_xor_si128(_mm_set1_epi32(target as i32), sign_flip);
let chunks = n / 16;
let mut base = 0usize;
for _ in 0..chunks {
let v0 = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
let v1 = _mm_xor_si128(
_mm_loadu_si128(ptr.add(base + 4) as *const __m128i),
sign_flip,
);
let v2 = _mm_xor_si128(
_mm_loadu_si128(ptr.add(base + 8) as *const __m128i),
sign_flip,
);
let v3 = _mm_xor_si128(
_mm_loadu_si128(ptr.add(base + 12) as *const __m128i),
sign_flip,
);
let ge0 = _mm_or_si128(
_mm_cmpeq_epi32(v0, target_xor),
_mm_cmpgt_epi32(v0, target_xor),
);
let m0 = _mm_movemask_ps(_mm_castsi128_ps(ge0)) as u32;
if m0 != 0 {
return base + m0.trailing_zeros() as usize;
}
let ge1 = _mm_or_si128(
_mm_cmpeq_epi32(v1, target_xor),
_mm_cmpgt_epi32(v1, target_xor),
);
let m1 = _mm_movemask_ps(_mm_castsi128_ps(ge1)) as u32;
if m1 != 0 {
return base + 4 + m1.trailing_zeros() as usize;
}
let ge2 = _mm_or_si128(
_mm_cmpeq_epi32(v2, target_xor),
_mm_cmpgt_epi32(v2, target_xor),
);
let m2 = _mm_movemask_ps(_mm_castsi128_ps(ge2)) as u32;
if m2 != 0 {
return base + 8 + m2.trailing_zeros() as usize;
}
let ge3 = _mm_or_si128(
_mm_cmpeq_epi32(v3, target_xor),
_mm_cmpgt_epi32(v3, target_xor),
);
let m3 = _mm_movemask_ps(_mm_castsi128_ps(ge3)) as u32;
if m3 != 0 {
return base + 12 + m3.trailing_zeros() as usize;
}
base += 16;
}
while base + 4 <= n {
let vals = _mm_xor_si128(_mm_loadu_si128(ptr.add(base) as *const __m128i), sign_flip);
let ge = _mm_or_si128(
_mm_cmpeq_epi32(vals, target_xor),
_mm_cmpgt_epi32(vals, target_xor),
);
let mask = _mm_movemask_ps(_mm_castsi128_ps(ge)) as u32;
if mask != 0 {
return base + mask.trailing_zeros() as usize;
}
base += 4;
}
while base < n {
if *slice.get_unchecked(base) >= target {
return base;
}
base += 1;
}
n
}
#[cfg(test)]
mod find_first_ge_tests {
use super::find_first_ge_u32;
#[test]
fn test_find_first_ge_basic() {
let data: Vec<u32> = (0..128).map(|i| i * 3).collect(); assert_eq!(find_first_ge_u32(&data, 0), 0);
assert_eq!(find_first_ge_u32(&data, 1), 1); assert_eq!(find_first_ge_u32(&data, 3), 1);
assert_eq!(find_first_ge_u32(&data, 4), 2); assert_eq!(find_first_ge_u32(&data, 381), 127);
assert_eq!(find_first_ge_u32(&data, 382), 128); }
#[test]
fn test_find_first_ge_matches_partition_point() {
let data: Vec<u32> = vec![1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75];
for target in 0..80 {
let expected = data.partition_point(|&d| d < target);
let actual = find_first_ge_u32(&data, target);
assert_eq!(actual, expected, "target={}", target);
}
}
#[test]
fn test_find_first_ge_small_slices() {
assert_eq!(find_first_ge_u32(&[], 5), 0);
assert_eq!(find_first_ge_u32(&[10], 5), 0);
assert_eq!(find_first_ge_u32(&[10], 10), 0);
assert_eq!(find_first_ge_u32(&[10], 11), 1);
assert_eq!(find_first_ge_u32(&[2, 4, 6], 5), 2);
}
#[test]
fn test_find_first_ge_full_block() {
let data: Vec<u32> = (100..228).collect();
assert_eq!(find_first_ge_u32(&data, 100), 0);
assert_eq!(find_first_ge_u32(&data, 150), 50);
assert_eq!(find_first_ge_u32(&data, 227), 127);
assert_eq!(find_first_ge_u32(&data, 228), 128);
assert_eq!(find_first_ge_u32(&data, 99), 0);
}
#[test]
fn test_find_first_ge_u32_max() {
let data = vec![u32::MAX - 10, u32::MAX - 5, u32::MAX - 1, u32::MAX];
assert_eq!(find_first_ge_u32(&data, u32::MAX - 10), 0);
assert_eq!(find_first_ge_u32(&data, u32::MAX - 7), 1);
assert_eq!(find_first_ge_u32(&data, u32::MAX), 3);
}
}