1use crate::hnsw::{HnswConfig, HnswIndex};
7use crate::Vector;
8use anyhow::{anyhow, Result};
9use serde::{Deserialize, Serialize};
10use std::fs::{File, OpenOptions};
11use std::io::{BufReader, BufWriter, Read, Write};
12use std::path::Path;
13
14const PERSISTENCE_VERSION: u32 = 1;
16
17const MAGIC_NUMBER: &[u8; 4] = b"OxVe";
19
20#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
22pub enum CompressionAlgorithm {
23 None,
25 Zstd { level: i32 },
27 ZstdMax,
29}
30
31impl Default for CompressionAlgorithm {
32 fn default() -> Self {
33 Self::Zstd { level: 3 } }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct PersistenceConfig {
40 pub compression: CompressionAlgorithm,
42 pub include_metadata: bool,
44 pub validate_on_load: bool,
46 pub incremental: bool,
48 pub checkpoint_interval: usize,
50}
51
52impl Default for PersistenceConfig {
53 fn default() -> Self {
54 Self {
55 compression: CompressionAlgorithm::default(),
56 include_metadata: true,
57 validate_on_load: true,
58 incremental: false,
59 checkpoint_interval: 10000,
60 }
61 }
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66struct IndexHeader {
67 version: u32,
68 compression: CompressionAlgorithm,
69 node_count: usize,
70 dimension: usize,
71 config: HnswConfig,
72 timestamp: u64,
73 checksum: u64,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78struct SerializableNode {
79 uri: String,
80 vector_data: Vec<f32>,
81 connections: Vec<Vec<usize>>,
82 level: usize,
83}
84
85pub struct PersistenceManager {
87 config: PersistenceConfig,
88}
89
90impl PersistenceManager {
91 pub fn new(config: PersistenceConfig) -> Self {
93 Self { config }
94 }
95
96 pub fn save_index<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
98 let path = path.as_ref();
99 tracing::info!("Saving HNSW index to {:?}", path);
100
101 let file = OpenOptions::new()
102 .write(true)
103 .create(true)
104 .truncate(true)
105 .open(path)?;
106
107 let mut writer = BufWriter::new(file);
108
109 writer.write_all(MAGIC_NUMBER)?;
111
112 let header = IndexHeader {
114 version: PERSISTENCE_VERSION,
115 compression: self.config.compression,
116 node_count: index.len(),
117 dimension: if let Some(node) = index.nodes().first() {
118 node.vector.dimensions
119 } else {
120 0
121 },
122 config: index.config().clone(),
123 timestamp: std::time::SystemTime::now()
124 .duration_since(std::time::UNIX_EPOCH)
125 .unwrap()
126 .as_secs(),
127 checksum: 0, };
129
130 let header_bytes = bincode::serialize(&header)?;
132 let header_len = header_bytes.len() as u32;
133 writer.write_all(&header_len.to_le_bytes())?;
134 writer.write_all(&header_bytes)?;
135
136 let nodes = self.serialize_nodes(index)?;
138
139 let data = match self.config.compression {
141 CompressionAlgorithm::None => nodes,
142 CompressionAlgorithm::Zstd { level } => zstd::encode_all(&nodes[..], level)?,
143 CompressionAlgorithm::ZstdMax => zstd::encode_all(&nodes[..], 21)?,
144 };
145
146 let data_len = data.len() as u64;
148 writer.write_all(&data_len.to_le_bytes())?;
149 writer.write_all(&data)?;
150
151 let uri_mapping = bincode::serialize(&index.uri_to_id())?;
153 let mapping_len = uri_mapping.len() as u32;
154 writer.write_all(&mapping_len.to_le_bytes())?;
155 writer.write_all(&uri_mapping)?;
156
157 let entry_point = bincode::serialize(&index.entry_point())?;
159 writer.write_all(&entry_point)?;
160
161 writer.flush()?;
162
163 tracing::info!(
164 "Successfully saved HNSW index with {} nodes (compression: {:?})",
165 index.len(),
166 self.config.compression
167 );
168
169 Ok(())
170 }
171
172 pub fn load_index<P: AsRef<Path>>(&self, path: P) -> Result<HnswIndex> {
174 let path = path.as_ref();
175 tracing::info!("Loading HNSW index from {:?}", path);
176
177 let file = File::open(path)?;
178 let mut reader = BufReader::new(file);
179
180 let mut magic = [0u8; 4];
182 reader.read_exact(&mut magic)?;
183 if &magic != MAGIC_NUMBER {
184 return Err(anyhow!("Invalid index file format"));
185 }
186
187 let mut header_len_bytes = [0u8; 4];
189 reader.read_exact(&mut header_len_bytes)?;
190 let header_len = u32::from_le_bytes(header_len_bytes) as usize;
191
192 let mut header_bytes = vec![0u8; header_len];
193 reader.read_exact(&mut header_bytes)?;
194 let header: IndexHeader = bincode::deserialize(&header_bytes)?;
195
196 if header.version != PERSISTENCE_VERSION {
198 return Err(anyhow!(
199 "Unsupported index version: {} (expected {})",
200 header.version,
201 PERSISTENCE_VERSION
202 ));
203 }
204
205 let mut data_len_bytes = [0u8; 8];
207 reader.read_exact(&mut data_len_bytes)?;
208 let data_len = u64::from_le_bytes(data_len_bytes) as usize;
209
210 let mut compressed_data = vec![0u8; data_len];
212 reader.read_exact(&mut compressed_data)?;
213
214 let nodes_data = match header.compression {
215 CompressionAlgorithm::None => compressed_data,
216 CompressionAlgorithm::Zstd { .. } | CompressionAlgorithm::ZstdMax => {
217 zstd::decode_all(&compressed_data[..])?
218 }
219 };
220
221 let mut mapping_len_bytes = [0u8; 4];
223 reader.read_exact(&mut mapping_len_bytes)?;
224 let mapping_len = u32::from_le_bytes(mapping_len_bytes) as usize;
225
226 let mut mapping_bytes = vec![0u8; mapping_len];
227 reader.read_exact(&mut mapping_bytes)?;
228 let uri_mapping: std::collections::HashMap<String, usize> =
229 bincode::deserialize(&mapping_bytes)?;
230
231 let mut entry_point_bytes = Vec::new();
233 reader.read_to_end(&mut entry_point_bytes)?;
234 let entry_point: Option<usize> = bincode::deserialize(&entry_point_bytes)?;
235
236 let mut index = HnswIndex::new(header.config)?;
238 self.deserialize_nodes(&nodes_data, &mut index)?;
239
240 *index.uri_to_id_mut() = uri_mapping;
242
243 index.set_entry_point(entry_point);
245
246 if self.config.validate_on_load {
248 self.validate_index(&index)?;
249 }
250
251 tracing::info!("Successfully loaded HNSW index with {} nodes", index.len());
252
253 Ok(index)
254 }
255
256 fn serialize_nodes(&self, index: &HnswIndex) -> Result<Vec<u8>> {
258 let serializable_nodes: Vec<SerializableNode> = index
259 .nodes()
260 .iter()
261 .map(|node| SerializableNode {
262 uri: node.uri.clone(),
263 vector_data: node.vector.as_f32(),
264 connections: node
265 .connections
266 .iter()
267 .map(|set| set.iter().copied().collect())
268 .collect(),
269 level: node.level(),
270 })
271 .collect();
272
273 Ok(bincode::serialize(&serializable_nodes)?)
274 }
275
276 fn deserialize_nodes(&self, data: &[u8], index: &mut HnswIndex) -> Result<()> {
278 let serializable_nodes: Vec<SerializableNode> = bincode::deserialize(data)?;
279
280 for node_data in serializable_nodes {
281 let vector = Vector::new(node_data.vector_data);
282 let mut node = crate::hnsw::Node::new(node_data.uri, vector, node_data.level);
283
284 for (level, connections) in node_data.connections.into_iter().enumerate() {
286 for conn_id in connections {
287 node.add_connection(level, conn_id);
288 }
289 }
290
291 index.nodes_mut().push(node);
292 }
293
294 Ok(())
295 }
296
297 fn validate_index(&self, index: &HnswIndex) -> Result<()> {
299 tracing::debug!("Validating index integrity");
300
301 for (node_id, node) in index.nodes().iter().enumerate() {
303 for level in 0..=node.level() {
304 if let Some(connections) = node.get_connections(level) {
305 for &conn_id in connections {
306 if conn_id >= index.len() {
307 return Err(anyhow!(
308 "Invalid connection: node {} has connection to non-existent node {}",
309 node_id,
310 conn_id
311 ));
312 }
313 }
314 }
315 }
316 }
317
318 for (uri, &node_id) in index.uri_to_id() {
320 if node_id >= index.len() {
321 return Err(anyhow!(
322 "Invalid URI mapping: {} points to non-existent node {}",
323 uri,
324 node_id
325 ));
326 }
327
328 let actual_uri = &index.nodes()[node_id].uri;
329 if uri != actual_uri {
330 return Err(anyhow!(
331 "URI mapping mismatch: expected '{}', found '{}'",
332 uri,
333 actual_uri
334 ));
335 }
336 }
337
338 if let Some(entry_id) = index.entry_point() {
340 if entry_id >= index.len() {
341 return Err(anyhow!(
342 "Invalid entry point: {} (index has {} nodes)",
343 entry_id,
344 index.len()
345 ));
346 }
347 }
348
349 tracing::debug!("Index validation passed");
350 Ok(())
351 }
352
353 pub fn create_snapshot<P: AsRef<Path>>(&self, index: &HnswIndex, path: P) -> Result<()> {
355 let path = path.as_ref();
356 let snapshot_path = path.with_extension(format!(
357 "snapshot.{}",
358 std::time::SystemTime::now()
359 .duration_since(std::time::UNIX_EPOCH)
360 .unwrap()
361 .as_secs()
362 ));
363
364 self.save_index(index, snapshot_path)?;
365 Ok(())
366 }
367
368 pub fn estimate_compressed_size(&self, index: &HnswIndex) -> Result<usize> {
370 let nodes = self.serialize_nodes(index)?;
371
372 let compressed_size = match self.config.compression {
373 CompressionAlgorithm::None => nodes.len(),
374 CompressionAlgorithm::Zstd { level } => zstd::encode_all(&nodes[..], level)?.len(),
375 CompressionAlgorithm::ZstdMax => zstd::encode_all(&nodes[..], 21)?.len(),
376 };
377
378 Ok(compressed_size)
379 }
380}
381
382pub struct IncrementalPersistence {
384 config: PersistenceConfig,
385 operation_count: usize,
386 last_checkpoint: std::time::Instant,
387}
388
389impl IncrementalPersistence {
390 pub fn new(config: PersistenceConfig) -> Self {
391 Self {
392 config,
393 operation_count: 0,
394 last_checkpoint: std::time::Instant::now(),
395 }
396 }
397
398 pub fn record_operation(&mut self) {
400 self.operation_count += 1;
401 }
402
403 pub fn needs_checkpoint(&self) -> bool {
405 self.operation_count >= self.config.checkpoint_interval
406 }
407
408 pub fn checkpoint<P: AsRef<Path>>(&mut self, index: &HnswIndex, base_path: P) -> Result<()> {
410 if !self.needs_checkpoint() {
411 return Ok(());
412 }
413
414 let manager = PersistenceManager::new(self.config.clone());
415 manager.create_snapshot(index, base_path)?;
416
417 self.operation_count = 0;
418 self.last_checkpoint = std::time::Instant::now();
419
420 Ok(())
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use crate::hnsw::HnswConfig;
428 use crate::Vector;
429 use std::env::temp_dir;
430
431 #[test]
432 fn test_save_and_load_index() {
433 let config = HnswConfig::default();
434 let mut index = HnswIndex::new(config).unwrap();
435
436 for i in 0..10 {
438 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
439 index.add_vector(format!("vec_{}", i), vec).unwrap();
440 }
441
442 let mut temp_path = temp_dir();
444 temp_path.push("test_hnsw_index.bin");
445
446 let persistence_config = PersistenceConfig::default();
447 let manager = PersistenceManager::new(persistence_config);
448
449 manager.save_index(&index, &temp_path).unwrap();
450
451 let loaded_index = manager.load_index(&temp_path).unwrap();
453
454 assert_eq!(loaded_index.len(), 10);
455 assert_eq!(loaded_index.uri_to_id().len(), 10);
456
457 std::fs::remove_file(temp_path).ok();
459 }
460
461 #[test]
462 fn test_compression() {
463 let config = HnswConfig::default();
464 let mut index = HnswIndex::new(config).unwrap();
465
466 for i in 0..50 {
468 let vec = Vector::new(vec![i as f32; 128]);
469 index.add_vector(format!("vec_{}", i), vec).unwrap();
470 }
471
472 let mut temp_path = temp_dir();
473 temp_path.push("test_compressed_index.bin");
474
475 let compressed_config = PersistenceConfig {
477 compression: CompressionAlgorithm::Zstd { level: 3 },
478 ..Default::default()
479 };
480 let compressed_manager = PersistenceManager::new(compressed_config);
481 compressed_manager.save_index(&index, &temp_path).unwrap();
482
483 let compressed_size = std::fs::metadata(&temp_path).unwrap().len();
484
485 let uncompressed_config = PersistenceConfig {
487 compression: CompressionAlgorithm::None,
488 ..Default::default()
489 };
490 let uncompressed_manager = PersistenceManager::new(uncompressed_config);
491
492 let mut temp_path2 = temp_dir();
493 temp_path2.push("test_uncompressed_index.bin");
494 uncompressed_manager
495 .save_index(&index, &temp_path2)
496 .unwrap();
497
498 let uncompressed_size = std::fs::metadata(&temp_path2).unwrap().len();
499
500 assert!(compressed_size < uncompressed_size);
502
503 std::fs::remove_file(temp_path).ok();
505 std::fs::remove_file(temp_path2).ok();
506 }
507
508 #[test]
509 fn test_validation() {
510 let config = HnswConfig::default();
511 let mut index = HnswIndex::new(config).unwrap();
512
513 for i in 0..5 {
514 let vec = Vector::new(vec![i as f32, 0.0, 0.0]);
515 index.add_vector(format!("vec_{}", i), vec).unwrap();
516 }
517
518 let persistence_config = PersistenceConfig {
519 validate_on_load: true,
520 ..Default::default()
521 };
522 let manager = PersistenceManager::new(persistence_config);
523
524 manager.validate_index(&index).unwrap();
526 }
527}