use std::collections::{HashMap, HashSet};
use crate::bm25;
use crate::hnsw::index::HNSW;
use crate::index::{load_index, save_index, Index};
use crate::similarities::cosine_similarity;
use crate::storage::{append_record, retrieve_record};
use crate::types::Record;
const DEFAULT_CHECKPOINT_EVERY: usize = 100000;
pub struct Database {
pub data_path: String,
pub index_path: String,
pub bm25_index_path: String,
pub index: Index,
pub bm25_index: bm25::index::Bm25Index,
pub bm25_tokenizer: bm25::tokenizer::Tokenizer,
pub hnsw: HNSW,
pub graph_path: String,
pub checkpoint_every: usize,
dirty_count: usize,
}
impl Database {
pub fn new(data_path: &str, index_path: &str, bm25_index_path: &str, hnsw_path: &str, graph_path: &str) -> Database {
if let Some(dir) = std::path::Path::new(data_path).parent() {
std::fs::create_dir_all(dir).unwrap();
}
let index = load_index(index_path);
let bm25_index = bm25::index::Bm25Index::load(bm25_index_path);
let bm25_tokenizer = bm25::tokenizer::Tokenizer::new();
let hnsw = HNSW::load(hnsw_path, graph_path);
Database {
data_path: data_path.to_string(),
index_path: index_path.to_string(),
bm25_index_path: bm25_index_path.to_string(),
index,
bm25_index,
bm25_tokenizer,
hnsw,
graph_path: graph_path.to_string(),
checkpoint_every: DEFAULT_CHECKPOINT_EVERY,
dirty_count: 0,
}
}
pub fn save_all(&mut self) {
save_index(&self.index_path, &self.index);
self.bm25_index.save(&self.bm25_index_path);
self.hnsw.save();
self.dirty_count = 0;
}
pub fn insert_raw(&mut self, vector: Vec<f32>, text: &str, id: Option<&str>) {
let record = Record::new(vector, Some(text.to_string()), id.map(str::to_string));
let tokens = self.bm25_tokenizer.tokenize(text);
let offset = append_record(&record, &self.data_path);
self.bm25_index.index_record(&record.id, &tokens);
self.index.insert(record.id.clone(), offset);
let (node_index, assigned_layer, prev_entry, prev_highest) = self.hnsw.insert(&record);
self.make_connections(&record.vector, node_index, assigned_layer, prev_entry, prev_highest);
self.dirty_count += 1;
if self.dirty_count >= self.checkpoint_every {
self.save_all();
}
}
fn make_connections(&mut self, vector: &[f32], node_index: u32, assigned_layer: usize, prev_entry: Option<u32>, prev_highest: usize) {
let entry = match prev_entry {
Some(ep) => ep,
None => return,
};
let candidates_by_layer = self.search_layers(vector, self.hnsw.max_neighbors_per_document * 2, Some((entry, prev_highest)));
for (layer, candidates) in &candidates_by_layer {
if *layer > assigned_layer {
continue;
}
let neighbors: Vec<u32> = candidates.iter()
.filter(|&&n| n != node_index) .take(self.hnsw.max_neighbors_per_document) .cloned()
.collect();
self.hnsw.set_neighbors(node_index, *layer, &neighbors);
for &neighbor in &neighbors {
let mut existing = self.hnsw.get_neighbors(neighbor, *layer);
if !existing.contains(&node_index) {
existing.push(node_index);
if existing.len() > self.hnsw.max_neighbors_per_document {
let neighbor_uuid = &self.hnsw.index_to_id[neighbor as usize];
let neighbor_emb = retrieve_record(*self.index.get(neighbor_uuid).unwrap(), &self.data_path).vector;
let mut scored: Vec<(f32, u32)> = existing.iter()
.map(|&n| {
let uuid = &self.hnsw.index_to_id[n as usize];
let vec = retrieve_record(*self.index.get(uuid).unwrap(), &self.data_path).vector;
(cosine_similarity(&neighbor_emb, &vec), n)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
existing = scored.iter()
.take(self.hnsw.max_neighbors_per_document)
.map(|&(_, n)| n)
.collect();
}
self.hnsw.set_neighbors(neighbor, *layer, &existing);
}
}
}
}
pub fn text_search(&self, query: &str, k: usize) -> Vec<Record> {
let tokens = self.bm25_tokenizer.tokenize(query);
let scores = self.bm25_index.score(&tokens);
scores.into_iter().take(k)
.filter_map(|(doc_id, _)| self.index.get(&doc_id))
.map(|offset| retrieve_record(*offset, &self.data_path))
.collect()
}
pub fn delete(&mut self, id: &str) -> bool {
let Some(offset) = self.index.get(id).copied() else { return false };
let record = retrieve_record(offset, &self.data_path);
let tokens = self.bm25_tokenizer.tokenize(record.metadata.as_deref().unwrap_or(""));
self.bm25_index.remove_record(id, &tokens);
self.index.remove(id);
save_index(&self.index_path, &self.index);
self.bm25_index.save(&self.bm25_index_path);
true
}
pub fn wipe(&mut self) {
self.index.clear();
self.bm25_index = bm25::index::Bm25Index::new();
self.hnsw.wipe();
self.dirty_count = 0;
let _ = std::fs::remove_file(&self.data_path);
let _ = std::fs::remove_file(&self.index_path);
let _ = std::fs::remove_file(&self.bm25_index_path);
}
pub fn search_scored(&self, query_vector: &[f32], k: usize) -> Vec<(f32, Record)> {
let mut results: Vec<(f32, Record)> = self.index.iter()
.map(|(_id, offset)| retrieve_record(*offset, &self.data_path))
.map(|r| {
let score = cosine_similarity(&r.vector, query_vector);
(score, r)
})
.collect();
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
results.truncate(k);
results
}
fn score_node(&self, node_index: u32, query_vector: &[f32]) -> Option<f32> {
let uuid = self.hnsw.index_to_id.get(node_index as usize)?;
let offset = self.index.get(uuid)?;
let record = retrieve_record(*offset, &self.data_path);
Some(cosine_similarity(&record.vector, query_vector))
}
pub fn search_hnsw(&self, query_vector: &[f32], ef: usize) -> Vec<Record> {
if self.hnsw.node_offsets.is_empty() {
return Vec::new();
}
let (mut current, top_layer) = match self.hnsw.entry_point {
Some(ep) => (ep, self.hnsw.highest_layer),
None => return Vec::new(),
};
for layer in (1..=top_layer).rev() { loop {
let current_score = self.score_node(current, query_vector).unwrap_or(f32::NEG_INFINITY); let best = self.hnsw.get_neighbors(current, layer).into_iter()
.filter_map(|n| Some((self.score_node(n, query_vector)?, n))) .filter(|(s, _)| *s > current_score) .max_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)); match best {
Some((_, n)) => current = n, None => break, }
}
}
let entrypoint_score = match self.score_node(current, query_vector) {
Some(s) => s,
None => return Vec::new(),
};
let mut visited: HashSet<u32> = HashSet::from([current]);
let mut candidates: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
let mut results: Vec<(f32, u32)> = vec![(entrypoint_score, current)];
while let Some(&(c_score, c_node)) = candidates.first() {
let worst_result = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
if c_score < worst_result {
break; }
candidates.remove(0);
for neighbor in self.hnsw.get_neighbors(c_node, 0) {
if visited.insert(neighbor) {
if let Some(n_score) = self.score_node(neighbor, query_vector) {
let worst = results.last().map(|(s, _)| *s).unwrap_or(f32::NEG_INFINITY);
if n_score > worst || results.len() < ef {
let pos = candidates.partition_point(|(s, _)| *s > n_score);
candidates.insert(pos, (n_score, neighbor));
let pos = results.partition_point(|(s, _)| *s > n_score);
results.insert(pos, (n_score, neighbor));
if results.len() > ef {
results.pop();
}
}
}
}
}
}
results.iter()
.filter_map(|(_, node)| {
let uuid = self.hnsw.index_to_id.get(*node as usize)?;
let offset = self.index.get(uuid)?;
Some(retrieve_record(*offset, &self.data_path))
})
.collect()
}
fn search_layers(&self, query_vector: &[f32], candidates_per_layer: usize, entry_override: Option<(u32, usize)>) -> HashMap<usize, Vec<u32>> {
if self.hnsw.node_offsets.is_empty() {
return HashMap::new();
}
let (entry_node, top_layer) = match entry_override {
Some(pair) => pair,
None => match self.hnsw.entry_point {
Some(ep) => (ep, self.hnsw.highest_layer),
None => return HashMap::new(),
},
};
let mut current_candidates: Vec<u32> = vec![entry_node];
let mut layer_candidates: HashMap<usize, Vec<u32>> = HashMap::new();
for layer in (0..=top_layer).rev() {
let mut seen: HashSet<u32> = current_candidates.iter().cloned().collect();
for &candidate in ¤t_candidates {
for neighbor in self.hnsw.get_neighbors(candidate, layer) {
seen.insert(neighbor);
}
}
let mut scored: Vec<(f32, u32)> = seen.iter()
.filter_map(|&node_index| {
let uuid = &self.hnsw.index_to_id[node_index as usize];
let data_offset = self.index.get(uuid)?;
let record = retrieve_record(*data_offset, &self.data_path);
let score = cosine_similarity(&record.vector, query_vector);
Some((score, node_index))
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
scored.truncate(candidates_per_layer);
layer_candidates.insert(layer, scored.iter().map(|&(_, idx)| idx).collect());
let next_size = (candidates_per_layer / 2).max(1);
current_candidates = scored.iter().take(next_size).map(|&(_, idx)| idx).collect();
}
layer_candidates
}
pub fn search(&self, query_vector: &[f32], k: usize) -> Vec<Record> {
let mut results = Vec::new();
for (_id, offset) in &self.index {
let record = retrieve_record(*offset, &self.data_path);
let similarity = cosine_similarity(&record.vector, query_vector);
results.push((similarity, record));
}
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
results.into_iter().take(k).map(|(_, record)| record).collect()
}
}
impl Drop for Database {
fn drop(&mut self) {
if self.dirty_count > 0 {
self.save_all();
}
}
}