use crate::distance::DistanceMetric;
use crate::index::hnsw::HnswIndex;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AsyncIndexBuilderConfig {
#[serde(default = "default_merge_threshold")]
pub merge_threshold: usize,
#[serde(default)]
pub segment_count: Option<usize>,
}
fn default_merge_threshold() -> usize {
10_000
}
impl Default for AsyncIndexBuilderConfig {
fn default() -> Self {
Self {
merge_threshold: default_merge_threshold(),
segment_count: None,
}
}
}
#[allow(dead_code)] pub struct AsyncIndexBuilder {
buffer: RwLock<Vec<(u64, Vec<f32>)>>,
config: AsyncIndexBuilderConfig,
building: Arc<AtomicBool>,
}
#[allow(dead_code)] impl AsyncIndexBuilder {
#[must_use]
pub fn new(config: AsyncIndexBuilderConfig) -> Self {
Self {
buffer: RwLock::new(Vec::new()),
config,
building: Arc::new(AtomicBool::new(false)),
}
}
pub fn enqueue(&self, vectors: Vec<(u64, Vec<f32>)>) -> bool {
let mut buf = self.buffer.write();
buf.extend(vectors);
buf.len() >= self.config.merge_threshold
}
#[must_use]
pub fn buffer_len(&self) -> usize {
self.buffer.read().len()
}
pub fn drain_buffer(&self) -> Vec<(u64, Vec<f32>)> {
let mut buf = self.buffer.write();
std::mem::take(&mut *buf)
}
#[must_use]
pub fn search_buffer(
&self,
query: &[f32],
k: usize,
metric: DistanceMetric,
) -> Vec<(u64, f32)> {
let buf = self.buffer.read();
if buf.is_empty() {
return Vec::new();
}
let mut results: Vec<(u64, f32)> = buf
.iter()
.filter(|(_, v)| v.len() == query.len())
.map(|(id, v)| {
let dist = metric.calculate(query, v);
(*id, dist)
})
.collect();
metric.sort_results(&mut results);
results.truncate(k);
results
}
pub fn flush_sync(&self, hnsw_index: &HnswIndex) -> crate::error::Result<usize> {
if self.building.swap(true, Ordering::AcqRel) {
return Ok(0);
}
let vectors = self.drain_buffer();
let count = vectors.len();
if count == 0 {
self.building.store(false, Ordering::Release);
return Ok(0);
}
let pairs: Vec<(u64, &[f32])> = vectors.iter().map(|(id, v)| (*id, v.as_slice())).collect();
let inserted = hnsw_index.insert_batch_parallel(pairs);
self.building.store(false, Ordering::Release);
tracing::debug!("AsyncIndexBuilder::flush_sync: indexed {inserted}/{count} vectors");
Ok(inserted)
}
#[must_use]
pub fn is_building(&self) -> bool {
self.building.load(Ordering::Acquire)
}
pub fn trigger_build_async(&self, hnsw_index: &Arc<HnswIndex>) {
if self.building.swap(true, Ordering::AcqRel) {
return; }
let vectors = self.drain_buffer();
if vectors.is_empty() {
self.building.store(false, Ordering::Release);
return;
}
let index = Arc::clone(hnsw_index);
let flag = Arc::clone(&self.building);
let count = vectors.len();
std::thread::spawn(move || {
let pairs: Vec<(u64, &[f32])> =
vectors.iter().map(|(id, v)| (*id, v.as_slice())).collect();
let _ = index.insert_batch_parallel(pairs);
flag.store(false, Ordering::Release);
tracing::debug!("AsyncIndexBuilder: background build complete ({count} vectors)");
});
}
#[must_use]
pub fn merge_threshold(&self) -> usize {
self.config.merge_threshold
}
}