use super::super::vector::Int8QuantizedVector;
use std::arch::x86_64::*;
#[target_feature(enable = "avx2")]
pub unsafe fn dot_product_adc(query: &[f32], qvec: &Int8QuantizedVector, query_sum: f32) -> f32 {
let q_slice = qvec.as_slice();
let dimension = qvec.len();
let n_chunks = dimension / 16;
let mut acc0 = _mm256_setzero_ps();
let mut acc1 = _mm256_setzero_ps();
for i in 0..n_chunks {
if dimension > 1536 && i + 4 < n_chunks {
let prefetch_ptr = q_slice.as_ptr().add((i + 4) * 16);
_mm_prefetch(prefetch_ptr, _MM_HINT_T0);
}
let i8_ptr = q_slice.as_ptr().add(i * 16) as *const __m128i;
let i8_vec = _mm_loadu_si128(i8_ptr);
let i16_vec = _mm256_cvtepi8_epi16(i8_vec);
let low_i16 = _mm256_extracti128_si256(i16_vec, 0);
let high_i16 = _mm256_extracti128_si256(i16_vec, 1);
let low_i32 = _mm256_cvtepi16_epi32(low_i16);
let high_i32 = _mm256_cvtepi16_epi32(high_i16);
let low_f32 = _mm256_cvtepi32_ps(low_i32);
let high_f32 = _mm256_cvtepi32_ps(high_i32);
let q_ptr = query.as_ptr().add(i * 16);
let q_low = _mm256_loadu_ps(q_ptr);
let q_high = _mm256_loadu_ps(q_ptr.add(8));
acc0 = _mm256_fmadd_ps(q_low, low_f32, acc0);
acc1 = _mm256_fmadd_ps(q_high, high_f32, acc1);
}
let sum01 = _mm256_add_ps(acc0, acc1);
let low128 = _mm256_castps256_ps128(sum01);
let high128 = _mm256_extractf128_ps(sum01, 1);
let sum128 = _mm_add_ps(low128, high128);
let shuffled = _mm_movehl_ps(sum128, sum128);
let sum64 = _mm_add_ps(sum128, shuffled);
let shuffled2 = _mm_shuffle_ps(sum64, sum64, 0x01);
let sum32 = _mm_add_ss(sum64, shuffled2);
let mut total = _mm_cvtss_f32(sum32);
for (i, &q_val) in query.iter().enumerate().skip(n_chunks * 16) {
debug_assert!(
i < q_slice.len(),
"Index {} out of bounds for q_slice of len {}",
i,
q_slice.len()
);
total += q_val * (*q_slice.get_unchecked(i) as f32);
}
(total - qvec.metadata.bias * query_sum) / qvec.metadata.scale
}
#[target_feature(enable = "avx2")]
pub unsafe fn l2_squared_distance(
query: &[f32],
qvec: &Int8QuantizedVector,
query_sum: f32,
query_norm_sq: f32,
) -> f32 {
let dot_adc = dot_product_adc(query, qvec, query_sum);
(query_norm_sq + qvec.metadata.squared_sum - 2.0 * dot_adc).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::quantization::{Dequantize, Quantize};
#[test]
#[cfg(target_arch = "x86_64")]
fn test_avx2_basic() {
if !is_x86_feature_detected!("avx2") {
return;
}
let original = vec![0.1f32, 0.5, -0.2, 0.8, 0.3, 0.9, -0.5, 0.4];
let query = vec![0.2f32, 0.4, 0.1, 0.7, -0.3, 0.6, 0.2, 0.5];
let qvec = original.quantize();
let query_sum: f32 = query.iter().sum();
let result = unsafe { dot_product_adc(&query, &qvec, query_sum) };
assert!(result.is_finite());
let dequantized = qvec.dequantize();
let expected: f32 = query
.iter()
.zip(dequantized.iter())
.map(|(q, d)| q * d)
.sum();
assert!(
(result - expected).abs() < 1e-3,
"AVX2 result {} doesn't match expected {}",
result,
expected
);
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_avx2_various_dimensions() {
if !is_x86_feature_detected!("avx2") {
return;
}
for dim in [
1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129,
] {
let original: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
let qvec = original.quantize();
let query_sum: f32 = query.iter().sum();
let result = unsafe { dot_product_adc(&query, &qvec, query_sum) };
assert!(result.is_finite(), "Failed for dimension {}", dim);
}
}
#[test]
#[cfg(target_arch = "x86_64")]
fn test_avx2_matches_fallback() {
if !is_x86_feature_detected!("avx2") {
return;
}
use super::super::fallback;
for dim in [8, 16, 32, 64, 128, 256] {
let original: Vec<f32> = (0..dim).map(|i| ((i * 7) % 100) as f32 * 0.01).collect();
let query: Vec<f32> = (0..dim).map(|i| ((i * 13) % 100) as f32 * 0.01).collect();
let qvec = original.quantize();
let query_sum: f32 = query.iter().sum();
let avx2_result = unsafe { dot_product_adc(&query, &qvec, query_sum) };
let fallback_result = fallback::dot_product_adc(&query, &qvec, query_sum);
assert!(
(avx2_result - fallback_result).abs() < 1e-4,
"Dimension {}: AVX2 {} != Fallback {}",
dim,
avx2_result,
fallback_result
);
}
}
}