use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
use tracing::info;
#[cfg(feature = "zerocopy")]
use std::io::Write;
#[cfg(feature = "mmap")]
use memmap2::Mmap;
#[cfg(feature = "zerocopy")]
use rkyv::{Archive, Deserialize as RkyvDeserialize, Serialize as RkyvSerialize};
pub fn save_index<T: Serialize, P: AsRef<Path>>(index: &T, path: P) -> Result<()> {
let path = path.as_ref();
info!("Saving index to: {}", path.display());
let file =
File::create(path).with_context(|| format!("Failed to create file: {}", path.display()))?;
let writer = BufWriter::new(file);
serde_json::to_writer_pretty(writer, index)
.with_context(|| format!("Failed to serialize index to: {}", path.display()))?;
info!("Index saved successfully");
Ok(())
}
pub fn load_index<T: for<'de> Deserialize<'de>, P: AsRef<Path>>(path: P) -> Result<T> {
let path = path.as_ref();
info!("Loading index from: {}", path.display());
let file =
File::open(path).with_context(|| format!("Failed to open file: {}", path.display()))?;
let reader = BufReader::new(file);
let index = serde_json::from_reader(reader)
.with_context(|| format!("Failed to deserialize index from: {}", path.display()))?;
info!("Index loaded successfully");
Ok(index)
}
pub fn get_serialized_size<T: Serialize>(index: &T) -> Result<usize> {
let json =
serde_json::to_string(index).context("Failed to serialize index for size calculation")?;
Ok(json.len())
}
pub fn index_file_exists<P: AsRef<Path>>(path: P) -> bool {
path.as_ref().exists() && path.as_ref().is_file()
}
#[cfg(feature = "zerocopy")]
pub fn save_index_binary<T, P>(index: &T, path: P) -> Result<()>
where
T: for<'a> RkyvSerialize<
rkyv::rancor::Strategy<
rkyv::ser::Serializer<
rkyv::util::AlignedVec,
rkyv::ser::allocator::ArenaHandle<'a>,
rkyv::ser::sharing::Share,
>,
rkyv::rancor::Error,
>,
>,
P: AsRef<Path>,
{
let path = path.as_ref();
info!("Saving index (binary) to: {}", path.display());
let bytes = rkyv::to_bytes::<rkyv::rancor::Error>(index)
.map_err(|e| anyhow::anyhow!("Failed to serialize index: {}", e))?;
let mut file =
File::create(path).with_context(|| format!("Failed to create file: {}", path.display()))?;
file.write_all(&bytes)
.with_context(|| format!("Failed to write to file: {}", path.display()))?;
info!("Index saved successfully ({} bytes)", bytes.len());
Ok(())
}
#[cfg(feature = "zerocopy")]
pub fn load_index_binary<T, P>(path: P) -> Result<T>
where
T: Archive,
T::Archived: RkyvDeserialize<T, rkyv::rancor::Strategy<rkyv::de::Pool, rkyv::rancor::Error>>,
P: AsRef<Path>,
{
let path = path.as_ref();
info!("Loading index (binary) from: {}", path.display());
let bytes =
std::fs::read(path).with_context(|| format!("Failed to read file: {}", path.display()))?;
let archived = unsafe { rkyv::access_unchecked::<T::Archived>(&bytes) };
let mut deserializer = rkyv::de::Pool::new();
let index: T = archived
.deserialize(rkyv::rancor::Strategy::wrap(&mut deserializer))
.map_err(|e| anyhow::anyhow!("Failed to deserialize archived data: {}", e))?;
info!("Index loaded successfully");
Ok(index)
}
#[cfg(feature = "mmap")]
pub struct MappedIndex {
_mmap: Mmap,
data: Vec<u8>,
}
#[cfg(feature = "mmap")]
impl MappedIndex {
pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
info!("Memory-mapping index from: {}", path.display());
let file =
File::open(path).with_context(|| format!("Failed to open file: {}", path.display()))?;
let mmap = unsafe {
Mmap::map(&file)
.with_context(|| format!("Failed to memory-map file: {}", path.display()))?
};
let data = mmap.to_vec();
info!("Index memory-mapped successfully ({} bytes)", data.len());
Ok(Self { _mmap: mmap, data })
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn deserialize<T: for<'de> Deserialize<'de>>(&self) -> Result<T> {
serde_json::from_slice(&self.data).context("Failed to deserialize memory-mapped index")
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hnsw::{HnswConfig, HnswIndex};
use crate::ivf::{IvfPqConfig, IvfPqIndex};
use crate::search::VectorSearchIndex;
use crate::types::SearchConfig;
use std::collections::HashMap;
use tempfile::TempDir;
fn create_test_embeddings() -> HashMap<String, Vec<f32>> {
let mut embeddings = HashMap::new();
embeddings.insert("doc1".to_string(), vec![0.1, 0.2, 0.3]);
embeddings.insert("doc2".to_string(), vec![0.4, 0.5, 0.6]);
embeddings.insert("doc3".to_string(), vec![0.7, 0.8, 0.9]);
embeddings
}
#[test]
fn test_save_and_load_hnsw() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("hnsw_index.json");
let embeddings = create_test_embeddings();
let mut index = HnswIndex::new(HnswConfig::default());
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
assert!(index_file_exists(&index_path));
let loaded_index: HnswIndex = load_index(&index_path).unwrap();
let query = vec![0.2, 0.3, 0.4];
let results = loaded_index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_save_and_load_exact_search() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("exact_index.json");
let embeddings = create_test_embeddings();
let mut index = VectorSearchIndex::new(SearchConfig::default());
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
let loaded_index: VectorSearchIndex = load_index(&index_path).unwrap();
let query = vec![0.5, 0.6, 0.7];
let results = loaded_index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_save_and_load_ivf_pq() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("ivf_index.json");
let mut embeddings = HashMap::new();
for i in 0..500 {
let vec = vec![
i as f32 * 0.001,
(i + 1) as f32 * 0.001,
(i + 2) as f32 * 0.001,
(i + 3) as f32 * 0.001,
];
embeddings.insert(format!("doc{}", i), vec);
}
let config = IvfPqConfig {
nclusters: 8, nsubvectors: 4,
nbits: 4, nprobe: 2, max_kmeans_iterations: 20, ..IvfPqConfig::default()
};
let mut index = IvfPqIndex::new(config);
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
let loaded_index: IvfPqIndex = load_index(&index_path).unwrap();
let query = vec![0.5, 0.6, 0.7, 0.8];
let results = loaded_index.search(&query, 5).unwrap();
assert!(!results.is_empty());
}
#[test]
#[ignore]
fn test_save_and_load_ivf_pq_full() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("ivf_index_full.json");
let mut embeddings = HashMap::new();
for i in 0..1000 {
let vec = vec![
i as f32 * 0.001,
(i + 1) as f32 * 0.001,
(i + 2) as f32 * 0.001,
(i + 3) as f32 * 0.001,
];
embeddings.insert(format!("doc{}", i), vec);
}
let config = IvfPqConfig {
nclusters: 16, nsubvectors: 4,
nprobe: 4,
..IvfPqConfig::default()
};
let mut index = IvfPqIndex::new(config);
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
let loaded_index: IvfPqIndex = load_index(&index_path).unwrap();
let query = vec![0.5, 0.6, 0.7, 0.8];
let results = loaded_index.search(&query, 5).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_get_serialized_size() {
let embeddings = create_test_embeddings();
let mut index = HnswIndex::new(HnswConfig::default());
index.build(&embeddings).unwrap();
let size = get_serialized_size(&index).unwrap();
assert!(size > 0);
assert!(size < 100000); }
#[test]
fn test_index_file_exists() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("test_index.json");
assert!(!index_file_exists(&index_path));
let embeddings = create_test_embeddings();
let mut index = HnswIndex::new(HnswConfig::default());
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
assert!(index_file_exists(&index_path));
}
#[test]
fn test_load_nonexistent_file() {
let result: Result<HnswIndex> = load_index("/nonexistent/path/index.json");
assert!(result.is_err());
}
#[test]
fn test_save_to_invalid_path() {
let embeddings = create_test_embeddings();
let mut index = HnswIndex::new(HnswConfig::default());
index.build(&embeddings).unwrap();
let result = save_index(&index, "/invalid/nonexistent/path/index.json");
assert!(result.is_err());
}
#[test]
#[cfg(feature = "mmap")]
fn test_mmap_index_creation() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("mmap_index.json");
let embeddings = create_test_embeddings();
let mut index = HnswIndex::new(HnswConfig::default());
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
let mapped = MappedIndex::new(&index_path).unwrap();
assert!(!mapped.as_bytes().is_empty());
let loaded_index: HnswIndex = mapped.deserialize().unwrap();
let query = vec![0.2, 0.3, 0.4];
let results = loaded_index.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
#[cfg(feature = "mmap")]
fn test_mmap_nonexistent_file() {
let result = MappedIndex::new("/nonexistent/file.json");
assert!(result.is_err());
}
#[test]
#[cfg(feature = "mmap")]
fn test_mmap_large_index() {
let temp_dir = TempDir::new().unwrap();
let index_path = temp_dir.path().join("mmap_large_index.json");
let mut embeddings = HashMap::new();
for i in 0..1000 {
embeddings.insert(
format!("doc{}", i),
vec![
i as f32 * 0.001,
(i + 1) as f32 * 0.001,
(i + 2) as f32 * 0.001,
],
);
}
let mut index = HnswIndex::new(HnswConfig::default());
index.build(&embeddings).unwrap();
save_index(&index, &index_path).unwrap();
let mapped = MappedIndex::new(&index_path).unwrap();
let loaded_index: HnswIndex = mapped.deserialize().unwrap();
let query = vec![0.5, 0.6, 0.7];
let results = loaded_index.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
}
}