use crate::error::ShardexError;
use crate::layout::DirectoryLayout;
use crate::transactions::{BatchConfig, WalBatchManager, WalOperation, WalTransaction};
use crate::wal::WalManager;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio::time::interval;
use tracing::{debug, error, warn};
pub struct BatchProcessor {
batch_interval: Duration,
pending_operations: Vec<WalOperation>,
timer_handle: Option<JoinHandle<()>>,
shutdown_signal: Arc<AtomicBool>,
batch_config: BatchConfig,
expected_vector_dimension: Option<usize>,
command_sender: Option<mpsc::Sender<BatchProcessorCommand>>,
layout: DirectoryLayout,
wal_segment_size: usize,
}
#[derive(Debug)]
enum BatchProcessorCommand {
AddOperation(WalOperation),
FlushNow(oneshot::Sender<Result<(), ShardexError>>),
Shutdown,
}
impl BatchProcessor {
pub fn new(
batch_interval: Duration,
batch_config: BatchConfig,
expected_vector_dimension: Option<usize>,
layout: DirectoryLayout,
wal_segment_size: usize,
) -> Self {
let shutdown_signal = Arc::new(AtomicBool::new(true));
Self {
batch_interval,
pending_operations: Vec::new(),
timer_handle: None,
shutdown_signal,
batch_config,
expected_vector_dimension,
command_sender: None,
layout,
wal_segment_size,
}
}
pub async fn start(&mut self) -> Result<(), ShardexError> {
if self.timer_handle.is_some() {
return Err(ShardexError::Wal("Batch processor already started".to_string()));
}
let (command_sender, mut command_receiver) = mpsc::channel::<BatchProcessorCommand>(1000);
self.command_sender = Some(command_sender);
let mut batch_manager = WalBatchManager::new(self.batch_config.clone(), self.expected_vector_dimension);
let mut wal_manager = WalManager::new(self.layout.clone(), self.wal_segment_size);
wal_manager.initialize()?;
let pending_ops = std::mem::take(&mut self.pending_operations);
let batch_interval = self.batch_interval;
let shutdown_signal = self.shutdown_signal.clone();
let handle = tokio::spawn(async move {
let mut timer = interval(batch_interval);
for operation in pending_ops {
match batch_manager.add_operation(operation) {
Ok(should_flush) => {
if should_flush {
if let Err(e) = Self::flush_batch(&mut batch_manager, &mut wal_manager).await {
error!("Failed to flush batch from pending operations: {}", e);
}
}
}
Err(e) => {
error!("Failed to add pending operation to batch: {}", e);
}
}
}
loop {
tokio::select! {
_ = timer.tick() => {
if shutdown_signal.load(Ordering::SeqCst) {
debug!("Timer task shutting down due to shutdown signal");
break;
}
if !batch_manager.batch_stats().is_empty {
debug!("Timer triggered batch flush");
if let Err(e) = Self::flush_batch(&mut batch_manager, &mut wal_manager).await {
error!("Timer-triggered batch flush failed: {}", e);
}
}
}
command = command_receiver.recv() => {
match command {
Some(BatchProcessorCommand::Shutdown) => {
debug!("Timer task received shutdown command");
if !batch_manager.batch_stats().is_empty {
if let Err(e) = Self::flush_batch(&mut batch_manager, &mut wal_manager).await {
error!("Shutdown batch flush failed: {}", e);
}
}
break;
}
Some(BatchProcessorCommand::FlushNow(response_tx)) => {
debug!("Timer task received flush command");
let result = if batch_manager.batch_stats().is_empty {
Ok(()) } else {
Self::flush_batch(&mut batch_manager, &mut wal_manager).await
};
let _ = response_tx.send(result);
}
Some(BatchProcessorCommand::AddOperation(operation)) => {
debug!("Timer task received add operation command");
match batch_manager.add_operation(operation) {
Ok(should_flush) => {
if should_flush {
if let Err(e) = Self::flush_batch(&mut batch_manager, &mut wal_manager).await {
error!("Batch size-triggered flush failed: {}", e);
}
}
}
Err(e) => {
error!("Failed to add operation to batch: {}", e);
}
}
}
None => {
debug!("Command channel closed, shutting down timer task");
break;
}
}
}
}
}
debug!("Batch processor background task completed");
});
self.timer_handle = Some(handle);
self.shutdown_signal.store(false, Ordering::SeqCst);
Ok(())
}
pub async fn flush_now(&mut self) -> Result<(), ShardexError> {
if let Some(ref command_sender) = self.command_sender {
let (response_tx, response_rx) = oneshot::channel();
match command_sender
.send(BatchProcessorCommand::FlushNow(response_tx))
.await
{
Ok(_) => {
let _ = response_rx
.await
.map_err(|_| ShardexError::Wal("Failed to receive flush response".to_string()))?;
Ok(())
}
Err(_) => {
debug!("Batch processor channel closed during flush_now - assuming shutdown");
Ok(())
}
}
} else {
Ok(()) }
}
pub async fn shutdown(&mut self) -> Result<(), ShardexError> {
self.shutdown_signal.store(true, Ordering::SeqCst);
if let Some(ref command_sender) = self.command_sender {
let _ = command_sender.send(BatchProcessorCommand::Shutdown).await;
}
if let Some(handle) = self.timer_handle.take() {
match handle.await {
Ok(_) => debug!("Background task completed successfully"),
Err(e) => warn!("Background task completed with error: {}", e),
}
}
self.command_sender = None;
Ok(())
}
pub async fn add_operation(&mut self, _operation: WalOperation) -> Result<(), ShardexError> {
if let Some(ref command_sender) = self.command_sender {
match command_sender
.send(BatchProcessorCommand::AddOperation(_operation.clone()))
.await
{
Ok(()) => Ok(()),
Err(_) => {
warn!("Batch processor channel closed, restarting background task");
self.restart().await?;
if let Some(ref command_sender) = self.command_sender {
command_sender
.send(BatchProcessorCommand::AddOperation(_operation))
.await
.map_err(|_| {
ShardexError::Wal("Failed to send add operation command after restart".to_string())
})?;
}
Ok(())
}
}
} else {
self.pending_operations.push(_operation);
Ok(())
}
}
async fn restart(&mut self) -> Result<(), ShardexError> {
warn!("Restarting batch processor due to background task failure");
self.command_sender = None;
if let Some(handle) = self.timer_handle.take() {
handle.abort();
}
self.start().await
}
#[cfg(test)]
pub fn is_running(&self) -> bool {
self.timer_handle.is_some() && !self.shutdown_signal.load(Ordering::SeqCst)
}
#[cfg(test)]
pub fn batch_interval(&self) -> Duration {
self.batch_interval
}
pub fn pending_operation_count(&self) -> usize {
self.pending_operations.len()
}
async fn flush_batch(
batch_manager: &mut WalBatchManager,
wal_manager: &mut WalManager,
) -> Result<(), ShardexError> {
let current_segment = wal_manager.current_segment()?;
let write_result = batch_manager
.flush_batch(|transaction: &WalTransaction| {
let serialized_data = transaction.serialize()?;
current_segment.append(&serialized_data)?;
debug!(
"Successfully wrote transaction {} with {} operations to WAL segment",
transaction.id,
transaction.operations.len()
);
Ok(())
})
.await;
match write_result {
Ok(Some(transaction_id)) => {
debug!("Batch flush completed, transaction ID: {}", transaction_id);
Ok(())
}
Ok(None) => {
debug!("Batch flush completed, no operations to flush");
Ok(())
}
Err(e) => {
error!("Batch flush failed: {}", e);
Err(e)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::identifiers::DocumentId;
use crate::layout::DirectoryLayout;
use crate::test_utils::TestEnvironment;
use crate::transactions::BatchConfig;
use crate::wal::WalManager;
#[tokio::test]
async fn test_batch_processor_creation() {
let _test_env = TestEnvironment::new("test_batch_processor_creation");
let layout = DirectoryLayout::new(_test_env.path());
layout.create_directories().unwrap();
let batch_interval = Duration::from_millis(100);
let batch_config = BatchConfig::default();
let processor = BatchProcessor::new(batch_interval, batch_config, Some(128), layout, 8192);
assert_eq!(processor.batch_interval(), batch_interval);
assert!(!processor.is_running());
assert_eq!(processor.pending_operation_count(), 0);
}
#[tokio::test]
async fn test_batch_processor_start() {
let _test_env = TestEnvironment::new("test_batch_processor_start");
let layout = DirectoryLayout::new(_test_env.path());
layout.create_directories().unwrap();
let batch_interval = Duration::from_millis(50);
let batch_config = BatchConfig::default();
let mut processor = BatchProcessor::new(batch_interval, batch_config, Some(128), layout, 8192);
let result = processor.start().await;
assert!(result.is_ok());
let result2 = processor.start().await;
assert!(result2.is_err());
let _ = processor.shutdown().await;
}
#[tokio::test]
async fn test_batch_processor_with_wal_integration() {
let _test_env = TestEnvironment::new("test_batch_processor_with_wal_integration");
let layout = DirectoryLayout::new(_test_env.path());
layout.create_directories().unwrap();
let mut wal_manager = WalManager::new(layout.clone(), 8192); wal_manager.initialize().unwrap();
let batch_config = BatchConfig {
batch_write_interval_ms: 50,
max_operations_per_batch: 5,
max_batch_size_bytes: 1024,
max_document_text_size: 10 * 1024 * 1024,
};
let batch_interval = Duration::from_millis(50);
let mut processor = BatchProcessor::new(batch_interval, batch_config, Some(3), layout.clone(), 8192);
processor.start().await.unwrap();
assert!(processor.is_running());
let doc_id = DocumentId::new();
let operations = vec![
WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
},
WalOperation::RemoveDocument { document_id: doc_id },
];
for operation in operations {
processor.add_operation(operation).await.unwrap();
}
processor.flush_now().await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
processor.shutdown().await.unwrap();
assert!(!processor.is_running());
}
#[tokio::test]
async fn test_batch_processor_timer_based_flushing() {
let _test_env = TestEnvironment::new("test_batch_processor_timer_based_flushing");
let layout = DirectoryLayout::new(_test_env.path());
layout.create_directories().unwrap();
let batch_config = BatchConfig {
batch_write_interval_ms: 20, max_operations_per_batch: 100, max_batch_size_bytes: 10000,
max_document_text_size: 10 * 1024 * 1024,
};
let batch_interval = Duration::from_millis(20);
let mut processor = BatchProcessor::new(batch_interval, batch_config, Some(3), layout, 8192);
processor.start().await.unwrap();
let doc_id = DocumentId::new();
let operation = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
};
processor.add_operation(operation).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
processor.shutdown().await.unwrap();
}
#[tokio::test]
async fn test_batch_processor_graceful_shutdown_with_pending_operations() {
let _test_env = TestEnvironment::new("test_batch_processor_graceful_shutdown_with_pending_operations");
let layout = DirectoryLayout::new(_test_env.path());
layout.create_directories().unwrap();
let batch_config = BatchConfig {
batch_write_interval_ms: 1000, max_operations_per_batch: 100,
max_batch_size_bytes: 10000,
max_document_text_size: 10 * 1024 * 1024,
};
let batch_interval = Duration::from_millis(1000);
let mut processor = BatchProcessor::new(batch_interval, batch_config, Some(3), layout, 8192);
processor.start().await.unwrap();
let doc_id = DocumentId::new();
for i in 0..3 {
let operation = WalOperation::AddPosting {
document_id: doc_id,
start: i * 100,
length: 100,
vector: vec![1.0 + i as f32, 2.0 + i as f32, 3.0 + i as f32],
};
processor.add_operation(operation).await.unwrap();
}
processor.shutdown().await.unwrap();
assert!(!processor.is_running());
}
#[tokio::test]
async fn test_batch_processor_basic_lifecycle() {
let _test_env = TestEnvironment::new("test_batch_processor_basic_lifecycle");
let layout = DirectoryLayout::new(_test_env.path());
layout.create_directories().unwrap();
let batch_interval = Duration::from_millis(50);
let batch_config = BatchConfig::default();
let mut processor = BatchProcessor::new(batch_interval, batch_config, Some(128), layout, 8192);
processor.start().await.unwrap();
assert!(processor.is_running());
let doc_id = DocumentId::new();
let operation = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
};
processor.add_operation(operation).await.unwrap();
processor.flush_now().await.unwrap();
processor.shutdown().await.unwrap();
assert!(!processor.is_running());
}
}