1use crate::e8::{E8Codec, E8EncodedVector};
9use crate::error::{EmbedVecError, Result};
10use crate::quantization::Quantization;
11use serde::{Deserialize, Serialize};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum StoredVector {
16 Raw(Vec<f32>),
18 E8(E8EncodedVector),
20}
21
22impl StoredVector {
23 pub fn to_f32(&self, codec: Option<&E8Codec>) -> Vec<f32> {
25 match self {
26 StoredVector::Raw(v) => v.clone(),
27 StoredVector::E8(encoded) => {
28 if let Some(c) = codec {
29 c.decode(encoded)
30 } else {
31 vec![0.0; encoded.points.len() * 8]
33 }
34 }
35 }
36 }
37
38 pub fn size_bytes(&self) -> usize {
40 match self {
41 StoredVector::Raw(v) => v.len() * 4,
42 StoredVector::E8(encoded) => encoded.size_bytes(),
43 }
44 }
45}
46
47#[derive(Debug)]
52pub struct VectorStorage {
53 dimension: usize,
55 vectors: Vec<StoredVector>,
57 quantization: Quantization,
59 memory_bytes: usize,
61}
62
63impl VectorStorage {
64 pub fn new(dimension: usize, quantization: Quantization) -> Self {
70 Self {
71 dimension,
72 vectors: Vec::new(),
73 quantization,
74 memory_bytes: 0,
75 }
76 }
77
78 pub fn add(&mut self, vector: &[f32], codec: Option<&E8Codec>) -> Result<usize> {
87 let stored = match &self.quantization {
88 Quantization::None => StoredVector::Raw(vector.to_vec()),
89 Quantization::E8 { .. } => {
90 if let Some(c) = codec {
91 let encoded = c.encode(vector)?;
92 StoredVector::E8(encoded)
93 } else {
94 return Err(EmbedVecError::QuantizationError(
95 "E8 codec required for E8 quantization".to_string(),
96 ));
97 }
98 }
99 };
100
101 self.memory_bytes += stored.size_bytes();
102 let id = self.vectors.len();
103 self.vectors.push(stored);
104 Ok(id)
105 }
106
107 #[inline]
116 pub fn get(&self, id: usize, codec: Option<&E8Codec>) -> Result<Vec<f32>> {
117 self.vectors
118 .get(id)
119 .map(|v| v.to_f32(codec))
120 .ok_or(EmbedVecError::VectorNotFound(id))
121 }
122
123 #[inline]
126 pub fn get_raw_slice(&self, id: usize) -> Option<&[f32]> {
127 match self.vectors.get(id) {
128 Some(StoredVector::Raw(v)) => Some(v.as_slice()),
129 _ => None,
130 }
131 }
132
133 #[inline]
135 pub fn get_stored(&self, id: usize) -> Option<&StoredVector> {
136 self.vectors.get(id)
137 }
138
139 pub fn get_batch(&self, ids: &[usize], codec: Option<&E8Codec>) -> Vec<Option<Vec<f32>>> {
141 ids.iter()
142 .map(|&id| self.vectors.get(id).map(|v| v.to_f32(codec)))
143 .collect()
144 }
145
146 pub fn len(&self) -> usize {
148 self.vectors.len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.vectors.is_empty()
154 }
155
156 pub fn clear(&mut self) {
158 self.vectors.clear();
159 self.memory_bytes = 0;
160 }
161
162 pub fn memory_bytes(&self) -> usize {
164 self.memory_bytes
165 }
166
167 pub fn bytes_per_vector(&self) -> f32 {
169 if self.vectors.is_empty() {
170 0.0
171 } else {
172 self.memory_bytes as f32 / self.vectors.len() as f32
173 }
174 }
175
176 pub fn set_quantization(
182 &mut self,
183 new_quantization: Quantization,
184 codec: Option<&E8Codec>,
185 ) -> Result<()> {
186 if self.quantization == new_quantization {
187 return Ok(());
188 }
189
190 let mut new_vectors = Vec::with_capacity(self.vectors.len());
192 let mut new_memory = 0usize;
193
194 for stored in &self.vectors {
195 let raw = stored.to_f32(codec);
197
198 let new_stored = match &new_quantization {
200 Quantization::None => StoredVector::Raw(raw),
201 Quantization::E8 { .. } => {
202 if let Some(c) = codec {
203 let encoded = c.encode(&raw)?;
204 StoredVector::E8(encoded)
205 } else {
206 return Err(EmbedVecError::QuantizationError(
207 "E8 codec required for E8 quantization".to_string(),
208 ));
209 }
210 }
211 };
212
213 new_memory += new_stored.size_bytes();
214 new_vectors.push(new_stored);
215 }
216
217 self.vectors = new_vectors;
218 self.memory_bytes = new_memory;
219 self.quantization = new_quantization;
220
221 Ok(())
222 }
223
224 pub fn quantization(&self) -> &Quantization {
226 &self.quantization
227 }
228
229 pub fn dimension(&self) -> usize {
231 self.dimension
232 }
233
234 pub fn compute_distance(
238 &self,
239 query: &[f32],
240 id: usize,
241 codec: Option<&E8Codec>,
242 distance_fn: impl Fn(&[f32], &[f32]) -> f32,
243 ) -> Result<f32> {
244 let stored = self
245 .vectors
246 .get(id)
247 .ok_or(EmbedVecError::VectorNotFound(id))?;
248
249 match stored {
250 StoredVector::Raw(v) => Ok(distance_fn(query, v)),
251 StoredVector::E8(encoded) => {
252 if let Some(c) = codec {
253 let decoded = c.decode(encoded);
255 Ok(distance_fn(query, &decoded))
256 } else {
257 Err(EmbedVecError::QuantizationError(
258 "E8 codec required for distance computation".to_string(),
259 ))
260 }
261 }
262 }
263 }
264
265 pub fn iter_ids(&self) -> impl Iterator<Item = usize> {
267 0..self.vectors.len()
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_raw_storage() {
277 let mut storage = VectorStorage::new(4, Quantization::None);
278
279 let v1 = vec![1.0, 2.0, 3.0, 4.0];
280 let id = storage.add(&v1, None).unwrap();
281 assert_eq!(id, 0);
282
283 let retrieved = storage.get(0, None).unwrap();
284 assert_eq!(retrieved, v1);
285 }
286
287 #[test]
288 fn test_e8_storage() {
289 let codec = E8Codec::new(16, 10, true, 42);
290 let mut storage = VectorStorage::new(16, Quantization::e8_default());
291
292 let v1: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
293 let id = storage.add(&v1, Some(&codec)).unwrap();
294 assert_eq!(id, 0);
295
296 let retrieved = storage.get(0, Some(&codec)).unwrap();
297 assert_eq!(retrieved.len(), 16);
298
299 let mse: f32 = v1
302 .iter()
303 .zip(retrieved.iter())
304 .map(|(a, b)| (a - b).powi(2))
305 .sum::<f32>()
306 / 16.0;
307 assert!(mse < 10.0, "MSE too high: {}", mse);
308 }
309
310 #[test]
311 fn test_memory_tracking() {
312 let mut storage = VectorStorage::new(768, Quantization::None);
313
314 for _ in 0..10 {
315 let v: Vec<f32> = vec![0.0; 768];
316 storage.add(&v, None).unwrap();
317 }
318
319 assert_eq!(storage.memory_bytes(), 768 * 4 * 10);
320 }
321
322 #[test]
323 fn test_clear() {
324 let mut storage = VectorStorage::new(4, Quantization::None);
325 storage.add(&[1.0, 2.0, 3.0, 4.0], None).unwrap();
326
327 assert_eq!(storage.len(), 1);
328 storage.clear();
329 assert_eq!(storage.len(), 0);
330 assert_eq!(storage.memory_bytes(), 0);
331 }
332}