1use crate::hnsw::{HnswConfig, HnswIndex, Node};
26use crate::Vector;
27use crate::VectorError;
28use std::collections::HashSet;
29use std::io::{Read, Write};
30use std::path::Path;
31
32const SNAPSHOT_MAGIC: &[u8; 4] = b"HNSW";
34
35const SNAPSHOT_VERSION: u32 = 1;
37
38#[derive(Debug, Clone)]
40pub struct SnapshotHeader {
41 pub magic: [u8; 4],
43 pub version: u32,
45 pub num_nodes: usize,
47 pub num_layers: usize,
49 pub dimension: usize,
51 pub ef_construction: usize,
53 pub m: usize,
55 pub m_l0: usize,
57 pub entry_point: Option<usize>,
59}
60
61pub struct IndexSnapshot;
66
67impl IndexSnapshot {
68 pub fn save<W: Write>(index: &HnswIndex, writer: &mut W) -> Result<usize, VectorError> {
76 let mut written = 0usize;
77
78 writer
80 .write_all(SNAPSHOT_MAGIC)
81 .map_err(VectorError::IoError)?;
82 written += 4;
83
84 Self::write_u32(writer, SNAPSHOT_VERSION).map_err(VectorError::IoError)?;
86 written += 4;
87
88 let nodes = index.nodes();
89 let config = index.config();
90
91 let num_layers = nodes.iter().map(|n| n.connections.len()).max().unwrap_or(0);
93
94 let dimension = nodes.first().map(|n| n.vector_data_f32.len()).unwrap_or(0);
95
96 Self::write_u64(writer, nodes.len() as u64).map_err(VectorError::IoError)?;
98 written += 8;
99 Self::write_u64(writer, num_layers as u64).map_err(VectorError::IoError)?;
100 written += 8;
101 Self::write_u64(writer, dimension as u64).map_err(VectorError::IoError)?;
102 written += 8;
103 Self::write_u64(writer, config.ef_construction as u64).map_err(VectorError::IoError)?;
104 written += 8;
105 Self::write_u64(writer, config.m as u64).map_err(VectorError::IoError)?;
106 written += 8;
107 Self::write_u64(writer, config.m_l0 as u64).map_err(VectorError::IoError)?;
108 written += 8;
109
110 match index.entry_point() {
112 None => {
113 Self::write_u8(writer, 0).map_err(VectorError::IoError)?;
114 written += 1;
115 }
116 Some(ep) => {
117 Self::write_u8(writer, 1).map_err(VectorError::IoError)?;
118 written += 1;
119 Self::write_u64(writer, ep as u64).map_err(VectorError::IoError)?;
120 written += 8;
121 }
122 }
123
124 for node in nodes {
126 written += Self::write_node(writer, node).map_err(VectorError::IoError)?;
127 }
128
129 writer.flush().map_err(VectorError::IoError)?;
130 Ok(written)
131 }
132
133 pub fn load<R: Read>(reader: &mut R) -> Result<HnswIndex, VectorError> {
135 let mut magic = [0u8; 4];
137 reader
138 .read_exact(&mut magic)
139 .map_err(VectorError::IoError)?;
140 if &magic != SNAPSHOT_MAGIC {
141 return Err(VectorError::InvalidData(format!(
142 "Invalid snapshot magic: expected {:?}, got {:?}",
143 SNAPSHOT_MAGIC, magic
144 )));
145 }
146
147 let version = Self::read_u32(reader).map_err(VectorError::IoError)?;
149 if version != SNAPSHOT_VERSION {
150 return Err(VectorError::InvalidData(format!(
151 "Unsupported snapshot version: {}",
152 version
153 )));
154 }
155
156 let num_nodes = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
158 let _num_layers = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
159 let _dimension = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
160 let ef_construction = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
161 let m = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
162 let m_l0 = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
163
164 let has_entry = Self::read_u8(reader).map_err(VectorError::IoError)?;
166 let entry_point = if has_entry == 1 {
167 Some(Self::read_u64(reader).map_err(VectorError::IoError)? as usize)
168 } else {
169 None
170 };
171
172 let config = HnswConfig {
174 m,
175 m_l0,
176 ef_construction,
177 ..HnswConfig::default()
178 };
179
180 let mut nodes: Vec<Node> = Vec::with_capacity(num_nodes);
182 let mut uri_to_id: std::collections::HashMap<String, usize> =
183 std::collections::HashMap::with_capacity(num_nodes);
184
185 for idx in 0..num_nodes {
186 let node = Self::read_node(reader)?;
187 uri_to_id.insert(node.uri.clone(), idx);
188 nodes.push(node);
189 }
190
191 let mut index = HnswIndex::new_cpu_only(config);
193 *index.nodes_mut() = nodes;
195 *index.uri_to_id_mut() = uri_to_id;
196 index.set_entry_point(entry_point);
197
198 Ok(index)
199 }
200
201 pub fn save_to_file(index: &HnswIndex, path: &Path) -> Result<usize, VectorError> {
205 let tmp_path = path.with_extension("hnsw.tmp");
207 let file = std::fs::File::create(&tmp_path).map_err(VectorError::IoError)?;
208 let mut writer = std::io::BufWriter::new(file);
209
210 let written = Self::save(index, &mut writer)?;
211 drop(writer);
212
213 std::fs::rename(&tmp_path, path).map_err(VectorError::IoError)?;
214 Ok(written)
215 }
216
217 pub fn load_from_file(path: &Path) -> Result<HnswIndex, VectorError> {
219 let file = std::fs::File::open(path).map_err(VectorError::IoError)?;
220 let mut reader = std::io::BufReader::new(file);
221 Self::load(&mut reader)
222 }
223
224 fn write_node<W: Write>(writer: &mut W, node: &Node) -> std::io::Result<usize> {
229 let mut written = 0usize;
230
231 let uri_bytes = node.uri.as_bytes();
233 Self::write_u64(writer, uri_bytes.len() as u64)?;
234 written += 8;
235 writer.write_all(uri_bytes)?;
236 written += uri_bytes.len();
237
238 Self::write_u64(writer, node.vector_data_f32.len() as u64)?;
240 written += 8;
241 for &v in &node.vector_data_f32 {
242 Self::write_f32(writer, v)?;
243 written += 4;
244 }
245
246 Self::write_u64(writer, node.connections.len() as u64)?;
248 written += 8;
249 for layer_connections in &node.connections {
250 Self::write_u64(writer, layer_connections.len() as u64)?;
251 written += 8;
252 for &neighbor_id in layer_connections {
253 Self::write_u64(writer, neighbor_id as u64)?;
254 written += 8;
255 }
256 }
257
258 Ok(written)
259 }
260
261 fn read_node<R: Read>(reader: &mut R) -> Result<Node, VectorError> {
262 let uri_len = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
264 let mut uri_bytes = vec![0u8; uri_len];
265 reader
266 .read_exact(&mut uri_bytes)
267 .map_err(VectorError::IoError)?;
268 let uri = String::from_utf8(uri_bytes)
269 .map_err(|e| VectorError::InvalidData(format!("Invalid UTF-8 in URI: {}", e)))?;
270
271 let vec_len = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
273 let mut vector_data_f32 = Vec::with_capacity(vec_len);
274 for _ in 0..vec_len {
275 let v = Self::read_f32(reader).map_err(VectorError::IoError)?;
276 vector_data_f32.push(v);
277 }
278
279 let vector = Vector::new(vector_data_f32.clone());
281
282 let num_layers = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
284 let mut connections: Vec<HashSet<usize>> = Vec::with_capacity(num_layers);
285 for _ in 0..num_layers {
286 let num_conn = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
287 let mut layer_set = HashSet::with_capacity(num_conn);
288 for _ in 0..num_conn {
289 let neighbor = Self::read_u64(reader).map_err(VectorError::IoError)? as usize;
290 layer_set.insert(neighbor);
291 }
292 connections.push(layer_set);
293 }
294
295 let max_level = num_layers.saturating_sub(1);
296 let mut node = Node::new(uri, vector, max_level);
297 node.connections = connections;
298 node.vector_data_f32 = vector_data_f32;
299
300 Ok(node)
301 }
302
303 fn write_u64<W: Write>(w: &mut W, v: u64) -> std::io::Result<()> {
308 w.write_all(&v.to_le_bytes())
309 }
310
311 fn read_u64<R: Read>(r: &mut R) -> std::io::Result<u64> {
312 let mut buf = [0u8; 8];
313 r.read_exact(&mut buf)?;
314 Ok(u64::from_le_bytes(buf))
315 }
316
317 fn write_u32<W: Write>(w: &mut W, v: u32) -> std::io::Result<()> {
318 w.write_all(&v.to_le_bytes())
319 }
320
321 fn read_u32<R: Read>(r: &mut R) -> std::io::Result<u32> {
322 let mut buf = [0u8; 4];
323 r.read_exact(&mut buf)?;
324 Ok(u32::from_le_bytes(buf))
325 }
326
327 fn write_u8<W: Write>(w: &mut W, v: u8) -> std::io::Result<()> {
328 w.write_all(&[v])
329 }
330
331 fn read_u8<R: Read>(r: &mut R) -> std::io::Result<u8> {
332 let mut buf = [0u8; 1];
333 r.read_exact(&mut buf)?;
334 Ok(buf[0])
335 }
336
337 fn write_f32<W: Write>(w: &mut W, v: f32) -> std::io::Result<()> {
338 w.write_all(&v.to_le_bytes())
339 }
340
341 fn read_f32<R: Read>(r: &mut R) -> std::io::Result<f32> {
342 let mut buf = [0u8; 4];
343 r.read_exact(&mut buf)?;
344 Ok(f32::from_le_bytes(buf))
345 }
346}
347
348#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::hnsw::{HnswConfig, HnswIndex};
356 use crate::VectorIndex;
357
358 fn make_index_with_vectors(n: usize, dim: usize) -> HnswIndex {
359 let config = HnswConfig::default();
360 let mut index = HnswIndex::new_cpu_only(config);
361
362 for i in 0..n {
363 let data: Vec<f32> = (0..dim).map(|j| (i * dim + j) as f32 / 1000.0).collect();
364 let uri = format!("http://example.org/v{}", i);
365 let vec = Vector::new(data);
366 index.insert(uri, vec).expect("insert failed");
367 }
368
369 index
370 }
371
372 #[test]
373 fn test_save_and_load_empty_index() {
374 let index = HnswIndex::new_cpu_only(HnswConfig::default());
375 let mut buf = Vec::new();
376 let bytes = IndexSnapshot::save(&index, &mut buf).expect("save failed");
377 assert!(bytes > 0);
378
379 let loaded = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
380 assert_eq!(loaded.len(), 0);
381 assert_eq!(loaded.entry_point(), None);
382 }
383
384 #[test]
385 fn test_save_and_load_roundtrip() {
386 let original = make_index_with_vectors(20, 8);
387 assert_eq!(original.len(), 20);
388
389 let mut buf = Vec::new();
390 IndexSnapshot::save(&original, &mut buf).expect("save failed");
391
392 let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
393
394 assert_eq!(restored.len(), original.len());
396
397 for uri in original.uri_to_id().keys() {
399 assert!(
400 restored.uri_to_id().contains_key(uri),
401 "URI {} missing after restore",
402 uri
403 );
404 }
405
406 assert_eq!(original.entry_point(), restored.entry_point());
408 }
409
410 #[test]
411 fn test_save_and_load_vectors_preserved() {
412 let original = make_index_with_vectors(10, 4);
413
414 let mut buf = Vec::new();
415 IndexSnapshot::save(&original, &mut buf).expect("save failed");
416 let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
417
418 for (orig_node, rest_node) in original.nodes().iter().zip(restored.nodes().iter()) {
420 assert_eq!(orig_node.uri, rest_node.uri);
421 assert_eq!(
422 orig_node.vector_data_f32.len(),
423 rest_node.vector_data_f32.len()
424 );
425 for (a, b) in orig_node
426 .vector_data_f32
427 .iter()
428 .zip(rest_node.vector_data_f32.iter())
429 {
430 assert!((a - b).abs() < 1e-6, "Vector data mismatch: {} vs {}", a, b);
431 }
432 }
433 }
434
435 #[test]
436 fn test_save_and_load_connections_preserved() {
437 let original = make_index_with_vectors(30, 8);
438
439 let mut buf = Vec::new();
440 IndexSnapshot::save(&original, &mut buf).expect("save failed");
441 let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load failed");
442
443 for (i, (orig, rest)) in original
445 .nodes()
446 .iter()
447 .zip(restored.nodes().iter())
448 .enumerate()
449 {
450 assert_eq!(
451 orig.connections.len(),
452 rest.connections.len(),
453 "Node {} layer count mismatch",
454 i
455 );
456 for (layer, (oc, rc)) in orig
457 .connections
458 .iter()
459 .zip(rest.connections.iter())
460 .enumerate()
461 {
462 assert_eq!(oc, rc, "Node {} layer {} connections mismatch", i, layer);
463 }
464 }
465 }
466
467 #[test]
468 fn test_file_save_and_load() {
469 let original = make_index_with_vectors(15, 6);
470
471 let dir = std::env::temp_dir();
472 let path = dir.join("oxirs_snapshot_test.hnsw");
473
474 let bytes = IndexSnapshot::save_to_file(&original, &path).expect("save_to_file failed");
475 assert!(bytes > 0);
476 assert!(path.exists());
477
478 let restored = IndexSnapshot::load_from_file(&path).expect("load_from_file failed");
479 assert_eq!(restored.len(), original.len());
480
481 let _ = std::fs::remove_file(&path);
483 }
484
485 #[test]
486 fn test_corrupt_magic_rejected() {
487 let mut buf = vec![0u8; 100];
488 buf[0] = b'X'; let result = IndexSnapshot::load(&mut buf.as_slice());
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn test_config_restored() {
495 let config = HnswConfig {
496 m: 8,
497 m_l0: 16,
498 ef_construction: 50,
499 ..Default::default()
500 };
501
502 let mut index = HnswIndex::new_cpu_only(config);
503 let vec_a = Vector::new(vec![1.0, 0.0, 0.0, 0.0]);
504 index
505 .insert("http://example.org/a".to_string(), vec_a)
506 .expect("insert");
507
508 let mut buf = Vec::new();
509 IndexSnapshot::save(&index, &mut buf).expect("save");
510 let restored = IndexSnapshot::load(&mut buf.as_slice()).expect("load");
511
512 assert_eq!(restored.config().m, 8);
513 assert_eq!(restored.config().m_l0, 16);
514 assert_eq!(restored.config().ef_construction, 50);
515 }
516}