use serde::{Deserialize, Serialize};
use crate::Vector;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum MultiVectorComparator {
#[default]
MaxSim,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MultiVectorConfig {
pub comparator: MultiVectorComparator,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct MultiVector {
vectors: Vec<Vec<f32>>,
dim: usize,
}
impl MultiVector {
pub fn new(vectors: Vec<Vec<f32>>) -> Result<Self, &'static str> {
if vectors.is_empty() {
return Err("MultiVector cannot be empty");
}
let dim = vectors[0].len();
if dim == 0 {
return Err("Sub-vectors cannot be empty");
}
if !vectors.iter().all(|v| v.len() == dim) {
return Err("All sub-vectors must have the same dimension");
}
Ok(Self { vectors, dim })
}
pub fn from_single(vector: Vec<f32>) -> Result<Self, &'static str> {
if vector.is_empty() {
return Err("Vector cannot be empty");
}
let dim = vector.len();
Ok(Self { vectors: vec![vector], dim })
}
#[inline]
pub fn dim(&self) -> usize {
self.dim
}
#[inline]
pub fn len(&self) -> usize {
self.vectors.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
#[inline]
pub fn vectors(&self) -> &[Vec<f32>] {
&self.vectors
}
#[inline]
pub fn first(&self) -> Option<&Vec<f32>> {
self.vectors.first()
}
pub fn to_single_vector(&self) -> Vector {
Vector::new(self.vectors[0].clone())
}
pub fn max_sim(&self, other: &MultiVector) -> f32 {
if self.dim != other.dim {
return 0.0;
}
let mut total_score = 0.0;
for query_vec in &self.vectors {
let mut max_sim = f32::NEG_INFINITY;
for doc_vec in &other.vectors {
let sim = dot_product(query_vec, doc_vec);
if sim > max_sim {
max_sim = sim;
}
}
if max_sim > f32::NEG_INFINITY {
total_score += max_sim;
}
}
total_score
}
pub fn max_sim_cosine(&self, other: &MultiVector) -> f32 {
if self.dim != other.dim {
return 0.0;
}
let mut total_score = 0.0;
for query_vec in &self.vectors {
let query_norm = norm(query_vec);
if query_norm < f32::EPSILON {
continue;
}
let mut max_sim = f32::NEG_INFINITY;
for doc_vec in &other.vectors {
let doc_norm = norm(doc_vec);
if doc_norm < f32::EPSILON {
continue;
}
let sim = dot_product(query_vec, doc_vec) / (query_norm * doc_norm);
if sim > max_sim {
max_sim = sim;
}
}
if max_sim > f32::NEG_INFINITY {
total_score += max_sim;
}
}
total_score
}
pub fn max_sim_l2(&self, other: &MultiVector) -> f32 {
if self.dim != other.dim {
return f32::NEG_INFINITY;
}
let mut total_score = 0.0;
for query_vec in &self.vectors {
let mut min_dist = f32::INFINITY;
for doc_vec in &other.vectors {
let dist = l2_distance(query_vec, doc_vec);
if dist < min_dist {
min_dist = dist;
}
}
if min_dist < f32::INFINITY {
total_score -= min_dist;
}
}
total_score
}
}
#[inline]
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
fn norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[inline]
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multivector_creation() {
let mv = MultiVector::new(vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
]).unwrap();
assert_eq!(mv.dim(), 3);
assert_eq!(mv.len(), 2);
}
#[test]
fn test_max_sim_identical() {
let mv1 = MultiVector::new(vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
]).unwrap();
let mv2 = MultiVector::new(vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
]).unwrap();
let score = mv1.max_sim(&mv2);
assert!((score - 2.0).abs() < 1e-6);
}
#[test]
fn test_max_sim_different() {
let query = MultiVector::new(vec![
vec![1.0, 0.0],
]).unwrap();
let doc = MultiVector::new(vec![
vec![0.5, 0.5],
vec![1.0, 0.0],
]).unwrap();
let score = query.max_sim(&doc);
assert!((score - 1.0).abs() < 1e-6);
}
#[test]
fn test_max_sim_cosine() {
let query = MultiVector::new(vec![
vec![2.0, 0.0], ]).unwrap();
let doc = MultiVector::new(vec![
vec![1.0, 0.0],
]).unwrap();
let score = query.max_sim_cosine(&doc);
assert!((score - 1.0).abs() < 1e-6);
}
}