use super::distance::{DistanceMetric, distance};
use super::flat::FlatIndex;
use super::hnsw::{HnswIndex, HnswParams, SearchResult};
use super::mmap_segment::MmapVectorSegment;
use super::quantize::sq8::Sq8Codec;
use crate::storage::tier::StorageTier;
pub const DEFAULT_SEAL_THRESHOLD: usize = 65_536;
pub struct BuildRequest {
pub key: String,
pub segment_id: u32,
pub vectors: Vec<Vec<f32>>,
pub dim: usize,
pub params: HnswParams,
}
pub struct BuildComplete {
pub key: String,
pub segment_id: u32,
pub index: HnswIndex,
}
pub(super) struct BuildingSegment {
pub(super) flat: FlatIndex,
pub(super) base_id: u32,
pub(super) segment_id: u32,
}
pub struct SealedSegment {
pub index: HnswIndex,
pub(super) base_id: u32,
pub(super) sq8: Option<(Sq8Codec, Vec<u8>)>,
pub(super) tier: StorageTier,
pub(super) mmap_vectors: Option<MmapVectorSegment>,
}
pub struct VectorCollection {
pub(super) growing: FlatIndex,
pub(super) growing_base_id: u32,
pub(super) sealed: Vec<SealedSegment>,
pub(super) building: Vec<BuildingSegment>,
pub(super) params: HnswParams,
pub(super) next_id: u32,
pub(super) next_segment_id: u32,
pub(super) dim: usize,
pub(super) data_dir: Option<std::path::PathBuf>,
pub(super) ram_budget_bytes: usize,
pub(super) mmap_fallback_count: u32,
pub(super) mmap_segment_count: u32,
pub doc_id_map: std::collections::HashMap<u32, String>,
pub(super) seal_threshold: usize,
}
impl VectorCollection {
pub fn new(dim: usize, params: HnswParams) -> Self {
Self::with_seal_threshold(dim, params, DEFAULT_SEAL_THRESHOLD)
}
pub fn with_seal_threshold(dim: usize, params: HnswParams, seal_threshold: usize) -> Self {
Self {
growing: FlatIndex::new(dim, params.metric),
growing_base_id: 0,
sealed: Vec::new(),
building: Vec::new(),
params,
next_id: 0,
next_segment_id: 0,
dim,
data_dir: None,
ram_budget_bytes: 0,
mmap_fallback_count: 0,
mmap_segment_count: 0,
doc_id_map: std::collections::HashMap::new(),
seal_threshold,
}
}
pub fn with_seed(dim: usize, params: HnswParams, _seed: u64) -> Self {
Self::with_seal_threshold(dim, params, DEFAULT_SEAL_THRESHOLD)
}
pub fn insert(&mut self, vector: Vec<f32>) -> u32 {
let id = self.next_id;
self.growing.insert(vector);
self.next_id += 1;
id
}
pub fn insert_with_doc_id(&mut self, vector: Vec<f32>, doc_id: String) -> u32 {
let id = self.insert(vector);
self.doc_id_map.insert(id, doc_id);
id
}
pub fn get_doc_id(&self, vector_id: u32) -> Option<&str> {
self.doc_id_map.get(&vector_id).map(|s| s.as_str())
}
pub fn delete(&mut self, id: u32) -> bool {
if id >= self.growing_base_id {
let local = id - self.growing_base_id;
if (local as usize) < self.growing.len() {
return self.growing.delete(local);
}
}
for seg in &mut self.sealed {
if id >= seg.base_id {
let local = id - seg.base_id;
if (local as usize) < seg.index.len() {
return seg.index.delete(local);
}
}
}
for seg in &mut self.building {
if id >= seg.base_id {
let local = id - seg.base_id;
if (local as usize) < seg.flat.len() {
return seg.flat.delete(local);
}
}
}
false
}
pub fn undelete(&mut self, id: u32) -> bool {
for seg in &mut self.sealed {
if id >= seg.base_id {
let local = id - seg.base_id;
if (local as usize) < seg.index.len() {
return seg.index.undelete(local);
}
}
}
false
}
pub fn search(&self, query: &[f32], top_k: usize, ef: usize) -> Vec<SearchResult> {
let mut all: Vec<SearchResult> = Vec::new();
let growing_results = self.growing.search(query, top_k);
for mut r in growing_results {
r.id += self.growing_base_id;
all.push(r);
}
for seg in &self.sealed {
let results = if let Some((codec, sq8_data)) = &seg.sq8 {
let rerank_k = top_k.saturating_mul(3).max(20);
let mut candidates: Vec<(u32, f32)> = Vec::with_capacity(seg.index.len());
let dim = seg.index.dim();
for i in 0..seg.index.len() {
if seg.index.is_deleted(i as u32) {
continue;
}
let sq8_vec = &sq8_data[i * dim..(i + 1) * dim];
let d = match self.params.metric {
DistanceMetric::L2 => codec.asymmetric_l2(query, sq8_vec),
DistanceMetric::Cosine => codec.asymmetric_cosine(query, sq8_vec),
DistanceMetric::InnerProduct => codec.asymmetric_ip(query, sq8_vec),
_ => {
let dequant = codec.dequantize(sq8_vec);
super::distance::distance(query, &dequant, self.params.metric)
}
};
candidates.push((i as u32, d));
}
if candidates.len() > rerank_k {
candidates.select_nth_unstable_by(rerank_k, |a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(rerank_k);
}
let mut reranked: Vec<SearchResult> = candidates
.iter()
.filter_map(|&(id, _)| {
let v = if let Some(mmap) = &seg.mmap_vectors {
mmap.get_vector(id)?
} else {
seg.index.get_vector(id)?
};
Some(SearchResult {
id,
distance: distance(query, v, self.params.metric),
})
})
.collect();
reranked.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
reranked.truncate(top_k);
reranked
} else {
seg.index.search(query, top_k, ef)
};
for mut r in results {
r.id += seg.base_id;
all.push(r);
}
}
for seg in &self.building {
let results = seg.flat.search(query, top_k);
for mut r in results {
r.id += seg.base_id;
all.push(r);
}
}
all.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
all.truncate(top_k);
all
}
pub fn search_with_bitmap_bytes(
&self,
query: &[f32],
top_k: usize,
ef: usize,
bitmap: &[u8],
) -> Vec<SearchResult> {
let mut all: Vec<SearchResult> = Vec::new();
let growing_results = self.growing.search_filtered(query, top_k, bitmap);
for mut r in growing_results {
r.id += self.growing_base_id;
all.push(r);
}
for seg in &self.sealed {
let results = seg.index.search_with_bitmap_bytes(query, top_k, ef, bitmap);
for mut r in results {
r.id += seg.base_id;
all.push(r);
}
}
for seg in &self.building {
let results = seg.flat.search_filtered(query, top_k, bitmap);
for mut r in results {
r.id += seg.base_id;
all.push(r);
}
}
all.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
all.truncate(top_k);
all
}
pub fn needs_seal(&self) -> bool {
self.growing.len() >= self.seal_threshold
}
pub fn seal(&mut self, key: &str) -> Option<BuildRequest> {
if self.growing.is_empty() {
return None;
}
let segment_id = self.next_segment_id;
self.next_segment_id += 1;
let count = self.growing.len();
let mut vectors = Vec::with_capacity(count);
for i in 0..count as u32 {
if let Some(v) = self.growing.get_vector(i) {
vectors.push(v.to_vec());
}
}
let old_growing = std::mem::replace(
&mut self.growing,
FlatIndex::new(self.dim, self.params.metric),
);
let old_base = self.growing_base_id;
self.growing_base_id = self.next_id;
self.building.push(BuildingSegment {
flat: old_growing,
base_id: old_base,
segment_id,
});
Some(BuildRequest {
key: key.to_string(),
segment_id,
vectors,
dim: self.dim,
params: self.params.clone(),
})
}
pub fn complete_build(&mut self, segment_id: u32, index: HnswIndex) {
if let Some(pos) = self
.building
.iter()
.position(|b| b.segment_id == segment_id)
{
let building = self.building.remove(pos);
let sq8 = Self::build_sq8_for_index(&index);
let (tier, mmap_vectors) = self.resolve_tier_for_build(segment_id, &index);
self.sealed.push(SealedSegment {
index,
base_id: building.base_id,
sq8,
tier,
mmap_vectors,
});
}
}
pub(super) fn build_sq8_for_index(index: &HnswIndex) -> Option<(Sq8Codec, Vec<u8>)> {
if index.live_count() < 1000 {
return None; }
let dim = index.dim();
let n = index.len();
let mut refs: Vec<&[f32]> = Vec::with_capacity(n);
for i in 0..n {
if !index.is_deleted(i as u32)
&& let Some(v) = index.get_vector(i as u32)
{
refs.push(v);
}
}
if refs.is_empty() {
return None;
}
let codec = Sq8Codec::calibrate(&refs, dim);
let mut data = Vec::with_capacity(dim * n);
for i in 0..n {
if let Some(v) = index.get_vector(i as u32) {
data.extend(codec.quantize(v));
} else {
data.extend(vec![0u8; dim]);
}
}
Some((codec, data))
}
pub fn sealed_segments(&self) -> &[SealedSegment] {
&self.sealed
}
pub fn compact(&mut self) -> usize {
let mut total_removed = 0;
for seg in &mut self.sealed {
total_removed += seg.index.compact();
}
total_removed
}
pub fn len(&self) -> usize {
let mut total = self.growing.len();
for seg in &self.sealed {
total += seg.index.len();
}
for seg in &self.building {
total += seg.flat.len();
}
total
}
pub fn live_count(&self) -> usize {
let mut total = self.growing.live_count();
for seg in &self.sealed {
total += seg.index.live_count();
}
for seg in &self.building {
total += seg.flat.live_count();
}
total
}
pub fn is_empty(&self) -> bool {
self.live_count() == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn params(&self) -> &HnswParams {
&self.params
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_collection() -> VectorCollection {
VectorCollection::new(
3,
HnswParams {
metric: DistanceMetric::L2,
..HnswParams::default()
},
)
}
#[test]
fn insert_and_search() {
let mut coll = make_collection();
for i in 0..100u32 {
coll.insert(vec![i as f32, 0.0, 0.0]);
}
assert_eq!(coll.len(), 100);
let results = coll.search(&[50.0, 0.0, 0.0], 3, 64);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 50);
}
#[test]
fn seal_moves_to_building() {
let mut coll = VectorCollection::new(2, HnswParams::default());
for i in 0..DEFAULT_SEAL_THRESHOLD {
coll.insert(vec![i as f32, 0.0]);
}
assert!(coll.needs_seal());
let req = coll.seal("test_key").unwrap();
assert_eq!(req.vectors.len(), DEFAULT_SEAL_THRESHOLD);
assert_eq!(coll.building.len(), 1);
assert_eq!(coll.growing.len(), 0);
let results = coll.search(&[100.0, 0.0], 1, 64);
assert!(!results.is_empty());
}
#[test]
fn complete_build_promotes_to_sealed() {
let mut coll = VectorCollection::new(2, HnswParams::default());
for i in 0..100 {
coll.insert(vec![i as f32, 0.0]);
}
let req = coll.seal("test").unwrap();
let mut index = HnswIndex::new(req.dim, req.params);
for v in &req.vectors {
index.insert(v.clone());
}
coll.complete_build(req.segment_id, index);
assert_eq!(coll.building.len(), 0);
assert_eq!(coll.sealed.len(), 1);
let results = coll.search(&[50.0, 0.0], 3, 64);
assert!(!results.is_empty());
}
#[test]
fn checkpoint_roundtrip() {
let mut coll = make_collection();
for i in 0..50u32 {
coll.insert(vec![i as f32, 0.0, 0.0]);
}
let bytes = coll.checkpoint_to_bytes();
let restored = VectorCollection::from_checkpoint(&bytes).unwrap();
assert_eq!(restored.len(), 50);
assert_eq!(restored.dim(), 3);
let results = restored.search(&[25.0, 0.0, 0.0], 1, 64);
assert_eq!(results[0].id, 25);
}
#[test]
fn multi_segment_search_merges() {
let mut coll = VectorCollection::new(
2,
HnswParams {
metric: DistanceMetric::L2,
..HnswParams::default()
},
);
for i in 0..100 {
coll.insert(vec![i as f32, 0.0]);
}
let req = coll.seal("test").unwrap();
let mut idx = HnswIndex::new(2, req.params);
for v in &req.vectors {
idx.insert(v.clone());
}
coll.complete_build(req.segment_id, idx);
for i in 100..200 {
coll.insert(vec![i as f32, 0.0]);
}
let results = coll.search(&[150.0, 0.0], 3, 64);
assert_eq!(results.len(), 3);
assert_eq!(results[0].id, 150);
}
#[test]
fn delete_across_segments() {
let mut coll = VectorCollection::new(2, HnswParams::default());
for i in 0..10 {
coll.insert(vec![i as f32, 0.0]);
}
assert!(coll.delete(5));
assert_eq!(coll.live_count(), 9);
let results = coll.search(&[5.0, 0.0], 10, 64);
assert!(results.iter().all(|r| r.id != 5));
}
}