use std::collections::HashMap;
use std::fs::OpenOptions;
use std::io::{Read, Seek, SeekFrom, Write};
use serde::{Deserialize, Serialize};
use crate::types::Record;
const DEFAULT_MMAX: usize = 16;
const DEFAULT_MAX_LAYERS: usize = 16;
fn assign_layer(max_layers: usize) -> usize {
let mut layer = 0;
while layer < max_layers - 1 && rand::random_bool(0.5) {
layer += 1;
}
layer
}
fn slot_size(max_neighbors: usize) -> usize {
let count_size = std::mem::size_of::<u32>();
let neighbors_size = max_neighbors * std::mem::size_of::<u32>();
count_size + neighbors_size
}
fn layer_offset(file_offset: u64, layer: usize, max_neighbors: usize) -> u64 {
file_offset + (layer * slot_size(max_neighbors)) as u64
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HNSW {
pub max_neighbors_per_document: usize,
max_layers: usize,
pub highest_layer: usize,
pub entry_point: Option<u32>, id_to_index: HashMap<String, u32>, pub node_offsets: Vec<u64>, #[serde(skip)]
pub index_to_id: Vec<String>, #[serde(skip)]
path: String,
#[serde(skip)]
graph_path: String,
}
impl HNSW {
pub fn load(path: &str, graph_path: &str) -> Self {
let mut hnsw = std::fs::read_to_string(path)
.ok()
.and_then(|data| serde_json::from_str::<HNSW>(&data).ok())
.unwrap_or_else(|| HNSW {
max_neighbors_per_document: DEFAULT_MMAX,
max_layers: DEFAULT_MAX_LAYERS,
highest_layer: 0,
entry_point: None,
id_to_index: HashMap::new(),
node_offsets: Vec::new(),
index_to_id: Vec::new(),
path: String::new(),
graph_path: String::new(),
});
hnsw.index_to_id = vec![String::new(); hnsw.id_to_index.len()];
for (id, idx) in &hnsw.id_to_index {
hnsw.index_to_id[*idx as usize] = id.clone();
}
hnsw.path = path.to_string();
hnsw.graph_path = graph_path.to_string();
hnsw
}
fn save(&self) {
let data = serde_json::to_string_pretty(self).unwrap();
std::fs::write(&self.path, data).unwrap();
}
fn init_node(&self, assigned_layer: usize) -> u64 {
let mut file = OpenOptions::new()
.create(true).append(true).read(true)
.open(&self.graph_path).unwrap();
let offset = file.seek(SeekFrom::End(0)).unwrap();
let zeroes = vec![0u8; (assigned_layer + 1) * slot_size(self.max_neighbors_per_document)];
file.write_all(&zeroes).unwrap();
offset
}
pub fn get_neighbors(&self, node_index: u32, layer: usize) -> Vec<u32> {
let mut file = OpenOptions::new()
.read(true)
.open(&self.graph_path).unwrap();
let offset = layer_offset(self.node_offsets[node_index as usize], layer, self.max_neighbors_per_document);
file.seek(SeekFrom::Start(offset)).unwrap();
let mut count_buf = [0u8; 4];
file.read_exact(&mut count_buf).unwrap();
let count = u32::from_le_bytes(count_buf) as usize;
let mut neighbors = Vec::with_capacity(count);
for _ in 0..count {
let mut buf = [0u8; 4];
file.read_exact(&mut buf).unwrap();
neighbors.push(u32::from_le_bytes(buf));
}
neighbors
}
pub fn set_neighbors(&self, node_index: u32, layer: usize, neighbors: &[u32]) {
assert!(neighbors.len() <= self.max_neighbors_per_document);
let mut file = OpenOptions::new()
.write(true)
.open(&self.graph_path).unwrap();
let offset = layer_offset(self.node_offsets[node_index as usize], layer, self.max_neighbors_per_document);
file.seek(SeekFrom::Start(offset)).unwrap();
file.write_all(&(neighbors.len() as u32).to_le_bytes()).unwrap();
for &n in neighbors {
file.write_all(&n.to_le_bytes()).unwrap();
}
}
pub fn wipe(&mut self) {
let _ = std::fs::remove_file(&self.path);
let _ = std::fs::remove_file(&self.graph_path);
self.highest_layer = 0;
self.entry_point = None;
self.id_to_index.clear();
self.node_offsets.clear();
self.index_to_id.clear();
}
pub fn insert(&mut self, record: &Record) -> (u32, usize, Option<u32>, usize) {
let node_index = self.id_to_index.len() as u32;
self.id_to_index.insert(record.id.clone(), node_index);
self.index_to_id.push(record.id.clone());
let layer = assign_layer(self.max_layers);
let prev_entry_point = self.entry_point;
let prev_highest_layer = self.highest_layer;
if self.entry_point.is_none() || layer > self.highest_layer {
self.highest_layer = layer;
self.entry_point = Some(node_index);
}
let offset = self.init_node(layer);
self.node_offsets.push(offset);
self.save();
(node_index, layer, prev_entry_point, prev_highest_layer)
}
}