use std::collections::HashMap;
use nodedb_types::{Surrogate, VectorQuantization};
use crate::flat::FlatIndex;
use crate::hnsw::{HnswIndex, HnswParams};
use crate::index_config::{IndexConfig, IndexType};
use super::codec_dispatch::CollectionCodec;
use super::payload_index::PayloadIndexSet;
use super::segment::{BuildRequest, BuildingSegment, DEFAULT_SEAL_THRESHOLD, SealedSegment};
pub struct VectorCollection {
pub(crate) growing: FlatIndex,
pub(crate) growing_base_id: u32,
pub(crate) sealed: Vec<SealedSegment>,
pub(crate) building: Vec<BuildingSegment>,
pub(crate) params: HnswParams,
pub(crate) next_id: u32,
pub(crate) next_segment_id: u32,
pub(crate) dim: usize,
pub(crate) data_dir: Option<std::path::PathBuf>,
pub(crate) ram_budget_bytes: usize,
pub(crate) mmap_fallback_count: u32,
pub(crate) mmap_segment_count: u32,
pub surrogate_map: HashMap<u32, Surrogate>,
pub surrogate_to_local: HashMap<Surrogate, u32>,
pub multi_doc_map: HashMap<Surrogate, Vec<u32>>,
pub(crate) seal_threshold: usize,
pub(crate) index_config: IndexConfig,
pub codec_dispatch: Option<CollectionCodec>,
pub(crate) quantization: VectorQuantization,
pub payload: PayloadIndexSet,
pub arena_index: Option<u32>,
}
impl VectorCollection {
pub fn new(dim: usize, params: HnswParams) -> Self {
Self::with_seal_threshold(dim, params, DEFAULT_SEAL_THRESHOLD)
}
pub fn with_seal_threshold(dim: usize, params: HnswParams, seal_threshold: usize) -> Self {
let index_config = IndexConfig {
hnsw: params.clone(),
..IndexConfig::default()
};
Self::with_seal_threshold_and_config(dim, index_config, seal_threshold)
}
pub fn with_index_config(dim: usize, config: IndexConfig) -> Self {
Self::with_seal_threshold_and_config(dim, config, DEFAULT_SEAL_THRESHOLD)
}
pub fn with_seal_threshold_and_config(
dim: usize,
config: IndexConfig,
seal_threshold: usize,
) -> Self {
let params = config.hnsw.clone();
Self {
growing: FlatIndex::new(dim, params.metric),
growing_base_id: 0,
sealed: Vec::new(),
building: Vec::new(),
params,
next_id: 0,
next_segment_id: 0,
dim,
data_dir: None,
ram_budget_bytes: 0,
mmap_fallback_count: 0,
mmap_segment_count: 0,
surrogate_map: HashMap::new(),
surrogate_to_local: HashMap::new(),
multi_doc_map: HashMap::new(),
seal_threshold,
index_config: config,
codec_dispatch: None,
quantization: VectorQuantization::default(),
payload: PayloadIndexSet::default(),
arena_index: None,
}
}
pub fn with_seed(dim: usize, params: HnswParams, _seed: u64) -> Self {
Self::with_seal_threshold(dim, params, DEFAULT_SEAL_THRESHOLD)
}
pub fn needs_seal(&self) -> bool {
self.growing.len() >= self.seal_threshold
}
pub fn seal(&mut self, key: &str) -> Option<BuildRequest> {
if self.growing.is_empty() {
return None;
}
let segment_id = self.next_segment_id;
self.next_segment_id += 1;
let count = self.growing.len();
let mut vectors = Vec::with_capacity(count);
for i in 0..count as u32 {
if let Some(v) = self.growing.get_vector(i) {
vectors.push(v.to_vec());
}
}
let old_growing = std::mem::replace(
&mut self.growing,
FlatIndex::new(self.dim, self.params.metric),
);
let old_base = self.growing_base_id;
self.growing_base_id = self.next_id;
self.building.push(BuildingSegment {
flat: old_growing,
base_id: old_base,
segment_id,
});
Some(BuildRequest {
key: key.to_string(),
segment_id,
vectors,
dim: self.dim,
params: self.params.clone(),
})
}
pub fn complete_build(&mut self, segment_id: u32, index: HnswIndex) {
if let Some(pos) = self
.building
.iter()
.position(|b| b.segment_id == segment_id)
{
let building = self.building.remove(pos);
let use_codec_dispatch = matches!(
self.quantization,
VectorQuantization::RaBitQ | VectorQuantization::Bbq
);
let use_pq = !use_codec_dispatch && self.index_config.index_type == IndexType::HnswPq;
let (sq8, pq) = if use_codec_dispatch {
(None, None)
} else if use_pq {
(
None,
Self::build_pq_for_index(&index, self.index_config.pq_m),
)
} else {
(Self::build_sq8_for_index(&index), None)
};
let (tier, mmap_vectors) =
self.resolve_tier_for_build(segment_id, building.base_id, &index);
self.sealed.push(SealedSegment {
index,
base_id: building.base_id,
sq8,
pq,
tier,
mmap_vectors,
});
if use_codec_dispatch {
let tag = match self.quantization {
VectorQuantization::RaBitQ => "rabitq",
VectorQuantization::Bbq => "bbq",
_ => unreachable!(
"invariant: use_codec_dispatch is only true for RaBitQ and Bbq quantization variants"
),
};
self.build_codec_dispatch(tag);
}
}
}
pub fn sealed_segments(&self) -> &[SealedSegment] {
&self.sealed
}
pub fn sealed_segments_mut(&mut self) -> &mut Vec<SealedSegment> {
&mut self.sealed
}
pub fn growing_is_empty(&self) -> bool {
self.growing.is_empty()
}
pub fn len(&self) -> usize {
let mut total = self.growing.len();
for seg in &self.sealed {
total += seg.index.len();
}
for seg in &self.building {
total += seg.flat.len();
}
total
}
pub fn live_count(&self) -> usize {
let mut total = self.growing.live_count();
for seg in &self.sealed {
total += seg.index.live_count();
}
for seg in &self.building {
total += seg.flat.live_count();
}
total
}
pub fn is_empty(&self) -> bool {
self.live_count() == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn params(&self) -> &HnswParams {
&self.params
}
pub fn set_params(&mut self, params: HnswParams) {
self.params = params;
}
pub fn set_quantization(&mut self, q: VectorQuantization) {
self.quantization = q;
}
pub fn quantization(&self) -> VectorQuantization {
self.quantization
}
pub fn configure_payload_indexes(&mut self, fields: &[String]) {
use super::payload_index::PayloadIndexKind;
for field in fields {
self.payload
.add_index(field.as_str(), PayloadIndexKind::Equality);
}
}
}