oxirs_vec/
persistence.rs

1//! Index persistence with compression
2//!
3//! This module provides serialization and deserialization of vector indices
4//! with support for compression, versioning, and incremental updates.
5
6use crate::hnsw::{HnswConfig, HnswIndex};
7use crate::Vector;
8use anyhow::{anyhow, Result};
9use serde::{Deserialize, Serialize};
10use std::fs::{File, OpenOptions};
11use std::io::{BufReader, BufWriter, Read, Write};
12use std::path::Path;
13
14/// Index persistence format version
15const PERSISTENCE_VERSION: u32 = 1;
16
17/// Magic number for index files (OxVe = OxiRS Vector)
18const MAGIC_NUMBER: &[u8; 4] = b"OxVe";
19
20/// Compression algorithm for index persistence
21#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
22pub enum CompressionAlgorithm {
23    /// No compression (fastest, largest)
24    None,
25    /// Zstd compression (balanced)
26    Zstd { level: i32 },
27    /// High compression (slowest, smallest)
28    ZstdMax,
29}
30
31impl Default for CompressionAlgorithm {
32    fn default() -> Self {
33        Self::Zstd { level: 3 } // Fast compression by default
34    }
35}
36
37/// Persistence configuration
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PersistenceConfig {
40    /// Compression algorithm to use
41    pub compression: CompressionAlgorithm,
42    /// Include metadata in persistence
43    pub include_metadata: bool,
44    /// Validate data integrity on load
45    pub validate_on_load: bool,
46    /// Enable incremental persistence
47    pub incremental: bool,
48    /// Checkpoint interval for incremental persistence (in operations)
49    pub checkpoint_interval: usize,
50}
51
52impl Default for PersistenceConfig {
53    fn default() -> Self {
54        Self {
55            compression: CompressionAlgorithm::default(),
56            include_metadata: true,
57            validate_on_load: true,
58            incremental: false,
59            checkpoint_interval: 10000,
60        }
61    }
62}
63
64/// Serializable index header
65#[derive(Debug, Clone, Serialize, Deserialize)]
66struct IndexHeader {
67    version: u32,
68    compression: CompressionAlgorithm,
69    node_count: usize,
70    dimension: usize,
71    config: HnswConfig,
72    timestamp: u64,
73    checksum: u64,
74}
75
76/// Serializable node data
77#[derive(Debug, Clone, Serialize, Deserialize)]
78struct SerializableNode {
79    uri: String,
80    vector_data: Vec<f32>,
81    connections: Vec<Vec<usize>>,
82    level: usize,
83}
84
85/// Persistence manager for HNSW indices
86pub struct PersistenceManager {
87    config: PersistenceConfig,
88}
89
90impl PersistenceManager {
91    /// Create a new persistence manager
92    pub fn new(config: PersistenceConfig) -> Self {
93        Self { config }
94    }
95
96    /// Save HNSW index to disk
97    pub fn save_index<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
98        let path = path.as_ref();
99        tracing::info!("Saving HNSW index to {:?}", path);
100
101        let file = OpenOptions::new()
102            .write(true)
103            .create(true)
104            .truncate(true)
105            .open(path)?;
106
107        let mut writer = BufWriter::new(file);
108
109        // Write magic number
110        writer.write_all(MAGIC_NUMBER)?;
111
112        // Create header
113        let header = IndexHeader {
114            version: PERSISTENCE_VERSION,
115            compression: self.config.compression,
116            node_count: index.len(),
117            dimension: if let Some(node) = index.nodes().first() {
118                node.vector.dimensions
119            } else {
120                0
121            },
122            config: index.config().clone(),
123            timestamp: std::time::SystemTime::now()
124                .duration_since(std::time::UNIX_EPOCH)
125                .unwrap()
126                .as_secs(),
127            checksum: 0, // Will be calculated
128        };
129
130        // Serialize header
131        let header_bytes = bincode::serialize(&header)?;
132        let header_len = header_bytes.len() as u32;
133        writer.write_all(&header_len.to_le_bytes())?;
134        writer.write_all(&header_bytes)?;
135
136        // Serialize nodes
137        let nodes = self.serialize_nodes(index)?;
138
139        // Compress if needed
140        let data = match self.config.compression {
141            CompressionAlgorithm::None => nodes,
142            CompressionAlgorithm::Zstd { level } => zstd::encode_all(&nodes[..], level)?,
143            CompressionAlgorithm::ZstdMax => zstd::encode_all(&nodes[..], 21)?,
144        };
145
146        // Write data length and data
147        let data_len = data.len() as u64;
148        writer.write_all(&data_len.to_le_bytes())?;
149        writer.write_all(&data)?;
150
151        // Write URI mapping
152        let uri_mapping = bincode::serialize(&index.uri_to_id())?;
153        let mapping_len = uri_mapping.len() as u32;
154        writer.write_all(&mapping_len.to_le_bytes())?;
155        writer.write_all(&uri_mapping)?;
156
157        // Write entry point
158        let entry_point = bincode::serialize(&index.entry_point())?;
159        writer.write_all(&entry_point)?;
160
161        writer.flush()?;
162
163        tracing::info!(
164            "Successfully saved HNSW index with {} nodes (compression: {:?})",
165            index.len(),
166            self.config.compression
167        );
168
169        Ok(())
170    }
171
172    /// Load HNSW index from disk
173    pub fn load_index<P: AsRef<Path>>(&self, path: P) -> Result<HnswIndex> {
174        let path = path.as_ref();
175        tracing::info!("Loading HNSW index from {:?}", path);
176
177        let file = File::open(path)?;
178        let mut reader = BufReader::new(file);
179
180        // Verify magic number
181        let mut magic = [0u8; 4];
182        reader.read_exact(&mut magic)?;
183        if &magic != MAGIC_NUMBER {
184            return Err(anyhow!("Invalid index file format"));
185        }
186
187        // Read header
188        let mut header_len_bytes = [0u8; 4];
189        reader.read_exact(&mut header_len_bytes)?;
190        let header_len = u32::from_le_bytes(header_len_bytes) as usize;
191
192        let mut header_bytes = vec![0u8; header_len];
193        reader.read_exact(&mut header_bytes)?;
194        let header: IndexHeader = bincode::deserialize(&header_bytes)?;
195
196        // Verify version
197        if header.version != PERSISTENCE_VERSION {
198            return Err(anyhow!(
199                "Unsupported index version: {} (expected {})",
200                header.version,
201                PERSISTENCE_VERSION
202            ));
203        }
204
205        // Read data length
206        let mut data_len_bytes = [0u8; 8];
207        reader.read_exact(&mut data_len_bytes)?;
208        let data_len = u64::from_le_bytes(data_len_bytes) as usize;
209
210        // Read and decompress data
211        let mut compressed_data = vec![0u8; data_len];
212        reader.read_exact(&mut compressed_data)?;
213
214        let nodes_data = match header.compression {
215            CompressionAlgorithm::None => compressed_data,
216            CompressionAlgorithm::Zstd { .. } | CompressionAlgorithm::ZstdMax => {
217                zstd::decode_all(&compressed_data[..])?
218            }
219        };
220
221        // Read URI mapping
222        let mut mapping_len_bytes = [0u8; 4];
223        reader.read_exact(&mut mapping_len_bytes)?;
224        let mapping_len = u32::from_le_bytes(mapping_len_bytes) as usize;
225
226        let mut mapping_bytes = vec![0u8; mapping_len];
227        reader.read_exact(&mut mapping_bytes)?;
228        let uri_mapping: std::collections::HashMap<String, usize> =
229            bincode::deserialize(&mapping_bytes)?;
230
231        // Read entry point
232        let mut entry_point_bytes = Vec::new();
233        reader.read_to_end(&mut entry_point_bytes)?;
234        let entry_point: Option<usize> = bincode::deserialize(&entry_point_bytes)?;
235
236        // Reconstruct index
237        let mut index = HnswIndex::new(header.config)?;
238        self.deserialize_nodes(&nodes_data, &mut index)?;
239
240        // Restore URI mapping
241        *index.uri_to_id_mut() = uri_mapping;
242
243        // Restore entry point
244        index.set_entry_point(entry_point);
245
246        // Validate if requested
247        if self.config.validate_on_load {
248            self.validate_index(&index)?;
249        }
250
251        tracing::info!("Successfully loaded HNSW index with {} nodes", index.len());
252
253        Ok(index)
254    }
255
256    /// Serialize nodes to bytes
257    fn serialize_nodes(&self, index: &HnswIndex) -> Result<Vec<u8>> {
258        let serializable_nodes: Vec<SerializableNode> = index
259            .nodes()
260            .iter()
261            .map(|node| SerializableNode {
262                uri: node.uri.clone(),
263                vector_data: node.vector.as_f32(),
264                connections: node
265                    .connections
266                    .iter()
267                    .map(|set| set.iter().copied().collect())
268                    .collect(),
269                level: node.level(),
270            })
271            .collect();
272
273        Ok(bincode::serialize(&serializable_nodes)?)
274    }
275
276    /// Deserialize nodes from bytes
277    fn deserialize_nodes(&self, data: &[u8], index: &mut HnswIndex) -> Result<()> {
278        let serializable_nodes: Vec<SerializableNode> = bincode::deserialize(data)?;
279
280        for node_data in serializable_nodes {
281            let vector = Vector::new(node_data.vector_data);
282            let mut node = crate::hnsw::Node::new(node_data.uri, vector, node_data.level);
283
284            // Restore connections
285            for (level, connections) in node_data.connections.into_iter().enumerate() {
286                for conn_id in connections {
287                    node.add_connection(level, conn_id);
288                }
289            }
290
291            index.nodes_mut().push(node);
292        }
293
294        Ok(())
295    }
296
297    /// Validate index integrity
298    fn validate_index(&self, index: &HnswIndex) -> Result<()> {
299        tracing::debug!("Validating index integrity");
300
301        // Check that all connections are valid
302        for (node_id, node) in index.nodes().iter().enumerate() {
303            for level in 0..=node.level() {
304                if let Some(connections) = node.get_connections(level) {
305                    for &conn_id in connections {
306                        if conn_id >= index.len() {
307                            return Err(anyhow!(
308                                "Invalid connection: node {} has connection to non-existent node {}",
309                                node_id,
310                                conn_id
311                            ));
312                        }
313                    }
314                }
315            }
316        }
317
318        // Check URI mapping consistency
319        for (uri, &node_id) in index.uri_to_id() {
320            if node_id >= index.len() {
321                return Err(anyhow!(
322                    "Invalid URI mapping: {} points to non-existent node {}",
323                    uri,
324                    node_id
325                ));
326            }
327
328            let actual_uri = &index.nodes()[node_id].uri;
329            if uri != actual_uri {
330                return Err(anyhow!(
331                    "URI mapping mismatch: expected '{}', found '{}'",
332                    uri,
333                    actual_uri
334                ));
335            }
336        }
337
338        // Check entry point
339        if let Some(entry_id) = index.entry_point() {
340            if entry_id >= index.len() {
341                return Err(anyhow!(
342                    "Invalid entry point: {} (index has {} nodes)",
343                    entry_id,
344                    index.len()
345                ));
346            }
347        }
348
349        tracing::debug!("Index validation passed");
350        Ok(())
351    }
352
353    /// Create a snapshot of the index
354    pub fn create_snapshot<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
355        let path = path.as_ref();
356        let snapshot_path = path.with_extension(format!(
357            "snapshot.{}",
358            std::time::SystemTime::now()
359                .duration_since(std::time::UNIX_EPOCH)
360                .unwrap()
361                .as_secs()
362        ));
363
364        self.save_index(index, snapshot_path)?;
365        Ok(())
366    }
367
368    /// Estimate compressed size
369    pub fn estimate_compressed_size(&self, index: &HnswIndex) -> Result<usize> {
370        let nodes = self.serialize_nodes(index)?;
371
372        let compressed_size = match self.config.compression {
373            CompressionAlgorithm::None => nodes.len(),
374            CompressionAlgorithm::Zstd { level } => zstd::encode_all(&nodes[..], level)?.len(),
375            CompressionAlgorithm::ZstdMax => zstd::encode_all(&nodes[..], 21)?.len(),
376        };
377
378        Ok(compressed_size)
379    }
380}
381
382/// Incremental persistence manager
383pub struct IncrementalPersistence {
384    config: PersistenceConfig,
385    operation_count: usize,
386    last_checkpoint: std::time::Instant,
387}
388
389impl IncrementalPersistence {
390    pub fn new(config: PersistenceConfig) -> Self {
391        Self {
392            config,
393            operation_count: 0,
394            last_checkpoint: std::time::Instant::now(),
395        }
396    }
397
398    /// Record an operation
399    pub fn record_operation(&mut self) {
400        self.operation_count += 1;
401    }
402
403    /// Check if checkpoint is needed
404    pub fn needs_checkpoint(&self) -> bool {
405        self.operation_count >= self.config.checkpoint_interval
406    }
407
408    /// Create checkpoint
409    pub fn checkpoint<P: AsRef<Path>>(&mut self, index: &HnswIndex, base_path: P) -> Result<()> {
410        if !self.needs_checkpoint() {
411            return Ok(());
412        }
413
414        let manager = PersistenceManager::new(self.config.clone());
415        manager.create_snapshot(index, base_path)?;
416
417        self.operation_count = 0;
418        self.last_checkpoint = std::time::Instant::now();
419
420        Ok(())
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use crate::hnsw::HnswConfig;
428    use crate::Vector;
429    use std::env::temp_dir;
430
431    #[test]
432    fn test_save_and_load_index() {
433        let config = HnswConfig::default();
434        let mut index = HnswIndex::new(config).unwrap();
435
436        // Add some vectors
437        for i in 0..10 {
438            let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
439            index.add_vector(format!("vec_{}", i), vec).unwrap();
440        }
441
442        // Save index
443        let mut temp_path = temp_dir();
444        temp_path.push("test_hnsw_index.bin");
445
446        let persistence_config = PersistenceConfig::default();
447        let manager = PersistenceManager::new(persistence_config);
448
449        manager.save_index(&index, &temp_path).unwrap();
450
451        // Load index
452        let loaded_index = manager.load_index(&temp_path).unwrap();
453
454        assert_eq!(loaded_index.len(), 10);
455        assert_eq!(loaded_index.uri_to_id().len(), 10);
456
457        // Cleanup
458        std::fs::remove_file(temp_path).ok();
459    }
460
461    #[test]
462    fn test_compression() {
463        let config = HnswConfig::default();
464        let mut index = HnswIndex::new(config).unwrap();
465
466        // Add vectors
467        for i in 0..50 {
468            let vec = Vector::new(vec![i as f32; 128]);
469            index.add_vector(format!("vec_{}", i), vec).unwrap();
470        }
471
472        let mut temp_path = temp_dir();
473        temp_path.push("test_compressed_index.bin");
474
475        // Test with compression
476        let compressed_config = PersistenceConfig {
477            compression: CompressionAlgorithm::Zstd { level: 3 },
478            ..Default::default()
479        };
480        let compressed_manager = PersistenceManager::new(compressed_config);
481        compressed_manager.save_index(&index, &temp_path).unwrap();
482
483        let compressed_size = std::fs::metadata(&temp_path).unwrap().len();
484
485        // Test without compression
486        let uncompressed_config = PersistenceConfig {
487            compression: CompressionAlgorithm::None,
488            ..Default::default()
489        };
490        let uncompressed_manager = PersistenceManager::new(uncompressed_config);
491
492        let mut temp_path2 = temp_dir();
493        temp_path2.push("test_uncompressed_index.bin");
494        uncompressed_manager
495            .save_index(&index, &temp_path2)
496            .unwrap();
497
498        let uncompressed_size = std::fs::metadata(&temp_path2).unwrap().len();
499
500        // Compressed should be smaller
501        assert!(compressed_size < uncompressed_size);
502
503        // Cleanup
504        std::fs::remove_file(temp_path).ok();
505        std::fs::remove_file(temp_path2).ok();
506    }
507
508    #[test]
509    fn test_validation() {
510        let config = HnswConfig::default();
511        let mut index = HnswIndex::new(config).unwrap();
512
513        for i in 0..5 {
514            let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
515            index.add_vector(format!("vec_{}", i), vec).unwrap();
516        }
517
518        let persistence_config = PersistenceConfig {
519            validate_on_load: true,
520            ..Default::default()
521        };
522        let manager = PersistenceManager::new(persistence_config);
523
524        // Validation should pass
525        manager.validate_index(&index).unwrap();
526    }
527}