use crate::error::{Result, RuvectorError};
use crate::types::{DistanceMetric, SearchResult, VectorId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiVectorEntry {
pub doc_id: VectorId,
pub token_embeddings: Vec<Vec<f32>>,
pub norms: Vec<f32>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ScoringVariant {
MaxSim,
AvgSim,
SumMax,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiVectorConfig {
pub metric: DistanceMetric,
pub scoring: ScoringVariant,
}
impl Default for MultiVectorConfig {
fn default() -> Self {
Self {
metric: DistanceMetric::Cosine,
scoring: ScoringVariant::MaxSim,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MultiVectorIndex {
pub config: MultiVectorConfig,
entries: HashMap<VectorId, MultiVectorEntry>,
}
impl MultiVectorIndex {
pub fn new(config: MultiVectorConfig) -> Self {
Self {
config,
entries: HashMap::new(),
}
}
pub fn insert(
&mut self,
doc_id: VectorId,
embeddings: Vec<Vec<f32>>,
metadata: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
if embeddings.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Token embeddings cannot be empty".into(),
));
}
let dim = embeddings[0].len();
for (i, emb) in embeddings.iter().enumerate() {
if emb.len() != dim {
return Err(RuvectorError::DimensionMismatch {
expected: dim,
actual: emb.len(),
});
}
if emb.is_empty() {
return Err(RuvectorError::InvalidParameter(
format!("Embedding at index {} has zero dimensions", i),
));
}
}
let norms = embeddings.iter().map(|e| compute_norm(e)).collect();
self.entries.insert(
doc_id.clone(),
MultiVectorEntry {
doc_id,
token_embeddings: embeddings,
norms,
metadata,
},
);
Ok(())
}
pub fn remove(&mut self, doc_id: &str) -> Option<MultiVectorEntry> {
self.entries.remove(doc_id)
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn search(
&self,
query_embeddings: &[Vec<f32>],
top_k: usize,
) -> Result<Vec<SearchResult>> {
if query_embeddings.is_empty() {
return Err(RuvectorError::InvalidParameter(
"Query embeddings cannot be empty".into(),
));
}
let query_norms: Vec<f32> = query_embeddings.iter().map(|q| compute_norm(q)).collect();
let mut scored: Vec<(VectorId, f32)> = self
.entries
.values()
.map(|entry| {
let score = self.compute_score(
query_embeddings,
&query_norms,
&entry.token_embeddings,
&entry.norms,
);
(entry.doc_id.clone(), score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(top_k);
Ok(scored
.into_iter()
.map(|(id, score)| {
let metadata = self.entries.get(&id).and_then(|e| e.metadata.clone());
SearchResult {
id,
score,
vector: None,
metadata,
}
})
.collect())
}
pub fn search_with_scoring(
&self,
query_embeddings: &[Vec<f32>],
top_k: usize,
scoring: ScoringVariant,
) -> Result<Vec<SearchResult>> {
let original = self.config.scoring;
let mut temp = self.clone();
temp.config.scoring = scoring;
let results = temp.search(query_embeddings, top_k);
let _ = original;
results
}
fn compute_score(
&self,
query_embeddings: &[Vec<f32>],
query_norms: &[f32],
doc_embeddings: &[Vec<f32>],
doc_norms: &[f32],
) -> f32 {
match self.config.scoring {
ScoringVariant::MaxSim => {
self.maxsim(query_embeddings, query_norms, doc_embeddings, doc_norms)
}
ScoringVariant::AvgSim => {
self.avgsim(query_embeddings, query_norms, doc_embeddings, doc_norms)
}
ScoringVariant::SumMax => {
self.summax(query_embeddings, query_norms, doc_embeddings, doc_norms)
}
}
}
fn maxsim(
&self,
query_embeddings: &[Vec<f32>],
query_norms: &[f32],
doc_embeddings: &[Vec<f32>],
doc_norms: &[f32],
) -> f32 {
query_embeddings
.iter()
.enumerate()
.map(|(qi, q)| {
doc_embeddings
.iter()
.enumerate()
.map(|(di, d)| {
self.token_similarity(q, query_norms[qi], d, doc_norms[di])
})
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
fn avgsim(
&self,
query_embeddings: &[Vec<f32>],
query_norms: &[f32],
doc_embeddings: &[Vec<f32>],
doc_norms: &[f32],
) -> f32 {
let total_pairs = (query_embeddings.len() * doc_embeddings.len()) as f32;
if total_pairs == 0.0 {
return 0.0;
}
let sum: f32 = query_embeddings
.iter()
.enumerate()
.flat_map(|(qi, q)| {
doc_embeddings
.iter()
.enumerate()
.map(move |(di, d)| {
self.token_similarity(q, query_norms[qi], d, doc_norms[di])
})
})
.sum();
sum / total_pairs
}
fn summax(
&self,
query_embeddings: &[Vec<f32>],
query_norms: &[f32],
doc_embeddings: &[Vec<f32>],
doc_norms: &[f32],
) -> f32 {
doc_embeddings
.iter()
.enumerate()
.map(|(di, d)| {
query_embeddings
.iter()
.enumerate()
.map(|(qi, q)| {
self.token_similarity(q, query_norms[qi], d, doc_norms[di])
})
.fold(f32::NEG_INFINITY, f32::max)
})
.sum()
}
#[inline]
fn token_similarity(&self, a: &[f32], norm_a: f32, b: &[f32], norm_b: f32) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
match self.config.metric {
DistanceMetric::Cosine => {
let denom = norm_a * norm_b;
if denom < f32::EPSILON {
0.0
} else {
dot / denom
}
}
DistanceMetric::DotProduct => dot,
DistanceMetric::Euclidean => {
let dist_sq: f32 = a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum();
1.0 / (1.0 + dist_sq.sqrt())
}
DistanceMetric::Manhattan => {
let dist: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
1.0 / (1.0 + dist)
}
}
}
}
#[inline]
fn compute_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn default_index() -> MultiVectorIndex {
MultiVectorIndex::new(MultiVectorConfig::default())
}
#[test]
fn test_insert_and_len() {
let mut index = default_index();
assert!(index.is_empty());
index
.insert("d1".into(), vec![vec![1.0, 0.0], vec![0.0, 1.0]], None)
.unwrap();
assert_eq!(index.len(), 1);
index
.insert("d2".into(), vec![vec![0.5, 0.5]], None)
.unwrap();
assert_eq!(index.len(), 2);
}
#[test]
fn test_insert_empty_embeddings_error() {
let mut index = default_index();
let res = index.insert("d1".into(), vec![], None);
assert!(res.is_err());
}
#[test]
fn test_insert_dimension_mismatch_error() {
let mut index = default_index();
let res = index.insert("d1".into(), vec![vec![1.0, 0.0], vec![1.0]], None);
assert!(res.is_err());
}
#[test]
fn test_maxsim_search_basic() {
let mut index = default_index();
index
.insert("doc1".into(), vec![vec![1.0, 0.0], vec![0.0, 1.0]], None)
.unwrap();
index
.insert("doc2".into(), vec![vec![1.0, 0.0]], None)
.unwrap();
let results = index.search(&[vec![1.0, 0.0]], 10).unwrap();
assert_eq!(results.len(), 2);
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn test_maxsim_multi_query_tokens() {
let mut index = default_index();
index
.insert("doc1".into(), vec![vec![1.0, 0.0], vec![0.0, 1.0]], None)
.unwrap();
index
.insert("doc2".into(), vec![vec![1.0, 0.0]], None)
.unwrap();
let results = index.search(&[vec![1.0, 0.0], vec![0.0, 1.0]], 10).unwrap();
assert_eq!(results[0].id, "doc1");
assert!((results[0].score - 2.0).abs() < 1e-5);
assert_eq!(results[1].id, "doc2");
assert!((results[1].score - 1.0).abs() < 1e-5);
}
#[test]
fn test_avgsim_scoring() {
let config = MultiVectorConfig {
metric: DistanceMetric::Cosine,
scoring: ScoringVariant::AvgSim,
};
let mut index = MultiVectorIndex::new(config);
index
.insert("doc1".into(), vec![vec![1.0, 0.0], vec![0.0, 1.0]], None)
.unwrap();
let results = index.search(&[vec![1.0, 0.0]], 10).unwrap();
assert!((results[0].score - 0.5).abs() < 1e-5);
}
#[test]
fn test_summax_scoring() {
let config = MultiVectorConfig {
metric: DistanceMetric::Cosine,
scoring: ScoringVariant::SumMax,
};
let mut index = MultiVectorIndex::new(config);
index
.insert("doc1".into(), vec![vec![1.0, 0.0], vec![0.0, 1.0]], None)
.unwrap();
let results = index.search(&[vec![1.0, 0.0]], 10).unwrap();
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn test_dot_product_metric() {
let config = MultiVectorConfig {
metric: DistanceMetric::DotProduct,
scoring: ScoringVariant::MaxSim,
};
let mut index = MultiVectorIndex::new(config);
index
.insert("doc1".into(), vec![vec![2.0, 0.0], vec![0.0, 3.0]], None)
.unwrap();
let results = index.search(&[vec![1.0, 0.0]], 10).unwrap();
assert!((results[0].score - 2.0).abs() < 1e-5);
}
#[test]
fn test_search_empty_query_error() {
let index = default_index();
let res = index.search(&[], 10);
assert!(res.is_err());
}
#[test]
fn test_search_empty_index() {
let index = default_index();
let results = index.search(&[vec![1.0, 0.0]], 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn test_top_k_truncation() {
let mut index = default_index();
for i in 0..10 {
let val = (i as f32) / 10.0;
index
.insert(format!("d{}", i), vec![vec![val, 1.0 - val]], None)
.unwrap();
}
let results = index.search(&[vec![1.0, 0.0]], 3).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn test_remove_document() {
let mut index = default_index();
index
.insert("doc1".into(), vec![vec![1.0, 0.0]], None)
.unwrap();
assert_eq!(index.len(), 1);
let removed = index.remove("doc1");
assert!(removed.is_some());
assert!(index.is_empty());
}
#[test]
fn test_metadata_preserved() {
let mut index = default_index();
let mut meta = HashMap::new();
meta.insert("source".into(), serde_json::json!("colbert"));
index
.insert("doc1".into(), vec![vec![1.0, 0.0]], Some(meta))
.unwrap();
let results = index.search(&[vec![1.0, 0.0]], 10).unwrap();
let result_meta = results[0].metadata.as_ref().unwrap();
assert_eq!(result_meta.get("source").unwrap(), "colbert");
}
#[test]
fn test_search_with_scoring_override() {
let mut index = default_index(); index
.insert("doc1".into(), vec![vec![1.0, 0.0], vec![0.0, 1.0]], None)
.unwrap();
let results = index
.search_with_scoring(&[vec![1.0, 0.0]], 10, ScoringVariant::AvgSim)
.unwrap();
assert!((results[0].score - 0.5).abs() < 1e-5);
}
}