use super::buffer::{StreamBuffer, StreamBufferConfig};
use super::ops::UpdateStats;
use super::StreamingStats;
use crate::error::Result;
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub buffer: StreamBufferConfig,
pub auto_compact: bool,
pub merge_search_results: bool,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
buffer: StreamBufferConfig::default(),
auto_compact: true,
merge_search_results: true,
}
}
}
pub struct StreamingCoordinator<I> {
index: I,
buffer: StreamBuffer,
config: StreamingConfig,
total_inserts: u64,
total_deletes: u64,
total_compactions: u64,
}
impl<I> StreamingCoordinator<I> {
pub fn new(index: I) -> Self {
Self::with_config(index, StreamingConfig::default())
}
pub fn with_config(index: I, config: StreamingConfig) -> Self {
Self {
index,
buffer: StreamBuffer::with_config(config.buffer.clone()),
config,
total_inserts: 0,
total_deletes: 0,
total_compactions: 0,
}
}
pub fn buffer_insert(&mut self, id: u32, vector: Vec<f32>) -> Result<()> {
self.buffer.insert(id, vector)?;
self.total_inserts += 1;
Ok(())
}
pub fn buffer_delete(&mut self, id: u32) {
self.buffer.delete(id);
self.total_deletes += 1;
}
pub fn needs_compaction(&self) -> bool {
self.buffer.needs_compaction()
}
pub fn stats(&self) -> StreamingStats {
StreamingStats {
main_index_size: 0, buffer_size: self.buffer.insert_count(),
pending_deletes: self.buffer.delete_count(),
total_inserts: self.total_inserts,
total_deletes: self.total_deletes,
total_compactions: self.total_compactions,
}
}
pub fn inner(&self) -> &I {
&self.index
}
pub fn inner_mut(&mut self) -> &mut I {
&mut self.index
}
pub fn buffer(&self) -> &StreamBuffer {
&self.buffer
}
}
impl<I: IndexOps> StreamingCoordinator<I> {
pub fn insert(&mut self, id: u32, vector: Vec<f32>) -> Result<()> {
self.buffer.insert(id, vector)?;
self.total_inserts += 1;
if self.config.auto_compact && self.buffer.needs_compaction() {
self.compact()?;
}
Ok(())
}
pub fn delete(&mut self, id: u32) -> Result<()> {
self.buffer.delete(id);
self.total_deletes += 1;
if self.config.auto_compact && self.buffer.needs_compaction() {
self.compact()?;
}
Ok(())
}
pub fn update(&mut self, id: u32, vector: Vec<f32>) -> Result<()> {
self.buffer.delete(id);
self.buffer.insert(id, vector)?;
self.total_inserts += 1;
self.total_deletes += 1;
if self.config.auto_compact && self.buffer.needs_compaction() {
self.compact()?;
}
Ok(())
}
pub fn compact(&mut self) -> Result<UpdateStats> {
let start = Instant::now();
let (inserts, deletes) = self.buffer.drain();
let mut stats = UpdateStats::default();
for id in deletes {
if self.index.delete(id).is_ok() {
stats.deletes_applied += 1;
} else {
stats.errors += 1;
}
}
for (id, vector) in inserts {
if self.index.insert(id, vector).is_ok() {
stats.inserts_applied += 1;
} else {
stats.errors += 1;
}
}
stats.duration_us = start.elapsed().as_micros() as u64;
self.total_compactions += 1;
Ok(stats)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
if self.config.buffer.distance_metric != self.index.distance_metric() {
return Err(crate::RetrieveError::InvalidParameter(format!(
"buffer distance metric ({:?}) does not match index metric ({:?}); \
merged search results would have inconsistent rankings",
self.config.buffer.distance_metric,
self.index.distance_metric()
)));
}
if !self.config.merge_search_results || self.buffer.insert_count() == 0 {
let results = self.index.search(query, k)?;
return Ok(results
.into_iter()
.filter(|(id, _)| !self.buffer.is_deleted(*id))
.collect());
}
let buffer_results = self.buffer.search(query, k);
let mut main_results = self.index.search(query, k * 2)?;
main_results.retain(|(id, _)| !self.buffer.is_deleted(*id));
let mut combined = buffer_results;
combined.extend(main_results);
combined
.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
combined.truncate(k);
let mut seen = std::collections::HashSet::new();
combined.retain(|(id, _)| seen.insert(*id));
Ok(combined)
}
}
pub trait IndexOps {
fn insert(&mut self, id: u32, vector: Vec<f32>) -> Result<()>;
fn delete(&mut self, id: u32) -> Result<()>;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>>;
fn distance_metric(&self) -> crate::distance::DistanceMetric {
crate::distance::DistanceMetric::Cosine
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
struct MockIndex {
vectors: std::collections::HashMap<u32, Vec<f32>>,
}
impl MockIndex {
fn new() -> Self {
Self {
vectors: std::collections::HashMap::new(),
}
}
}
impl IndexOps for MockIndex {
fn insert(&mut self, id: u32, vector: Vec<f32>) -> Result<()> {
self.vectors.insert(id, vector);
Ok(())
}
fn delete(&mut self, id: u32) -> Result<()> {
self.vectors.remove(&id);
Ok(())
}
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
let mut results: Vec<_> = self
.vectors
.iter()
.map(|(&id, vec)| {
let dist: f32 = query
.iter()
.zip(vec.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum::<f32>()
.sqrt();
(id, dist)
})
.collect();
results.sort_unstable_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
fn distance_metric(&self) -> crate::distance::DistanceMetric {
crate::distance::DistanceMetric::L2
}
}
fn mock_config() -> StreamingConfig {
let mut config = StreamingConfig::default();
config.buffer.distance_metric = crate::distance::DistanceMetric::L2;
config
}
#[test]
fn test_streaming_insert_search() {
let index = MockIndex::new();
let mut streaming = StreamingCoordinator::with_config(index, mock_config());
streaming.insert(0, vec![0.0, 0.0]).unwrap();
streaming.insert(1, vec![1.0, 0.0]).unwrap();
let results = streaming.search(&[0.1, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_streaming_delete() {
let index = MockIndex::new();
let mut streaming = StreamingCoordinator::with_config(index, mock_config());
streaming.insert(0, vec![0.0, 0.0]).unwrap();
streaming.insert(1, vec![1.0, 0.0]).unwrap();
streaming.delete(0).unwrap();
let results = streaming.search(&[0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 1);
}
#[test]
fn test_compaction() {
let index = MockIndex::new();
let mut streaming = StreamingCoordinator::with_config(index, mock_config());
streaming.insert(0, vec![0.0, 0.0]).unwrap();
streaming.insert(1, vec![1.0, 0.0]).unwrap();
let stats = streaming.compact().unwrap();
assert_eq!(stats.inserts_applied, 2);
assert_eq!(streaming.buffer().insert_count(), 0);
}
}