use super::simd::{dot_product_adc, l2_squared_adc};
use super::vector::Int8QuantizedVector;
use hnsw_rs::prelude::Distance;
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
thread_local! {
static ADC_QUERY_CONTEXT: RefCell<Option<AdcQueryContext>> = const { RefCell::new(None) };
}
#[derive(Debug, Clone)]
pub struct AdcQueryContext {
pub query: Vec<f32>,
pub sum: f32,
pub norm_sq: f32,
pub metric: AdcDistanceMetric,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum AdcDistanceMetric {
#[default]
Cosine,
L2Squared,
Dot,
}
impl AdcQueryContext {
pub fn new(query: Vec<f32>, metric: AdcDistanceMetric) -> Self {
let sum: f32 = query.iter().sum();
let norm_sq: f32 = query.iter().map(|&x| x * x).sum();
Self {
query,
sum,
norm_sq,
metric,
}
}
pub fn with_query(query: Vec<f32>) -> Self {
Self::new(query, AdcDistanceMetric::default())
}
}
pub fn set_adc_query_context(query: &[f32], metric: AdcDistanceMetric) {
ADC_QUERY_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = Some(AdcQueryContext::new(query.to_vec(), metric));
});
}
pub fn clear_adc_query_context() {
ADC_QUERY_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = None;
});
}
pub fn with_adc_query_context<F, R>(f: F) -> Option<R>
where
F: FnOnce(&AdcQueryContext) -> R,
{
ADC_QUERY_CONTEXT.with(|ctx| ctx.borrow().as_ref().map(f))
}
pub fn has_adc_query_context() -> bool {
ADC_QUERY_CONTEXT.with(|ctx| ctx.borrow().is_some())
}
fn asymmetric_cosine_distance(query_ctx: &AdcQueryContext, qvec: &Int8QuantizedVector) -> f32 {
let dot_adc = dot_product_adc(&query_ctx.query, qvec, query_ctx.sum);
let query_norm = query_ctx.norm_sq.sqrt();
let stored_norm = qvec.metadata.norm();
if query_norm == 0.0 || stored_norm == 0.0 {
return 1.0;
}
let cosine_sim = dot_adc / (query_norm * stored_norm);
(1.0 - cosine_sim).max(0.0)
}
fn asymmetric_l2_squared_distance(query_ctx: &AdcQueryContext, qvec: &Int8QuantizedVector) -> f32 {
l2_squared_adc(&query_ctx.query, qvec, query_ctx.sum, query_ctx.norm_sq)
}
fn asymmetric_dot_distance(query_ctx: &AdcQueryContext, qvec: &Int8QuantizedVector) -> f32 {
let dot_adc = dot_product_adc(&query_ctx.query, qvec, query_ctx.sum);
((1.0 - dot_adc) / 2.0f32).max(0.0f32)
}
impl Distance<Int8QuantizedVector> for AdcDistanceMetric {
fn eval(&self, query: &[Int8QuantizedVector], stored: &[Int8QuantizedVector]) -> f32 {
if let Some(ctx) = with_adc_query_context(|ctx| ctx.clone()) {
match ctx.metric {
AdcDistanceMetric::Cosine => asymmetric_cosine_distance(&ctx, &stored[0]),
AdcDistanceMetric::L2Squared => asymmetric_l2_squared_distance(&ctx, &stored[0]),
AdcDistanceMetric::Dot => asymmetric_dot_distance(&ctx, &stored[0]),
}
} else {
let v1 = &query[0];
let v2 = &stored[0];
match self {
AdcDistanceMetric::Cosine => {
let d1 = v1.to_f32();
let d2 = v2.to_f32();
let mut dot = 0.0;
let mut n1 = 0.0;
let mut n2 = 0.0;
for i in 0..d1.len() {
dot += d1[i] * d2[i];
n1 += d1[i] * d1[i];
n2 += d2[i] * d2[i];
}
if n1 == 0.0 || n2 == 0.0 {
return 1.0;
}
(1.0 - dot / (n1.sqrt() * n2.sqrt())).max(0.0)
}
AdcDistanceMetric::L2Squared => {
let d1 = v1.to_f32();
let d2 = v2.to_f32();
d1.iter()
.zip(d2.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum()
}
AdcDistanceMetric::Dot => {
let d1 = v1.to_f32();
let d2 = v2.to_f32();
(1.0 - d1.iter().zip(d2.iter()).map(|(a, b)| a * b).sum::<f32>()).max(0.0)
}
}
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Int8AdcDistance {
pub metric: AdcDistanceMetric,
}
impl Int8AdcDistance {
pub fn new(metric: AdcDistanceMetric) -> Self {
Self { metric }
}
pub fn cosine() -> Self {
Self::new(AdcDistanceMetric::Cosine)
}
pub fn l2_squared() -> Self {
Self::new(AdcDistanceMetric::L2Squared)
}
pub fn dot() -> Self {
Self::new(AdcDistanceMetric::Dot)
}
}
impl Distance<Int8QuantizedVector> for Int8AdcDistance {
fn eval(&self, query: &[Int8QuantizedVector], stored: &[Int8QuantizedVector]) -> f32 {
self.metric.eval(query, stored)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::search::quantization::Quantize;
#[test]
fn test_adc_query_context_creation() {
let query = vec![0.1, 0.2, 0.3, 0.4];
let ctx = AdcQueryContext::with_query(query.clone());
assert_eq!(ctx.query, query);
assert!((ctx.sum - 1.0).abs() < 1e-6);
assert!((ctx.norm_sq - 0.3).abs() < 1e-6);
}
#[test]
fn test_set_and_clear_adc_context() {
assert!(!has_adc_query_context());
let query = vec![0.1, 0.2, 0.3];
set_adc_query_context(&query, AdcDistanceMetric::Cosine);
assert!(has_adc_query_context());
with_adc_query_context(|ctx| {
assert_eq!(ctx.query, vec![0.1, 0.2, 0.3]);
});
clear_adc_query_context();
assert!(!has_adc_query_context());
}
#[test]
fn test_int8_adc_distance_trait() {
let query = vec![1.0, 0.0, 0.0, 0.0];
set_adc_query_context(&query, AdcDistanceMetric::Cosine);
let stored_f32 = vec![0.5, 0.5, 0.5, 0.5];
let stored = stored_f32.quantize();
let distance_fn = Int8AdcDistance::cosine();
let distance =
distance_fn.eval(std::slice::from_ref(&stored), std::slice::from_ref(&stored));
assert!(distance >= 0.0 && distance <= 1.0);
clear_adc_query_context();
}
}