1use serde::{Deserialize, Serialize};
9
10#[derive(
16 Debug,
17 Clone,
18 PartialEq,
19 Serialize,
20 Deserialize,
21 zerompk::ToMessagePack,
22 zerompk::FromMessagePack,
23)]
24pub struct MultiVector {
25 data: Vec<f32>,
27 count: usize,
29 dim: usize,
31}
32
33impl MultiVector {
34 pub fn from_vectors(vectors: Vec<Vec<f32>>) -> Result<Self, MultiVectorError> {
36 if vectors.is_empty() {
37 return Err(MultiVectorError::Empty);
38 }
39 let dim = vectors[0].len();
40 if dim == 0 {
41 return Err(MultiVectorError::ZeroDimension);
42 }
43 let count = vectors.len();
44 let mut data = Vec::with_capacity(count * dim);
45 for (i, v) in vectors.iter().enumerate() {
46 if v.len() != dim {
47 return Err(MultiVectorError::DimensionMismatch {
48 expected: dim,
49 got: v.len(),
50 index: i,
51 });
52 }
53 for &val in v {
54 if !val.is_finite() {
55 return Err(MultiVectorError::NonFiniteValue);
56 }
57 }
58 data.extend_from_slice(v);
59 }
60 Ok(Self { data, count, dim })
61 }
62
63 pub fn from_raw(data: Vec<f32>, count: usize, dim: usize) -> Result<Self, MultiVectorError> {
65 if count == 0 || dim == 0 {
66 return Err(MultiVectorError::Empty);
67 }
68 if data.len() != count * dim {
69 return Err(MultiVectorError::DataLengthMismatch {
70 expected: count * dim,
71 got: data.len(),
72 });
73 }
74 Ok(Self { data, count, dim })
75 }
76
77 pub fn count(&self) -> usize {
79 self.count
80 }
81
82 pub fn dim(&self) -> usize {
84 self.dim
85 }
86
87 pub fn get(&self, i: usize) -> Option<&[f32]> {
89 if i >= self.count {
90 return None;
91 }
92 let start = i * self.dim;
93 Some(&self.data[start..start + self.dim])
94 }
95
96 pub fn iter(&self) -> impl Iterator<Item = &[f32]> {
98 (0..self.count).map(move |i| {
99 let start = i * self.dim;
100 &self.data[start..start + self.dim]
101 })
102 }
103
104 pub fn to_vectors(&self) -> Vec<Vec<f32>> {
106 self.iter().map(|s| s.to_vec()).collect()
107 }
108
109 pub fn raw_data(&self) -> &[f32] {
111 &self.data
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
117#[non_exhaustive]
118pub enum MultiVectorScoreMode {
119 MaxSim,
121 AvgSim,
123 SumSim,
125}
126
127impl MultiVectorScoreMode {
128 pub fn parse(s: &str) -> Option<Self> {
130 match s.to_lowercase().as_str() {
131 "max_sim" | "maxsim" => Some(Self::MaxSim),
132 "avg_sim" | "avgsim" => Some(Self::AvgSim),
133 "sum_sim" | "sumsim" => Some(Self::SumSim),
134 _ => None,
135 }
136 }
137
138 pub fn aggregate(&self, scores: &[f32]) -> f32 {
140 if scores.is_empty() {
141 return 0.0;
142 }
143 match self {
144 Self::MaxSim => scores.iter().cloned().reduce(f32::max).unwrap_or(0.0),
145 Self::AvgSim => scores.iter().sum::<f32>() / scores.len() as f32,
146 Self::SumSim => scores.iter().sum(),
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
153#[non_exhaustive]
154pub enum MultiVectorError {
155 Empty,
156 ZeroDimension,
157 NonFiniteValue,
158 DimensionMismatch {
159 expected: usize,
160 got: usize,
161 index: usize,
162 },
163 DataLengthMismatch {
164 expected: usize,
165 got: usize,
166 },
167}
168
169impl std::fmt::Display for MultiVectorError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 match self {
172 Self::Empty => write!(f, "multi-vector must contain at least one vector"),
173 Self::ZeroDimension => write!(f, "vector dimension must be > 0"),
174 Self::NonFiniteValue => write!(f, "vector values must be finite"),
175 Self::DimensionMismatch {
176 expected,
177 got,
178 index,
179 } => write!(
180 f,
181 "dimension mismatch at vector {index}: expected {expected}, got {got}"
182 ),
183 Self::DataLengthMismatch { expected, got } => {
184 write!(f, "data length mismatch: expected {expected}, got {got}")
185 }
186 }
187 }
188}
189
190impl std::error::Error for MultiVectorError {}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn from_vectors_basic() {
198 let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]]).unwrap();
199 assert_eq!(mv.count(), 2);
200 assert_eq!(mv.dim(), 3);
201 assert_eq!(mv.get(0).unwrap(), &[1.0, 2.0, 3.0]);
202 assert_eq!(mv.get(1).unwrap(), &[4.0, 5.0, 6.0]);
203 assert!(mv.get(2).is_none());
204 }
205
206 #[test]
207 fn dimension_mismatch_rejected() {
208 let err = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0]]).unwrap_err();
209 assert!(matches!(err, MultiVectorError::DimensionMismatch { .. }));
210 }
211
212 #[test]
213 fn non_finite_rejected() {
214 assert!(MultiVector::from_vectors(vec![vec![f32::NAN]]).is_err());
215 }
216
217 #[test]
218 fn empty_rejected() {
219 assert!(MultiVector::from_vectors(vec![]).is_err());
220 }
221
222 #[test]
223 fn iter_all_vectors() {
224 let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]])
225 .unwrap();
226 let collected: Vec<&[f32]> = mv.iter().collect();
227 assert_eq!(collected.len(), 3);
228 assert_eq!(collected[2], &[5.0, 6.0]);
229 }
230
231 #[test]
232 fn serde_roundtrip() {
233 let mv = MultiVector::from_vectors(vec![vec![1.0, 2.0], vec![3.0, 4.0]]).unwrap();
234 let bytes = zerompk::to_msgpack_vec(&mv).unwrap();
235 let restored: MultiVector = zerompk::from_msgpack(&bytes).unwrap();
236 assert_eq!(mv, restored);
237 }
238
239 #[test]
240 fn score_modes() {
241 let scores = vec![0.5, 0.8, 0.3];
242 assert!((MultiVectorScoreMode::MaxSim.aggregate(&scores) - 0.8).abs() < 1e-6);
243 assert!((MultiVectorScoreMode::SumSim.aggregate(&scores) - 1.6).abs() < 1e-6);
244 let avg = MultiVectorScoreMode::AvgSim.aggregate(&scores);
245 assert!((avg - 0.5333).abs() < 0.01);
246 }
247}