use std::collections::HashMap;
use std::sync::Arc;
use crate::error::{LaurusError, Result};
use crate::storage::Storage;
use crate::vector::core::distance::DistanceMetric;
use crate::vector::core::vector::Vector;
use crate::vector::index::hnsw::graph::HnswGraph;
use crate::vector::reader::{ValidationReport, VectorIndexMetadata, VectorStats};
use crate::vector::reader::{VectorIndexReader, VectorIterator};
use crate::maintenance::deletion::DeletionBitmap;
use crate::vector::index::storage::VectorStorage;
#[derive(Debug)]
pub struct HnswIndexReader {
vectors: VectorStorage,
vector_ids: Vec<(u64, String)>,
dimension: usize,
distance_metric: DistanceMetric,
m: usize,
ef_construction: usize,
pub graph: Option<Arc<HnswGraph>>,
deletion_bitmap: Option<Arc<DeletionBitmap>>,
prefetch_index: HashMap<String, HashMap<u64, usize>>,
}
impl HnswIndexReader {
pub fn from_bytes(_data: &[u8]) -> Result<Self> {
Err(LaurusError::InvalidOperation(
"from_bytes is deprecated, use load() instead".to_string(),
))
}
pub fn load(
storage: Arc<dyn Storage>,
path: &str,
distance_metric: DistanceMetric,
) -> Result<Self> {
use std::io::Read;
let file_name = format!("{}.hnsw", path);
let mut input = storage.open_input(&file_name)?;
let mut num_vectors_buf = [0u8; 8];
input.read_exact(&mut num_vectors_buf)?;
let num_vectors = u64::from_le_bytes(num_vectors_buf) as usize;
let mut dimension_buf = [0u8; 4];
input.read_exact(&mut dimension_buf)?;
let dimension = u32::from_le_bytes(dimension_buf) as usize;
let mut m_buf = [0u8; 4];
input.read_exact(&mut m_buf)?;
let m = u32::from_le_bytes(m_buf) as usize;
let mut ef_construction_buf = [0u8; 4];
input.read_exact(&mut ef_construction_buf)?;
let ef_construction = u32::from_le_bytes(ef_construction_buf) as usize;
let read_graph =
|input: &mut dyn crate::storage::StorageInput| -> Result<Option<Arc<HnswGraph>>> {
let mut has_graph_buf = [0u8; 1];
if input.read_exact(&mut has_graph_buf).is_ok() && has_graph_buf[0] == 1 {
let mut entry_point_buf = [0u8; 8];
input.read_exact(&mut entry_point_buf)?;
let entry_point_raw = u64::from_le_bytes(entry_point_buf);
let entry_point = if entry_point_raw == u64::MAX {
None
} else {
Some(entry_point_raw)
};
let mut max_level_buf = [0u8; 4];
input.read_exact(&mut max_level_buf)?;
let max_level = u32::from_le_bytes(max_level_buf) as usize;
let mut node_count_buf = [0u8; 8];
input.read_exact(&mut node_count_buf)?;
let node_count = u64::from_le_bytes(node_count_buf) as usize;
let mut nodes = HashMap::with_capacity(node_count);
for _ in 0..node_count {
let mut doc_id_buf = [0u8; 8];
input.read_exact(&mut doc_id_buf)?;
let doc_id = u64::from_le_bytes(doc_id_buf);
let mut layer_count_buf = [0u8; 4];
input.read_exact(&mut layer_count_buf)?;
let layer_count = u32::from_le_bytes(layer_count_buf) as usize;
let mut layers = Vec::with_capacity(layer_count);
for _ in 0..layer_count {
let mut neighbor_count_buf = [0u8; 4];
input.read_exact(&mut neighbor_count_buf)?;
let neighbor_count = u32::from_le_bytes(neighbor_count_buf) as usize;
let mut neighbors = Vec::with_capacity(neighbor_count);
for _ in 0..neighbor_count {
let mut neighbor_buf = [0u8; 8];
input.read_exact(&mut neighbor_buf)?;
neighbors.push(u64::from_le_bytes(neighbor_buf));
}
layers.push(neighbors);
}
nodes.insert(doc_id, layers);
}
Ok(Some(Arc::new(HnswGraph::new(
entry_point,
max_level,
nodes,
m,
m,
m * 2,
ef_construction,
1.0 / (m as f64).ln(),
))))
} else {
Ok(None)
}
};
let (vectors, vector_ids, graph) = match storage.loading_mode() {
crate::storage::LoadingMode::Eager => {
let mut vectors = HashMap::with_capacity(num_vectors);
let mut vector_ids = Vec::with_capacity(num_vectors);
for _ in 0..num_vectors {
let mut doc_id_buf = [0u8; 8];
input.read_exact(&mut doc_id_buf)?;
let doc_id = u64::from_le_bytes(doc_id_buf);
let mut field_name_len_buf = [0u8; 4];
input.read_exact(&mut field_name_len_buf)?;
let field_name_len = u32::from_le_bytes(field_name_len_buf) as usize;
let mut field_name_buf = vec![0u8; field_name_len];
input.read_exact(&mut field_name_buf)?;
let field_name = String::from_utf8(field_name_buf).map_err(|e| {
LaurusError::InvalidOperation(format!("Invalid UTF-8 in field name: {}", e))
})?;
let mut values = vec![0.0f32; dimension];
for value in &mut values {
let mut value_buf = [0u8; 4];
input.read_exact(&mut value_buf)?;
*value = f32::from_le_bytes(value_buf);
}
vector_ids.push((doc_id, field_name.clone()));
vectors.insert((doc_id, field_name), Vector::new(values));
}
let graph = read_graph(&mut input)?;
(VectorStorage::Owned(Arc::new(vectors)), vector_ids, graph)
}
crate::storage::LoadingMode::Lazy => {
let mut offsets = HashMap::with_capacity(num_vectors);
let mut vector_ids = Vec::with_capacity(num_vectors);
let start_pos = 8 + 4 + 4 + 4;
input
.seek(std::io::SeekFrom::Start(start_pos as u64))
.map_err(LaurusError::Io)?;
for _ in 0..num_vectors {
let start_offset = input.stream_position().map_err(LaurusError::Io)?;
let mut doc_id_buf = [0u8; 8];
input.read_exact(&mut doc_id_buf)?;
let doc_id = u64::from_le_bytes(doc_id_buf);
let mut field_name_len_buf = [0u8; 4];
input.read_exact(&mut field_name_len_buf)?;
let field_name_len = u32::from_le_bytes(field_name_len_buf) as usize;
let mut field_name_buf = vec![0u8; field_name_len];
input.read_exact(&mut field_name_buf)?;
let field_name = String::from_utf8(field_name_buf).map_err(|e| {
LaurusError::InvalidOperation(format!("Invalid UTF-8 in field name: {}", e))
})?;
offsets.insert((doc_id, field_name.clone()), start_offset);
vector_ids.push((doc_id, field_name.clone()));
input
.seek(std::io::SeekFrom::Current((dimension * 4) as i64))
.map_err(LaurusError::Io)?;
}
let graph = read_graph(&mut input)?;
(
VectorStorage::OnDemand {
storage: storage.clone(),
file_name: file_name.clone(),
offsets: Arc::new(offsets),
},
vector_ids,
graph,
)
}
};
let prefetch_index = if let VectorStorage::Owned(ref map) = vectors {
let mut idx: HashMap<String, HashMap<u64, usize>> = HashMap::new();
for ((doc_id, field_name), vector) in map.iter() {
idx.entry(field_name.clone())
.or_default()
.insert(*doc_id, vector.data.as_ptr() as usize);
}
idx
} else {
HashMap::new()
};
Ok(Self {
vectors,
vector_ids,
dimension,
distance_metric,
m,
ef_construction,
graph,
deletion_bitmap: None,
prefetch_index,
})
}
pub fn set_deletion_bitmap(&mut self, bitmap: Arc<DeletionBitmap>) {
self.deletion_bitmap = Some(bitmap);
}
fn is_deleted(&self, doc_id: u64) -> bool {
if let Some(bitmap) = &self.deletion_bitmap {
bitmap.is_deleted(doc_id)
} else {
false
}
}
pub fn hnsw_params(&self) -> (usize, usize) {
(self.m, self.ef_construction)
}
pub(crate) fn field_prefetch_index(&self, field_name: &str) -> Option<&HashMap<u64, usize>> {
self.prefetch_index.get(field_name)
}
}
impl VectorIndexReader for HnswIndexReader {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn get_vector(&self, doc_id: u64, field_name: &str) -> Result<Option<Vector>> {
if self.is_deleted(doc_id) {
return Ok(None);
}
self.vectors
.get(&(doc_id, field_name.to_string()), self.dimension)
}
fn get_vectors_for_doc(&self, doc_id: u64) -> Result<Vec<(String, Vector)>> {
let mut result = Vec::new();
for (id, field) in &self.vector_ids {
if *id == doc_id
&& !self.is_deleted(*id)
&& let Some(vec) = self.vectors.get(&(*id, field.clone()), self.dimension)?
{
result.push((field.clone(), vec));
}
}
Ok(result)
}
fn get_vectors(&self, doc_ids: &[(u64, String)]) -> Result<Vec<Option<Vector>>> {
let mut result = Vec::with_capacity(doc_ids.len());
for (id, field) in doc_ids {
if self.is_deleted(*id) {
result.push(None);
} else {
result.push(self.vectors.get(&(*id, field.clone()), self.dimension)?);
}
}
Ok(result)
}
fn vector_ids(&self) -> Result<Vec<(u64, String)>> {
Ok(self.vector_ids.clone())
}
fn vector_count(&self) -> usize {
self.vectors.len()
}
fn dimension(&self) -> usize {
self.dimension
}
fn distance_metric(&self) -> DistanceMetric {
self.distance_metric
}
fn stats(&self) -> VectorStats {
let memory_usage = match &self.vectors {
VectorStorage::Owned(vectors) => vectors.len() * (8 + self.dimension * 4),
VectorStorage::OnDemand { offsets, .. } => {
offsets.len() * (8 + 32 + 8) }
};
VectorStats {
vector_count: self.vectors.len(),
dimension: self.dimension,
memory_usage,
build_time_ms: 0,
}
}
fn contains_vector(&self, doc_id: u64, field_name: &str) -> bool {
match &self.vectors {
VectorStorage::Owned(vectors) => {
vectors.contains_key(&(doc_id, field_name.to_string()))
}
VectorStorage::OnDemand { offsets, .. } => {
offsets.contains_key(&(doc_id, field_name.to_string()))
}
}
}
fn get_vector_range(
&self,
start_doc_id: u64,
end_doc_id: u64,
) -> Result<Vec<(u64, String, Vector)>> {
let mut result = Vec::new();
for (id, field) in &self.vector_ids {
if *id >= start_doc_id
&& *id < end_doc_id
&& !self.is_deleted(*id)
&& let Some(vec) = self.vectors.get(&(*id, field.clone()), self.dimension)?
{
result.push((*id, field.clone(), vec));
}
}
Ok(result)
}
fn get_vectors_by_field(&self, field_name: &str) -> Result<Vec<(u64, Vector)>> {
let mut result = Vec::new();
for (id, field) in &self.vector_ids {
if field == field_name
&& !self.is_deleted(*id)
&& let Some(vec) = self.vectors.get(&(*id, field.clone()), self.dimension)?
{
result.push((*id, vec));
}
}
Ok(result)
}
fn field_names(&self) -> Result<Vec<String>> {
use std::collections::HashSet;
let fields: HashSet<String> = self.vector_ids.iter().map(|val| val.1.clone()).collect();
Ok(fields.into_iter().collect())
}
fn vector_iterator(&self) -> Result<Box<dyn VectorIterator>> {
Ok(Box::new(HnswVectorIterator {
storage: self.vectors.clone(),
keys: self.vector_ids.clone(),
current: 0,
dimension: self.dimension,
deletion_bitmap: self.deletion_bitmap.clone(),
}))
}
fn metadata(&self) -> Result<VectorIndexMetadata> {
Ok(VectorIndexMetadata {
index_type: "hnsw".to_string(),
created_at: chrono::Utc::now(),
modified_at: chrono::Utc::now(),
version: "1".to_string(),
build_config: serde_json::json!({}),
custom_metadata: std::collections::HashMap::new(),
})
}
fn validate(&self) -> Result<ValidationReport> {
let mut errors = Vec::new();
let mut warnings = Vec::new();
if self.vector_ids.len() != self.vectors.len() {
errors.push(format!(
"Mismatch between vector_ids count ({}) and vectors count ({})",
self.vector_ids.len(),
self.vectors.len()
));
}
match &self.vectors {
VectorStorage::Owned(map) => {
for (id, field) in &self.vector_ids {
if let Some(vector) = map.get(&(*id, field.clone())) {
if vector.dimension() != self.dimension {
errors.push(format!(
"Vector {}:{} has dimension {}, expected {}",
id,
field,
vector.dimension(),
self.dimension
));
}
if !vector.is_valid() {
errors.push(format!(
"Vector {}:{} contains invalid values (NaN or infinity)",
id, field
));
}
} else {
errors.push(format!(
"Vector {}:{} found in keys but missing in storage",
id, field
));
}
}
}
VectorStorage::OnDemand { offsets, .. } => {
for (id, field) in &self.vector_ids {
if !offsets.contains_key(&(*id, field.clone())) {
errors.push(format!(
"Vector {}:{} in ids but missing in storage",
id, field
));
}
}
warnings.push("OnDemand mode: Deep vector validation skipped".to_string());
}
}
if self.m == 0 {
warnings.push("HNSW parameter M is 0, this may indicate a corrupted index".to_string());
}
if self.ef_construction == 0 {
warnings.push(
"HNSW parameter ef_construction is 0, this may indicate a corrupted index"
.to_string(),
);
}
Ok(ValidationReport {
repair_suggestions: Vec::new(),
is_valid: errors.is_empty(),
errors,
warnings,
})
}
}
struct HnswVectorIterator {
storage: VectorStorage,
keys: Vec<(u64, String)>,
current: usize,
dimension: usize,
deletion_bitmap: Option<Arc<DeletionBitmap>>,
}
impl VectorIterator for HnswVectorIterator {
fn next(&mut self) -> Result<Option<(u64, String, Vector)>> {
while self.current < self.keys.len() {
let (doc_id, field) = &self.keys[self.current];
if let Some(bitmap) = &self.deletion_bitmap
&& bitmap.is_deleted(*doc_id)
{
self.current += 1;
continue;
}
if let Some(vec) = self
.storage
.get(&(*doc_id, field.clone()), self.dimension)?
{
self.current += 1;
return Ok(Some((*doc_id, field.clone(), vec)));
} else {
return Err(LaurusError::internal(format!(
"Vector {}:{} found in keys but missing in storage",
doc_id, field
)));
}
}
Ok(None)
}
fn skip_to(&mut self, doc_id: u64, field_name: &str) -> Result<bool> {
while self.current < self.keys.len() {
let (id, field) = &self.keys[self.current];
if *id > doc_id || (*id == doc_id && field.as_str() >= field_name) {
return Ok(true);
}
self.current += 1;
}
Ok(false)
}
fn position(&self) -> (u64, String) {
if self.current < self.keys.len() {
self.keys[self.current].clone()
} else {
(u64::MAX, String::new())
}
}
fn reset(&mut self) -> Result<()> {
self.current = 0;
Ok(())
}
}