use crate::flat::FlatIndex;
use crate::hnsw::{HnswIndex, HnswParams};
use crate::quantize::sq8::Sq8Codec;
use super::segment::{BuildRequest, BuildingSegment, DEFAULT_SEAL_THRESHOLD, SealedSegment};
pub struct VectorCollection {
pub(crate) growing: FlatIndex,
pub(crate) growing_base_id: u32,
pub(crate) sealed: Vec<SealedSegment>,
pub(crate) building: Vec<BuildingSegment>,
pub(crate) params: HnswParams,
pub(crate) next_id: u32,
pub(crate) next_segment_id: u32,
pub(crate) dim: usize,
pub(crate) data_dir: Option<std::path::PathBuf>,
pub(crate) ram_budget_bytes: usize,
pub(crate) mmap_fallback_count: u32,
pub(crate) mmap_segment_count: u32,
pub doc_id_map: std::collections::HashMap<u32, String>,
pub multi_doc_map: std::collections::HashMap<String, Vec<u32>>,
pub(crate) seal_threshold: usize,
}
impl VectorCollection {
pub fn new(dim: usize, params: HnswParams) -> Self {
Self::with_seal_threshold(dim, params, DEFAULT_SEAL_THRESHOLD)
}
pub fn with_seal_threshold(dim: usize, params: HnswParams, seal_threshold: usize) -> Self {
Self {
growing: FlatIndex::new(dim, params.metric),
growing_base_id: 0,
sealed: Vec::new(),
building: Vec::new(),
params,
next_id: 0,
next_segment_id: 0,
dim,
data_dir: None,
ram_budget_bytes: 0,
mmap_fallback_count: 0,
mmap_segment_count: 0,
doc_id_map: std::collections::HashMap::new(),
multi_doc_map: std::collections::HashMap::new(),
seal_threshold,
}
}
pub fn with_seed(dim: usize, params: HnswParams, _seed: u64) -> Self {
Self::with_seal_threshold(dim, params, DEFAULT_SEAL_THRESHOLD)
}
pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
let id = self.next_id;
self.growing.insert(vector);
self.next_id += 1;
id
}
pub fn insert_with_doc_id(&mut self, vector: Vec<f32>, doc_id: String) -> u32 {
let id = self.insert(vector);
self.doc_id_map.insert(id, doc_id);
id
}
pub fn insert_multi_vector(&mut self, vectors: &[&[f32]], doc_id: String) -> Vec<u32> {
let mut ids = Vec::with_capacity(vectors.len());
for &v in vectors {
let id = self.insert(v.to_vec());
self.doc_id_map.insert(id, doc_id.clone());
ids.push(id);
}
self.multi_doc_map.insert(doc_id, ids.clone());
ids
}
pub fn delete_multi_vector(&mut self, doc_id: &str) -> usize {
let Some(ids) = self.multi_doc_map.remove(doc_id) else {
return 0;
};
let mut deleted = 0;
for id in &ids {
if self.delete(*id) {
deleted += 1;
}
self.doc_id_map.remove(id);
}
deleted
}
pub fn get_doc_id(&self, vector_id: u32) -> Option<&str> {
self.doc_id_map.get(&vector_id).map(|s| s.as_str())
}
pub fn delete(&mut self, id: u32) -> bool {
if id >= self.growing_base_id {
let local = id - self.growing_base_id;
if (local as usize) < self.growing.len() {
return self.growing.delete(local);
}
}
for seg in &mut self.sealed {
if id >= seg.base_id {
let local = id - seg.base_id;
if (local as usize) < seg.index.len() {
return seg.index.delete(local);
}
}
}
for seg in &mut self.building {
if id >= seg.base_id {
let local = id - seg.base_id;
if (local as usize) < seg.flat.len() {
return seg.flat.delete(local);
}
}
}
false
}
pub fn undelete(&mut self, id: u32) -> bool {
for seg in &mut self.sealed {
if id >= seg.base_id {
let local = id - seg.base_id;
if (local as usize) < seg.index.len() {
return seg.index.undelete(local);
}
}
}
false
}
pub fn needs_seal(&self) -> bool {
self.growing.len() >= self.seal_threshold
}
pub fn seal(&mut self, key: &str) -> Option<BuildRequest> {
if self.growing.is_empty() {
return None;
}
let segment_id = self.next_segment_id;
self.next_segment_id += 1;
let count = self.growing.len();
let mut vectors = Vec::with_capacity(count);
for i in 0..count as u32 {
if let Some(v) = self.growing.get_vector(i) {
vectors.push(v.to_vec());
}
}
let old_growing = std::mem::replace(
&mut self.growing,
FlatIndex::new(self.dim, self.params.metric),
);
let old_base = self.growing_base_id;
self.growing_base_id = self.next_id;
self.building.push(BuildingSegment {
flat: old_growing,
base_id: old_base,
segment_id,
});
Some(BuildRequest {
key: key.to_string(),
segment_id,
vectors,
dim: self.dim,
params: self.params.clone(),
})
}
pub fn complete_build(&mut self, segment_id: u32, index: HnswIndex) {
if let Some(pos) = self
.building
.iter()
.position(|b| b.segment_id == segment_id)
{
let building = self.building.remove(pos);
let sq8 = Self::build_sq8_for_index(&index);
let (tier, mmap_vectors) = self.resolve_tier_for_build(segment_id, &index);
self.sealed.push(SealedSegment {
index,
base_id: building.base_id,
sq8,
tier,
mmap_vectors,
});
}
}
pub fn build_sq8_for_index(index: &HnswIndex) -> Option<(Sq8Codec, Vec<u8>)> {
if index.live_count() < 1000 {
return None;
}
let dim = index.dim();
let n = index.len();
let mut refs: Vec<&[f32]> = Vec::with_capacity(n);
for i in 0..n {
if !index.is_deleted(i as u32)
&& let Some(v) = index.get_vector(i as u32)
{
refs.push(v);
}
}
if refs.is_empty() {
return None;
}
let codec = Sq8Codec::calibrate(&refs, dim);
let mut data = Vec::with_capacity(dim * n);
for i in 0..n {
if let Some(v) = index.get_vector(i as u32) {
data.extend(codec.quantize(v));
} else {
data.extend(vec![0u8; dim]);
}
}
Some((codec, data))
}
pub fn sealed_segments(&self) -> &[SealedSegment] {
&self.sealed
}
pub fn sealed_segments_mut(&mut self) -> &mut Vec<SealedSegment> {
&mut self.sealed
}
pub fn growing_is_empty(&self) -> bool {
self.growing.is_empty()
}
pub fn compact(&mut self) -> usize {
let mut total_removed = 0;
for seg in &mut self.sealed {
total_removed += seg.index.compact();
}
total_removed
}
pub fn export_snapshot(&self) -> Vec<(u32, Vec<f32>, Option<String>)> {
let mut result = Vec::new();
for i in 0..self.growing.len() as u32 {
let vid = self.growing_base_id + i;
if let Some(data) = self.growing.get_vector(i) {
let doc_id = self.doc_id_map.get(&vid).cloned();
result.push((vid, data.to_vec(), doc_id));
}
}
for seg in &self.sealed {
let vectors = seg.index.export_vectors();
for (i, vec_data) in vectors.into_iter().enumerate() {
let vid = seg.base_id + i as u32;
let doc_id = self.doc_id_map.get(&vid).cloned();
result.push((vid, vec_data, doc_id));
}
}
for seg in &self.building {
for i in 0..seg.flat.len() as u32 {
let vid = seg.base_id + i;
if let Some(data) = seg.flat.get_vector(i) {
let doc_id = self.doc_id_map.get(&vid).cloned();
result.push((vid, data.to_vec(), doc_id));
}
}
}
result
}
pub fn len(&self) -> usize {
let mut total = self.growing.len();
for seg in &self.sealed {
total += seg.index.len();
}
for seg in &self.building {
total += seg.flat.len();
}
total
}
pub fn live_count(&self) -> usize {
let mut total = self.growing.live_count();
for seg in &self.sealed {
total += seg.index.live_count();
}
for seg in &self.building {
total += seg.flat.live_count();
}
total
}
pub fn is_empty(&self) -> bool {
self.live_count() == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn params(&self) -> &HnswParams {
&self.params
}
pub fn set_params(&mut self, params: HnswParams) {
self.params = params;
}
pub fn stats(&self) -> nodedb_types::VectorIndexStats {
let growing_vectors = self.growing.len();
let sealed_vectors: usize = self.sealed.iter().map(|s| s.index.len()).sum();
let building_vectors: usize = self.building.iter().map(|s| s.flat.len()).sum();
let tombstone_count: usize = self
.sealed
.iter()
.map(|s| s.index.tombstone_count())
.sum::<usize>()
+ self.growing.tombstone_count()
+ self
.building
.iter()
.map(|s| s.flat.tombstone_count())
.sum::<usize>();
let total = growing_vectors + sealed_vectors + building_vectors;
let tombstone_ratio = if total > 0 {
tombstone_count as f64 / total as f64
} else {
0.0
};
let quantization = if self.sealed.iter().any(|s| s.sq8.is_some()) {
nodedb_types::VectorIndexQuantization::Sq8
} else {
nodedb_types::VectorIndexQuantization::None
};
let hnsw_mem: usize = self
.sealed
.iter()
.map(|s| s.index.memory_usage_bytes())
.sum();
let sq8_mem: usize = self
.sealed
.iter()
.filter_map(|s| s.sq8.as_ref().map(|(_, data)| data.len()))
.sum();
let growing_mem = growing_vectors * self.dim * std::mem::size_of::<f32>();
let building_mem = building_vectors * self.dim * std::mem::size_of::<f32>();
let memory_bytes = hnsw_mem + sq8_mem + growing_mem + building_mem;
let disk_bytes: usize = self
.sealed
.iter()
.filter_map(|s| s.mmap_vectors.as_ref().map(|m| m.file_size()))
.sum();
let metric_name = format!("{:?}", self.params.metric).to_lowercase();
nodedb_types::VectorIndexStats {
sealed_count: self.sealed.len(),
building_count: self.building.len(),
growing_vectors,
sealed_vectors,
live_count: self.live_count(),
tombstone_count,
tombstone_ratio,
quantization,
memory_bytes,
disk_bytes,
build_in_progress: !self.building.is_empty(),
index_type: nodedb_types::VectorIndexType::Hnsw,
hnsw_m: self.params.m,
hnsw_m0: self.params.m0,
hnsw_ef_construction: self.params.ef_construction,
metric: metric_name,
dimensions: self.dim,
seal_threshold: self.seal_threshold,
mmap_segment_count: self.mmap_segment_count,
}
}
}