use crate::similarity::SimilarityResult;
use crate::{index::VectorIndex, VectorId};
use anyhow::{anyhow, Result};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use tokio::sync::{watch, Mutex};
use tokio::time::interval;
#[derive(Debug, Clone)]
pub enum UpdateOperation {
Insert {
id: VectorId,
vector: Vec<f32>,
metadata: HashMap<String, String>,
},
Update {
id: VectorId,
vector: Vec<f32>,
metadata: Option<HashMap<String, String>>,
},
Delete { id: VectorId },
Batch { operations: Vec<UpdateOperation> },
}
#[derive(Debug, Clone)]
pub struct UpdateBatch {
pub operations: Vec<UpdateOperation>,
pub timestamp: Instant,
pub priority: UpdatePriority,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum UpdatePriority {
Low = 0,
Normal = 1,
High = 2,
Critical = 3,
}
#[derive(Debug, Clone)]
pub struct RealTimeConfig {
pub max_batch_size: usize,
pub max_batch_wait: Duration,
pub buffer_capacity: usize,
pub background_compaction: bool,
pub compaction_interval: Duration,
pub enable_rebuilding: bool,
pub rebuild_threshold: f64,
}
impl Default for RealTimeConfig {
fn default() -> Self {
Self {
max_batch_size: 1000,
max_batch_wait: Duration::from_millis(100),
buffer_capacity: 10000,
background_compaction: true,
compaction_interval: Duration::from_secs(300), enable_rebuilding: true,
rebuild_threshold: 0.3, }
}
}
pub struct RealTimeVectorUpdater {
config: RealTimeConfig,
update_queue: Arc<Mutex<VecDeque<UpdateOperation>>>,
batch_processor: Arc<Mutex<BatchProcessor>>,
index: Arc<RwLock<dyn VectorIndex + Send + Sync>>,
stats: Arc<RwLock<UpdateStats>>,
shutdown: watch::Sender<bool>,
tasks: Vec<tokio::task::JoinHandle<()>>,
}
#[derive(Debug, Clone, Default)]
pub struct UpdateStats {
pub total_updates: u64,
pub total_inserts: u64,
pub total_deletes: u64,
pub total_batches: u64,
pub failed_updates: u64,
pub average_batch_size: f64,
pub average_processing_time: Duration,
pub last_compaction: Option<Instant>,
pub index_size: usize,
pub pending_updates: usize,
}
pub struct BatchProcessor {
pending_batch: Vec<UpdateOperation>,
batch_start_time: Option<Instant>,
total_updates_since_rebuild: usize,
last_rebuild: Option<Instant>,
}
impl RealTimeVectorUpdater {
pub fn new(
index: Arc<RwLock<dyn VectorIndex + Send + Sync>>,
config: RealTimeConfig,
) -> Result<Self> {
let (shutdown_tx, _shutdown_rx) = watch::channel(false);
let updater = Self {
config: config.clone(),
update_queue: Arc::new(Mutex::new(VecDeque::new())),
batch_processor: Arc::new(Mutex::new(BatchProcessor::new())),
index: index.clone(),
stats: Arc::new(RwLock::new(UpdateStats::default())),
shutdown: shutdown_tx,
tasks: Vec::new(),
};
Ok(updater)
}
pub async fn start(&mut self) -> Result<()> {
let shutdown_rx = self.shutdown.subscribe();
let batch_task = self
.start_batch_processing_task(shutdown_rx.clone())
.await?;
self.tasks.push(batch_task);
if self.config.background_compaction {
let compaction_task = self.start_compaction_task(shutdown_rx.clone()).await?;
self.tasks.push(compaction_task);
}
Ok(())
}
pub async fn stop(&mut self) -> Result<()> {
self.shutdown
.send(true)
.map_err(|_| anyhow!("Failed to send shutdown signal"))?;
for task in self.tasks.drain(..) {
task.await.map_err(|e| anyhow!("Task join error: {}", e))?;
}
self.flush_pending_updates().await?;
Ok(())
}
pub async fn submit_update(&self, operation: UpdateOperation) -> Result<()> {
let mut queue = self.update_queue.lock().await;
if queue.len() >= self.config.buffer_capacity {
return Err(anyhow!("Update queue is full"));
}
queue.push_back(operation);
Ok(())
}
pub async fn submit_batch(&self, operations: Vec<UpdateOperation>) -> Result<()> {
let batch_op = UpdateOperation::Batch { operations };
self.submit_update(batch_op).await
}
pub fn get_stats(&self) -> UpdateStats {
self.stats
.read()
.expect("rwlock should not be poisoned")
.clone()
}
pub async fn compact_index(&self) -> Result<()> {
let _index = self.index.read().expect("rwlock should not be poisoned");
let mut stats = self.stats.write().expect("rwlock should not be poisoned");
stats.last_compaction = Some(Instant::now());
Ok(())
}
pub async fn rebuild_index_if_needed(&self) -> Result<bool> {
let index_size = {
let stats = self.stats.read().expect("rwlock should not be poisoned");
stats.index_size
};
if index_size == 0 {
return Ok(false);
}
let processor = self.batch_processor.lock().await;
let update_ratio = processor.total_updates_since_rebuild as f64 / index_size as f64;
if update_ratio >= self.config.rebuild_threshold {
drop(processor);
self.rebuild_index().await?;
Ok(true)
} else {
Ok(false)
}
}
pub async fn rebuild_index(&self) -> Result<()> {
let mut processor = self.batch_processor.lock().await;
processor.total_updates_since_rebuild = 0;
processor.last_rebuild = Some(Instant::now());
Ok(())
}
pub async fn flush_pending_updates(&self) -> Result<()> {
let mut queue = self.update_queue.lock().await;
let mut processor = self.batch_processor.lock().await;
while let Some(operation) = queue.pop_front() {
processor.pending_batch.push(operation);
}
if !processor.pending_batch.is_empty() {
self.process_batch(&mut processor).await?;
}
Ok(())
}
async fn start_batch_processing_task(
&self,
mut shutdown_rx: watch::Receiver<bool>,
) -> Result<tokio::task::JoinHandle<()>> {
let queue = self.update_queue.clone();
let processor = self.batch_processor.clone();
let index = self.index.clone();
let stats = self.stats.clone();
let config = self.config.clone();
let task = tokio::spawn(async move {
let mut interval = interval(config.max_batch_wait);
loop {
tokio::select! {
_ = interval.tick() => {
if let Err(e) = Self::process_pending_batch(
&queue, &processor, &index, &stats, &config
).await {
eprintln!("Batch processing error: {e}");
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
}
}
});
Ok(task)
}
async fn start_compaction_task(
&self,
mut shutdown_rx: watch::Receiver<bool>,
) -> Result<tokio::task::JoinHandle<()>> {
let index = self.index.clone();
let stats = self.stats.clone();
let config = self.config.clone();
let task = tokio::spawn(async move {
let mut interval = interval(config.compaction_interval);
loop {
tokio::select! {
_ = interval.tick() => {
if let Err(e) = Self::perform_compaction(&index, &stats).await {
eprintln!("Compaction error: {e}");
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
break;
}
}
}
}
});
Ok(task)
}
async fn process_pending_batch(
queue: &Arc<Mutex<VecDeque<UpdateOperation>>>,
processor: &Arc<Mutex<BatchProcessor>>,
index: &Arc<RwLock<dyn VectorIndex + Send + Sync>>,
stats: &Arc<RwLock<UpdateStats>>,
config: &RealTimeConfig,
) -> Result<()> {
let operations = {
let mut queue_guard = queue.lock().await;
let mut processor_guard = processor.lock().await;
let mut batch_size = processor_guard.pending_batch.len();
while batch_size < config.max_batch_size && !queue_guard.is_empty() {
if let Some(operation) = queue_guard.pop_front() {
processor_guard.pending_batch.push(operation);
batch_size += 1;
}
}
if !processor_guard.pending_batch.is_empty() {
std::mem::take(&mut processor_guard.pending_batch)
} else {
return Ok(());
}
};
let start_time = Instant::now();
let (successful_ops, failed_ops) = {
let index_guard = index.write();
if let Ok(mut index_ref) = index_guard {
let mut successful = 0;
let mut failed = 0;
for operation in &operations {
match Self::apply_operation(&mut *index_ref, operation) {
Ok(_) => successful += 1,
Err(_) => failed += 1,
}
}
(successful, failed)
} else {
return Err(anyhow!("Failed to acquire index lock"));
}
};
let processing_time = start_time.elapsed();
{
let stats_guard = stats.write();
if let Ok(mut stats_ref) = stats_guard {
stats_ref.total_batches += 1;
stats_ref.total_updates += successful_ops;
stats_ref.failed_updates += failed_ops;
stats_ref.average_batch_size = (stats_ref.average_batch_size
* (stats_ref.total_batches - 1) as f64
+ operations.len() as f64)
/ stats_ref.total_batches as f64;
let total_time = stats_ref.average_processing_time.as_nanos() as f64
* (stats_ref.total_batches - 1) as f64
+ processing_time.as_nanos() as f64;
stats_ref.average_processing_time =
Duration::from_nanos((total_time / stats_ref.total_batches as f64) as u64);
}
};
{
let mut processor_guard = processor.lock().await;
processor_guard.total_updates_since_rebuild += successful_ops as usize;
}
Ok(())
}
fn count_operations(operation: &UpdateOperation) -> u64 {
match operation {
UpdateOperation::Insert { .. }
| UpdateOperation::Update { .. }
| UpdateOperation::Delete { .. } => 1,
UpdateOperation::Batch { operations } => {
operations.iter().map(Self::count_operations).sum()
}
}
}
fn apply_operation(index: &mut dyn VectorIndex, operation: &UpdateOperation) -> Result<()> {
match operation {
UpdateOperation::Insert {
id,
vector,
metadata,
} => {
let vector_obj = crate::Vector::new(vector.clone());
index.add_vector(id.clone(), vector_obj, Some(metadata.clone()))?;
}
UpdateOperation::Update {
id,
vector,
metadata,
} => {
let vector_obj = crate::Vector::new(vector.clone());
index.update_vector(id.clone(), vector_obj)?;
if let Some(meta) = metadata {
index.update_metadata(id.clone(), meta.clone())?;
}
}
UpdateOperation::Delete { id } => {
index.remove_vector(id.clone())?;
}
UpdateOperation::Batch { operations } => {
for op in operations {
Self::apply_operation(index, op)?;
}
}
}
Ok(())
}
async fn perform_compaction(
index: &Arc<RwLock<dyn VectorIndex + Send + Sync>>,
stats: &Arc<RwLock<UpdateStats>>,
) -> Result<()> {
let index_guard = index.read().expect("rwlock should not be poisoned");
drop(index_guard);
let mut stats_guard = stats.write().expect("rwlock should not be poisoned");
stats_guard.last_compaction = Some(Instant::now());
Ok(())
}
async fn process_batch(&self, processor: &mut BatchProcessor) -> Result<()> {
if processor.pending_batch.is_empty() {
return Ok(());
}
let start_time = Instant::now();
let operations = std::mem::take(&mut processor.pending_batch);
let mut index = self.index.write().expect("rwlock should not be poisoned");
let mut successful_ops = 0;
let mut failed_ops = 0;
for operation in &operations {
match Self::apply_operation(&mut *index, operation) {
Ok(_) => {
successful_ops += Self::count_operations(operation);
}
Err(_) => {
failed_ops += Self::count_operations(operation);
}
}
}
drop(index);
let processing_time = start_time.elapsed();
let mut stats = self.stats.write().expect("rwlock should not be poisoned");
stats.total_batches += 1;
stats.total_updates += successful_ops;
stats.failed_updates += failed_ops;
let total_time = stats.average_processing_time.as_nanos() as f64
* (stats.total_batches - 1) as f64
+ processing_time.as_nanos() as f64;
stats.average_processing_time =
Duration::from_nanos((total_time / stats.total_batches as f64) as u64);
processor.total_updates_since_rebuild += successful_ops as usize;
processor.batch_start_time = None;
Ok(())
}
}
impl BatchProcessor {
fn new() -> Self {
Self {
pending_batch: Vec::new(),
batch_start_time: None,
total_updates_since_rebuild: 0,
last_rebuild: None,
}
}
}
type SearchCache = Arc<RwLock<HashMap<String, (Vec<SimilarityResult>, Instant)>>>;
pub struct RealTimeVectorSearch {
updater: Arc<RealTimeVectorUpdater>,
search_cache: SearchCache,
cache_ttl: Duration,
}
impl RealTimeVectorSearch {
pub fn new(updater: Arc<RealTimeVectorUpdater>) -> Self {
Self {
updater,
search_cache: Arc::new(RwLock::new(HashMap::new())),
cache_ttl: Duration::from_secs(60), }
}
pub async fn similarity_search(
&self,
query_vector: &[f32],
k: usize,
) -> Result<Vec<SimilarityResult>> {
let query_hash = self.compute_query_hash(query_vector, k);
if let Some(cached_results) = self.get_cached_results(&query_hash) {
return Ok(cached_results);
}
let index = self
.updater
.index
.read()
.expect("rwlock should not be poisoned");
let query_vec = crate::Vector::new(query_vector.to_vec());
let search_results = index.search_knn(&query_vec, k)?;
drop(index);
let results: Vec<crate::similarity::SimilarityResult> = search_results
.into_iter()
.enumerate()
.map(
|(idx, (uri, similarity))| crate::similarity::SimilarityResult {
id: format!(
"rt_{}_{}",
idx,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
),
uri,
similarity,
metrics: std::collections::HashMap::new(),
metadata: None,
},
)
.collect();
self.cache_results(query_hash, &results);
Ok(results)
}
pub fn invalidate_cache(&self) {
let mut cache = self
.search_cache
.write()
.expect("rwlock should not be poisoned");
cache.clear();
}
fn get_cached_results(&self, query_hash: &str) -> Option<Vec<SimilarityResult>> {
let cache = self
.search_cache
.read()
.expect("rwlock should not be poisoned");
cache.get(query_hash).and_then(|(results, timestamp)| {
if timestamp.elapsed() < self.cache_ttl {
Some(results.clone())
} else {
None
}
})
}
fn cache_results(&self, query_hash: String, results: &[SimilarityResult]) {
let mut cache = self
.search_cache
.write()
.expect("rwlock should not be poisoned");
cache.insert(query_hash, (results.to_vec(), Instant::now()));
cache.retain(|_, (_, timestamp)| timestamp.elapsed() < self.cache_ttl);
}
fn compute_query_hash(&self, query_vector: &[f32], k: usize) -> String {
let mut hash = k as u64;
for &value in query_vector {
hash = hash.wrapping_mul(31).wrapping_add(value.to_bits() as u64);
}
hash.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MemoryVectorIndex;
#[tokio::test]
async fn test_real_time_updater() -> Result<()> {
let index = Arc::new(RwLock::new(MemoryVectorIndex::new()));
let config = RealTimeConfig::default();
let updater = RealTimeVectorUpdater::new(index, config)?;
let operation = UpdateOperation::Insert {
id: "1".to_string(),
vector: vec![1.0, 2.0, 3.0],
metadata: HashMap::new(),
};
updater.submit_update(operation).await?;
updater.flush_pending_updates().await?;
let stats = updater.get_stats();
assert!(stats.total_updates > 0);
Ok(())
}
#[tokio::test]
async fn test_batch_operations() -> Result<()> {
let index = Arc::new(RwLock::new(MemoryVectorIndex::new()));
let config = RealTimeConfig::default();
let updater = RealTimeVectorUpdater::new(index, config)?;
let operations = vec![
UpdateOperation::Insert {
id: "1".to_string(),
vector: vec![1.0, 0.0],
metadata: HashMap::new(),
},
UpdateOperation::Insert {
id: "2".to_string(),
vector: vec![0.0, 1.0],
metadata: HashMap::new(),
},
];
updater.submit_batch(operations).await?;
updater.flush_pending_updates().await?;
let stats = updater.get_stats();
assert_eq!(stats.total_updates, 2);
Ok(())
}
}