use std::sync::Arc;
use crate::error::{LaurusError, Result};
use crate::storage::Storage;
use crate::vector::core::vector::Vector;
use crate::vector::index::HnswIndexConfig;
use crate::vector::index::field::LegacyVectorFieldWriter;
use crate::vector::index::hnsw::graph::HnswGraph;
use crate::vector::writer::{VectorIndexWriter, VectorIndexWriterConfig};
use parking_lot::RwLock;
use rand::RngExt;
use rayon::prelude::*;
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
trait GraphView {
fn get_neighbors_view(&self, doc_id: u64, level: usize) -> Option<Vec<u64>>;
}
impl GraphView for HnswGraph {
fn get_neighbors_view(&self, doc_id: u64, level: usize) -> Option<Vec<u64>> {
self.get_neighbors(doc_id, level).cloned()
}
}
struct ConcurrentHnswGraph {
max_level: usize,
nodes: HashMap<u64, Vec<RwLock<Vec<u64>>>>,
}
impl ConcurrentHnswGraph {
fn new(nodes_with_levels: Vec<(u64, usize)>, max_level: usize) -> Self {
let mut nodes = HashMap::new();
for (doc_id, level) in nodes_with_levels {
let mut layers = Vec::with_capacity(level + 1);
for _ in 0..=level {
layers.push(RwLock::new(Vec::new()));
}
nodes.insert(doc_id, layers);
}
Self { max_level, nodes }
}
fn set_neighbors(&self, doc_id: u64, level: usize, new_neighbors: Vec<u64>) {
if let Some(layers) = self.nodes.get(&doc_id)
&& let Some(lock) = layers.get(level)
{
*lock.write() = new_neighbors;
}
}
fn add_neighbor_with_pruning(
&self,
doc_id: u64,
level: usize,
neighbor_id: u64,
max_conn: usize,
writer: &HnswIndexWriter,
) -> Result<()> {
if let Some(layers) = self.nodes.get(&doc_id)
&& let Some(lock) = layers.get(level)
{
let needs_pruning = {
let mut neighbors = lock.write();
if !neighbors.contains(&neighbor_id) {
neighbors.push(neighbor_id);
}
if neighbors.len() > max_conn {
Some(neighbors.clone())
} else {
None
}
};
if let Some(snapshot) = needs_pruning {
let pruned = writer.prune_neighbors(doc_id, snapshot, max_conn)?;
*lock.write() = pruned;
}
}
Ok(())
}
fn get_neighbors_raw(&self, doc_id: u64, level: usize) -> Option<Vec<u64>> {
self.nodes
.get(&doc_id)
.and_then(|layers| layers.get(level).map(|lock| lock.read().clone()))
}
fn from_hnsw_graph(graph: HnswGraph, extended_max_level: usize) -> Self {
let mut nodes = HashMap::with_capacity(graph.node_count());
for (doc_id, layered_neighbors) in graph.into_iter_nodes() {
let mut layers = Vec::with_capacity(layered_neighbors.len());
for neighbors in layered_neighbors {
layers.push(RwLock::new(neighbors));
}
nodes.insert(doc_id, layers);
}
Self {
max_level: extended_max_level,
nodes,
}
}
fn add_nodes(&mut self, nodes_with_levels: Vec<(u64, usize)>) {
for (doc_id, level) in nodes_with_levels {
if self.nodes.contains_key(&doc_id) {
continue;
}
let mut layers = Vec::with_capacity(level + 1);
for _ in 0..=level {
layers.push(RwLock::new(Vec::new()));
}
self.nodes.insert(doc_id, layers);
}
}
}
impl GraphView for ConcurrentHnswGraph {
fn get_neighbors_view(&self, doc_id: u64, level: usize) -> Option<Vec<u64>> {
self.get_neighbors_raw(doc_id, level)
}
}
#[derive(Debug)]
pub struct HnswIndexWriter {
index_config: HnswIndexConfig,
writer_config: VectorIndexWriterConfig,
storage: Option<Arc<dyn Storage>>,
path: String,
_ml: f64, vectors: Vec<(u64, String, Vector)>,
doc_id_map: HashMap<u64, usize>,
#[allow(dead_code)] levels: Vec<Vec<u64>>,
entry_point: Option<u64>,
graph: Option<HnswGraph>,
is_finalized: bool,
total_vectors_to_add: Option<usize>,
next_vec_id: u64,
}
#[derive(Debug, Clone, PartialEq)]
struct Candidate {
id: u64,
distance: f32,
similarity: f32,
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
self.distance
.partial_cmp(&other.distance)
.unwrap_or(Ordering::Equal)
}
}
impl HnswIndexWriter {
pub fn new(
index_config: HnswIndexConfig,
writer_config: VectorIndexWriterConfig,
path: impl Into<String>,
) -> Result<Self> {
if index_config.m < 2 {
return Err(crate::error::LaurusError::InvalidOperation(
"HNSW parameter m must be >= 2".to_string(),
));
}
let max_level = Self::calculate_max_level(index_config.m, index_config.ef_construction);
let _ml = 1.0 / (index_config.m as f64).ln();
Ok(Self {
index_config,
writer_config,
storage: None,
path: path.into(),
_ml,
levels: vec![Vec::new(); max_level + 1],
entry_point: None,
vectors: Vec::new(),
doc_id_map: HashMap::new(),
graph: None,
is_finalized: false,
total_vectors_to_add: None,
next_vec_id: 0,
})
}
pub fn with_storage(
index_config: HnswIndexConfig,
writer_config: VectorIndexWriterConfig,
path: impl Into<String>,
storage: Arc<dyn Storage>,
) -> Result<Self> {
let path = path.into();
let file_name = format!("{}.hnsw", path);
if storage.file_exists(&file_name) {
return Self::load(index_config, writer_config, storage, &path);
}
if index_config.m < 2 {
return Err(crate::error::LaurusError::InvalidOperation(
"HNSW parameter m must be >= 2".to_string(),
));
}
let max_level = Self::calculate_max_level(index_config.m, index_config.ef_construction);
let _ml = 1.0 / (index_config.m as f64).ln();
Ok(Self {
index_config,
writer_config,
storage: Some(storage),
path,
_ml,
levels: vec![Vec::new(); max_level + 1],
entry_point: None,
vectors: Vec::new(),
doc_id_map: HashMap::new(),
graph: None,
is_finalized: false,
total_vectors_to_add: None,
next_vec_id: 0,
})
}
pub fn into_field_writer(self, field_name: impl Into<String>) -> LegacyVectorFieldWriter<Self> {
LegacyVectorFieldWriter::new(field_name, self)
}
pub fn load(
index_config: HnswIndexConfig,
writer_config: VectorIndexWriterConfig,
storage: Arc<dyn Storage>,
path: &str,
) -> Result<Self> {
use std::io::Read;
let file_name = format!("{}.hnsw", path);
let mut input = storage.open_input(&file_name)?;
let mut num_vectors_buf = [0u8; 8];
input.read_exact(&mut num_vectors_buf)?;
let num_vectors = u64::from_le_bytes(num_vectors_buf) as usize;
let mut dimension_buf = [0u8; 4];
input.read_exact(&mut dimension_buf)?;
let dimension = u32::from_le_bytes(dimension_buf) as usize;
let mut m_buf = [0u8; 4];
input.read_exact(&mut m_buf)?;
let _m = u32::from_le_bytes(m_buf) as usize;
let mut ef_construction_buf = [0u8; 4];
input.read_exact(&mut ef_construction_buf)?;
let _ef_construction = u32::from_le_bytes(ef_construction_buf) as usize;
if dimension != index_config.dimension {
return Err(LaurusError::InvalidOperation(format!(
"Dimension mismatch: expected {}, found {}",
index_config.dimension, dimension
)));
}
let mut vectors = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
let mut doc_id_buf = [0u8; 8];
input.read_exact(&mut doc_id_buf)?;
let doc_id = u64::from_le_bytes(doc_id_buf);
let mut field_name_len_buf = [0u8; 4];
input.read_exact(&mut field_name_len_buf)?;
let field_name_len = u32::from_le_bytes(field_name_len_buf) as usize;
let mut field_name_buf = vec![0u8; field_name_len];
input.read_exact(&mut field_name_buf)?;
let field_name = String::from_utf8(field_name_buf).map_err(|e| {
LaurusError::InvalidOperation(format!("Invalid UTF-8 in field name: {}", e))
})?;
let mut values = vec![0.0f32; dimension];
for value in &mut values {
let mut value_buf = [0u8; 4];
input.read_exact(&mut value_buf)?;
*value = f32::from_le_bytes(value_buf);
}
vectors.push((doc_id, field_name, Vector::new(values)));
}
let mut doc_id_map = HashMap::new();
for (i, (doc_id, _, _)) in vectors.iter().enumerate() {
doc_id_map.insert(*doc_id, i);
}
let max_id = vectors.iter().map(|(id, _, _)| *id).max().unwrap_or(0);
let next_vec_id = if num_vectors > 0 { max_id + 1 } else { 0 };
if index_config.m < 2 {
return Err(LaurusError::InvalidOperation(
"HNSW parameter m must be >= 2".to_string(),
));
}
let max_level = Self::calculate_max_level(index_config.m, index_config.ef_construction);
let _ml = 1.0 / (index_config.m as f64).ln();
let mut has_graph_buf = [0u8; 1];
let graph = if input.read_exact(&mut has_graph_buf).is_ok() {
if has_graph_buf[0] == 1 {
let mut entry_point_buf = [0u8; 8];
input.read_exact(&mut entry_point_buf)?;
let entry_point_raw = u64::from_le_bytes(entry_point_buf);
let entry_point = if entry_point_raw == u64::MAX {
None
} else {
Some(entry_point_raw)
};
let mut max_level_buf = [0u8; 4];
input.read_exact(&mut max_level_buf)?;
let max_level = u32::from_le_bytes(max_level_buf) as usize;
let mut node_count_buf = [0u8; 8];
input.read_exact(&mut node_count_buf)?;
let node_count = u64::from_le_bytes(node_count_buf) as usize;
let mut nodes = HashMap::with_capacity(node_count);
for _ in 0..node_count {
let mut doc_id_buf = [0u8; 8];
input.read_exact(&mut doc_id_buf)?;
let doc_id = u64::from_le_bytes(doc_id_buf);
let mut layer_count_buf = [0u8; 4];
input.read_exact(&mut layer_count_buf)?;
let layer_count = u32::from_le_bytes(layer_count_buf) as usize;
let mut layers = Vec::with_capacity(layer_count);
for _ in 0..layer_count {
let mut neighbor_count_buf = [0u8; 4];
input.read_exact(&mut neighbor_count_buf)?;
let neighbor_count = u32::from_le_bytes(neighbor_count_buf) as usize;
let mut neighbors = Vec::with_capacity(neighbor_count);
for _ in 0..neighbor_count {
let mut neighbor_buf = [0u8; 8];
input.read_exact(&mut neighbor_buf)?;
neighbors.push(u64::from_le_bytes(neighbor_buf));
}
layers.push(neighbors);
}
nodes.insert(doc_id, layers);
}
Some(HnswGraph::new(
entry_point,
max_level,
nodes,
index_config.m,
index_config.m,
index_config.m * 2,
index_config.ef_construction,
_ml,
))
} else {
None
}
} else {
None
};
Ok(Self {
index_config,
writer_config,
storage: Some(storage),
path: path.to_string(),
_ml,
levels: vec![Vec::new(); max_level + 1], entry_point: graph.as_ref().and_then(|g| g.entry_point),
vectors,
is_finalized: false, total_vectors_to_add: Some(num_vectors),
next_vec_id,
doc_id_map,
graph,
})
}
pub fn with_hnsw_params(mut self, m: usize, ef_construction: usize) -> Self {
self.index_config.m = m;
self.index_config.ef_construction = ef_construction;
self
}
pub fn set_expected_vector_count(&mut self, count: usize) {
self.total_vectors_to_add = Some(count);
}
fn select_layer(&self) -> usize {
let mut layer = 0;
let mut rng = rand::rng();
while rng.random_range(0.0..1.0) < self._ml && layer < 16 {
layer += 1;
}
layer
}
fn calculate_max_level(_m: usize, _ef_construction: usize) -> usize {
16 }
fn validate_vectors(&self, vectors: &Vec<(u64, String, Vector)>) -> Result<()> {
if vectors.is_empty() {
return Ok(());
}
for (doc_id, _, vector) in vectors {
if vector.dimension() != self.index_config.dimension {
return Err(LaurusError::InvalidOperation(format!(
"Vector {} has dimension {}, expected {}",
doc_id,
vector.dimension(),
self.index_config.dimension
)));
}
if !vector.is_valid() {
return Err(LaurusError::InvalidOperation(format!(
"Vector {doc_id} contains invalid values (NaN or infinity)"
)));
}
}
Ok(())
}
fn normalize_vectors_internal(
index_config: &HnswIndexConfig,
writer_config: &VectorIndexWriterConfig,
vectors: &mut Vec<(u64, String, Vector)>,
) {
if !index_config.normalize_vectors {
return;
}
if writer_config.parallel_build && vectors.len() > 100 {
vectors.par_iter_mut().for_each(|(_, _, vector)| {
vector.normalize();
});
} else {
for (_, _, vector) in vectors {
vector.normalize();
}
}
}
fn rebuild_doc_id_map(&mut self) {
self.doc_id_map.clear();
for (idx, (doc_id, _, _)) in self.vectors.iter().enumerate() {
self.doc_id_map.insert(*doc_id, idx);
}
}
fn build_hnsw_graph(&mut self) -> Result<()> {
let count = self.vectors.len();
if count == 0 {
return Ok(());
}
self.rebuild_doc_id_map();
let m = self.index_config.m;
let m_max = m;
let m_max_0 = m * 2;
let ef_construction = self.index_config.ef_construction;
let mut new_node_levels = Vec::new(); let mut new_doc_ids_in_order = Vec::new();
let (graph, entry_point, max_level, search_entry_point) =
if let Some(existing_graph) = self.graph.take() {
for (doc_id, _, _) in &self.vectors {
if !existing_graph.contains_node(doc_id) {
new_doc_ids_in_order.push(*doc_id);
}
}
new_doc_ids_in_order.sort_unstable();
for doc_id in &new_doc_ids_in_order {
let level = self.select_layer();
new_node_levels.push((*doc_id, level));
}
let current_max_level = existing_graph.max_level;
let new_max_level = new_node_levels.iter().map(|(_, l)| *l).max().unwrap_or(0);
let total_max_level = current_max_level.max(new_max_level);
let old_ep = existing_graph.entry_point;
let mut ep = old_ep;
if new_max_level > current_max_level {
ep = new_node_levels
.iter()
.find(|(_, l)| *l == total_max_level)
.map(|(id, _)| *id)
.or(ep);
}
let mut concurrent_graph =
ConcurrentHnswGraph::from_hnsw_graph(existing_graph, total_max_level);
concurrent_graph.add_nodes(new_node_levels.clone());
let search_ep = old_ep.or(ep);
(concurrent_graph, ep, total_max_level, search_ep)
} else {
let mut doc_ids_in_order: Vec<u64> =
self.vectors.iter().map(|(id, _, _)| *id).collect();
doc_ids_in_order.sort_unstable();
for doc_id in &doc_ids_in_order {
let level = self.select_layer();
new_node_levels.push((*doc_id, level));
}
let max_level = new_node_levels.iter().map(|(_, l)| *l).max().unwrap_or(0);
let ep = new_node_levels
.iter()
.find(|(_, l)| *l == max_level)
.map(|(id, _)| *id);
new_doc_ids_in_order = doc_ids_in_order;
let concurrent_graph = ConcurrentHnswGraph::new(new_node_levels.clone(), max_level);
(concurrent_graph, ep, max_level, ep)
};
let writer_ref = &*self;
new_doc_ids_in_order
.into_par_iter()
.try_for_each(|doc_id| -> Result<()> {
let doc_vector_idx = *writer_ref.doc_id_map.get(&doc_id).ok_or_else(|| {
LaurusError::internal(format!("Doc ID {} not found in doc_id_map", doc_id))
})?;
let vector = &writer_ref.vectors[doc_vector_idx].2;
let start_node = match search_entry_point {
Some(sp) => sp,
None => return Ok(()), };
if start_node == doc_id {
return Ok(());
}
let layers_len = graph.nodes.get(&doc_id).map(|l| l.len()).unwrap_or(0);
if layers_len == 0 {
return Ok(());
}
let level = layers_len - 1;
let max_level = graph.max_level;
let mut curr_obj = start_node;
let mut dist = writer_ref.calc_dist(vector, curr_obj)?;
for lc in (level + 1..=max_level).rev() {
let mut changed = true;
while changed {
changed = false;
if let Some(neighbors) = graph.get_neighbors_view(curr_obj, lc) {
for neighbor_id in neighbors {
let d = writer_ref.calc_dist(vector, neighbor_id)?;
if d < dist {
dist = d;
curr_obj = neighbor_id;
changed = true;
}
}
}
}
}
let top_level = usize::min(max_level, level);
for lc in (0..=top_level).rev() {
let candidates =
writer_ref.search_layer(&graph, curr_obj, vector, ef_construction, lc)?;
if let Some(min_cand) = candidates.iter().min_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
}) {
curr_obj = min_cand.id;
}
let neighbors = writer_ref.select_neighbors(&candidates, m, lc, m_max, m_max_0);
graph.set_neighbors(doc_id, lc, neighbors.clone());
for neighbor_id in neighbors {
let current_m_max = if lc == 0 { m_max_0 } else { m_max };
graph.add_neighbor_with_pruning(
neighbor_id,
lc,
doc_id,
current_m_max,
writer_ref,
)?;
}
}
Ok(())
})?;
let mut final_nodes = HashMap::new();
let mut final_levels_map = HashMap::new();
for (doc_id, layers) in graph.nodes {
let mut vec_layers = Vec::with_capacity(layers.len());
for lock in layers {
vec_layers.push(lock.into_inner()); }
final_levels_map.insert(doc_id, vec_layers.len() - 1);
final_nodes.insert(doc_id, vec_layers);
}
self.graph = Some(HnswGraph::new(
entry_point,
max_level,
final_nodes,
m,
m_max,
m_max_0,
ef_construction,
1.0 / (self.index_config.m as f64).ln(),
));
self.entry_point = entry_point;
let mut levels_vec = vec![Vec::new(); max_level + 1];
for (doc_id, level) in final_levels_map {
if level < levels_vec.len() {
levels_vec[level].push(doc_id);
}
}
self.levels = levels_vec;
Ok(())
}
fn calc_dist(&self, query: &Vector, doc_id: u64) -> Result<f32> {
let idx = *self
.doc_id_map
.get(&doc_id)
.ok_or_else(|| LaurusError::internal(format!("Doc ID {} not found in map", doc_id)))?;
let target = &self.vectors[idx].2;
self.index_config
.distance_metric
.distance(&query.data, &target.data)
}
fn search_layer<G: GraphView>(
&self,
graph: &G,
entry_point: u64,
query: &Vector,
ef: usize,
level: usize,
) -> Result<BinaryHeap<Candidate>> {
let mut visited = HashSet::new();
let dist = self.calc_dist(query, entry_point)?;
#[derive(Debug, Clone, PartialEq)]
struct VisitorCandidate {
id: u64,
distance: f32,
}
impl Eq for VisitorCandidate {}
impl Ord for VisitorCandidate {
fn cmp(&self, other: &Self) -> Ordering {
other
.distance
.partial_cmp(&self.distance)
.unwrap_or(Ordering::Equal)
}
}
impl PartialOrd for VisitorCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
let mut to_visit = BinaryHeap::new();
let mut found = BinaryHeap::new();
to_visit.push(VisitorCandidate {
id: entry_point,
distance: dist,
});
found.push(Candidate {
id: entry_point,
distance: dist,
similarity: 0.0,
});
visited.insert(entry_point);
while let Some(curr) = to_visit.pop() {
if let Some(furthest_found) = found.peek()
&& curr.distance > furthest_found.distance
&& found.len() >= ef
{
break;
}
if let Some(neighbors) = graph.get_neighbors_view(curr.id, level) {
for neighbor_id in neighbors {
if !visited.insert(neighbor_id) {
continue;
}
let neighbor_dist = self.calc_dist(query, neighbor_id)?;
let furthest_dist = found.peek().map(|c| c.distance).unwrap_or(f32::MAX);
if neighbor_dist < furthest_dist || found.len() < ef {
let c = Candidate {
id: neighbor_id,
distance: neighbor_dist,
similarity: 0.0,
};
let vc = VisitorCandidate {
id: neighbor_id,
distance: neighbor_dist,
};
found.push(c);
to_visit.push(vc);
if found.len() > ef {
found.pop();
}
}
}
}
}
Ok(found)
}
fn select_neighbors(
&self,
candidates: &BinaryHeap<Candidate>,
m: usize,
_level: usize,
_m_max: usize,
_m_max_0: usize,
) -> Vec<u64> {
let mut sorted: Vec<_> = candidates.iter().cloned().collect();
sorted.sort_unstable_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
sorted.truncate(m);
sorted.into_iter().map(|c| c.id).collect()
}
fn prune_neighbors(
&self,
doc_id: u64,
neighbors: Vec<u64>,
max_conn: usize,
) -> Result<Vec<u64>> {
if neighbors.len() <= max_conn {
return Ok(neighbors);
}
let idx = *self.doc_id_map.get(&doc_id).ok_or_else(|| {
LaurusError::internal(format!(
"Doc ID {} not found in doc_id_map during pruning",
doc_id
))
})?;
let doc_vec = &self.vectors[idx].2;
let mut candidates = Vec::new();
for nid in neighbors {
let dist = self.calc_dist(doc_vec, nid)?;
candidates.push(Candidate {
id: nid,
distance: dist,
similarity: 0.0,
});
}
candidates.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(Ordering::Equal)
});
candidates.truncate(max_conn);
Ok(candidates.into_iter().map(|c| c.id).collect())
}
fn check_memory_limit(&self) -> Result<()> {
if let Some(limit) = self.writer_config.memory_limit {
let current_usage = self.estimated_memory_usage();
if current_usage > limit {
return Err(LaurusError::ResourceExhausted(format!(
"Memory usage {current_usage} bytes exceeds limit {limit} bytes"
)));
}
}
Ok(())
}
pub fn vectors(&self) -> &[(u64, String, Vector)] {
&self.vectors
}
pub fn hnsw_params(&self) -> (usize, usize) {
(self.index_config.m, self.index_config.ef_construction)
}
}
#[async_trait::async_trait]
impl VectorIndexWriter for HnswIndexWriter {
fn next_vector_id(&self) -> u64 {
self.next_vec_id
}
fn build(&mut self, vectors: Vec<(u64, String, Vector)>) -> Result<()> {
if self.is_finalized {
return Err(LaurusError::InvalidOperation(
"Cannot build on finalized index".to_string(),
));
}
self.validate_vectors(&vectors)?;
self.vectors = vectors;
Self::normalize_vectors_internal(
&self.index_config,
&self.writer_config,
&mut self.vectors,
);
self.rebuild_doc_id_map();
if let Some((max_id, _, _)) = self.vectors.iter().max_by_key(|(id, _, _)| id)
&& *max_id >= self.next_vec_id
{
self.next_vec_id = *max_id + 1;
}
self.total_vectors_to_add = Some(self.vectors.len());
self.check_memory_limit()?;
Ok(())
}
fn add_vectors(&mut self, mut vectors: Vec<(u64, String, Vector)>) -> Result<()> {
if self.is_finalized {
self.is_finalized = false;
}
self.validate_vectors(&vectors)?;
Self::normalize_vectors_internal(&self.index_config, &self.writer_config, &mut vectors);
self.rebuild_doc_id_map();
for (doc_id, field, vector) in vectors {
if let Some(&idx) = self.doc_id_map.get(&doc_id) {
self.vectors[idx] = (doc_id, field, vector);
} else {
let idx = self.vectors.len();
self.vectors.push((doc_id, field, vector));
self.doc_id_map.insert(doc_id, idx);
}
}
if let Some((max_id, _, _)) = self.vectors.iter().max_by_key(|(id, _, _)| id)
&& *max_id >= self.next_vec_id
{
self.next_vec_id = *max_id + 1;
}
self.check_memory_limit()?;
Ok(())
}
fn finalize(&mut self) -> Result<()> {
if self.is_finalized {
return Ok(());
}
self.build_hnsw_graph()?;
self.is_finalized = true;
Ok(())
}
fn progress(&self) -> f32 {
if let Some(total) = self.total_vectors_to_add {
if total == 0 {
if self.is_finalized { 1.0 } else { 0.0 }
} else {
let current = self.vectors.len() as u64 as f32;
let progress = current / total as f32;
if self.is_finalized {
1.0
} else {
progress.min(0.99) }
}
} else if self.is_finalized {
1.0
} else {
0.0
}
}
fn estimated_memory_usage(&self) -> usize {
let vector_memory = self.vectors.len()
* (
8 + 32 + self.index_config.dimension * 4
);
let avg_layers = 2.0;
let graph_memory =
self.vectors.len() * (self.index_config.m as f32 * avg_layers * 8.0) as usize;
let metadata_memory = self.vectors.len() * 128;
vector_memory + graph_memory + metadata_memory
}
fn vectors(&self) -> &[(u64, String, Vector)] {
&self.vectors
}
fn write(&self) -> Result<()> {
use std::io::Write;
if !self.is_finalized {
return Err(LaurusError::InvalidOperation(
"Index must be finalized before writing".to_string(),
));
}
let storage = self
.storage
.as_ref()
.ok_or_else(|| LaurusError::InvalidOperation("No storage configured".to_string()))?;
let file_name = format!("{}.hnsw", self.path);
let mut output = storage.create_output(&file_name)?;
output.write_all(&(self.vectors.len() as u64).to_le_bytes())?;
output.write_all(&(self.index_config.dimension as u32).to_le_bytes())?;
output.write_all(&(self.index_config.m as u32).to_le_bytes())?;
output.write_all(&(self.index_config.ef_construction as u32).to_le_bytes())?;
let mut sorted_vectors: Vec<_> = self.vectors.iter().collect();
sorted_vectors.sort_by_key(|(doc_id, _, _)| *doc_id);
for (doc_id, field_name, vector) in sorted_vectors {
output.write_all(&doc_id.to_le_bytes())?;
let field_name_bytes = field_name.as_bytes();
output.write_all(&(field_name_bytes.len() as u32).to_le_bytes())?;
output.write_all(field_name_bytes)?;
for value in vector.data.iter() {
output.write_all(&value.to_le_bytes())?;
}
}
if let Some(graph) = &self.graph {
let has_graph = 1u8;
output.write_all(&[has_graph])?;
let entry_point = graph.entry_point.unwrap_or(u64::MAX);
output.write_all(&entry_point.to_le_bytes())?;
output.write_all(&(graph.max_level as u32).to_le_bytes())?;
let node_count = graph.node_count() as u64;
output.write_all(&node_count.to_le_bytes())?;
let sorted_nodes = graph.sorted_nodes();
for (doc_id, layers) in sorted_nodes {
output.write_all(&doc_id.to_le_bytes())?;
let layer_count = layers.len() as u32;
output.write_all(&layer_count.to_le_bytes())?;
for neighbors in layers {
let neighbor_count = neighbors.len() as u32;
output.write_all(&neighbor_count.to_le_bytes())?;
for neighbor in neighbors {
output.write_all(&neighbor.to_le_bytes())?;
}
}
}
} else {
let has_graph = 0u8;
output.write_all(&[has_graph])?;
}
output.flush()?;
Ok(())
}
fn has_storage(&self) -> bool {
self.storage.is_some()
}
fn delete_document(&mut self, doc_id: u64) -> Result<()> {
if self.is_finalized {
self.is_finalized = false;
}
let initial_len = self.vectors.len();
self.vectors.retain(|(id, _, _)| *id != doc_id);
if self.vectors.len() < initial_len {
self.rebuild_doc_id_map();
self.graph = None;
}
Ok(())
}
fn delete_documents(&mut self, _field: &str, _value: &str) -> Result<usize> {
if self.is_finalized {
return Err(LaurusError::InvalidOperation(
"Cannot delete documents from finalized index".to_string(),
));
}
Ok(0)
}
fn rollback(&mut self) -> Result<()> {
self.vectors.clear();
self.doc_id_map.clear();
self.graph = None;
self.is_finalized = false;
self.next_vec_id = 0;
Ok(())
}
fn pending_docs(&self) -> u64 {
if self.is_finalized {
0
} else {
self.vectors.len() as u64
}
}
fn close(&mut self) -> Result<()> {
self.vectors.clear();
self.doc_id_map.clear();
self.graph = None;
self.is_finalized = true;
Ok(())
}
fn is_closed(&self) -> bool {
self.is_finalized && self.vectors.is_empty()
}
fn build_reader(&self) -> Result<Arc<dyn crate::vector::reader::VectorIndexReader>> {
use crate::vector::index::hnsw::reader::HnswIndexReader;
let storage = self.storage.as_ref().ok_or_else(|| {
LaurusError::InvalidOperation("Cannot build reader: storage not configured".to_string())
})?;
let reader = HnswIndexReader::load(
storage.clone(),
&self.path,
self.index_config.distance_metric,
)?;
Ok(Arc::new(reader))
}
}