#![cfg_attr(not(feature = "persistence"), allow(dead_code))]
#[allow(unused_imports)] use super::dispatch::{simd_level, SimdLevel};
#[cfg(target_arch = "x86_64")]
use super::reduction::hsum_avx256;
pub(crate) fn adc_distances_batch(
lut: &[f32],
codes: &[&[u16]],
m: usize,
) -> crate::error::Result<Vec<f32>> {
if m == 0 {
return Err(crate::error::Error::InvalidVector(
"ADC subspace count m must be > 0".into(),
));
}
if !lut.len().is_multiple_of(m) {
return Err(crate::error::Error::InvalidVector(format!(
"ADC lookup table length {} is not divisible by m={}",
lut.len(),
m
)));
}
let k = lut.len() / m;
Ok(match simd_level() {
#[cfg(target_arch = "x86_64")]
SimdLevel::Avx2 | SimdLevel::Avx512 => {
codes
.iter()
.enumerate()
.map(|(i, c)| {
if i + 1 < codes.len() {
super::prefetch::prefetch_vector_from_u16(codes[i + 1]);
}
unsafe { adc_single_avx2(lut, c, m, k) }
})
.collect()
}
#[cfg(target_arch = "aarch64")]
SimdLevel::Neon => {
codes
.iter()
.enumerate()
.map(|(i, c)| {
if i + 1 < codes.len() {
super::prefetch::prefetch_vector_from_u16(codes[i + 1]);
}
unsafe { adc_single_neon(lut, c, m, k) }
})
.collect()
}
_ => adc_batch_scalar(lut, codes, m, k),
})
}
fn adc_batch_scalar(lut: &[f32], codes: &[&[u16]], m: usize, k: usize) -> Vec<f32> {
codes
.iter()
.map(|code| adc_single_scalar(lut, code, m, k))
.collect()
}
#[inline]
fn adc_single_scalar(lut: &[f32], code: &[u16], m: usize, k: usize) -> f32 {
(0..m)
.map(|subspace| {
let idx = subspace * k + usize::from(code[subspace]);
lut[idx]
})
.sum()
}
#[cfg(target_arch = "x86_64")]
#[inline]
fn lane_index(code: &[u16], subspace: usize, k: usize) -> i32 {
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
let idx = (subspace * k + usize::from(code[subspace])) as i32;
idx
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn adc_single_avx2(lut: &[f32], code: &[u16], m: usize, k: usize) -> f32 {
use std::arch::x86_64::{
__m256, _mm256_add_ps, _mm256_i32gather_ps, _mm256_setr_epi32, _mm256_setzero_ps,
};
debug_assert_eq!(code.len(), m, "code length must equal m");
debug_assert!(
code.iter().all(|&c| usize::from(c) < k),
"PQ code out of range: all codes must be < k ({k})"
);
let full_chunks = m / 8;
let mut acc: __m256 = _mm256_setzero_ps();
for chunk in 0..full_chunks {
let base = chunk * 8;
let indices = _mm256_setr_epi32(
lane_index(code, base, k),
lane_index(code, base + 1, k),
lane_index(code, base + 2, k),
lane_index(code, base + 3, k),
lane_index(code, base + 4, k),
lane_index(code, base + 5, k),
lane_index(code, base + 6, k),
lane_index(code, base + 7, k),
);
let gathered = _mm256_i32gather_ps::<4>(lut.as_ptr(), indices);
acc = _mm256_add_ps(acc, gathered);
}
let mut total = hsum_avx256(acc);
#[allow(clippy::needless_range_loop)]
for subspace in (full_chunks * 8)..m {
let idx = subspace * k + usize::from(code[subspace]);
total += lut[idx];
}
total
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
unsafe fn adc_single_neon(lut: &[f32], code: &[u16], m: usize, k: usize) -> f32 {
use std::arch::aarch64::*;
debug_assert_eq!(code.len(), m, "code length must equal m");
debug_assert!(
code.iter().all(|&c| usize::from(c) < k),
"PQ code out of range: all codes must be < k ({k})"
);
let full_chunks = m / 4;
let tail = m % 4;
let mut acc = vdupq_n_f32(0.0);
for chunk in 0..full_chunks {
let base = chunk * 4;
let vals: [f32; 4] = [
*lut.get_unchecked((base) * k + usize::from(*code.get_unchecked(base))),
*lut.get_unchecked((base + 1) * k + usize::from(*code.get_unchecked(base + 1))),
*lut.get_unchecked((base + 2) * k + usize::from(*code.get_unchecked(base + 2))),
*lut.get_unchecked((base + 3) * k + usize::from(*code.get_unchecked(base + 3))),
];
let v = vld1q_f32(vals.as_ptr());
acc = vaddq_f32(acc, v);
}
let mut total = vaddvq_f32(acc);
let tail_start = full_chunks * 4;
for subspace in tail_start..tail_start + tail {
let idx = subspace * k + usize::from(code[subspace]);
total += lut[idx];
}
total
}
#[cfg(test)]
mod tests {
use super::*;
fn make_sequential_lut(m: usize, k: usize) -> Vec<f32> {
(0..m * k)
.map(|i| {
#[allow(clippy::cast_precision_loss)]
let v = i as f32;
v
})
.collect()
}
#[test]
fn adc_scalar_correct_sum() {
let m = 4;
let k = 4;
let lut = make_sequential_lut(m, k);
let codes: Vec<u16> = vec![0, 1, 2, 3];
let codes_ref: Vec<&[u16]> = vec![codes.as_slice()];
let result = adc_distances_batch(&lut, &codes_ref, m).expect("test: valid ADC input");
assert_eq!(result.len(), 1);
assert!(
(result[0] - 30.0).abs() < 1e-6,
"expected 30.0, got {}",
result[0]
);
}
#[test]
fn adc_batch_multiple_codes() {
let m = 2;
let k = 4;
let lut = make_sequential_lut(m, k);
let c1: Vec<u16> = vec![0, 0];
let c2: Vec<u16> = vec![3, 3];
let codes_ref: Vec<&[u16]> = vec![c1.as_slice(), c2.as_slice()];
let result = adc_distances_batch(&lut, &codes_ref, m).expect("test: valid ADC input");
assert_eq!(result.len(), 2);
assert!((result[0] - 4.0).abs() < 1e-6);
assert!((result[1] - 10.0).abs() < 1e-6);
}
#[test]
fn adc_m8_k256_standard_config() {
let m = 8;
let k = 256;
let lut = make_sequential_lut(m, k);
let codes: Vec<u16> = vec![0; 8];
let codes_ref: Vec<&[u16]> = vec![codes.as_slice()];
let result = adc_distances_batch(&lut, &codes_ref, m).expect("test: valid ADC input");
assert!(
(result[0] - 7168.0).abs() < 1e-2,
"expected 7168.0, got {}",
result[0]
);
}
#[test]
fn adc_m_not_divisible_by_8() {
let m = 5;
let k = 4;
let lut = make_sequential_lut(m, k);
let codes: Vec<u16> = vec![1, 1, 1, 1, 1];
let codes_ref: Vec<&[u16]> = vec![codes.as_slice()];
let result = adc_distances_batch(&lut, &codes_ref, m).expect("test: valid ADC input");
assert!(
(result[0] - 45.0).abs() < 1e-6,
"expected 45.0, got {}",
result[0]
);
}
#[test]
fn adc_lut_size_m8_k256() {
let m = 8;
let k = 256;
let lut = make_sequential_lut(m, k);
assert_eq!(lut.len() * std::mem::size_of::<f32>(), 8192);
}
#[test]
fn adc_avx2_matches_scalar() {
let m = 8;
let k = 16;
let lut = make_sequential_lut(m, k);
let codes: Vec<u16> = vec![3, 7, 1, 15, 0, 8, 12, 5];
let codes_ref: Vec<&[u16]> = vec![codes.as_slice()];
let scalar_result = adc_batch_scalar(&lut, &codes_ref, m, k);
let dispatch_result =
adc_distances_batch(&lut, &codes_ref, m).expect("test: valid ADC input");
assert!(
(scalar_result[0] - dispatch_result[0]).abs() < 1e-4,
"SIMD dispatch ({}) != scalar ({}) beyond f32 epsilon",
dispatch_result[0],
scalar_result[0]
);
}
}