embedvec/
storage.rs

1//! Vector Storage Module
2//!
3//! ## Table of Contents
4//! - **VectorStorage**: Main storage for vectors (raw f32 or E8-quantized)
5//! - **StoredVector**: Enum representing stored vector format
6//! - **Storage operations**: add, get, clear, re-quantization
7
8use crate::e8::{E8Codec, E8EncodedVector};
9use crate::error::{EmbedVecError, Result};
10use crate::quantization::Quantization;
11use serde::{Deserialize, Serialize};
12
13/// Stored vector representation
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub enum StoredVector {
16    /// Raw f32 vector (no quantization)
17    Raw(Vec<f32>),
18    /// E8-quantized vector
19    E8(E8EncodedVector),
20}
21
22impl StoredVector {
23    /// Get the raw f32 vector (decoding if necessary)
24    pub fn to_f32(&self, codec: Option<&E8Codec>) -> Vec<f32> {
25        match self {
26            StoredVector::Raw(v) => v.clone(),
27            StoredVector::E8(encoded) => {
28                if let Some(c) = codec {
29                    c.decode(encoded)
30                } else {
31                    // Fallback: return zeros if no codec (shouldn't happen)
32                    vec![0.0; encoded.points.len() * 8]
33                }
34            }
35        }
36    }
37
38    /// Get size in bytes
39    pub fn size_bytes(&self) -> usize {
40        match self {
41            StoredVector::Raw(v) => v.len() * 4,
42            StoredVector::E8(encoded) => encoded.size_bytes(),
43        }
44    }
45}
46
47/// Vector storage with optional quantization
48///
49/// Manages storage of vectors in either raw f32 format or E8-quantized format.
50/// Supports dynamic re-quantization when changing modes.
51#[derive(Debug)]
52pub struct VectorStorage {
53    /// Vector dimension
54    dimension: usize,
55    /// Stored vectors
56    vectors: Vec<StoredVector>,
57    /// Current quantization mode
58    quantization: Quantization,
59    /// Total memory usage in bytes
60    memory_bytes: usize,
61}
62
63impl VectorStorage {
64    /// Create new vector storage
65    ///
66    /// # Arguments
67    /// * `dimension` - Vector dimension
68    /// * `quantization` - Quantization mode
69    pub fn new(dimension: usize, quantization: Quantization) -> Self {
70        Self {
71            dimension,
72            vectors: Vec::new(),
73            quantization,
74            memory_bytes: 0,
75        }
76    }
77
78    /// Add a vector to storage
79    ///
80    /// # Arguments
81    /// * `vector` - Raw f32 vector to store
82    /// * `codec` - Optional E8 codec for quantization
83    ///
84    /// # Returns
85    /// Assigned vector ID
86    pub fn add(&mut self, vector: &[f32], codec: Option<&E8Codec>) -> Result<usize> {
87        let stored = match &self.quantization {
88            Quantization::None => StoredVector::Raw(vector.to_vec()),
89            Quantization::E8 { .. } => {
90                if let Some(c) = codec {
91                    let encoded = c.encode(vector)?;
92                    StoredVector::E8(encoded)
93                } else {
94                    return Err(EmbedVecError::QuantizationError(
95                        "E8 codec required for E8 quantization".to_string(),
96                    ));
97                }
98            }
99        };
100
101        self.memory_bytes += stored.size_bytes();
102        let id = self.vectors.len();
103        self.vectors.push(stored);
104        Ok(id)
105    }
106
107    /// Get a vector by ID
108    ///
109    /// # Arguments
110    /// * `id` - Vector ID
111    /// * `codec` - Optional E8 codec for decoding
112    ///
113    /// # Returns
114    /// Raw f32 vector (decoded if quantized)
115    #[inline]
116    pub fn get(&self, id: usize, codec: Option<&E8Codec>) -> Result<Vec<f32>> {
117        self.vectors
118            .get(id)
119            .map(|v| v.to_f32(codec))
120            .ok_or(EmbedVecError::VectorNotFound(id))
121    }
122
123    /// Get raw vector slice for unquantized storage (zero-copy)
124    /// Returns None if vector is quantized or ID is invalid
125    #[inline]
126    pub fn get_raw_slice(&self, id: usize) -> Option<&[f32]> {
127        match self.vectors.get(id) {
128            Some(StoredVector::Raw(v)) => Some(v.as_slice()),
129            _ => None,
130        }
131    }
132
133    /// Get stored vector reference by ID
134    #[inline]
135    pub fn get_stored(&self, id: usize) -> Option<&StoredVector> {
136        self.vectors.get(id)
137    }
138    
139    /// Batch get multiple vectors by IDs (more efficient than individual gets)
140    pub fn get_batch(&self, ids: &[usize], codec: Option<&E8Codec>) -> Vec<Option<Vec<f32>>> {
141        ids.iter()
142            .map(|&id| self.vectors.get(id).map(|v| v.to_f32(codec)))
143            .collect()
144    }
145
146    /// Get number of stored vectors
147    pub fn len(&self) -> usize {
148        self.vectors.len()
149    }
150
151    /// Check if storage is empty
152    pub fn is_empty(&self) -> bool {
153        self.vectors.is_empty()
154    }
155
156    /// Clear all vectors
157    pub fn clear(&mut self) {
158        self.vectors.clear();
159        self.memory_bytes = 0;
160    }
161
162    /// Get total memory usage in bytes
163    pub fn memory_bytes(&self) -> usize {
164        self.memory_bytes
165    }
166
167    /// Get memory usage per vector (average)
168    pub fn bytes_per_vector(&self) -> f32 {
169        if self.vectors.is_empty() {
170            0.0
171        } else {
172            self.memory_bytes as f32 / self.vectors.len() as f32
173        }
174    }
175
176    /// Change quantization mode (re-quantizes all vectors)
177    ///
178    /// # Arguments
179    /// * `new_quantization` - New quantization mode
180    /// * `codec` - E8 codec (required if switching to E8)
181    pub fn set_quantization(
182        &mut self,
183        new_quantization: Quantization,
184        codec: Option<&E8Codec>,
185    ) -> Result<()> {
186        if self.quantization == new_quantization {
187            return Ok(());
188        }
189
190        // Re-quantize all vectors
191        let mut new_vectors = Vec::with_capacity(self.vectors.len());
192        let mut new_memory = 0usize;
193
194        for stored in &self.vectors {
195            // First decode to f32
196            let raw = stored.to_f32(codec);
197
198            // Then encode with new quantization
199            let new_stored = match &new_quantization {
200                Quantization::None => StoredVector::Raw(raw),
201                Quantization::E8 { .. } => {
202                    if let Some(c) = codec {
203                        let encoded = c.encode(&raw)?;
204                        StoredVector::E8(encoded)
205                    } else {
206                        return Err(EmbedVecError::QuantizationError(
207                            "E8 codec required for E8 quantization".to_string(),
208                        ));
209                    }
210                }
211            };
212
213            new_memory += new_stored.size_bytes();
214            new_vectors.push(new_stored);
215        }
216
217        self.vectors = new_vectors;
218        self.memory_bytes = new_memory;
219        self.quantization = new_quantization;
220
221        Ok(())
222    }
223
224    /// Get current quantization mode
225    pub fn quantization(&self) -> &Quantization {
226        &self.quantization
227    }
228
229    /// Get vector dimension
230    pub fn dimension(&self) -> usize {
231        self.dimension
232    }
233
234    /// Compute distance between query and stored vector
235    ///
236    /// Uses asymmetric distance for quantized vectors (query in f32, db decoded on-the-fly)
237    pub fn compute_distance(
238        &self,
239        query: &[f32],
240        id: usize,
241        codec: Option<&E8Codec>,
242        distance_fn: impl Fn(&[f32], &[f32]) -> f32,
243    ) -> Result<f32> {
244        let stored = self
245            .vectors
246            .get(id)
247            .ok_or(EmbedVecError::VectorNotFound(id))?;
248
249        match stored {
250            StoredVector::Raw(v) => Ok(distance_fn(query, v)),
251            StoredVector::E8(encoded) => {
252                if let Some(c) = codec {
253                    // Asymmetric: decode on-the-fly
254                    let decoded = c.decode(encoded);
255                    Ok(distance_fn(query, &decoded))
256                } else {
257                    Err(EmbedVecError::QuantizationError(
258                        "E8 codec required for distance computation".to_string(),
259                    ))
260                }
261            }
262        }
263    }
264
265    /// Iterate over all vector IDs
266    pub fn iter_ids(&self) -> impl Iterator<Item = usize> {
267        0..self.vectors.len()
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_raw_storage() {
277        let mut storage = VectorStorage::new(4, Quantization::None);
278
279        let v1 = vec![1.0, 2.0, 3.0, 4.0];
280        let id = storage.add(&v1, None).unwrap();
281        assert_eq!(id, 0);
282
283        let retrieved = storage.get(0, None).unwrap();
284        assert_eq!(retrieved, v1);
285    }
286
287    #[test]
288    fn test_e8_storage() {
289        let codec = E8Codec::new(16, 10, true, 42);
290        let mut storage = VectorStorage::new(16, Quantization::e8_default());
291
292        let v1: Vec<f32> = (0..16).map(|i| i as f32 * 0.1).collect();
293        let id = storage.add(&v1, Some(&codec)).unwrap();
294        assert_eq!(id, 0);
295
296        let retrieved = storage.get(0, Some(&codec)).unwrap();
297        assert_eq!(retrieved.len(), 16);
298
299        // Check that it's approximately equal (quantization introduces error)
300        // Current E8 implementation has higher error - acceptable for first version
301        let mse: f32 = v1
302            .iter()
303            .zip(retrieved.iter())
304            .map(|(a, b)| (a - b).powi(2))
305            .sum::<f32>()
306            / 16.0;
307        assert!(mse < 10.0, "MSE too high: {}", mse);
308    }
309
310    #[test]
311    fn test_memory_tracking() {
312        let mut storage = VectorStorage::new(768, Quantization::None);
313
314        for _ in 0..10 {
315            let v: Vec<f32> = vec![0.0; 768];
316            storage.add(&v, None).unwrap();
317        }
318
319        assert_eq!(storage.memory_bytes(), 768 * 4 * 10);
320    }
321
322    #[test]
323    fn test_clear() {
324        let mut storage = VectorStorage::new(4, Quantization::None);
325        storage.add(&[1.0, 2.0, 3.0, 4.0], None).unwrap();
326        
327        assert_eq!(storage.len(), 1);
328        storage.clear();
329        assert_eq!(storage.len(), 0);
330        assert_eq!(storage.memory_bytes(), 0);
331    }
332}