1use serde::{Deserialize, Serialize};
7
8#[derive(
14 Debug,
15 Clone,
16 PartialEq,
17 Serialize,
18 Deserialize,
19 zerompk::ToMessagePack,
20 zerompk::FromMessagePack,
21)]
22pub struct MultiVector {
23 data: Vec<f32>,
25 count: usize,
27 dim: usize,
29}
30
31impl MultiVector {
32 pub fn from_vectors(vectors: Vec<Vec<f32>>) -> Result<Self, MultiVectorError> {
34 if vectors.is_empty() {
35 return Err(MultiVectorError::Empty);
36 }
37 let dim = vectors[0].len();
38 if dim == 0 {
39 return Err(MultiVectorError::ZeroDimension);
40 }
41 let count = vectors.len();
42 let mut data = Vec::with_capacity(count * dim);
43 for (i, v) in vectors.iter().enumerate() {
44 if v.len() != dim {
45 return Err(MultiVectorError::DimensionMismatch {
46 expected: dim,
47 got: v.len(),
48 index: i,
49 });
50 }
51 for &val in v {
52 if !val.is_finite() {
53 return Err(MultiVectorError::NonFiniteValue);
54 }
55 }
56 data.extend_from_slice(v);
57 }
58 Ok(Self { data, count, dim })
59 }
60
61 pub fn from_raw(data: Vec<f32>, count: usize, dim: usize) -> Result<Self, MultiVectorError> {
63 if count == 0 || dim == 0 {
64 return Err(MultiVectorError::Empty);
65 }
66 if data.len() != count * dim {
67 return Err(MultiVectorError::DataLengthMismatch {
68 expected: count * dim,
69 got: data.len(),
70 });
71 }
72 Ok(Self { data, count, dim })
73 }
74
75 pub fn count(&self) -> usize {
77 self.count
78 }
79
80 pub fn dim(&self) -> usize {
82 self.dim
83 }
84
85 pub fn get(&self, i: usize) -> Option<&[f32]> {
87 if i >= self.count {
88 return None;
89 }
90 let start = i * self.dim;
91 Some(&self.data[start..start + self.dim])
92 }
93
94 pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
96 (0..self.count).map(move |i| {
97 let start = i * self.dim;
98 &self.data[start..start + self.dim]
99 })
100 }
101
102 pub fn to_vectors(&self) -> Vec<Vec<f32>> {
104 self.iter().map(|s| s.to_vec()).collect()
105 }
106
107 pub fn raw_data(&self) -> &[f32] {
109 &self.data
110 }
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
115pub enum MultiVectorScoreMode {
116 MaxSim,
118 AvgSim,
120 SumSim,
122}
123
124impl MultiVectorScoreMode {
125 pub fn parse(s: &str) -> Option<Self> {
127 match s.to_lowercase().as_str() {
128 "max_sim" | "maxsim" => Some(Self::MaxSim),
129 "avg_sim" | "avgsim" => Some(Self::AvgSim),
130 "sum_sim" | "sumsim" => Some(Self::SumSim),
131 _ => None,
132 }
133 }
134
135 pub fn aggregate(&self, scores: &[f32]) -> f32 {
137 if scores.is_empty() {
138 return 0.0;
139 }
140 match self {
141 Self::MaxSim => scores.iter().cloned().reduce(f32::max).unwrap_or(0.0),
142 Self::AvgSim => scores.iter().sum::<f32>() / scores.len() as f32,
143 Self::SumSim => scores.iter().sum(),
144 }
145 }
146}
147
148#[derive(Debug, Clone)]
150pub enum MultiVectorError {
151 Empty,
152 ZeroDimension,
153 NonFiniteValue,
154 DimensionMismatch {
155 expected: usize,
156 got: usize,
157 index: usize,
158 },
159 DataLengthMismatch {
160 expected: usize,
161 got: usize,
162 },
163}
164
165impl std::fmt::Display for MultiVectorError {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 match self {
168 Self::Empty => write!(f, "multi-vector must contain at least one vector"),
169 Self::ZeroDimension => write!(f, "vector dimension must be > 0"),
170 Self::NonFiniteValue => write!(f, "vector values must be finite"),
171 Self::DimensionMismatch {
172 expected,
173 got,
174 index,
175 } => write!(
176 f,
177 "dimension mismatch at vector {index}: expected {expected}, got {got}"
178 ),
179 Self::DataLengthMismatch { expected, got } => {
180 write!(f, "data length mismatch: expected {expected}, got {got}")
181 }
182 }
183 }
184}
185
186impl std::error::Error for MultiVectorError {}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn from_vectors_basic() {
194 let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]).unwrap();
195 assert_eq!(mv.count(), 2);
196 assert_eq!(mv.dim(), 3);
197 assert_eq!(mv.get(0).unwrap(), &[1.0, 2.0, 3.0]);
198 assert_eq!(mv.get(1).unwrap(), &[4.0, 5.0, 6.0]);
199 assert!(mv.get(2).is_none());
200 }
201
202 #[test]
203 fn dimension_mismatch_rejected() {
204 let err = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0]]).unwrap_err();
205 assert!(matches!(err, MultiVectorError::DimensionMismatch { .. }));
206 }
207
208 #[test]
209 fn non_finite_rejected() {
210 assert!(MultiVector::from_vectors(vec![vec![f32::NAN]]).is_err());
211 }
212
213 #[test]
214 fn empty_rejected() {
215 assert!(MultiVector::from_vectors(vec![]).is_err());
216 }
217
218 #[test]
219 fn iter_all_vectors() {
220 let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]])
221 .unwrap();
222 let collected: Vec<&[f32]> = mv.iter().collect();
223 assert_eq!(collected.len(), 3);
224 assert_eq!(collected[2], &[5.0, 6.0]);
225 }
226
227 #[test]
228 fn serde_roundtrip() {
229 let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
230 let bytes = zerompk::to_msgpack_vec(&mv).unwrap();
231 let restored: MultiVector = zerompk::from_msgpack(&bytes).unwrap();
232 assert_eq!(mv, restored);
233 }
234
235 #[test]
236 fn score_modes() {
237 let scores = vec![0.5, 0.8, 0.3];
238 assert!((MultiVectorScoreMode::MaxSim.aggregate(&scores) - 0.8).abs() < 1e-6);
239 assert!((MultiVectorScoreMode::SumSim.aggregate(&scores) - 1.6).abs() < 1e-6);
240 let avg = MultiVectorScoreMode::AvgSim.aggregate(&scores);
241 assert!((avg - 0.5333).abs() < 0.01);
242 }
243}