use crate::error::{Result, RetrieveError};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct StreamBufferConfig {
pub max_buffer_size: usize,
pub max_pending_deletes: usize,
pub max_memory_bytes: usize,
}
impl Default for StreamBufferConfig {
fn default() -> Self {
Self {
max_buffer_size: 10_000,
max_pending_deletes: 1_000,
max_memory_bytes: 100 * 1024 * 1024, }
}
}
#[derive(Debug)]
pub struct StreamBuffer {
inserts: HashMap<u32, Vec<f32>>,
deletes: HashSet<u32>,
config: StreamBufferConfig,
dimension: Option<usize>,
memory_bytes: usize,
}
impl StreamBuffer {
pub fn new() -> Self {
Self::with_config(StreamBufferConfig::default())
}
pub fn with_config(config: StreamBufferConfig) -> Self {
Self {
inserts: HashMap::new(),
deletes: HashSet::new(),
config,
dimension: None,
memory_bytes: 0,
}
}
pub fn insert(&mut self, id: u32, vector: Vec<f32>) -> Result<()> {
match self.dimension {
None => self.dimension = Some(vector.len()),
Some(dim) if dim != vector.len() => {
return Err(RetrieveError::DimensionMismatch {
query_dim: dim,
doc_dim: vector.len(),
});
}
_ => {}
}
self.deletes.remove(&id);
let vec_bytes = vector.len() * std::mem::size_of::<f32>();
if let Some(old) = self.inserts.insert(id, vector) {
let old_bytes = old.len() * std::mem::size_of::<f32>();
self.memory_bytes = self.memory_bytes.saturating_sub(old_bytes);
}
self.memory_bytes += vec_bytes;
Ok(())
}
pub fn delete(&mut self, id: u32) {
if let Some(vec) = self.inserts.remove(&id) {
let vec_bytes = vec.len() * std::mem::size_of::<f32>();
self.memory_bytes = self.memory_bytes.saturating_sub(vec_bytes);
} else {
self.deletes.insert(id);
}
}
pub fn needs_compaction(&self) -> bool {
self.inserts.len() >= self.config.max_buffer_size
|| self.deletes.len() >= self.config.max_pending_deletes
|| self.memory_bytes >= self.config.max_memory_bytes
}
pub fn drain(&mut self) -> (HashMap<u32, Vec<f32>>, HashSet<u32>) {
let inserts = std::mem::take(&mut self.inserts);
let deletes = std::mem::take(&mut self.deletes);
self.memory_bytes = 0;
(inserts, deletes)
}
pub fn insert_count(&self) -> usize {
self.inserts.len()
}
pub fn delete_count(&self) -> usize {
self.deletes.len()
}
pub fn is_deleted(&self, id: u32) -> bool {
self.deletes.contains(&id)
}
pub fn get(&self, id: u32) -> Option<&Vec<f32>> {
self.inserts.get(&id)
}
pub fn iter(&self) -> impl Iterator<Item = (u32, &Vec<f32>)> {
self.inserts.iter().map(|(&id, vec)| (id, vec))
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(u32, f32)> {
let mut results: Vec<(u32, f32)> = self
.inserts
.iter()
.filter(|(id, _)| !self.deletes.contains(id))
.map(|(&id, vec)| {
let dist = euclidean_distance(query, vec);
(id, dist)
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(k);
results
}
}
impl Default for StreamBuffer {
fn default() -> Self {
Self::new()
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_delete() {
let mut buffer = StreamBuffer::new();
buffer.insert(0, vec![1.0, 2.0]).unwrap();
buffer.insert(1, vec![3.0, 4.0]).unwrap();
assert_eq!(buffer.insert_count(), 2);
buffer.delete(0);
assert_eq!(buffer.insert_count(), 1);
assert!(!buffer.is_deleted(0)); }
#[test]
fn test_delete_from_main() {
let mut buffer = StreamBuffer::new();
buffer.delete(42);
assert!(buffer.is_deleted(42));
assert_eq!(buffer.delete_count(), 1);
}
#[test]
fn test_search() {
let mut buffer = StreamBuffer::new();
buffer.insert(0, vec![0.0, 0.0]).unwrap();
buffer.insert(1, vec![1.0, 0.0]).unwrap();
buffer.insert(2, vec![0.0, 1.0]).unwrap();
let query = vec![0.1, 0.1];
let results = buffer.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 0); }
}