Skip to main content

nodedb_types/
multi_vector.rs

1//! Multi-vector type for ColBERT-style late interaction retrieval.
2//!
3//! Stores N vectors per document (one per token/passage), all same dimension.
4//! Contiguous `Vec<f32>` layout for cache efficiency: `[v0_d0..v0_dD, v1_d0..v1_dD, ...]`.
5
6use serde::{Deserialize, Serialize};
7
8/// A multi-vector: an array of dense vectors, all sharing the same dimension.
9///
10/// Used for ColBERT/ColPali-style late interaction where each document
11/// produces one embedding per token. All vectors share a doc_id and are
12/// inserted as separate HNSW nodes with shared doc_id metadata.
13#[derive(
14    Debug,
15    Clone,
16    PartialEq,
17    Serialize,
18    Deserialize,
19    zerompk::ToMessagePack,
20    zerompk::FromMessagePack,
21)]
22pub struct MultiVector {
23    /// Contiguous f32 data: `count × dim` elements.
24    data: Vec<f32>,
25    /// Number of vectors.
26    count: usize,
27    /// Dimensionality of each vector.
28    dim: usize,
29}
30
31impl MultiVector {
32    /// Create from a list of vectors. All must have the same dimension.
33    pub fn from_vectors(vectors: Vec<Vec<f32>>) -> Result<Self, MultiVectorError> {
34        if vectors.is_empty() {
35            return Err(MultiVectorError::Empty);
36        }
37        let dim = vectors[0].len();
38        if dim == 0 {
39            return Err(MultiVectorError::ZeroDimension);
40        }
41        let count = vectors.len();
42        let mut data = Vec::with_capacity(count * dim);
43        for (i, v) in vectors.iter().enumerate() {
44            if v.len() != dim {
45                return Err(MultiVectorError::DimensionMismatch {
46                    expected: dim,
47                    got: v.len(),
48                    index: i,
49                });
50            }
51            for &val in v {
52                if !val.is_finite() {
53                    return Err(MultiVectorError::NonFiniteValue);
54                }
55            }
56            data.extend_from_slice(v);
57        }
58        Ok(Self { data, count, dim })
59    }
60
61    /// Create from contiguous f32 data with known count and dim.
62    pub fn from_raw(data: Vec<f32>, count: usize, dim: usize) -> Result<Self, MultiVectorError> {
63        if count == 0 || dim == 0 {
64            return Err(MultiVectorError::Empty);
65        }
66        if data.len() != count * dim {
67            return Err(MultiVectorError::DataLengthMismatch {
68                expected: count * dim,
69                got: data.len(),
70            });
71        }
72        Ok(Self { data, count, dim })
73    }
74
75    /// Number of vectors.
76    pub fn count(&self) -> usize {
77        self.count
78    }
79
80    /// Dimensionality of each vector.
81    pub fn dim(&self) -> usize {
82        self.dim
83    }
84
85    /// Access the i-th vector as a slice.
86    pub fn get(&self, i: usize) -> Option<&[f32]> {
87        if i >= self.count {
88            return None;
89        }
90        let start = i * self.dim;
91        Some(&self.data[start..start + self.dim])
92    }
93
94    /// Iterate over all vectors as slices.
95    pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
96        (0..self.count).map(move |i| {
97            let start = i * self.dim;
98            &self.data[start..start + self.dim]
99        })
100    }
101
102    /// Extract all vectors as owned Vec<Vec<f32>>.
103    pub fn to_vectors(&self) -> Vec<Vec<f32>> {
104        self.iter().map(|s| s.to_vec()).collect()
105    }
106
107    /// Access the raw contiguous data.
108    pub fn raw_data(&self) -> &[f32] {
109        &self.data
110    }
111}
112
113/// Aggregation mode for multi-vector scoring.
114#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115pub enum MultiVectorScoreMode {
116    /// MaxSim: max similarity across all document vectors (ColBERT scoring).
117    MaxSim,
118    /// Average similarity across all document vectors.
119    AvgSim,
120    /// Sum of similarities across all document vectors.
121    SumSim,
122}
123
124impl MultiVectorScoreMode {
125    /// Parse from string.
126    pub fn parse(s: &str) -> Option<Self> {
127        match s.to_lowercase().as_str() {
128            "max_sim" | "maxsim" => Some(Self::MaxSim),
129            "avg_sim" | "avgsim" => Some(Self::AvgSim),
130            "sum_sim" | "sumsim" => Some(Self::SumSim),
131            _ => None,
132        }
133    }
134
135    /// Aggregate a set of similarity scores for one document.
136    pub fn aggregate(&self, scores: &[f32]) -> f32 {
137        if scores.is_empty() {
138            return 0.0;
139        }
140        match self {
141            Self::MaxSim => scores.iter().cloned().reduce(f32::max).unwrap_or(0.0),
142            Self::AvgSim => scores.iter().sum::<f32>() / scores.len() as f32,
143            Self::SumSim => scores.iter().sum(),
144        }
145    }
146}
147
148/// Errors from multi-vector construction.
149#[derive(Debug, Clone)]
150pub enum MultiVectorError {
151    Empty,
152    ZeroDimension,
153    NonFiniteValue,
154    DimensionMismatch {
155        expected: usize,
156        got: usize,
157        index: usize,
158    },
159    DataLengthMismatch {
160        expected: usize,
161        got: usize,
162    },
163}
164
165impl std::fmt::Display for MultiVectorError {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        match self {
168            Self::Empty => write!(f, "multi-vector must contain at least one vector"),
169            Self::ZeroDimension => write!(f, "vector dimension must be > 0"),
170            Self::NonFiniteValue => write!(f, "vector values must be finite"),
171            Self::DimensionMismatch {
172                expected,
173                got,
174                index,
175            } => write!(
176                f,
177                "dimension mismatch at vector {index}: expected {expected}, got {got}"
178            ),
179            Self::DataLengthMismatch { expected, got } => {
180                write!(f, "data length mismatch: expected {expected}, got {got}")
181            }
182        }
183    }
184}
185
186impl std::error::Error for MultiVectorError {}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn from_vectors_basic() {
194        let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]).unwrap();
195        assert_eq!(mv.count(), 2);
196        assert_eq!(mv.dim(), 3);
197        assert_eq!(mv.get(0).unwrap(), &[1.0, 2.0, 3.0]);
198        assert_eq!(mv.get(1).unwrap(), &[4.0, 5.0, 6.0]);
199        assert!(mv.get(2).is_none());
200    }
201
202    #[test]
203    fn dimension_mismatch_rejected() {
204        let err = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0]]).unwrap_err();
205        assert!(matches!(err, MultiVectorError::DimensionMismatch { .. }));
206    }
207
208    #[test]
209    fn non_finite_rejected() {
210        assert!(MultiVector::from_vectors(vec![vec![f32::NAN]]).is_err());
211    }
212
213    #[test]
214    fn empty_rejected() {
215        assert!(MultiVector::from_vectors(vec![]).is_err());
216    }
217
218    #[test]
219    fn iter_all_vectors() {
220        let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]])
221            .unwrap();
222        let collected: Vec<&[f32]> = mv.iter().collect();
223        assert_eq!(collected.len(), 3);
224        assert_eq!(collected[2], &[5.0, 6.0]);
225    }
226
227    #[test]
228    fn serde_roundtrip() {
229        let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
230        let bytes = zerompk::to_msgpack_vec(&mv).unwrap();
231        let restored: MultiVector = zerompk::from_msgpack(&bytes).unwrap();
232        assert_eq!(mv, restored);
233    }
234
235    #[test]
236    fn score_modes() {
237        let scores = vec![0.5, 0.8, 0.3];
238        assert!((MultiVectorScoreMode::MaxSim.aggregate(&scores) - 0.8).abs() < 1e-6);
239        assert!((MultiVectorScoreMode::SumSim.aggregate(&scores) - 1.6).abs() < 1e-6);
240        let avg = MultiVectorScoreMode::AvgSim.aggregate(&scores);
241        assert!((avg - 0.5333).abs() < 0.01);
242    }
243}