1use crate::hnsw::{HnswConfig, HnswIndex};
7use crate::Vector;
8use anyhow::{anyhow, Result};
9use oxicode::{Decode, Encode};
10use serde::{Deserialize, Serialize};
11use std::fs::{File, OpenOptions};
12use std::io::{BufReader, BufWriter, Read, Write};
13use std::path::Path;
14
15const PERSISTENCE_VERSION: u32 = 1;
17
18const MAGIC_NUMBER: &[u8; 4] = b"OxVe";
20
21#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Encode, Decode)]
23pub enum CompressionAlgorithm {
24 None,
26 Zstd { level: i32 },
28 ZstdMax,
30}
31
32impl Default for CompressionAlgorithm {
33 fn default() -> Self {
34 Self::Zstd { level: 3 } }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
40pub struct PersistenceConfig {
41 pub compression: CompressionAlgorithm,
43 pub include_metadata: bool,
45 pub validate_on_load: bool,
47 pub incremental: bool,
49 pub checkpoint_interval: usize,
51}
52
53impl Default for PersistenceConfig {
54 fn default() -> Self {
55 Self {
56 compression: CompressionAlgorithm::default(),
57 include_metadata: true,
58 validate_on_load: true,
59 incremental: false,
60 checkpoint_interval: 10000,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
67struct IndexHeader {
68 version: u32,
69 compression: CompressionAlgorithm,
70 node_count: usize,
71 dimension: usize,
72 config: HnswConfig,
73 timestamp: u64,
74 checksum: u64,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
79struct SerializableNode {
80 uri: String,
81 vector_data: Vec<f32>,
82 connections: Vec<Vec<usize>>,
83 level: usize,
84}
85
86pub struct PersistenceManager {
88 config: PersistenceConfig,
89}
90
91impl PersistenceManager {
92 pub fn new(config: PersistenceConfig) -> Self {
94 Self { config }
95 }
96
97 pub fn save_index<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
99 let path = path.as_ref();
100 tracing::info!("Saving HNSW index to {:?}", path);
101
102 let file = OpenOptions::new()
103 .write(true)
104 .create(true)
105 .truncate(true)
106 .open(path)?;
107
108 let mut writer = BufWriter::new(file);
109
110 writer.write_all(MAGIC_NUMBER)?;
112
113 let header = IndexHeader {
115 version: PERSISTENCE_VERSION,
116 compression: self.config.compression,
117 node_count: index.len(),
118 dimension: if let Some(node) = index.nodes().first() {
119 node.vector.dimensions
120 } else {
121 0
122 },
123 config: index.config().clone(),
124 timestamp: std::time::SystemTime::now()
125 .duration_since(std::time::UNIX_EPOCH)
126 .expect("SystemTime should be after UNIX_EPOCH")
127 .as_secs(),
128 checksum: 0, };
130
131 let header_bytes = oxicode::serde::encode_to_vec(&header, oxicode::config::standard())
133 .map_err(|e| anyhow!("Failed to serialize header: {}", e))?;
134 let header_len = header_bytes.len() as u32;
135 writer.write_all(&header_len.to_le_bytes())?;
136 writer.write_all(&header_bytes)?;
137
138 let nodes = self.serialize_nodes(index)?;
140
141 let data = match self.config.compression {
143 CompressionAlgorithm::None => nodes,
144 CompressionAlgorithm::Zstd { level } => oxiarc_zstd::encode_all(&nodes, level)
145 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?,
146 CompressionAlgorithm::ZstdMax => oxiarc_zstd::encode_all(&nodes, 21)
147 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?,
148 };
149
150 let data_len = data.len() as u64;
152 writer.write_all(&data_len.to_le_bytes())?;
153 writer.write_all(&data)?;
154
155 let uri_mapping =
157 oxicode::serde::encode_to_vec(index.uri_to_id(), oxicode::config::standard())
158 .map_err(|e| anyhow!("Failed to serialize URI mapping: {}", e))?;
159 let mapping_len = uri_mapping.len() as u32;
160 writer.write_all(&mapping_len.to_le_bytes())?;
161 writer.write_all(&uri_mapping)?;
162
163 let entry_point =
165 oxicode::serde::encode_to_vec(&index.entry_point(), oxicode::config::standard())
166 .map_err(|e| anyhow!("Failed to serialize entry point: {}", e))?;
167 writer.write_all(&entry_point)?;
168
169 writer.flush()?;
170
171 tracing::info!(
172 "Successfully saved HNSW index with {} nodes (compression: {:?})",
173 index.len(),
174 self.config.compression
175 );
176
177 Ok(())
178 }
179
180 pub fn load_index<P: AsRef<Path>>(&self, path: P) -> Result<HnswIndex> {
182 let path = path.as_ref();
183 tracing::info!("Loading HNSW index from {:?}", path);
184
185 let file = File::open(path)?;
186 let mut reader = BufReader::new(file);
187
188 let mut magic = [0u8; 4];
190 reader.read_exact(&mut magic)?;
191 if &magic != MAGIC_NUMBER {
192 return Err(anyhow!("Invalid index file format"));
193 }
194
195 let mut header_len_bytes = [0u8; 4];
197 reader.read_exact(&mut header_len_bytes)?;
198 let header_len = u32::from_le_bytes(header_len_bytes) as usize;
199
200 let mut header_bytes = vec![0u8; header_len];
201 reader.read_exact(&mut header_bytes)?;
202 let (header, _): (IndexHeader, _) =
203 oxicode::serde::decode_from_slice(&header_bytes, oxicode::config::standard())
204 .map_err(|e| anyhow!("Failed to deserialize header: {}", e))?;
205
206 if header.version != PERSISTENCE_VERSION {
208 return Err(anyhow!(
209 "Unsupported index version: {} (expected {})",
210 header.version,
211 PERSISTENCE_VERSION
212 ));
213 }
214
215 let mut data_len_bytes = [0u8; 8];
217 reader.read_exact(&mut data_len_bytes)?;
218 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
219
220 let mut compressed_data = vec![0u8; data_len];
222 reader.read_exact(&mut compressed_data)?;
223
224 let nodes_data = match header.compression {
225 CompressionAlgorithm::None => compressed_data,
226 CompressionAlgorithm::Zstd { .. } | CompressionAlgorithm::ZstdMax => {
227 oxiarc_zstd::decode_all(&compressed_data)
228 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
229 }
230 };
231
232 let mut mapping_len_bytes = [0u8; 4];
234 reader.read_exact(&mut mapping_len_bytes)?;
235 let mapping_len = u32::from_le_bytes(mapping_len_bytes) as usize;
236
237 let mut mapping_bytes = vec![0u8; mapping_len];
238 reader.read_exact(&mut mapping_bytes)?;
239 let (uri_mapping, _): (std::collections::HashMap<String, usize>, _) =
240 oxicode::serde::decode_from_slice(&mapping_bytes, oxicode::config::standard())
241 .map_err(|e| anyhow!("Failed to deserialize URI mapping: {}", e))?;
242
243 let mut entry_point_bytes = Vec::new();
245 reader.read_to_end(&mut entry_point_bytes)?;
246 let (entry_point, _): (Option<usize>, _) =
247 oxicode::serde::decode_from_slice(&entry_point_bytes, oxicode::config::standard())
248 .map_err(|e| anyhow!("Failed to deserialize entry point: {}", e))?;
249
250 let mut index = HnswIndex::new(header.config)?;
252 self.deserialize_nodes(&nodes_data, &mut index)?;
253
254 *index.uri_to_id_mut() = uri_mapping;
256
257 index.set_entry_point(entry_point);
259
260 if self.config.validate_on_load {
262 self.validate_index(&index)?;
263 }
264
265 tracing::info!("Successfully loaded HNSW index with {} nodes", index.len());
266
267 Ok(index)
268 }
269
270 fn serialize_nodes(&self, index: &HnswIndex) -> Result<Vec<u8>> {
272 let serializable_nodes: Vec<SerializableNode> = index
273 .nodes()
274 .iter()
275 .map(|node| SerializableNode {
276 uri: node.uri.clone(),
277 vector_data: node.vector.as_f32(),
278 connections: node
279 .connections
280 .iter()
281 .map(|set| set.iter().copied().collect())
282 .collect(),
283 level: node.level(),
284 })
285 .collect();
286
287 oxicode::serde::encode_to_vec(&serializable_nodes, oxicode::config::standard())
288 .map_err(|e| anyhow!("Failed to serialize nodes: {}", e))
289 }
290
291 fn deserialize_nodes(&self, data: &[u8], index: &mut HnswIndex) -> Result<()> {
293 let (serializable_nodes, _): (Vec<SerializableNode>, _) =
294 oxicode::serde::decode_from_slice(data, oxicode::config::standard())
295 .map_err(|e| anyhow!("Failed to deserialize nodes: {}", e))?;
296
297 for node_data in serializable_nodes {
298 let vector = Vector::new(node_data.vector_data);
299 let mut node = crate::hnsw::Node::new(node_data.uri, vector, node_data.level);
300
301 for (level, connections) in node_data.connections.into_iter().enumerate() {
303 for conn_id in connections {
304 node.add_connection(level, conn_id);
305 }
306 }
307
308 index.nodes_mut().push(node);
309 }
310
311 Ok(())
312 }
313
314 fn validate_index(&self, index: &HnswIndex) -> Result<()> {
316 tracing::debug!("Validating index integrity");
317
318 for (node_id, node) in index.nodes().iter().enumerate() {
320 for level in 0..=node.level() {
321 if let Some(connections) = node.get_connections(level) {
322 for &conn_id in connections {
323 if conn_id >= index.len() {
324 return Err(anyhow!(
325 "Invalid connection: node {} has connection to non-existent node {}",
326 node_id,
327 conn_id
328 ));
329 }
330 }
331 }
332 }
333 }
334
335 for (uri, &node_id) in index.uri_to_id() {
337 if node_id >= index.len() {
338 return Err(anyhow!(
339 "Invalid URI mapping: {} points to non-existent node {}",
340 uri,
341 node_id
342 ));
343 }
344
345 let actual_uri = &index.nodes()[node_id].uri;
346 if uri != actual_uri {
347 return Err(anyhow!(
348 "URI mapping mismatch: expected '{}', found '{}'",
349 uri,
350 actual_uri
351 ));
352 }
353 }
354
355 if let Some(entry_id) = index.entry_point() {
357 if entry_id >= index.len() {
358 return Err(anyhow!(
359 "Invalid entry point: {} (index has {} nodes)",
360 entry_id,
361 index.len()
362 ));
363 }
364 }
365
366 tracing::debug!("Index validation passed");
367 Ok(())
368 }
369
370 pub fn create_snapshot<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
372 let path = path.as_ref();
373 let snapshot_path = path.with_extension(format!(
374 "snapshot.{}",
375 std::time::SystemTime::now()
376 .duration_since(std::time::UNIX_EPOCH)
377 .expect("SystemTime should be after UNIX_EPOCH")
378 .as_secs()
379 ));
380
381 self.save_index(index, snapshot_path)?;
382 Ok(())
383 }
384
385 pub fn estimate_compressed_size(&self, index: &HnswIndex) -> Result<usize> {
387 let nodes = self.serialize_nodes(index)?;
388
389 let compressed_size = match self.config.compression {
390 CompressionAlgorithm::None => nodes.len(),
391 CompressionAlgorithm::Zstd { level } => oxiarc_zstd::encode_all(&nodes, level)
392 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
393 .len(),
394 CompressionAlgorithm::ZstdMax => oxiarc_zstd::encode_all(&nodes, 21)
395 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
396 .len(),
397 };
398
399 Ok(compressed_size)
400 }
401}
402
403pub struct IncrementalPersistence {
405 config: PersistenceConfig,
406 operation_count: usize,
407 last_checkpoint: std::time::Instant,
408}
409
410impl IncrementalPersistence {
411 pub fn new(config: PersistenceConfig) -> Self {
412 Self {
413 config,
414 operation_count: 0,
415 last_checkpoint: std::time::Instant::now(),
416 }
417 }
418
419 pub fn record_operation(&mut self) {
421 self.operation_count += 1;
422 }
423
424 pub fn needs_checkpoint(&self) -> bool {
426 self.operation_count >= self.config.checkpoint_interval
427 }
428
429 pub fn checkpoint<P: AsRef<Path>>(&mut self, index: &HnswIndex, base_path: P) -> Result<()> {
431 if !self.needs_checkpoint() {
432 return Ok(());
433 }
434
435 let manager = PersistenceManager::new(self.config.clone());
436 manager.create_snapshot(index, base_path)?;
437
438 self.operation_count = 0;
439 self.last_checkpoint = std::time::Instant::now();
440
441 Ok(())
442 }
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448 use crate::hnsw::HnswConfig;
449 use crate::Vector;
450 use std::env::temp_dir;
451
452 #[test]
453 fn test_save_and_load_index() {
454 let config = HnswConfig::default();
455 let mut index = HnswIndex::new(config).unwrap();
456
457 for i in 0..10 {
459 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
460 index.add_vector(format!("vec_{}", i), vec).unwrap();
461 }
462
463 let mut temp_path = temp_dir();
465 temp_path.push("test_hnsw_index.bin");
466
467 let persistence_config = PersistenceConfig::default();
468 let manager = PersistenceManager::new(persistence_config);
469
470 manager.save_index(&index, &temp_path).unwrap();
471
472 let loaded_index = manager.load_index(&temp_path).unwrap();
474
475 assert_eq!(loaded_index.len(), 10);
476 assert_eq!(loaded_index.uri_to_id().len(), 10);
477
478 std::fs::remove_file(temp_path).ok();
480 }
481
482 #[test]
483 fn test_compression() {
484 let config = HnswConfig::default();
485 let mut index = HnswIndex::new(config).unwrap();
486
487 for i in 0..50 {
489 let vec = Vector::new(vec![i as f32; 128]);
490 index.add_vector(format!("vec_{}", i), vec).unwrap();
491 }
492
493 let mut temp_path = temp_dir();
494 temp_path.push("test_compressed_index.bin");
495
496 let compressed_config = PersistenceConfig {
498 compression: CompressionAlgorithm::Zstd { level: 3 },
499 ..Default::default()
500 };
501 let compressed_manager = PersistenceManager::new(compressed_config);
502 compressed_manager.save_index(&index, &temp_path).unwrap();
503
504 let compressed_size = std::fs::metadata(&temp_path).unwrap().len();
505
506 let uncompressed_config = PersistenceConfig {
508 compression: CompressionAlgorithm::None,
509 ..Default::default()
510 };
511 let uncompressed_manager = PersistenceManager::new(uncompressed_config);
512
513 let mut temp_path2 = temp_dir();
514 temp_path2.push("test_uncompressed_index.bin");
515 uncompressed_manager
516 .save_index(&index, &temp_path2)
517 .unwrap();
518
519 let uncompressed_size = std::fs::metadata(&temp_path2).unwrap().len();
520
521 assert!(compressed_size < uncompressed_size);
523
524 std::fs::remove_file(temp_path).ok();
526 std::fs::remove_file(temp_path2).ok();
527 }
528
529 #[test]
530 fn test_validation() {
531 let config = HnswConfig::default();
532 let mut index = HnswIndex::new(config).unwrap();
533
534 for i in 0..5 {
535 let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
536 index.add_vector(format!("vec_{}", i), vec).unwrap();
537 }
538
539 let persistence_config = PersistenceConfig {
540 validate_on_load: true,
541 ..Default::default()
542 };
543 let manager = PersistenceManager::new(persistence_config);
544
545 manager.validate_index(&index).unwrap();
547 }
548}
549
550pub mod snapshot;
552pub use snapshot::{IndexSnapshot, SnapshotHeader};