#[derive(Debug, Clone)]
pub struct MultiVector {
pub id: usize,
pub vectors: Vec<Vec<f32>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DistanceMetric {
L2,
Cosine,
DotProduct,
}
#[derive(Debug, Clone)]
pub struct ProductSearchConfig {
pub sub_dimensions: usize,
pub distance_metric: DistanceMetric,
}
#[derive(Debug, Clone)]
pub struct SearchCandidate {
pub id: usize,
pub scores: Vec<f32>,
pub combined_score: f32,
}
pub struct ProductSearchIndex {
config: ProductSearchConfig,
items: Vec<MultiVector>,
}
impl ProductSearchIndex {
pub fn new(config: ProductSearchConfig) -> Self {
Self {
config,
items: Vec::new(),
}
}
pub fn insert(&mut self, item: MultiVector) {
self.items.push(item);
}
pub fn search(&self, query: &MultiVector, k: usize) -> Vec<SearchCandidate> {
let mut candidates: Vec<SearchCandidate> = self
.items
.iter()
.filter_map(|item| self.score_all(query, item))
.collect();
candidates.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
candidates
}
pub fn search_sub(&self, query_sub: &[f32], sub_idx: usize, k: usize) -> Vec<SearchCandidate> {
let mut candidates: Vec<SearchCandidate> = self
.items
.iter()
.filter_map(|item| {
let item_sub = item.vectors.get(sub_idx)?;
if item_sub.len() != query_sub.len() {
return None;
}
let score = self.compute_score(query_sub, item_sub);
Some(SearchCandidate {
id: item.id,
scores: vec![score],
combined_score: score,
})
})
.collect();
candidates.sort_by(|a, b| {
b.combined_score
.partial_cmp(&a.combined_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
candidates
}
pub fn item_count(&self) -> usize {
self.items.len()
}
pub fn sub_dimension_count(&self) -> usize {
self.config.sub_dimensions
}
pub fn remove(&mut self, id: usize) -> bool {
let before = self.items.len();
self.items.retain(|item| item.id != id);
self.items.len() < before
}
fn score_all(&self, query: &MultiVector, item: &MultiVector) -> Option<SearchCandidate> {
let n_subs = query.vectors.len().min(item.vectors.len());
if n_subs == 0 {
return None;
}
let mut scores: Vec<f32> = Vec::with_capacity(n_subs);
for i in 0..n_subs {
let qv = &query.vectors[i];
let iv = &item.vectors[i];
if qv.len() != iv.len() {
return None;
}
scores.push(self.compute_score(qv, iv));
}
let combined_score = scores.iter().sum::<f32>() / scores.len() as f32;
Some(SearchCandidate {
id: item.id,
scores,
combined_score,
})
}
fn compute_score(&self, a: &[f32], b: &[f32]) -> f32 {
match &self.config.distance_metric {
DistanceMetric::L2 => -l2_distance(a, b),
DistanceMetric::Cosine => cosine_sim(a, b),
DistanceMetric::DotProduct => dot_product(a, b),
}
}
}
pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
pub fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
let dot = dot_product(a, b);
let norm_a = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[cfg(test)]
mod tests {
use super::*;
fn vec1(v: &[f32]) -> Vec<Vec<f32>> {
vec![v.to_vec()]
}
fn vec2(v1: &[f32], v2: &[f32]) -> Vec<Vec<f32>> {
vec![v1.to_vec(), v2.to_vec()]
}
fn cfg(metric: DistanceMetric) -> ProductSearchConfig {
ProductSearchConfig {
sub_dimensions: 1,
distance_metric: metric,
}
}
fn mv(id: usize, vecs: Vec<Vec<f32>>) -> MultiVector {
MultiVector { id, vectors: vecs }
}
#[test]
fn test_l2_distance_zero() {
assert!((l2_distance(&[1.0, 2.0], &[1.0, 2.0])).abs() < 1e-6);
}
#[test]
fn test_l2_distance_known() {
assert!((l2_distance(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-5);
}
#[test]
fn test_cosine_sim_identical() {
let v = [1.0f32, 0.0, 0.0];
assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_sim_orthogonal() {
assert!((cosine_sim(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
}
#[test]
fn test_cosine_sim_opposite() {
assert!((cosine_sim(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_sim_zero_vector() {
assert_eq!(cosine_sim(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
}
#[test]
fn test_dot_product_basic() {
assert!((dot_product(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]) - 32.0).abs() < 1e-6);
}
#[test]
fn test_dot_product_zero() {
assert_eq!(dot_product(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
}
#[test]
fn test_insert_increments_count() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
idx.insert(mv(1, vec1(&[1.0])));
assert_eq!(idx.item_count(), 1);
}
#[test]
fn test_insert_multiple() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
idx.insert(mv(1, vec1(&[1.0])));
idx.insert(mv(2, vec1(&[2.0])));
assert_eq!(idx.item_count(), 2);
}
#[test]
fn test_sub_dimension_count() {
let idx = ProductSearchIndex::new(ProductSearchConfig {
sub_dimensions: 3,
distance_metric: DistanceMetric::Cosine,
});
assert_eq!(idx.sub_dimension_count(), 3);
}
#[test]
fn test_search_l2_nearest_neighbor() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
idx.insert(mv(1, vec1(&[0.0])));
idx.insert(mv(2, vec1(&[10.0])));
let q = mv(0, vec1(&[0.5]));
let results = idx.search(&q, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1); }
#[test]
fn test_search_l2_same_vector_best_score() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
idx.insert(mv(1, vec1(&[1.0, 2.0, 3.0])));
idx.insert(mv(2, vec1(&[10.0, 10.0, 10.0])));
let q = mv(0, vec1(&[1.0, 2.0, 3.0]));
let results = idx.search(&q, 2);
assert_eq!(results[0].id, 1); }
#[test]
fn test_search_l2_k_limit() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
for i in 0..10usize {
idx.insert(mv(i, vec1(&[i as f32])));
}
let q = mv(99, vec1(&[0.0]));
let results = idx.search(&q, 3);
assert_eq!(results.len(), 3);
}
#[test]
fn test_search_cosine_identical_is_top() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::Cosine));
idx.insert(mv(1, vec1(&[1.0, 0.0])));
idx.insert(mv(2, vec1(&[0.0, 1.0])));
let q = mv(0, vec1(&[1.0, 0.0]));
let results = idx.search(&q, 2);
assert_eq!(results[0].id, 1);
}
#[test]
fn test_search_dot_product() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::DotProduct));
idx.insert(mv(1, vec1(&[1.0, 2.0])));
idx.insert(mv(2, vec1(&[3.0, 4.0])));
let q = mv(0, vec1(&[1.0, 1.0]));
let results = idx.search(&q, 2);
assert_eq!(results[0].id, 2);
}
#[test]
fn test_search_multi_vector_combination() {
let mut idx = ProductSearchIndex::new(ProductSearchConfig {
sub_dimensions: 2,
distance_metric: DistanceMetric::Cosine,
});
idx.insert(mv(1, vec2(&[1.0, 0.0], &[0.0, 1.0])));
idx.insert(mv(2, vec2(&[1.0, 0.0], &[1.0, 0.0])));
let q = mv(0, vec2(&[1.0, 0.0], &[1.0, 0.0]));
let results = idx.search(&q, 2);
assert_eq!(results[0].id, 2);
}
#[test]
fn test_search_candidate_scores_count_equals_sub_vectors() {
let mut idx = ProductSearchIndex::new(ProductSearchConfig {
sub_dimensions: 3,
distance_metric: DistanceMetric::Cosine,
});
idx.insert(mv(1, vec![vec![1.0], vec![1.0], vec![1.0]]));
let q = mv(0, vec![vec![1.0], vec![1.0], vec![1.0]]);
let results = idx.search(&q, 1);
assert_eq!(results[0].scores.len(), 3);
}
#[test]
fn test_search_sub_single_dimension() {
let mut idx = ProductSearchIndex::new(ProductSearchConfig {
sub_dimensions: 2,
distance_metric: DistanceMetric::L2,
});
idx.insert(mv(1, vec2(&[0.0], &[10.0])));
idx.insert(mv(2, vec2(&[5.0], &[10.0])));
let results = idx.search_sub(&[0.0], 0, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 1); }
#[test]
fn test_search_sub_k_limit() {
let mut idx = ProductSearchIndex::new(ProductSearchConfig {
sub_dimensions: 1,
distance_metric: DistanceMetric::Cosine,
});
for i in 0..5usize {
idx.insert(mv(i, vec1(&[i as f32 + 1.0])));
}
let results = idx.search_sub(&[1.0], 0, 2);
assert_eq!(results.len(), 2);
}
#[test]
fn test_remove_existing_item() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
idx.insert(mv(42, vec1(&[1.0])));
assert!(idx.remove(42));
assert_eq!(idx.item_count(), 0);
}
#[test]
fn test_remove_nonexistent_item() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
assert!(!idx.remove(99));
}
#[test]
fn test_remove_does_not_affect_other_items() {
let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
idx.insert(mv(1, vec1(&[1.0])));
idx.insert(mv(2, vec1(&[2.0])));
idx.remove(1);
assert_eq!(idx.item_count(), 1);
let q = mv(0, vec1(&[2.0]));
let results = idx.search(&q, 1);
assert_eq!(results[0].id, 2);
}
#[test]
fn test_search_empty_index() {
let idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
let q = mv(0, vec1(&[1.0]));
let results = idx.search(&q, 5);
assert!(results.is_empty());
}
#[test]
fn test_search_sub_empty_index() {
let idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
let results = idx.search_sub(&[1.0], 0, 5);
assert!(results.is_empty());
}
#[test]
fn test_combined_score_is_mean_of_scores() {
let mut idx = ProductSearchIndex::new(ProductSearchConfig {
sub_dimensions: 2,
distance_metric: DistanceMetric::Cosine,
});
idx.insert(mv(1, vec2(&[1.0, 0.0], &[1.0, 0.0])));
let q = mv(0, vec2(&[1.0, 0.0], &[1.0, 0.0]));
let results = idx.search(&q, 1);
let c = &results[0];
let expected = c.scores.iter().sum::<f32>() / c.scores.len() as f32;
assert!((c.combined_score - expected).abs() < 1e-5);
}
}