Skip to main content

nodedb_types/
multi_vector.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Multi-vector type for ColBERT-style late interaction retrieval.
4//!
5//! Stores N vectors per document (one per token/passage), all same dimension.
6//! Contiguous `Vec<f32>` layout for cache efficiency: `[v0_d0..v0_dD, v1_d0..v1_dD, ...]`.
7
8use serde::{Deserialize, Serialize};
9
10/// A multi-vector: an array of dense vectors, all sharing the same dimension.
11///
12/// Used for ColBERT/ColPali-style late interaction where each document
13/// produces one embedding per token. All vectors share a doc_id and are
14/// inserted as separate HNSW nodes with shared doc_id metadata.
15#[derive(
16    Debug,
17    Clone,
18    PartialEq,
19    Serialize,
20    Deserialize,
21    zerompk::ToMessagePack,
22    zerompk::FromMessagePack,
23)]
24pub struct MultiVector {
25    /// Contiguous f32 data: `count × dim` elements.
26    data: Vec<f32>,
27    /// Number of vectors.
28    count: usize,
29    /// Dimensionality of each vector.
30    dim: usize,
31}
32
33impl MultiVector {
34    /// Create from a list of vectors. All must have the same dimension.
35    pub fn from_vectors(vectors: Vec<Vec<f32>>) -> Result<Self, MultiVectorError> {
36        if vectors.is_empty() {
37            return Err(MultiVectorError::Empty);
38        }
39        let dim = vectors[0].len();
40        if dim == 0 {
41            return Err(MultiVectorError::ZeroDimension);
42        }
43        let count = vectors.len();
44        let mut data = Vec::with_capacity(count * dim);
45        for (i, v) in vectors.iter().enumerate() {
46            if v.len() != dim {
47                return Err(MultiVectorError::DimensionMismatch {
48                    expected: dim,
49                    got: v.len(),
50                    index: i,
51                });
52            }
53            for &val in v {
54                if !val.is_finite() {
55                    return Err(MultiVectorError::NonFiniteValue);
56                }
57            }
58            data.extend_from_slice(v);
59        }
60        Ok(Self { data, count, dim })
61    }
62
63    /// Create from contiguous f32 data with known count and dim.
64    pub fn from_raw(data: Vec<f32>, count: usize, dim: usize) -> Result<Self, MultiVectorError> {
65        if count == 0 || dim == 0 {
66            return Err(MultiVectorError::Empty);
67        }
68        if data.len() != count * dim {
69            return Err(MultiVectorError::DataLengthMismatch {
70                expected: count * dim,
71                got: data.len(),
72            });
73        }
74        Ok(Self { data, count, dim })
75    }
76
77    /// Number of vectors.
78    pub fn count(&self) -> usize {
79        self.count
80    }
81
82    /// Dimensionality of each vector.
83    pub fn dim(&self) -> usize {
84        self.dim
85    }
86
87    /// Access the i-th vector as a slice.
88    pub fn get(&self, i: usize) -> Option<&[f32]> {
89        if i >= self.count {
90            return None;
91        }
92        let start = i * self.dim;
93        Some(&self.data[start..start + self.dim])
94    }
95
96    /// Iterate over all vectors as slices.
97    pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
98        (0..self.count).map(move |i| {
99            let start = i * self.dim;
100            &self.data[start..start + self.dim]
101        })
102    }
103
104    /// Extract all vectors as owned Vec<Vec<f32>>.
105    pub fn to_vectors(&self) -> Vec<Vec<f32>> {
106        self.iter().map(|s| s.to_vec()).collect()
107    }
108
109    /// Access the raw contiguous data.
110    pub fn raw_data(&self) -> &[f32] {
111        &self.data
112    }
113}
114
115/// Aggregation mode for multi-vector scoring.
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117#[non_exhaustive]
118pub enum MultiVectorScoreMode {
119    /// MaxSim: max similarity across all document vectors (ColBERT scoring).
120    MaxSim,
121    /// Average similarity across all document vectors.
122    AvgSim,
123    /// Sum of similarities across all document vectors.
124    SumSim,
125}
126
127impl MultiVectorScoreMode {
128    /// Parse from string.
129    pub fn parse(s: &str) -> Option<Self> {
130        match s.to_lowercase().as_str() {
131            "max_sim" | "maxsim" => Some(Self::MaxSim),
132            "avg_sim" | "avgsim" => Some(Self::AvgSim),
133            "sum_sim" | "sumsim" => Some(Self::SumSim),
134            _ => None,
135        }
136    }
137
138    /// Aggregate a set of similarity scores for one document.
139    pub fn aggregate(&self, scores: &[f32]) -> f32 {
140        if scores.is_empty() {
141            return 0.0;
142        }
143        match self {
144            Self::MaxSim => scores.iter().cloned().reduce(f32::max).unwrap_or(0.0),
145            Self::AvgSim => scores.iter().sum::<f32>() / scores.len() as f32,
146            Self::SumSim => scores.iter().sum(),
147        }
148    }
149}
150
151/// Errors from multi-vector construction.
152#[derive(Debug, Clone)]
153#[non_exhaustive]
154pub enum MultiVectorError {
155    Empty,
156    ZeroDimension,
157    NonFiniteValue,
158    DimensionMismatch {
159        expected: usize,
160        got: usize,
161        index: usize,
162    },
163    DataLengthMismatch {
164        expected: usize,
165        got: usize,
166    },
167}
168
169impl std::fmt::Display for MultiVectorError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        match self {
172            Self::Empty => write!(f, "multi-vector must contain at least one vector"),
173            Self::ZeroDimension => write!(f, "vector dimension must be > 0"),
174            Self::NonFiniteValue => write!(f, "vector values must be finite"),
175            Self::DimensionMismatch {
176                expected,
177                got,
178                index,
179            } => write!(
180                f,
181                "dimension mismatch at vector {index}: expected {expected}, got {got}"
182            ),
183            Self::DataLengthMismatch { expected, got } => {
184                write!(f, "data length mismatch: expected {expected}, got {got}")
185            }
186        }
187    }
188}
189
190impl std::error::Error for MultiVectorError {}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn from_vectors_basic() {
198        let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]).unwrap();
199        assert_eq!(mv.count(), 2);
200        assert_eq!(mv.dim(), 3);
201        assert_eq!(mv.get(0).unwrap(), &[1.0, 2.0, 3.0]);
202        assert_eq!(mv.get(1).unwrap(), &[4.0, 5.0, 6.0]);
203        assert!(mv.get(2).is_none());
204    }
205
206    #[test]
207    fn dimension_mismatch_rejected() {
208        let err = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0]]).unwrap_err();
209        assert!(matches!(err, MultiVectorError::DimensionMismatch { .. }));
210    }
211
212    #[test]
213    fn non_finite_rejected() {
214        assert!(MultiVector::from_vectors(vec![vec![f32::NAN]]).is_err());
215    }
216
217    #[test]
218    fn empty_rejected() {
219        assert!(MultiVector::from_vectors(vec![]).is_err());
220    }
221
222    #[test]
223    fn iter_all_vectors() {
224        let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]])
225            .unwrap();
226        let collected: Vec<&[f32]> = mv.iter().collect();
227        assert_eq!(collected.len(), 3);
228        assert_eq!(collected[2], &[5.0, 6.0]);
229    }
230
231    #[test]
232    fn serde_roundtrip() {
233        let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
234        let bytes = zerompk::to_msgpack_vec(&mv).unwrap();
235        let restored: MultiVector = zerompk::from_msgpack(&bytes).unwrap();
236        assert_eq!(mv, restored);
237    }
238
239    #[test]
240    fn score_modes() {
241        let scores = vec![0.5, 0.8, 0.3];
242        assert!((MultiVectorScoreMode::MaxSim.aggregate(&scores) - 0.8).abs() < 1e-6);
243        assert!((MultiVectorScoreMode::SumSim.aggregate(&scores) - 1.6).abs() < 1e-6);
244        let avg = MultiVectorScoreMode::AvgSim.aggregate(&scores);
245        assert!((avg - 0.5333).abs() < 0.01);
246    }
247}