Skip to main content

oxirs_vec/persistence/
mod.rs

1//! Index persistence with compression
2//!
3//! This module provides serialization and deserialization of vector indices
4//! with support for compression, versioning, and incremental updates.
5
6use crate::hnsw::{HnswConfig, HnswIndex};
7use crate::Vector;
8use anyhow::{anyhow, Result};
9use oxicode::{Decode, Encode};
10use serde::{Deserialize, Serialize};
11use std::fs::{File, OpenOptions};
12use std::io::{BufReader, BufWriter, Read, Write};
13use std::path::Path;
14
15/// Index persistence format version
16const PERSISTENCE_VERSION: u32 = 1;
17
18/// Magic number for index files (OxVe = OxiRS Vector)
19const MAGIC_NUMBER: &[u8; 4] = b"OxVe";
20
21/// Compression algorithm for index persistence
22#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Encode, Decode)]
23pub enum CompressionAlgorithm {
24    /// No compression (fastest, largest)
25    None,
26    /// Zstd compression (balanced)
27    Zstd { level: i32 },
28    /// High compression (slowest, smallest)
29    ZstdMax,
30}
31
32impl Default for CompressionAlgorithm {
33    fn default() -> Self {
34        Self::Zstd { level: 3 } // Fast compression by default
35    }
36}
37
38/// Persistence configuration
39#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
40pub struct PersistenceConfig {
41    /// Compression algorithm to use
42    pub compression: CompressionAlgorithm,
43    /// Include metadata in persistence
44    pub include_metadata: bool,
45    /// Validate data integrity on load
46    pub validate_on_load: bool,
47    /// Enable incremental persistence
48    pub incremental: bool,
49    /// Checkpoint interval for incremental persistence (in operations)
50    pub checkpoint_interval: usize,
51}
52
53impl Default for PersistenceConfig {
54    fn default() -> Self {
55        Self {
56            compression: CompressionAlgorithm::default(),
57            include_metadata: true,
58            validate_on_load: true,
59            incremental: false,
60            checkpoint_interval: 10000,
61        }
62    }
63}
64
65/// Serializable index header
66#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
67struct IndexHeader {
68    version: u32,
69    compression: CompressionAlgorithm,
70    node_count: usize,
71    dimension: usize,
72    config: HnswConfig,
73    timestamp: u64,
74    checksum: u64,
75}
76
77/// Serializable node data
78#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
79struct SerializableNode {
80    uri: String,
81    vector_data: Vec<f32>,
82    connections: Vec<Vec<usize>>,
83    level: usize,
84}
85
86/// Persistence manager for HNSW indices
87pub struct PersistenceManager {
88    config: PersistenceConfig,
89}
90
91impl PersistenceManager {
92    /// Create a new persistence manager
93    pub fn new(config: PersistenceConfig) -> Self {
94        Self { config }
95    }
96
97    /// Save HNSW index to disk
98    pub fn save_index<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
99        let path = path.as_ref();
100        tracing::info!("Saving HNSW index to {:?}", path);
101
102        let file = OpenOptions::new()
103            .write(true)
104            .create(true)
105            .truncate(true)
106            .open(path)?;
107
108        let mut writer = BufWriter::new(file);
109
110        // Write magic number
111        writer.write_all(MAGIC_NUMBER)?;
112
113        // Create header
114        let header = IndexHeader {
115            version: PERSISTENCE_VERSION,
116            compression: self.config.compression,
117            node_count: index.len(),
118            dimension: if let Some(node) = index.nodes().first() {
119                node.vector.dimensions
120            } else {
121                0
122            },
123            config: index.config().clone(),
124            timestamp: std::time::SystemTime::now()
125                .duration_since(std::time::UNIX_EPOCH)
126                .expect("SystemTime should be after UNIX_EPOCH")
127                .as_secs(),
128            checksum: 0, // Will be calculated
129        };
130
131        // Serialize header
132        let header_bytes = oxicode::serde::encode_to_vec(&header, oxicode::config::standard())
133            .map_err(|e| anyhow!("Failed to serialize header: {}", e))?;
134        let header_len = header_bytes.len() as u32;
135        writer.write_all(&header_len.to_le_bytes())?;
136        writer.write_all(&header_bytes)?;
137
138        // Serialize nodes
139        let nodes = self.serialize_nodes(index)?;
140
141        // Compress if needed
142        let data = match self.config.compression {
143            CompressionAlgorithm::None => nodes,
144            CompressionAlgorithm::Zstd { level } => oxiarc_zstd::encode_all(&nodes, level)
145                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?,
146            CompressionAlgorithm::ZstdMax => oxiarc_zstd::encode_all(&nodes, 21)
147                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?,
148        };
149
150        // Write data length and data
151        let data_len = data.len() as u64;
152        writer.write_all(&data_len.to_le_bytes())?;
153        writer.write_all(&data)?;
154
155        // Write URI mapping
156        let uri_mapping =
157            oxicode::serde::encode_to_vec(index.uri_to_id(), oxicode::config::standard())
158                .map_err(|e| anyhow!("Failed to serialize URI mapping: {}", e))?;
159        let mapping_len = uri_mapping.len() as u32;
160        writer.write_all(&mapping_len.to_le_bytes())?;
161        writer.write_all(&uri_mapping)?;
162
163        // Write entry point
164        let entry_point =
165            oxicode::serde::encode_to_vec(&index.entry_point(), oxicode::config::standard())
166                .map_err(|e| anyhow!("Failed to serialize entry point: {}", e))?;
167        writer.write_all(&entry_point)?;
168
169        writer.flush()?;
170
171        tracing::info!(
172            "Successfully saved HNSW index with {} nodes (compression: {:?})",
173            index.len(),
174            self.config.compression
175        );
176
177        Ok(())
178    }
179
180    /// Load HNSW index from disk
181    pub fn load_index<P: AsRef<Path>>(&self, path: P) -> Result<HnswIndex> {
182        let path = path.as_ref();
183        tracing::info!("Loading HNSW index from {:?}", path);
184
185        let file = File::open(path)?;
186        let mut reader = BufReader::new(file);
187
188        // Verify magic number
189        let mut magic = [0u8; 4];
190        reader.read_exact(&mut magic)?;
191        if &magic != MAGIC_NUMBER {
192            return Err(anyhow!("Invalid index file format"));
193        }
194
195        // Read header
196        let mut header_len_bytes = [0u8; 4];
197        reader.read_exact(&mut header_len_bytes)?;
198        let header_len = u32::from_le_bytes(header_len_bytes) as usize;
199
200        let mut header_bytes = vec![0u8; header_len];
201        reader.read_exact(&mut header_bytes)?;
202        let (header, _): (IndexHeader, _) =
203            oxicode::serde::decode_from_slice(&header_bytes, oxicode::config::standard())
204                .map_err(|e| anyhow!("Failed to deserialize header: {}", e))?;
205
206        // Verify version
207        if header.version != PERSISTENCE_VERSION {
208            return Err(anyhow!(
209                "Unsupported index version: {} (expected {})",
210                header.version,
211                PERSISTENCE_VERSION
212            ));
213        }
214
215        // Read data length
216        let mut data_len_bytes = [0u8; 8];
217        reader.read_exact(&mut data_len_bytes)?;
218        let data_len = u64::from_le_bytes(data_len_bytes) as usize;
219
220        // Read and decompress data
221        let mut compressed_data = vec![0u8; data_len];
222        reader.read_exact(&mut compressed_data)?;
223
224        let nodes_data = match header.compression {
225            CompressionAlgorithm::None => compressed_data,
226            CompressionAlgorithm::Zstd { .. } | CompressionAlgorithm::ZstdMax => {
227                oxiarc_zstd::decode_all(&compressed_data)
228                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
229            }
230        };
231
232        // Read URI mapping
233        let mut mapping_len_bytes = [0u8; 4];
234        reader.read_exact(&mut mapping_len_bytes)?;
235        let mapping_len = u32::from_le_bytes(mapping_len_bytes) as usize;
236
237        let mut mapping_bytes = vec![0u8; mapping_len];
238        reader.read_exact(&mut mapping_bytes)?;
239        let (uri_mapping, _): (std::collections::HashMap<String, usize>, _) =
240            oxicode::serde::decode_from_slice(&mapping_bytes, oxicode::config::standard())
241                .map_err(|e| anyhow!("Failed to deserialize URI mapping: {}", e))?;
242
243        // Read entry point
244        let mut entry_point_bytes = Vec::new();
245        reader.read_to_end(&mut entry_point_bytes)?;
246        let (entry_point, _): (Option<usize>, _) =
247            oxicode::serde::decode_from_slice(&entry_point_bytes, oxicode::config::standard())
248                .map_err(|e| anyhow!("Failed to deserialize entry point: {}", e))?;
249
250        // Reconstruct index
251        let mut index = HnswIndex::new(header.config)?;
252        self.deserialize_nodes(&nodes_data, &mut index)?;
253
254        // Restore URI mapping
255        *index.uri_to_id_mut() = uri_mapping;
256
257        // Restore entry point
258        index.set_entry_point(entry_point);
259
260        // Validate if requested
261        if self.config.validate_on_load {
262            self.validate_index(&index)?;
263        }
264
265        tracing::info!("Successfully loaded HNSW index with {} nodes", index.len());
266
267        Ok(index)
268    }
269
270    /// Serialize nodes to bytes
271    fn serialize_nodes(&self, index: &HnswIndex) -> Result<Vec<u8>> {
272        let serializable_nodes: Vec<SerializableNode> = index
273            .nodes()
274            .iter()
275            .map(|node| SerializableNode {
276                uri: node.uri.clone(),
277                vector_data: node.vector.as_f32(),
278                connections: node
279                    .connections
280                    .iter()
281                    .map(|set| set.iter().copied().collect())
282                    .collect(),
283                level: node.level(),
284            })
285            .collect();
286
287        oxicode::serde::encode_to_vec(&serializable_nodes, oxicode::config::standard())
288            .map_err(|e| anyhow!("Failed to serialize nodes: {}", e))
289    }
290
291    /// Deserialize nodes from bytes
292    fn deserialize_nodes(&self, data: &[u8], index: &mut HnswIndex) -> Result<()> {
293        let (serializable_nodes, _): (Vec<SerializableNode>, _) =
294            oxicode::serde::decode_from_slice(data, oxicode::config::standard())
295                .map_err(|e| anyhow!("Failed to deserialize nodes: {}", e))?;
296
297        for node_data in serializable_nodes {
298            let vector = Vector::new(node_data.vector_data);
299            let mut node = crate::hnsw::Node::new(node_data.uri, vector, node_data.level);
300
301            // Restore connections
302            for (level, connections) in node_data.connections.into_iter().enumerate() {
303                for conn_id in connections {
304                    node.add_connection(level, conn_id);
305                }
306            }
307
308            index.nodes_mut().push(node);
309        }
310
311        Ok(())
312    }
313
314    /// Validate index integrity
315    fn validate_index(&self, index: &HnswIndex) -> Result<()> {
316        tracing::debug!("Validating index integrity");
317
318        // Check that all connections are valid
319        for (node_id, node) in index.nodes().iter().enumerate() {
320            for level in 0..=node.level() {
321                if let Some(connections) = node.get_connections(level) {
322                    for &conn_id in connections {
323                        if conn_id >= index.len() {
324                            return Err(anyhow!(
325                                "Invalid connection: node {} has connection to non-existent node {}",
326                                node_id,
327                                conn_id
328                            ));
329                        }
330                    }
331                }
332            }
333        }
334
335        // Check URI mapping consistency
336        for (uri, &node_id) in index.uri_to_id() {
337            if node_id >= index.len() {
338                return Err(anyhow!(
339                    "Invalid URI mapping: {} points to non-existent node {}",
340                    uri,
341                    node_id
342                ));
343            }
344
345            let actual_uri = &index.nodes()[node_id].uri;
346            if uri != actual_uri {
347                return Err(anyhow!(
348                    "URI mapping mismatch: expected '{}', found '{}'",
349                    uri,
350                    actual_uri
351                ));
352            }
353        }
354
355        // Check entry point
356        if let Some(entry_id) = index.entry_point() {
357            if entry_id >= index.len() {
358                return Err(anyhow!(
359                    "Invalid entry point: {} (index has {} nodes)",
360                    entry_id,
361                    index.len()
362                ));
363            }
364        }
365
366        tracing::debug!("Index validation passed");
367        Ok(())
368    }
369
370    /// Create a snapshot of the index
371    pub fn create_snapshot<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
372        let path = path.as_ref();
373        let snapshot_path = path.with_extension(format!(
374            "snapshot.{}",
375            std::time::SystemTime::now()
376                .duration_since(std::time::UNIX_EPOCH)
377                .expect("SystemTime should be after UNIX_EPOCH")
378                .as_secs()
379        ));
380
381        self.save_index(index, snapshot_path)?;
382        Ok(())
383    }
384
385    /// Estimate compressed size
386    pub fn estimate_compressed_size(&self, index: &HnswIndex) -> Result<usize> {
387        let nodes = self.serialize_nodes(index)?;
388
389        let compressed_size = match self.config.compression {
390            CompressionAlgorithm::None => nodes.len(),
391            CompressionAlgorithm::Zstd { level } => oxiarc_zstd::encode_all(&nodes, level)
392                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
393                .len(),
394            CompressionAlgorithm::ZstdMax => oxiarc_zstd::encode_all(&nodes, 21)
395                .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
396                .len(),
397        };
398
399        Ok(compressed_size)
400    }
401}
402
403/// Incremental persistence manager
404pub struct IncrementalPersistence {
405    config: PersistenceConfig,
406    operation_count: usize,
407    last_checkpoint: std::time::Instant,
408}
409
410impl IncrementalPersistence {
411    pub fn new(config: PersistenceConfig) -> Self {
412        Self {
413            config,
414            operation_count: 0,
415            last_checkpoint: std::time::Instant::now(),
416        }
417    }
418
419    /// Record an operation
420    pub fn record_operation(&mut self) {
421        self.operation_count += 1;
422    }
423
424    /// Check if checkpoint is needed
425    pub fn needs_checkpoint(&self) -> bool {
426        self.operation_count >= self.config.checkpoint_interval
427    }
428
429    /// Create checkpoint
430    pub fn checkpoint<P: AsRef<Path>>(&mut self, index: &HnswIndex, base_path: P) -> Result<()> {
431        if !self.needs_checkpoint() {
432            return Ok(());
433        }
434
435        let manager = PersistenceManager::new(self.config.clone());
436        manager.create_snapshot(index, base_path)?;
437
438        self.operation_count = 0;
439        self.last_checkpoint = std::time::Instant::now();
440
441        Ok(())
442    }
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::hnsw::HnswConfig;
449    use crate::Vector;
450    use anyhow::Result;
451    use std::env::temp_dir;
452
453    #[test]
454    fn test_save_and_load_index() -> Result<()> {
455        let config = HnswConfig::default();
456        let mut index = HnswIndex::new(config)?;
457
458        // Add some vectors
459        for i in 0..10 {
460            let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
461            index.add_vector(format!("vec_{}", i), vec)?;
462        }
463
464        // Save index
465        let mut temp_path = temp_dir();
466        temp_path.push("test_hnsw_index.bin");
467
468        let persistence_config = PersistenceConfig::default();
469        let manager = PersistenceManager::new(persistence_config);
470
471        manager.save_index(&index, &temp_path)?;
472
473        // Load index
474        let loaded_index = manager.load_index(&temp_path)?;
475
476        assert_eq!(loaded_index.len(), 10);
477        assert_eq!(loaded_index.uri_to_id().len(), 10);
478
479        // Cleanup
480        std::fs::remove_file(temp_path).ok();
481        Ok(())
482    }
483
484    #[test]
485    fn test_compression() -> Result<()> {
486        let config = HnswConfig::default();
487        let mut index = HnswIndex::new(config)?;
488
489        // Add vectors
490        for i in 0..50 {
491            let vec = Vector::new(vec![i as f32; 128]);
492            index.add_vector(format!("vec_{}", i), vec)?;
493        }
494
495        let mut temp_path = temp_dir();
496        temp_path.push("test_compressed_index.bin");
497
498        // Test with compression
499        let compressed_config = PersistenceConfig {
500            compression: CompressionAlgorithm::Zstd { level: 3 },
501            ..Default::default()
502        };
503        let compressed_manager = PersistenceManager::new(compressed_config);
504        compressed_manager.save_index(&index, &temp_path)?;
505
506        let compressed_size = std::fs::metadata(&temp_path)?.len();
507
508        // Test without compression
509        let uncompressed_config = PersistenceConfig {
510            compression: CompressionAlgorithm::None,
511            ..Default::default()
512        };
513        let uncompressed_manager = PersistenceManager::new(uncompressed_config);
514
515        let mut temp_path2 = temp_dir();
516        temp_path2.push("test_uncompressed_index.bin");
517        uncompressed_manager.save_index(&index, &temp_path2)?;
518
519        let uncompressed_size = std::fs::metadata(&temp_path2)?.len();
520
521        // Compressed should be smaller
522        assert!(compressed_size < uncompressed_size);
523
524        // Cleanup
525        std::fs::remove_file(temp_path).ok();
526        std::fs::remove_file(temp_path2).ok();
527        Ok(())
528    }
529
530    #[test]
531    fn test_validation() -> Result<()> {
532        let config = HnswConfig::default();
533        let mut index = HnswIndex::new(config)?;
534
535        for i in 0..5 {
536            let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
537            index.add_vector(format!("vec_{}", i), vec)?;
538        }
539
540        let persistence_config = PersistenceConfig {
541            validate_on_load: true,
542            ..Default::default()
543        };
544        let manager = PersistenceManager::new(persistence_config);
545
546        // Validation should pass
547        manager.validate_index(&index)?;
548        Ok(())
549    }
550}
551
552// Sub-module: dependency-free binary snapshot
553pub mod snapshot;
554pub use snapshot::{IndexSnapshot, SnapshotHeader};