use super::vector::Int8QuantizedVector;
use std::sync::OnceLock;
pub mod fallback;
#[cfg(target_arch = "x86_64")]
pub mod x86_avx2;
pub type DotProductFn = fn(&[f32], &Int8QuantizedVector, f32) -> f32;
static DOT_PRODUCT_ADC_IMPL: OnceLock<DotProductFn> = OnceLock::new();
#[cfg(target_arch = "x86_64")]
fn avx2_dot_product_adc_safe(query: &[f32], qvec: &Int8QuantizedVector, query_sum: f32) -> f32 {
unsafe { x86_avx2::dot_product_adc(query, qvec, query_sum) }
}
fn get_best_dot_product_impl() -> DotProductFn {
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return avx2_dot_product_adc_safe;
}
}
fallback::dot_product_adc
}
pub fn dot_product_adc(query: &[f32], qvec: &Int8QuantizedVector, query_sum: f32) -> f32 {
let implementation = DOT_PRODUCT_ADC_IMPL.get_or_init(get_best_dot_product_impl);
implementation(query, qvec, query_sum)
}
pub fn l2_squared_adc(
query: &[f32],
qvec: &Int8QuantizedVector,
query_sum: f32,
query_norm_sq: f32,
) -> f32 {
let dot_adc = dot_product_adc(query, qvec, query_sum);
(query_norm_sq + qvec.metadata.squared_sum - 2.0 * dot_adc).max(0.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::quantization::{Dequantize, Quantize};
#[test]
fn test_dot_product_adc_basic() {
let original = vec![0.1f32, 0.5, -0.2, 0.8];
let query = vec![0.2f32, 0.4, 0.1, 0.7];
let qvec = original.quantize();
let query_sum: f32 = query.iter().sum();
let result = dot_product_adc(&query, &qvec, query_sum);
assert!(result.is_finite());
let dequantized = qvec.dequantize();
let expected_dot: f32 = query
.iter()
.zip(dequantized.iter())
.map(|(q, d)| q * d)
.sum();
assert!((result - expected_dot).abs() < 1e-3);
}
#[test]
fn test_various_dimensions() {
for dim in [
1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129,
] {
let original: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.01).collect();
let query: Vec<f32> = (0..dim).map(|i| (i as f32) * 0.02).collect();
let qvec = original.quantize();
let query_sum: f32 = query.iter().sum();
let result = dot_product_adc(&query, &qvec, query_sum);
assert!(result.is_finite(), "Failed for dimension {}", dim);
}
}
}