use std::f32::consts::PI;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
L2,
Cosine,
InnerProduct,
NegativeInnerProduct,
}
impl DistanceMetric {
pub fn higher_is_better(&self) -> bool {
matches!(self, Self::Cosine | Self::InnerProduct)
}
pub fn requires_normalization(&self) -> bool {
matches!(self, Self::Cosine)
}
}
#[derive(Debug, Clone)]
pub struct SphericalCapMetadata {
pub centroid: Vec<f32>,
pub theta_max: f32,
pub min_dot_to_centroid: f32,
pub max_dot_to_centroid: f32,
pub vector_count: u32,
pub mean_dot_to_centroid: f32,
}
impl SphericalCapMetadata {
pub fn from_vectors(vectors: &[Vec<f32>], centroid: &[f32]) -> Self {
if vectors.is_empty() {
return Self {
centroid: centroid.to_vec(),
theta_max: 0.0,
min_dot_to_centroid: 1.0,
max_dot_to_centroid: 1.0,
vector_count: 0,
mean_dot_to_centroid: 1.0,
};
}
let mut min_dot = f32::MAX;
let mut max_dot = f32::MIN;
let mut sum_dot = 0.0;
for v in vectors {
let dot = dot_product(v, centroid);
min_dot = min_dot.min(dot);
max_dot = max_dot.max(dot);
sum_dot += dot;
}
let clamped_min = min_dot.clamp(-1.0, 1.0);
let theta_max = clamped_min.acos();
Self {
centroid: centroid.to_vec(),
theta_max,
min_dot_to_centroid: min_dot,
max_dot_to_centroid: max_dot,
vector_count: vectors.len() as u32,
mean_dot_to_centroid: sum_dot / vectors.len() as f32,
}
}
pub fn from_flat_vectors(data: &[f32], dim: usize, centroid: &[f32]) -> Self {
let n_vectors = data.len() / dim;
if n_vectors == 0 {
return Self {
centroid: centroid.to_vec(),
theta_max: 0.0,
min_dot_to_centroid: 1.0,
max_dot_to_centroid: 1.0,
vector_count: 0,
mean_dot_to_centroid: 1.0,
};
}
let mut min_dot = f32::MAX;
let mut max_dot = f32::MIN;
let mut sum_dot = 0.0;
for i in 0..n_vectors {
let v = &data[i * dim..(i + 1) * dim];
let dot = dot_product(v, centroid);
min_dot = min_dot.min(dot);
max_dot = max_dot.max(dot);
sum_dot += dot;
}
let clamped_min = min_dot.clamp(-1.0, 1.0);
let theta_max = clamped_min.acos();
Self {
centroid: centroid.to_vec(),
theta_max,
min_dot_to_centroid: min_dot,
max_dot_to_centroid: max_dot,
vector_count: n_vectors as u32,
mean_dot_to_centroid: sum_dot / n_vectors as f32,
}
}
pub fn add_vector(&mut self, vector: &[f32]) {
let dot = dot_product(vector, &self.centroid);
let old_sum = self.mean_dot_to_centroid * self.vector_count as f32;
self.vector_count += 1;
self.mean_dot_to_centroid = (old_sum + dot) / self.vector_count as f32;
if dot < self.min_dot_to_centroid {
self.min_dot_to_centroid = dot;
self.theta_max = dot.clamp(-1.0, 1.0).acos();
}
if dot > self.max_dot_to_centroid {
self.max_dot_to_centroid = dot;
}
}
pub fn angular_radius(&self) -> f32 {
self.theta_max
}
pub fn angular_radius_degrees(&self) -> f32 {
self.theta_max * 180.0 / PI
}
pub fn tightness(&self) -> f32 {
1.0 - (self.theta_max / PI)
}
}
#[derive(Debug, Clone)]
pub struct L2ListMetadata {
pub centroid: Vec<f32>,
pub radius: f32,
pub mean_radius: f32,
pub vector_count: u32,
}
impl L2ListMetadata {
pub fn from_vectors(vectors: &[Vec<f32>], centroid: &[f32]) -> Self {
if vectors.is_empty() {
return Self {
centroid: centroid.to_vec(),
radius: 0.0,
mean_radius: 0.0,
vector_count: 0,
};
}
let mut max_dist = 0.0f32;
let mut sum_dist = 0.0;
for v in vectors {
let dist = l2_distance(v, centroid);
max_dist = max_dist.max(dist);
sum_dist += dist;
}
Self {
centroid: centroid.to_vec(),
radius: max_dist,
mean_radius: sum_dist / vectors.len() as f32,
vector_count: vectors.len() as u32,
}
}
pub fn lower_bound(&self, query: &[f32]) -> f32 {
let dist_to_centroid = l2_distance(query, &self.centroid);
(dist_to_centroid - self.radius).max(0.0)
}
}
pub struct ListBoundComputer<'a> {
query: &'a [f32],
query_norm: f32,
metric: DistanceMetric,
}
impl<'a> ListBoundComputer<'a> {
pub fn new(query: &'a [f32], metric: DistanceMetric) -> Self {
let query_norm = l2_norm(query);
Self {
query,
query_norm,
metric,
}
}
pub fn cosine_upper_bound(&self, metadata: &SphericalCapMetadata) -> f32 {
let query_dot_centroid = dot_product(self.query, &metadata.centroid);
let clamped = query_dot_centroid.clamp(-1.0, 1.0);
let angle_to_centroid = clamped.acos();
let min_angle = (angle_to_centroid - metadata.theta_max).max(0.0);
min_angle.cos()
}
pub fn l2_lower_bound(&self, metadata: &L2ListMetadata) -> f32 {
let dist_to_centroid = l2_distance(self.query, &metadata.centroid);
(dist_to_centroid - metadata.radius).max(0.0)
}
pub fn compute_bound(&self, cap: &SphericalCapMetadata, l2: Option<&L2ListMetadata>) -> f32 {
match self.metric {
DistanceMetric::Cosine | DistanceMetric::InnerProduct => self.cosine_upper_bound(cap),
DistanceMetric::L2 => {
if let Some(l2_meta) = l2 {
self.l2_lower_bound(l2_meta)
} else {
let ub = self.cosine_upper_bound(cap);
(2.0 - 2.0 * ub).max(0.0).sqrt()
}
}
DistanceMetric::NegativeInnerProduct => -self.cosine_upper_bound(cap),
}
}
}
#[derive(Debug, Clone)]
pub struct ListBound {
pub list_idx: u32,
pub bound: f32,
}
impl ListBound {
pub fn order_for_probing(bounds: &mut [ListBound], metric: DistanceMetric) {
match metric {
DistanceMetric::Cosine | DistanceMetric::InnerProduct => {
bounds.sort_by(|a, b| b.bound.partial_cmp(&a.bound).unwrap());
}
DistanceMetric::L2 | DistanceMetric::NegativeInnerProduct => {
bounds.sort_by(|a, b| a.bound.partial_cmp(&b.bound).unwrap());
}
}
}
pub fn can_terminate(
kth_score: f32,
best_remaining_bound: f32,
metric: DistanceMetric,
) -> bool {
match metric {
DistanceMetric::Cosine | DistanceMetric::InnerProduct => {
kth_score > best_remaining_bound
}
DistanceMetric::L2 | DistanceMetric::NegativeInnerProduct => {
kth_score < best_remaining_bound
}
}
}
}
#[derive(Debug, Clone)]
pub struct UnifiedListMetadata {
pub cap: SphericalCapMetadata,
pub l2: Option<L2ListMetadata>,
pub list_idx: u32,
}
impl UnifiedListMetadata {
pub fn new(list_idx: u32, cap: SphericalCapMetadata) -> Self {
Self {
cap,
l2: None,
list_idx,
}
}
pub fn with_l2(mut self, l2: L2ListMetadata) -> Self {
self.l2 = Some(l2);
self
}
}
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[inline]
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn normalize_inplace(v: &mut [f32]) {
let norm = l2_norm(v);
if norm > 1e-10 {
for x in v.iter_mut() {
*x /= norm;
}
}
}
pub fn normalize(v: &[f32]) -> Vec<f32> {
let norm = l2_norm(v);
if norm > 1e-10 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spherical_cap_metadata() {
let centroid = vec![1.0, 0.0, 0.0];
let vectors = vec![
normalize(&[1.0, 0.1, 0.0]),
normalize(&[1.0, -0.1, 0.0]),
normalize(&[1.0, 0.0, 0.1]),
normalize(&[1.0, 0.0, -0.1]),
];
let metadata = SphericalCapMetadata::from_vectors(&vectors, ¢roid);
assert!(metadata.theta_max > 0.0);
assert!(metadata.theta_max < PI / 4.0); assert!(metadata.tightness() > 0.5);
}
#[test]
fn test_cosine_upper_bound() {
let centroid = vec![1.0, 0.0, 0.0];
let metadata = SphericalCapMetadata {
centroid: centroid.clone(),
theta_max: 0.3, min_dot_to_centroid: 0.3_f32.cos(),
max_dot_to_centroid: 1.0,
vector_count: 10,
mean_dot_to_centroid: 0.95,
};
let query = vec![1.0, 0.0, 0.0];
let computer = ListBoundComputer::new(&query, DistanceMetric::Cosine);
let bound = computer.cosine_upper_bound(&metadata);
assert!((bound - 1.0).abs() < 0.01);
let query2 = vec![0.0, 1.0, 0.0];
let computer2 = ListBoundComputer::new(&query2, DistanceMetric::Cosine);
let bound2 = computer2.cosine_upper_bound(&metadata);
assert!((bound2 - 0.3_f32.sin()).abs() < 0.01);
}
#[test]
fn test_l2_lower_bound() {
let centroid = vec![0.0, 0.0, 0.0];
let metadata = L2ListMetadata {
centroid,
radius: 1.0,
mean_radius: 0.5,
vector_count: 100,
};
let query = vec![2.0, 0.0, 0.0];
let computer = ListBoundComputer::new(&query, DistanceMetric::L2);
let lb = computer.l2_lower_bound(&metadata);
assert!((lb - 1.0).abs() < 0.01);
let query2 = vec![0.5, 0.0, 0.0];
let computer2 = ListBoundComputer::new(&query2, DistanceMetric::L2);
let lb2 = computer2.l2_lower_bound(&metadata);
assert!((lb2 - 0.0).abs() < 0.01);
}
#[test]
fn test_list_ordering() {
let mut bounds = vec![
ListBound {
list_idx: 0,
bound: 0.5,
},
ListBound {
list_idx: 1,
bound: 0.9,
},
ListBound {
list_idx: 2,
bound: 0.3,
},
];
ListBound::order_for_probing(&mut bounds, DistanceMetric::Cosine);
assert_eq!(bounds[0].list_idx, 1); assert_eq!(bounds[1].list_idx, 0); assert_eq!(bounds[2].list_idx, 2);
ListBound::order_for_probing(&mut bounds, DistanceMetric::L2);
assert_eq!(bounds[0].list_idx, 2); assert_eq!(bounds[1].list_idx, 0); assert_eq!(bounds[2].list_idx, 1); }
}