1use crate::hnsw::{HnswConfig, HnswIndex};
7use crate::Vector;
8use anyhow::{anyhow, Result};
9use bincode::{Decode, Encode};
10use serde::{Deserialize, Serialize};
11use std::fs::{File, OpenOptions};
12use std::io::{BufReader, BufWriter, Read, Write};
13use std::path::Path;
14
15const PERSISTENCE_VERSION: u32 = 1;
17
18const MAGIC_NUMBER: &[u8; 4] = b"OxVe";
20
21#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize, Encode, Decode)]
23pub enum CompressionAlgorithm {
24 None,
26 Zstd { level: i32 },
28 ZstdMax,
30}
31
32impl Default for CompressionAlgorithm {
33 fn default() -> Self {
34 Self::Zstd { level: 3 } }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
40pub struct PersistenceConfig {
41 pub compression: CompressionAlgorithm,
43 pub include_metadata: bool,
45 pub validate_on_load: bool,
47 pub incremental: bool,
49 pub checkpoint_interval: usize,
51}
52
53impl Default for PersistenceConfig {
54 fn default() -> Self {
55 Self {
56 compression: CompressionAlgorithm::default(),
57 include_metadata: true,
58 validate_on_load: true,
59 incremental: false,
60 checkpoint_interval: 10000,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
67struct IndexHeader {
68 version: u32,
69 compression: CompressionAlgorithm,
70 node_count: usize,
71 dimension: usize,
72 config: HnswConfig,
73 timestamp: u64,
74 checksum: u64,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
79struct SerializableNode {
80 uri: String,
81 vector_data: Vec<f32>,
82 connections: Vec<Vec<usize>>,
83 level: usize,
84}
85
86pub struct PersistenceManager {
88 config: PersistenceConfig,
89}
90
91impl PersistenceManager {
92 pub fn new(config: PersistenceConfig) -> Self {
94 Self { config }
95 }
96
97 pub fn save_index<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
99 let path = path.as_ref();
100 tracing::info!("Saving HNSW index to {:?}", path);
101
102 let file = OpenOptions::new()
103 .write(true)
104 .create(true)
105 .truncate(true)
106 .open(path)?;
107
108 let mut writer = BufWriter::new(file);
109
110 writer.write_all(MAGIC_NUMBER)?;
112
113 let header = IndexHeader {
115 version: PERSISTENCE_VERSION,
116 compression: self.config.compression,
117 node_count: index.len(),
118 dimension: if let Some(node) = index.nodes().first() {
119 node.vector.dimensions
120 } else {
121 0
122 },
123 config: index.config().clone(),
124 timestamp: std::time::SystemTime::now()
125 .duration_since(std::time::UNIX_EPOCH)
126 .unwrap()
127 .as_secs(),
128 checksum: 0, };
130
131 let header_bytes = bincode::encode_to_vec(&header, bincode::config::standard())
133 .map_err(|e| anyhow!("Failed to serialize header: {}", e))?;
134 let header_len = header_bytes.len() as u32;
135 writer.write_all(&header_len.to_le_bytes())?;
136 writer.write_all(&header_bytes)?;
137
138 let nodes = self.serialize_nodes(index)?;
140
141 let data = match self.config.compression {
143 CompressionAlgorithm::None => nodes,
144 CompressionAlgorithm::Zstd { level } => zstd::encode_all(&nodes[..], level)?,
145 CompressionAlgorithm::ZstdMax => zstd::encode_all(&nodes[..], 21)?,
146 };
147
148 let data_len = data.len() as u64;
150 writer.write_all(&data_len.to_le_bytes())?;
151 writer.write_all(&data)?;
152
153 let uri_mapping = bincode::encode_to_vec(&index.uri_to_id(), bincode::config::standard())
155 .map_err(|e| anyhow!("Failed to serialize URI mapping: {}", e))?;
156 let mapping_len = uri_mapping.len() as u32;
157 writer.write_all(&mapping_len.to_le_bytes())?;
158 writer.write_all(&uri_mapping)?;
159
160 let entry_point = bincode::encode_to_vec(&index.entry_point(), bincode::config::standard())
162 .map_err(|e| anyhow!("Failed to serialize entry point: {}", e))?;
163 writer.write_all(&entry_point)?;
164
165 writer.flush()?;
166
167 tracing::info!(
168 "Successfully saved HNSW index with {} nodes (compression: {:?})",
169 index.len(),
170 self.config.compression
171 );
172
173 Ok(())
174 }
175
176 pub fn load_index<P: AsRef<Path>>(&self, path: P) -> Result<HnswIndex> {
178 let path = path.as_ref();
179 tracing::info!("Loading HNSW index from {:?}", path);
180
181 let file = File::open(path)?;
182 let mut reader = BufReader::new(file);
183
184 let mut magic = [0u8; 4];
186 reader.read_exact(&mut magic)?;
187 if &magic != MAGIC_NUMBER {
188 return Err(anyhow!("Invalid index file format"));
189 }
190
191 let mut header_len_bytes = [0u8; 4];
193 reader.read_exact(&mut header_len_bytes)?;
194 let header_len = u32::from_le_bytes(header_len_bytes) as usize;
195
196 let mut header_bytes = vec![0u8; header_len];
197 reader.read_exact(&mut header_bytes)?;
198 let (header, _): (IndexHeader, _) =
199 bincode::decode_from_slice(&header_bytes, bincode::config::standard())
200 .map_err(|e| anyhow!("Failed to deserialize header: {}", e))?;
201
202 if header.version != PERSISTENCE_VERSION {
204 return Err(anyhow!(
205 "Unsupported index version: {} (expected {})",
206 header.version,
207 PERSISTENCE_VERSION
208 ));
209 }
210
211 let mut data_len_bytes = [0u8; 8];
213 reader.read_exact(&mut data_len_bytes)?;
214 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
215
216 let mut compressed_data = vec![0u8; data_len];
218 reader.read_exact(&mut compressed_data)?;
219
220 let nodes_data = match header.compression {
221 CompressionAlgorithm::None => compressed_data,
222 CompressionAlgorithm::Zstd { .. } | CompressionAlgorithm::ZstdMax => {
223 zstd::decode_all(&compressed_data[..])?
224 }
225 };
226
227 let mut mapping_len_bytes = [0u8; 4];
229 reader.read_exact(&mut mapping_len_bytes)?;
230 let mapping_len = u32::from_le_bytes(mapping_len_bytes) as usize;
231
232 let mut mapping_bytes = vec![0u8; mapping_len];
233 reader.read_exact(&mut mapping_bytes)?;
234 let (uri_mapping, _): (std::collections::HashMap<String, usize>, _) =
235 bincode::decode_from_slice(&mapping_bytes, bincode::config::standard())
236 .map_err(|e| anyhow!("Failed to deserialize URI mapping: {}", e))?;
237
238 let mut entry_point_bytes = Vec::new();
240 reader.read_to_end(&mut entry_point_bytes)?;
241 let (entry_point, _): (Option<usize>, _) =
242 bincode::decode_from_slice(&entry_point_bytes, bincode::config::standard())
243 .map_err(|e| anyhow!("Failed to deserialize entry point: {}", e))?;
244
245 let mut index = HnswIndex::new(header.config)?;
247 self.deserialize_nodes(&nodes_data, &mut index)?;
248
249 *index.uri_to_id_mut() = uri_mapping;
251
252 index.set_entry_point(entry_point);
254
255 if self.config.validate_on_load {
257 self.validate_index(&index)?;
258 }
259
260 tracing::info!("Successfully loaded HNSW index with {} nodes", index.len());
261
262 Ok(index)
263 }
264
265 fn serialize_nodes(&self, index: &HnswIndex) -> Result<Vec<u8>> {
267 let serializable_nodes: Vec<SerializableNode> = index
268 .nodes()
269 .iter()
270 .map(|node| SerializableNode {
271 uri: node.uri.clone(),
272 vector_data: node.vector.as_f32(),
273 connections: node
274 .connections
275 .iter()
276 .map(|set| set.iter().copied().collect())
277 .collect(),
278 level: node.level(),
279 })
280 .collect();
281
282 Ok(
283 bincode::encode_to_vec(&serializable_nodes, bincode::config::standard())
284 .map_err(|e| anyhow!("Failed to serialize nodes: {}", e))?,
285 )
286 }
287
288 fn deserialize_nodes(&self, data: &[u8], index: &mut HnswIndex) -> Result<()> {
290 let (serializable_nodes, _): (Vec<SerializableNode>, _) =
291 bincode::decode_from_slice(data, bincode::config::standard())
292 .map_err(|e| anyhow!("Failed to deserialize nodes: {}", e))?;
293
294 for node_data in serializable_nodes {
295 let vector = Vector::new(node_data.vector_data);
296 let mut node = crate::hnsw::Node::new(node_data.uri, vector, node_data.level);
297
298 for (level, connections) in node_data.connections.into_iter().enumerate() {
300 for conn_id in connections {
301 node.add_connection(level, conn_id);
302 }
303 }
304
305 index.nodes_mut().push(node);
306 }
307
308 Ok(())
309 }
310
311 fn validate_index(&self, index: &HnswIndex) -> Result<()> {
313 tracing::debug!("Validating index integrity");
314
315 for (node_id, node) in index.nodes().iter().enumerate() {
317 for level in 0..=node.level() {
318 if let Some(connections) = node.get_connections(level) {
319 for &conn_id in connections {
320 if conn_id >= index.len() {
321 return Err(anyhow!(
322 "Invalid connection: node {} has connection to non-existent node {}",
323 node_id,
324 conn_id
325 ));
326 }
327 }
328 }
329 }
330 }
331
332 for (uri, &node_id) in index.uri_to_id() {
334 if node_id >= index.len() {
335 return Err(anyhow!(
336 "Invalid URI mapping: {} points to non-existent node {}",
337 uri,
338 node_id
339 ));
340 }
341
342 let actual_uri = &index.nodes()[node_id].uri;
343 if uri != actual_uri {
344 return Err(anyhow!(
345 "URI mapping mismatch: expected '{}', found '{}'",
346 uri,
347 actual_uri
348 ));
349 }
350 }
351
352 if let Some(entry_id) = index.entry_point() {
354 if entry_id >= index.len() {
355 return Err(anyhow!(
356 "Invalid entry point: {} (index has {} nodes)",
357 entry_id,
358 index.len()
359 ));
360 }
361 }
362
363 tracing::debug!("Index validation passed");
364 Ok(())
365 }
366
367 pub fn create_snapshot<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
369 let path = path.as_ref();
370 let snapshot_path = path.with_extension(format!(
371 "snapshot.{}",
372 std::time::SystemTime::now()
373 .duration_since(std::time::UNIX_EPOCH)
374 .unwrap()
375 .as_secs()
376 ));
377
378 self.save_index(index, snapshot_path)?;
379 Ok(())
380 }
381
382 pub fn estimate_compressed_size(&self, index: &HnswIndex) -> Result<usize> {
384 let nodes = self.serialize_nodes(index)?;
385
386 let compressed_size = match self.config.compression {
387 CompressionAlgorithm::None => nodes.len(),
388 CompressionAlgorithm::Zstd { level } => zstd::encode_all(&nodes[..], level)?.len(),
389 CompressionAlgorithm::ZstdMax => zstd::encode_all(&nodes[..], 21)?.len(),
390 };
391
392 Ok(compressed_size)
393 }
394}
395
396pub struct IncrementalPersistence {
398 config: PersistenceConfig,
399 operation_count: usize,
400 last_checkpoint: std::time::Instant,
401}
402
403impl IncrementalPersistence {
404 pub fn new(config: PersistenceConfig) -> Self {
405 Self {
406 config,
407 operation_count: 0,
408 last_checkpoint: std::time::Instant::now(),
409 }
410 }
411
412 pub fn record_operation(&mut self) {
414 self.operation_count += 1;
415 }
416
417 pub fn needs_checkpoint(&self) -> bool {
419 self.operation_count >= self.config.checkpoint_interval
420 }
421
422 pub fn checkpoint<P: AsRef<Path>>(&mut self, index: &HnswIndex, base_path: P) -> Result<()> {
424 if !self.needs_checkpoint() {
425 return Ok(());
426 }
427
428 let manager = PersistenceManager::new(self.config.clone());
429 manager.create_snapshot(index, base_path)?;
430
431 self.operation_count = 0;
432 self.last_checkpoint = std::time::Instant::now();
433
434 Ok(())
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::hnsw::HnswConfig;
442 use crate::Vector;
443 use std::env::temp_dir;
444
445 #[test]
446 fn test_save_and_load_index() {
447 let config = HnswConfig::default();
448 let mut index = HnswIndex::new(config).unwrap();
449
450 for i in 0..10 {
452 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
453 index.add_vector(format!("vec_{}", i), vec).unwrap();
454 }
455
456 let mut temp_path = temp_dir();
458 temp_path.push("test_hnsw_index.bin");
459
460 let persistence_config = PersistenceConfig::default();
461 let manager = PersistenceManager::new(persistence_config);
462
463 manager.save_index(&index, &temp_path).unwrap();
464
465 let loaded_index = manager.load_index(&temp_path).unwrap();
467
468 assert_eq!(loaded_index.len(), 10);
469 assert_eq!(loaded_index.uri_to_id().len(), 10);
470
471 std::fs::remove_file(temp_path).ok();
473 }
474
475 #[test]
476 fn test_compression() {
477 let config = HnswConfig::default();
478 let mut index = HnswIndex::new(config).unwrap();
479
480 for i in 0..50 {
482 let vec = Vector::new(vec![i as f32; 128]);
483 index.add_vector(format!("vec_{}", i), vec).unwrap();
484 }
485
486 let mut temp_path = temp_dir();
487 temp_path.push("test_compressed_index.bin");
488
489 let compressed_config = PersistenceConfig {
491 compression: CompressionAlgorithm::Zstd { level: 3 },
492 ..Default::default()
493 };
494 let compressed_manager = PersistenceManager::new(compressed_config);
495 compressed_manager.save_index(&index, &temp_path).unwrap();
496
497 let compressed_size = std::fs::metadata(&temp_path).unwrap().len();
498
499 let uncompressed_config = PersistenceConfig {
501 compression: CompressionAlgorithm::None,
502 ..Default::default()
503 };
504 let uncompressed_manager = PersistenceManager::new(uncompressed_config);
505
506 let mut temp_path2 = temp_dir();
507 temp_path2.push("test_uncompressed_index.bin");
508 uncompressed_manager
509 .save_index(&index, &temp_path2)
510 .unwrap();
511
512 let uncompressed_size = std::fs::metadata(&temp_path2).unwrap().len();
513
514 assert!(compressed_size < uncompressed_size);
516
517 std::fs::remove_file(temp_path).ok();
519 std::fs::remove_file(temp_path2).ok();
520 }
521
522 #[test]
523 fn test_validation() {
524 let config = HnswConfig::default();
525 let mut index = HnswIndex::new(config).unwrap();
526
527 for i in 0..5 {
528 let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
529 index.add_vector(format!("vec_{}", i), vec).unwrap();
530 }
531
532 let persistence_config = PersistenceConfig {
533 validate_on_load: true,
534 ..Default::default()
535 };
536 let manager = PersistenceManager::new(persistence_config);
537
538 manager.validate_index(&index).unwrap();
540 }
541}