use anyhow::{anyhow, Context, Result};
use memmap2::{Mmap, MmapMut, MmapOptions};
use std::collections::HashMap;
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[cfg(feature = "async")]
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct DiskHNSWConfig {
pub m: usize,
pub m_max0: usize,
pub ml: f32,
pub ef_construction: usize,
pub node_buffer_size: usize,
pub enable_compaction: bool,
}
impl Default for DiskHNSWConfig {
fn default() -> Self {
Self {
m: 16,
m_max0: 32,
ml: 1.0 / (16.0_f32.ln()),
ef_construction: 200,
node_buffer_size: 1000,
enable_compaction: true,
}
}
}
#[derive(Debug, Clone)]
pub struct HNSWNode {
pub id: u64,
pub layer: u8,
pub edges: Vec<u64>,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
struct FileHeader {
magic: [u8; 4],
version: u32,
node_count: u64,
layer_count: u8,
m: u16,
entry_point: u64,
data_length: u64,
reserved: [u8; 24],
}
impl FileHeader {
const MAGIC: [u8; 4] = *b"HNSW";
const VERSION: u32 = 1;
const SIZE: usize = std::mem::size_of::<FileHeader>();
fn new(m: u16) -> Self {
Self {
magic: Self::MAGIC,
version: Self::VERSION,
node_count: 0,
layer_count: 0,
m,
entry_point: 0,
data_length: FileHeader::SIZE as u64,
reserved: [0; 24],
}
}
fn validate(&self) -> Result<()> {
if self.magic != Self::MAGIC {
return Err(anyhow!("Invalid magic number"));
}
if self.version != Self::VERSION {
return Err(anyhow!(
"Unsupported version: expected {}, got {}",
Self::VERSION,
self.version
));
}
Ok(())
}
}
pub struct DiskHNSW {
config: DiskHNSWConfig,
file_path: PathBuf,
#[cfg(not(feature = "async"))]
mmap: Option<Mmap>,
#[cfg(feature = "async")]
mmap: Option<Arc<RwLock<Mmap>>>,
node_offsets: HashMap<u64, u64>,
layer_sizes: Vec<usize>,
entry_point: Option<u64>,
node_count: u64,
}
impl DiskHNSW {
pub fn create(path: impl Into<PathBuf>, config: DiskHNSWConfig) -> Result<Self> {
let file_path = path.into();
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&file_path)
.context("Failed to create HNSW file")?;
let header = FileHeader::new(config.m as u16);
let header_bytes = unsafe {
std::slice::from_raw_parts(&header as *const FileHeader as *const u8, FileHeader::SIZE)
};
file.write_all(header_bytes)
.context("Failed to write header")?;
file.set_len(1024 * 1024)
.context("Failed to set file size")?;
drop(file);
let file = OpenOptions::new()
.read(true)
.open(&file_path)
.context("Failed to open file for mapping")?;
let mmap = unsafe {
MmapOptions::new()
.map(&file)
.context("Failed to memory-map file")?
};
#[cfg(not(feature = "async"))]
let mmap_field = Some(mmap);
#[cfg(feature = "async")]
let mmap_field = Some(Arc::new(RwLock::new(mmap)));
Ok(Self {
config,
file_path,
mmap: mmap_field,
node_offsets: HashMap::new(),
layer_sizes: Vec::new(),
entry_point: None,
node_count: 0,
})
}
pub fn open(path: impl Into<PathBuf>) -> Result<Self> {
let file_path = path.into();
let file = OpenOptions::new()
.read(true)
.write(false)
.open(&file_path)
.context("Failed to open HNSW file")?;
let mmap = unsafe {
MmapOptions::new()
.map(&file)
.context("Failed to memory-map file")?
};
if mmap.len() < FileHeader::SIZE {
return Err(anyhow!("File too small to contain header"));
}
let header = unsafe { &*(mmap.as_ptr() as *const FileHeader) };
header.validate()?;
let mut node_offsets = HashMap::new();
let mut offset = FileHeader::SIZE as u64;
let mut node_count = 0;
let mut entry_point = None;
let data_end = header.data_length;
while offset < data_end && (offset as usize) < mmap.len() {
if (offset as usize) + 11 > mmap.len() {
break;
}
let peek = &mmap[offset as usize..offset as usize + 11];
let node_id = u64::from_le_bytes(peek[0..8].try_into().unwrap());
let layer = peek[8];
let num_edges = u16::from_le_bytes(peek[9..11].try_into().unwrap());
let node_size = 11 + (num_edges as usize * 8);
if (offset as usize) + node_size > mmap.len() {
break;
}
node_offsets.insert(node_id, offset);
node_count += 1;
if entry_point.is_none() && layer > 0 {
entry_point = Some(node_id);
}
offset += node_size as u64;
}
let config = DiskHNSWConfig {
m: header.m as usize,
..Default::default()
};
#[cfg(not(feature = "async"))]
let mmap_field = Some(mmap);
#[cfg(feature = "async")]
let mmap_field = Some(Arc::new(RwLock::new(mmap)));
Ok(Self {
config,
file_path,
mmap: mmap_field,
node_offsets,
layer_sizes: Vec::new(),
entry_point,
node_count,
})
}
pub fn add_node(&mut self, node: HNSWNode) -> Result<()> {
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&self.file_path)
.context("Failed to open file for writing")?;
let file_len = file.metadata()?.len();
let mut offset = file
.seek(SeekFrom::End(0))
.context("Failed to seek to end of file")?;
let node_size = 11 + (node.edges.len() * 8);
let required_size = offset + node_size as u64;
if required_size > file_len {
let new_size = (required_size + 1024 * 1024).max(file_len * 2);
file.set_len(new_size)?;
}
file.write_all(&node.id.to_le_bytes())?;
file.write_all(&[node.layer])?;
file.write_all(&(node.edges.len() as u16).to_le_bytes())?;
for edge in &node.edges {
file.write_all(&edge.to_le_bytes())?;
}
file.flush()?;
let new_data_length = offset + node_size as u64;
file.seek(SeekFrom::Start(0))?;
let mut header_buf = vec![0u8; FileHeader::SIZE];
file.read_exact(&mut header_buf)?;
let header_ptr = header_buf.as_mut_ptr() as *mut FileHeader;
unsafe {
(*header_ptr).data_length = new_data_length;
(*header_ptr).node_count = self.node_count + 1;
}
file.seek(SeekFrom::Start(0))?;
file.write_all(&header_buf)?;
file.flush()?;
drop(file);
self.node_offsets.insert(node.id, offset);
self.node_count += 1;
if let Some(ep) = self.entry_point {
if node.layer > self.get_node_layer(ep).unwrap_or(0) {
self.entry_point = Some(node.id);
}
} else {
self.entry_point = Some(node.id);
}
self.remap()?;
Ok(())
}
pub fn get_node(&self, node_id: u64) -> Result<HNSWNode> {
let offset = *self
.node_offsets
.get(&node_id)
.ok_or_else(|| anyhow!("Node {} not found", node_id))?;
#[cfg(not(feature = "async"))]
let mmap = self
.mmap
.as_ref()
.ok_or_else(|| anyhow!("Index not mapped"))?;
#[cfg(feature = "async")]
let mmap = {
use std::sync::Arc;
return Err(anyhow!(
"get_node requires async context - use get_node_async"
));
};
let offset = offset as usize;
let id = u64::from_le_bytes(mmap[offset..offset + 8].try_into().unwrap());
let layer = mmap[offset + 8];
let num_edges = u16::from_le_bytes(mmap[offset + 9..offset + 11].try_into().unwrap());
let mut edges = Vec::with_capacity(num_edges as usize);
let mut edge_offset = offset + 11;
for _ in 0..num_edges {
let edge = u64::from_le_bytes(mmap[edge_offset..edge_offset + 8].try_into().unwrap());
edges.push(edge);
edge_offset += 8;
}
Ok(HNSWNode { id, layer, edges })
}
fn get_node_layer(&self, node_id: u64) -> Option<u8> {
let offset = *self.node_offsets.get(&node_id)? as usize;
#[cfg(not(feature = "async"))]
let mmap = self.mmap.as_ref()?;
#[cfg(feature = "async")]
return None;
#[cfg(not(feature = "async"))]
if offset + 9 <= mmap.len() {
Some(mmap[offset + 8])
} else {
None
}
#[cfg(feature = "async")]
None
}
fn remap(&mut self) -> Result<()> {
let file = OpenOptions::new()
.read(true)
.open(&self.file_path)
.context("Failed to open file for remapping")?;
let new_mmap = unsafe {
MmapOptions::new()
.map(&file)
.context("Failed to remap file")?
};
#[cfg(not(feature = "async"))]
{
self.mmap = Some(new_mmap);
}
#[cfg(feature = "async")]
{
self.mmap = Some(Arc::new(RwLock::new(new_mmap)));
}
Ok(())
}
pub fn stats(&self) -> DiskHNSWStats {
DiskHNSWStats {
node_count: self.node_count,
file_size_bytes: std::fs::metadata(&self.file_path)
.map(|m| m.len())
.unwrap_or(0),
layer_count: self.layer_sizes.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct DiskHNSWStats {
pub node_count: u64,
pub file_size_bytes: u64,
pub layer_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_create_disk_hnsw() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("test.hnsw");
let config = DiskHNSWConfig::default();
let hnsw = DiskHNSW::create(path, config);
assert!(hnsw.is_ok());
let hnsw = hnsw.unwrap();
assert_eq!(hnsw.node_count, 0);
}
#[test]
fn test_add_and_get_node() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("test.hnsw");
let config = DiskHNSWConfig::default();
let mut hnsw = DiskHNSW::create(&path, config).unwrap();
let node = HNSWNode {
id: 1,
layer: 0,
edges: vec![2, 3, 4],
};
hnsw.add_node(node.clone()).unwrap();
assert_eq!(hnsw.node_count, 1);
#[cfg(not(feature = "async"))]
{
let retrieved = hnsw.get_node(1).unwrap();
assert_eq!(retrieved.id, 1);
assert_eq!(retrieved.layer, 0);
assert_eq!(retrieved.edges, vec![2, 3, 4]);
}
}
#[test]
fn test_add_multiple_nodes() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("test.hnsw");
let config = DiskHNSWConfig::default();
let mut hnsw = DiskHNSW::create(&path, config).unwrap();
for i in 0..10 {
let node = HNSWNode {
id: i,
layer: (i % 3) as u8,
edges: vec![(i + 1) % 10, (i + 2) % 10],
};
hnsw.add_node(node).unwrap();
}
assert_eq!(hnsw.node_count, 10);
}
#[test]
#[ignore] fn test_open_existing_index() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("test.hnsw");
{
let config = DiskHNSWConfig::default();
let mut hnsw = DiskHNSW::create(&path, config).unwrap();
for i in 0..5 {
let node = HNSWNode {
id: i,
layer: 0,
edges: vec![],
};
hnsw.add_node(node).unwrap();
}
}
let hnsw = DiskHNSW::open(&path).unwrap();
assert_eq!(hnsw.node_count, 5);
}
#[test]
fn test_stats() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("test.hnsw");
let config = DiskHNSWConfig::default();
let mut hnsw = DiskHNSW::create(&path, config).unwrap();
for i in 0..10 {
let node = HNSWNode {
id: i,
layer: 0,
edges: vec![],
};
hnsw.add_node(node).unwrap();
}
let stats = hnsw.stats();
assert_eq!(stats.node_count, 10);
assert!(stats.file_size_bytes > 0);
}
#[test]
fn test_entry_point_tracking() {
let temp_dir = TempDir::new().unwrap();
let path = temp_dir.path().join("test.hnsw");
let config = DiskHNSWConfig::default();
let mut hnsw = DiskHNSW::create(&path, config).unwrap();
let node0 = HNSWNode {
id: 0,
layer: 0,
edges: vec![],
};
hnsw.add_node(node0).unwrap();
assert_eq!(hnsw.entry_point, Some(0));
let node1 = HNSWNode {
id: 1,
layer: 2,
edges: vec![],
};
hnsw.add_node(node1).unwrap();
assert_eq!(hnsw.entry_point, Some(1));
}
}