use std::collections::{HashMap, HashSet};
use std::fs::{File, OpenOptions};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use memmap2::MmapMut;
use crate::distance::DistanceMetric;
use crate::error::{Result, SynaError};
use crate::hnsw::{HnswConfig, HnswIndex, HnswNode};
const MMAP_MAGIC: u32 = 0x4D564543;
const MMAP_VERSION: u32 = 1;
const HEADER_SIZE: usize = 64;
const DEFAULT_INITIAL_CAPACITY: usize = 100_000;
const DEFAULT_CHECKPOINT_SECS: u64 = 30;
#[derive(Debug, Clone)]
pub struct MmapVectorConfig {
pub dimensions: u16,
pub metric: DistanceMetric,
pub initial_capacity: usize,
pub index_threshold: usize,
pub hnsw_config: HnswConfig,
pub checkpoint_interval_secs: u64,
}
impl Default for MmapVectorConfig {
fn default() -> Self {
Self {
dimensions: 768,
metric: DistanceMetric::Cosine,
initial_capacity: DEFAULT_INITIAL_CAPACITY,
index_threshold: 10_000,
hnsw_config: HnswConfig::default(),
checkpoint_interval_secs: DEFAULT_CHECKPOINT_SECS,
}
}
}
#[derive(Debug, Clone)]
pub struct MmapSearchResult {
pub key: String,
pub score: f32,
pub vector: Vec<f32>,
}
pub struct MmapVectorStore {
path: PathBuf,
mmap: MmapMut,
file: File,
config: MmapVectorConfig,
write_offset: AtomicU64,
vector_count: AtomicU64,
keys: HashSet<String>,
key_to_offset: HashMap<String, u64>,
hnsw_index: Option<HnswIndex>,
index_dirty: bool,
last_checkpoint: Instant,
checkpoint_interval: Duration,
}
impl MmapVectorStore {
pub fn new<P: AsRef<Path>>(path: P, config: MmapVectorConfig) -> Result<Self> {
if config.dimensions < 64 || config.dimensions > 8192 {
return Err(SynaError::InvalidDimensions(config.dimensions));
}
let path = path.as_ref().to_path_buf();
let exists = path.exists();
let vector_size = Self::vector_entry_size(config.dimensions, 256); let file_size = HEADER_SIZE + (config.initial_capacity * vector_size);
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&path)?;
let current_size = file.metadata()?.len() as usize;
if current_size < file_size {
file.set_len(file_size as u64)?;
}
let mmap = unsafe { MmapMut::map_mut(&file)? };
let checkpoint_interval = Duration::from_secs(config.checkpoint_interval_secs);
let mut store = Self {
path,
mmap,
file,
config,
write_offset: AtomicU64::new(HEADER_SIZE as u64),
vector_count: AtomicU64::new(0),
keys: HashSet::new(),
key_to_offset: HashMap::new(),
hnsw_index: None,
index_dirty: false,
last_checkpoint: Instant::now(),
checkpoint_interval,
};
if exists {
store.load_existing()?;
} else {
store.write_header()?;
}
store.try_load_hnsw_index();
Ok(store)
}
#[inline]
fn vector_entry_size(dimensions: u16, key_len: usize) -> usize {
2 + key_len + (dimensions as usize * 4) }
fn write_header(&mut self) -> Result<()> {
let header = &mut self.mmap[0..HEADER_SIZE];
header[0..4].copy_from_slice(&MMAP_MAGIC.to_le_bytes());
header[4..8].copy_from_slice(&MMAP_VERSION.to_le_bytes());
header[8..10].copy_from_slice(&self.config.dimensions.to_le_bytes());
header[10] = self.config.metric as u8;
header[16..24].copy_from_slice(&0u64.to_le_bytes());
header[24..32].copy_from_slice(&(HEADER_SIZE as u64).to_le_bytes());
Ok(())
}
fn load_existing(&mut self) -> Result<()> {
let magic = u32::from_le_bytes([self.mmap[0], self.mmap[1], self.mmap[2], self.mmap[3]]);
if magic != MMAP_MAGIC {
return Err(SynaError::CorruptedIndex(
"Invalid mmap vector file magic".to_string(),
));
}
let version = u32::from_le_bytes([self.mmap[4], self.mmap[5], self.mmap[6], self.mmap[7]]);
if version != MMAP_VERSION {
return Err(SynaError::CorruptedIndex(format!(
"Unsupported mmap vector file version: {}",
version
)));
}
let dimensions = u16::from_le_bytes([self.mmap[8], self.mmap[9]]);
if dimensions != self.config.dimensions {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: dimensions,
});
}
let vector_count = u64::from_le_bytes([
self.mmap[16],
self.mmap[17],
self.mmap[18],
self.mmap[19],
self.mmap[20],
self.mmap[21],
self.mmap[22],
self.mmap[23],
]);
let write_offset = u64::from_le_bytes([
self.mmap[24],
self.mmap[25],
self.mmap[26],
self.mmap[27],
self.mmap[28],
self.mmap[29],
self.mmap[30],
self.mmap[31],
]);
self.vector_count.store(vector_count, Ordering::SeqCst);
self.write_offset.store(write_offset, Ordering::SeqCst);
self.rebuild_index_from_mmap()?;
Ok(())
}
fn rebuild_index_from_mmap(&mut self) -> Result<()> {
let mut offset = HEADER_SIZE as u64;
let write_offset = self.write_offset.load(Ordering::SeqCst);
let dims = self.config.dimensions as usize;
while offset < write_offset {
let key_len = u16::from_le_bytes(
self.mmap[offset as usize..(offset as usize + 2)]
.try_into()
.map_err(|_| {
SynaError::CorruptedIndex("Failed to read key length".to_string())
})?,
) as usize;
let key_start = offset as usize + 2;
let key_end = key_start + key_len;
let key = String::from_utf8(self.mmap[key_start..key_end].to_vec())
.map_err(|_| SynaError::CorruptedIndex("Invalid UTF-8 key".to_string()))?;
self.keys.insert(key.clone());
self.key_to_offset.insert(key, offset);
let entry_size = 2 + key_len + (dims * 4);
offset += entry_size as u64;
}
Ok(())
}
fn try_load_hnsw_index(&mut self) {
let hnsw_path = self.hnsw_index_path();
if hnsw_path.exists() {
if let Ok(index) =
HnswIndex::load_validated(&hnsw_path, self.config.dimensions, self.config.metric)
{
if index.len() == self.keys.len() {
self.hnsw_index = Some(index);
}
}
}
}
fn hnsw_index_path(&self) -> PathBuf {
let mut path = self.path.clone();
let ext = match path.extension() {
Some(e) => format!("{}.hnsw", e.to_string_lossy()),
None => "hnsw".to_string(),
};
path.set_extension(ext);
path
}
pub fn insert(&mut self, key: &str, vector: &[f32]) -> Result<()> {
if vector.len() != self.config.dimensions as usize {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: vector.len() as u16,
});
}
if self.keys.contains(key) {
return Ok(()); }
let key_bytes = key.as_bytes();
let key_len = key_bytes.len();
let entry_size = 2 + key_len + (vector.len() * 4);
let offset = self.write_offset.load(Ordering::SeqCst) as usize;
if offset + entry_size > self.mmap.len() {
self.grow_file(entry_size)?;
}
self.mmap[offset..offset + 2].copy_from_slice(&(key_len as u16).to_le_bytes());
self.mmap[offset + 2..offset + 2 + key_len].copy_from_slice(key_bytes);
let vector_start = offset + 2 + key_len;
for (i, &val) in vector.iter().enumerate() {
let byte_offset = vector_start + i * 4;
self.mmap[byte_offset..byte_offset + 4].copy_from_slice(&val.to_le_bytes());
}
self.write_offset
.store((offset + entry_size) as u64, Ordering::SeqCst);
self.vector_count.fetch_add(1, Ordering::SeqCst);
self.keys.insert(key.to_string());
self.key_to_offset.insert(key.to_string(), offset as u64);
if self.hnsw_index.is_some() {
self.insert_to_hnsw_incremental(key, vector);
self.index_dirty = true;
} else if self.config.index_threshold > 0 && self.keys.len() >= self.config.index_threshold
{
self.build_index()?;
}
if self.index_dirty
&& self.checkpoint_interval.as_secs() > 0
&& self.last_checkpoint.elapsed() >= self.checkpoint_interval
{
self.checkpoint()?;
}
Ok(())
}
pub fn insert_batch(&mut self, keys: &[&str], vectors: &[&[f32]]) -> Result<usize> {
if keys.len() != vectors.len() {
return Err(SynaError::ShapeMismatch {
data_size: vectors.len(),
expected_size: keys.len(),
});
}
let dims = self.config.dimensions as usize;
let mut inserted = 0;
let mut offset = self.write_offset.load(Ordering::SeqCst) as usize;
for (key, vector) in keys.iter().zip(vectors.iter()) {
if vector.len() != dims {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: vector.len() as u16,
});
}
if self.keys.contains(*key) {
continue;
}
let key_bytes = key.as_bytes();
let key_len = key_bytes.len();
let entry_size = 2 + key_len + (dims * 4);
if offset + entry_size > self.mmap.len() {
self.write_offset.store(offset as u64, Ordering::SeqCst);
self.grow_file(entry_size)?;
}
self.mmap[offset..offset + 2].copy_from_slice(&(key_len as u16).to_le_bytes());
self.mmap[offset + 2..offset + 2 + key_len].copy_from_slice(key_bytes);
let vector_start = offset + 2 + key_len;
for (i, &val) in vector.iter().enumerate() {
let byte_offset = vector_start + i * 4;
self.mmap[byte_offset..byte_offset + 4].copy_from_slice(&val.to_le_bytes());
}
self.keys.insert(key.to_string());
self.key_to_offset.insert(key.to_string(), offset as u64);
offset += entry_size;
inserted += 1;
}
self.write_offset.store(offset as u64, Ordering::SeqCst);
self.vector_count
.fetch_add(inserted as u64, Ordering::SeqCst);
Ok(inserted)
}
fn grow_file(&mut self, additional: usize) -> Result<()> {
let current_size = self.mmap.len();
let required = self.write_offset.load(Ordering::SeqCst) as usize + additional;
let new_size = (current_size * 2).max(required + 1024 * 1024);
self.mmap.flush()?;
self.file.set_len(new_size as u64)?;
self.mmap = unsafe { MmapMut::map_mut(&self.file)? };
Ok(())
}
pub fn get(&self, key: &str) -> Result<Option<Vec<f32>>> {
let offset = match self.key_to_offset.get(key) {
Some(&o) => o as usize,
None => return Ok(None),
};
let dims = self.config.dimensions as usize;
let key_len = u16::from_le_bytes([self.mmap[offset], self.mmap[offset + 1]]) as usize;
let vector_start = offset + 2 + key_len;
let vector_bytes = &self.mmap[vector_start..vector_start + dims * 4];
let vector: Vec<f32> = vector_bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(Some(vector))
}
pub fn get_slice(&self, key: &str) -> Option<&[f32]> {
let offset = *self.key_to_offset.get(key)? as usize;
let dims = self.config.dimensions as usize;
let key_len = u16::from_le_bytes(self.mmap[offset..offset + 2].try_into().ok()?) as usize;
let vector_start = offset + 2 + key_len;
let vector_bytes = &self.mmap[vector_start..vector_start + dims * 4];
let (prefix, floats, suffix) = unsafe { vector_bytes.align_to::<f32>() };
if prefix.is_empty() && suffix.is_empty() && floats.len() == dims {
Some(floats)
} else {
None
}
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<MmapSearchResult>> {
if query.len() != self.config.dimensions as usize {
return Err(SynaError::DimensionMismatch {
expected: self.config.dimensions,
got: query.len() as u16,
});
}
if let Some(ref index) = self.hnsw_index {
if self.keys.len() >= self.config.index_threshold {
return self.search_hnsw(index, query, k);
}
}
self.search_brute_force(query, k)
}
fn search_brute_force(&self, query: &[f32], k: usize) -> Result<Vec<MmapSearchResult>> {
let mut results: Vec<MmapSearchResult> = Vec::with_capacity(self.keys.len());
for key in &self.keys {
if let Some(vector) = self.get_slice(key) {
let score = self.config.metric.distance(query, vector);
results.push(MmapSearchResult {
key: key.clone(),
score,
vector: vector.to_vec(),
});
}
}
results.sort_by(|a, b| a.score.total_cmp(&b.score));
results.truncate(k);
Ok(results)
}
fn search_hnsw(
&self,
index: &HnswIndex,
query: &[f32],
k: usize,
) -> Result<Vec<MmapSearchResult>> {
let hnsw_results = index.search(query, k);
let mut results = Vec::with_capacity(hnsw_results.len());
for (key, score) in hnsw_results {
if let Some(vector) = self.get_slice(&key) {
results.push(MmapSearchResult {
key,
score,
vector: vector.to_vec(),
});
}
}
Ok(results)
}
pub fn build_index(&mut self) -> Result<()> {
let mut index = HnswIndex::new(
self.config.dimensions,
self.config.metric,
self.config.hnsw_config.clone(),
);
for key in &self.keys {
if let Some(vector) = self.get_slice(key) {
self.add_node_to_index(&mut index, key, vector);
}
}
let hnsw_path = self.hnsw_index_path();
index.save(&hnsw_path)?;
self.hnsw_index = Some(index);
self.index_dirty = false;
self.last_checkpoint = Instant::now();
Ok(())
}
fn add_node_to_index(&self, index: &mut HnswIndex, key: &str, vector: &[f32]) {
if index.key_to_id.contains_key(key) {
return;
}
let level = index.random_level();
let node = HnswNode::new(key.to_string(), vector.to_vec(), level);
let node_id = index.nodes.len();
index.nodes.push(node);
index.key_to_id.insert(key.to_string(), node_id);
let current_max_level = index.max_level();
if index.entry_point.is_none() || level > current_max_level {
index.entry_point = Some(node_id);
index.set_max_level(level);
}
if node_id > 0 {
let m = index.config().m;
let m_max = index.config().m_max;
for l in 0..=level {
let max_neighbors = if l == 0 { m } else { m_max };
let mut neighbors = Vec::new();
let mut distances: Vec<(usize, f32)> = index
.nodes
.iter()
.enumerate()
.filter(|(id, n)| *id != node_id && n.neighbors.len() > l)
.map(|(id, n)| (id, index.metric().distance(vector, &n.vector)))
.collect();
distances.sort_by(|a, b| a.1.total_cmp(&b.1));
for (neighbor_id, dist) in distances.into_iter().take(max_neighbors) {
neighbors.push((neighbor_id, dist));
if l < index.nodes[neighbor_id].neighbors.len() {
index.nodes[neighbor_id].neighbors[l].push((node_id, dist));
if index.nodes[neighbor_id].neighbors[l].len() > max_neighbors {
index.nodes[neighbor_id].neighbors[l]
.sort_by(|a, b| a.1.total_cmp(&b.1));
index.nodes[neighbor_id].neighbors[l].truncate(max_neighbors);
}
}
}
if l < index.nodes[node_id].neighbors.len() {
index.nodes[node_id].neighbors[l] = neighbors;
}
}
}
}
fn insert_to_hnsw_incremental(&mut self, key: &str, vector: &[f32]) {
let index = match self.hnsw_index.as_mut() {
Some(idx) => idx,
None => return,
};
if index.key_to_id.contains_key(key) {
return;
}
let level = index.random_level();
let node = HnswNode::new(key.to_string(), vector.to_vec(), level);
let node_id = index.nodes.len();
index.nodes.push(node);
index.key_to_id.insert(key.to_string(), node_id);
let current_max_level = index.max_level();
if index.entry_point.is_none() || level > current_max_level {
index.entry_point = Some(node_id);
index.set_max_level(level);
}
if node_id == 0 {
return;
}
let m = index.config().m;
let m_max = index.config().m_max;
let ef_construction = index.config().ef_construction;
let mut ep = index.entry_point.unwrap_or(0);
for lc in ((level + 1)..=index.max_level()).rev() {
let results = index.search_layer(vector, ep, 1, lc);
if !results.is_empty() {
ep = results[0].0;
}
}
let start_level = level.min(index.max_level());
for l in (0..=start_level).rev() {
let candidates = index.search_layer(vector, ep, ef_construction, l);
let max_neighbors = if l == 0 { m } else { m_max };
let neighbors: Vec<(usize, f32)> = candidates.into_iter().take(max_neighbors).collect();
if !neighbors.is_empty() {
ep = neighbors[0].0;
}
if l < index.nodes[node_id].neighbors.len() {
index.nodes[node_id].neighbors[l] = neighbors.clone();
}
for (neighbor_id, dist) in neighbors {
if l < index.nodes[neighbor_id].neighbors.len() {
index.nodes[neighbor_id].neighbors[l].push((node_id, dist));
if index.nodes[neighbor_id].neighbors[l].len() > max_neighbors {
index.nodes[neighbor_id].neighbors[l].sort_by(|a, b| a.1.total_cmp(&b.1));
index.nodes[neighbor_id].neighbors[l].truncate(max_neighbors);
}
}
}
}
}
pub fn checkpoint(&mut self) -> Result<()> {
let count = self.vector_count.load(Ordering::SeqCst);
let offset = self.write_offset.load(Ordering::SeqCst);
self.mmap[16..24].copy_from_slice(&count.to_le_bytes());
self.mmap[24..32].copy_from_slice(&offset.to_le_bytes());
self.mmap.flush()?;
if self.index_dirty {
if let Some(ref index) = self.hnsw_index {
let hnsw_path = self.hnsw_index_path();
index.save(&hnsw_path)?;
}
self.index_dirty = false;
}
self.last_checkpoint = Instant::now();
Ok(())
}
pub fn flush(&mut self) -> Result<()> {
self.checkpoint()
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
pub fn dimensions(&self) -> u16 {
self.config.dimensions
}
pub fn metric(&self) -> DistanceMetric {
self.config.metric
}
pub fn has_index(&self) -> bool {
self.hnsw_index.is_some()
}
pub fn is_dirty(&self) -> bool {
self.index_dirty
}
pub fn keys(&self) -> Vec<String> {
self.keys.iter().cloned().collect()
}
}
impl Drop for MmapVectorStore {
fn drop(&mut self) {
if let Err(e) = self.checkpoint() {
eprintln!(
"Warning: Failed to checkpoint MmapVectorStore on drop: {}",
e
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_mmap_vector_store_basic() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.mmap");
let config = MmapVectorConfig {
dimensions: 128,
initial_capacity: 1000,
..Default::default()
};
let mut store = MmapVectorStore::new(&path, config).unwrap();
let vec1: Vec<f32> = (0..128).map(|i| i as f32 * 0.01).collect();
store.insert("v1", &vec1).unwrap();
let retrieved = store.get("v1").unwrap().unwrap();
assert_eq!(retrieved.len(), 128);
assert!((retrieved[0] - 0.0).abs() < 0.001);
assert!((retrieved[1] - 0.01).abs() < 0.001);
assert_eq!(store.len(), 1);
}
#[test]
fn test_mmap_vector_store_batch_insert() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.mmap");
let config = MmapVectorConfig {
dimensions: 64,
initial_capacity: 10000,
..Default::default()
};
let mut store = MmapVectorStore::new(&path, config).unwrap();
let keys: Vec<String> = (0..100).map(|i| format!("v{}", i)).collect();
let key_refs: Vec<&str> = keys.iter().map(|s| s.as_str()).collect();
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..64).map(|j| (i * 64 + j) as f32 * 0.001).collect())
.collect();
let vec_refs: Vec<&[f32]> = vectors.iter().map(|v| v.as_slice()).collect();
let inserted = store.insert_batch(&key_refs, &vec_refs).unwrap();
assert_eq!(inserted, 100);
assert_eq!(store.len(), 100);
}
#[test]
fn test_mmap_vector_store_search() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.mmap");
let config = MmapVectorConfig {
dimensions: 64,
initial_capacity: 1000,
index_threshold: 0, metric: DistanceMetric::Euclidean,
..Default::default()
};
let mut store = MmapVectorStore::new(&path, config).unwrap();
for i in 0..10 {
let vec: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 * 0.001).collect();
store.insert(&format!("v{}", i), &vec).unwrap();
}
let query: Vec<f32> = (0..64).map(|j| j as f32 * 0.001).collect();
let results = store.search(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "v0");
assert!(results[0].score < 0.001);
}
#[test]
fn test_mmap_vector_store_persistence() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.mmap");
{
let config = MmapVectorConfig {
dimensions: 64,
initial_capacity: 1000,
..Default::default()
};
let mut store = MmapVectorStore::new(&path, config).unwrap();
for i in 0..10 {
let vec: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 * 0.001).collect();
store.insert(&format!("v{}", i), &vec).unwrap();
}
store.flush().unwrap();
}
{
let config = MmapVectorConfig {
dimensions: 64,
initial_capacity: 1000,
..Default::default()
};
let store = MmapVectorStore::new(&path, config).unwrap();
assert_eq!(store.len(), 10);
let vec = store.get("v0").unwrap().unwrap();
assert_eq!(vec.len(), 64);
}
}
#[test]
fn test_mmap_vector_store_dimension_validation() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.mmap");
let config = MmapVectorConfig {
dimensions: 32,
..Default::default()
};
assert!(MmapVectorStore::new(&path, config).is_err());
let config = MmapVectorConfig {
dimensions: 128,
..Default::default()
};
let mut store = MmapVectorStore::new(&path, config).unwrap();
let wrong_vec = vec![0.1f32; 64];
assert!(store.insert("v1", &wrong_vec).is_err());
}
#[test]
fn test_mmap_vector_store_hnsw_index() {
let dir = tempdir().unwrap();
let path = dir.path().join("test.mmap");
let config = MmapVectorConfig {
dimensions: 64,
initial_capacity: 1000,
index_threshold: 5,
metric: DistanceMetric::Euclidean,
..Default::default()
};
let mut store = MmapVectorStore::new(&path, config).unwrap();
for i in 0..10 {
let vec: Vec<f32> = (0..64).map(|j| (i * 64 + j) as f32 * 0.001).collect();
store.insert(&format!("v{}", i), &vec).unwrap();
}
assert!(store.has_index());
let query: Vec<f32> = (0..64).map(|j| j as f32 * 0.001).collect();
let results = store.search(&query, 3).unwrap();
assert!(!results.is_empty());
assert!(results[0].key.starts_with("v"));
}
}