use crate::error::ShardexError;
use crate::identifiers::{DocumentId, TransactionId};
use bytemuck::{Pod, Zeroable};
use serde::{Deserialize, Serialize};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
use tokio::time::{interval, Interval};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum WalOperation {
AddPosting {
document_id: DocumentId,
start: u32,
length: u32,
vector: Vec<f32>,
},
RemoveDocument { document_id: DocumentId },
StoreDocumentText { document_id: DocumentId, text: String },
DeleteDocumentText { document_id: DocumentId },
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct WalTransaction {
pub id: TransactionId,
pub timestamp: SystemTime,
pub operations: Vec<WalOperation>,
pub checksum: u32,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(C)]
pub struct WalTransactionHeader {
pub id: TransactionId,
pub timestamp_micros: u64,
pub operation_count: u32,
pub operations_data_size: u32,
pub checksum: u32,
pub reserved: [u8; 4],
}
unsafe impl Pod for WalTransactionHeader {}
unsafe impl Zeroable for WalTransactionHeader {}
impl WalOperation {
pub fn document_id(&self) -> DocumentId {
match self {
WalOperation::AddPosting { document_id, .. } => *document_id,
WalOperation::RemoveDocument { document_id } => *document_id,
WalOperation::StoreDocumentText { document_id, .. } => *document_id,
WalOperation::DeleteDocumentText { document_id } => *document_id,
}
}
pub fn is_add_posting(&self) -> bool {
matches!(self, WalOperation::AddPosting { .. })
}
pub fn is_remove_document(&self) -> bool {
matches!(self, WalOperation::RemoveDocument { .. })
}
pub fn is_store_document_text(&self) -> bool {
matches!(self, WalOperation::StoreDocumentText { .. })
}
pub fn is_delete_document_text(&self) -> bool {
matches!(self, WalOperation::DeleteDocumentText { .. })
}
pub fn estimated_serialized_size(&self) -> usize {
match self {
WalOperation::AddPosting { vector, .. } => {
1 + 16 + 4 + 4 + 4 + (vector.len() * 4)
}
WalOperation::RemoveDocument { .. } => {
1 + 16
}
WalOperation::StoreDocumentText { text, .. } => {
1 + 16 + 4 + text.len()
}
WalOperation::DeleteDocumentText { .. } => {
1 + 16
}
}
}
pub fn validate(
&self,
expected_vector_dimension: Option<usize>,
max_document_text_size: usize,
) -> Result<(), ShardexError> {
match self {
WalOperation::AddPosting {
vector, start, length, ..
} => {
if *start > u32::MAX - *length {
return Err(ShardexError::Wal(
"AddPosting start + length would overflow u32".to_string(),
));
}
if *length == 0 {
return Err(ShardexError::Wal("AddPosting length cannot be zero".to_string()));
}
if vector.is_empty() {
return Err(ShardexError::Wal("AddPosting vector cannot be empty".to_string()));
}
if let Some(expected_dim) = expected_vector_dimension {
if vector.len() != expected_dim {
return Err(ShardexError::InvalidDimension {
expected: expected_dim,
actual: vector.len(),
});
}
}
for (i, &value) in vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::Wal(format!(
"Invalid vector value at index {}: {} (must be finite)",
i, value
)));
}
}
}
WalOperation::RemoveDocument { .. } => {
}
WalOperation::StoreDocumentText { text, .. } => {
if text.is_empty() {
return Err(ShardexError::Wal("StoreDocumentText text cannot be empty".to_string()));
}
if text.len() > max_document_text_size {
return Err(ShardexError::DocumentTooLarge {
size: text.len(),
max_size: max_document_text_size,
});
}
}
WalOperation::DeleteDocumentText { .. } => {
}
}
Ok(())
}
}
impl WalTransaction {
pub fn new(operations: Vec<WalOperation>) -> Result<Self, ShardexError> {
if operations.is_empty() {
return Err(ShardexError::Wal("Transaction cannot have zero operations".to_string()));
}
let id = TransactionId::new();
let timestamp = SystemTime::now();
let operations_data = Self::serialize_operations(&operations)?;
let checksum = crc32fast::hash(&operations_data);
Ok(Self {
id,
timestamp,
operations,
checksum,
})
}
pub fn with_id_and_timestamp(
id: TransactionId,
timestamp: SystemTime,
operations: Vec<WalOperation>,
) -> Result<Self, ShardexError> {
if operations.is_empty() {
return Err(ShardexError::Wal("Transaction cannot have zero operations".to_string()));
}
let operations_data = Self::serialize_operations(&operations)?;
let checksum = crc32fast::hash(&operations_data);
Ok(Self {
id,
timestamp,
operations,
checksum,
})
}
pub fn operation_count(&self) -> usize {
self.operations.len()
}
pub fn affected_document_ids(&self) -> Vec<DocumentId> {
let mut doc_ids: Vec<DocumentId> = self.operations.iter().map(|op| op.document_id()).collect();
doc_ids.sort();
doc_ids.dedup();
doc_ids
}
pub fn estimated_serialized_size(&self) -> usize {
std::mem::size_of::<WalTransactionHeader>()
+ self
.operations
.iter()
.map(|op| op.estimated_serialized_size())
.sum::<usize>()
}
pub fn validate(
&self,
expected_vector_dimension: Option<usize>,
max_document_text_size: usize,
) -> Result<(), ShardexError> {
if self.operations.is_empty() {
return Err(ShardexError::Wal(
"Transaction must contain at least one operation".to_string(),
));
}
for (i, operation) in self.operations.iter().enumerate() {
operation
.validate(expected_vector_dimension, max_document_text_size)
.map_err(|e| {
ShardexError::Wal(format!("Operation {} in transaction {} is invalid: {}", i, self.id, e))
})?;
}
let now = SystemTime::now();
const FUTURE_TOLERANCE_SECONDS: u64 = 60;
if let Ok(duration_since_epoch) = self.timestamp.duration_since(UNIX_EPOCH) {
if let Ok(now_duration) = now.duration_since(UNIX_EPOCH) {
if duration_since_epoch.as_secs() > now_duration.as_secs() + FUTURE_TOLERANCE_SECONDS {
return Err(ShardexError::Wal(format!(
"Transaction timestamp is too far in the future: transaction time {:?}, current time {:?}",
self.timestamp, now
)));
}
}
}
Ok(())
}
pub fn verify_checksum(&self) -> Result<(), ShardexError> {
let operations_data = Self::serialize_operations(&self.operations)?;
let calculated_checksum = crc32fast::hash(&operations_data);
if calculated_checksum != self.checksum {
return Err(ShardexError::Wal(format!(
"Transaction {} checksum mismatch: expected {}, calculated {}",
self.id, self.checksum, calculated_checksum
)));
}
Ok(())
}
fn serialize_operations(operations: &[WalOperation]) -> Result<Vec<u8>, ShardexError> {
bincode::serialize(operations).map_err(|e| ShardexError::Wal(format!("Failed to serialize operations: {}", e)))
}
pub fn to_header(&self) -> Result<WalTransactionHeader, ShardexError> {
let operations_data = Self::serialize_operations(&self.operations)?;
let timestamp_micros = self
.timestamp
.duration_since(UNIX_EPOCH)
.map_err(|e| ShardexError::Wal(format!("Invalid timestamp: {}", e)))?
.as_micros() as u64;
Ok(WalTransactionHeader {
id: self.id,
timestamp_micros,
operation_count: self.operations.len() as u32,
operations_data_size: operations_data.len() as u32,
checksum: self.checksum,
reserved: [0; 4],
})
}
pub fn serialize(&self) -> Result<Vec<u8>, ShardexError> {
let header = self.to_header()?;
let operations_data = Self::serialize_operations(&self.operations)?;
let mut result = Vec::with_capacity(std::mem::size_of::<WalTransactionHeader>() + operations_data.len());
result.extend_from_slice(bytemuck::bytes_of(&header));
result.extend_from_slice(&operations_data);
Ok(result)
}
pub fn deserialize(data: &[u8]) -> Result<Self, ShardexError> {
if data.len() < std::mem::size_of::<WalTransactionHeader>() {
return Err(ShardexError::Wal("Transaction data too short for header".to_string()));
}
let header_bytes = &data[0..std::mem::size_of::<WalTransactionHeader>()];
let header: WalTransactionHeader = bytemuck::pod_read_unaligned(header_bytes);
let expected_total_size = std::mem::size_of::<WalTransactionHeader>() + header.operations_data_size as usize;
if data.len() != expected_total_size {
return Err(ShardexError::Wal(format!(
"Transaction data size mismatch: expected {}, got {}",
expected_total_size,
data.len()
)));
}
let operations_data_start = std::mem::size_of::<WalTransactionHeader>();
let operations_data = &data[operations_data_start..];
let calculated_checksum = crc32fast::hash(operations_data);
if calculated_checksum != header.checksum {
return Err(ShardexError::Wal(format!(
"Transaction checksum mismatch: expected {}, calculated {}",
header.checksum, calculated_checksum
)));
}
let operations: Vec<WalOperation> = bincode::deserialize(operations_data)
.map_err(|e| ShardexError::Wal(format!("Failed to deserialize operations: {}", e)))?;
if operations.len() != header.operation_count as usize {
return Err(ShardexError::Wal(format!(
"Operation count mismatch: header says {}, found {}",
header.operation_count,
operations.len()
)));
}
let timestamp = UNIX_EPOCH + std::time::Duration::from_micros(header.timestamp_micros);
Ok(Self {
id: header.id,
timestamp,
operations,
checksum: header.checksum,
})
}
}
impl WalTransactionHeader {
pub fn new_zero() -> Self {
Self::zeroed()
}
pub fn is_valid(&self) -> bool {
self.operation_count > 0 && self.operations_data_size > 0
}
pub fn total_size(&self) -> usize {
std::mem::size_of::<WalTransactionHeader>() + self.operations_data_size as usize
}
}
#[derive(Debug, Clone)]
pub struct BatchConfig {
pub batch_write_interval_ms: u64,
pub max_operations_per_batch: usize,
pub max_batch_size_bytes: usize,
pub max_document_text_size: usize,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
batch_write_interval_ms: 100,
max_operations_per_batch: 1000,
max_batch_size_bytes: 1024 * 1024, max_document_text_size: 10 * 1024 * 1024, }
}
}
#[derive(Debug)]
pub enum BatchCommand {
AddOperation(WalOperation),
Flush,
Shutdown,
}
#[derive(Debug)]
pub enum BatchResponse {
OperationAdded,
BatchFlushed(TransactionId),
Error(ShardexError),
Shutdown,
}
pub struct WalBatchManager {
current_batch: Vec<WalOperation>,
config: BatchConfig,
current_batch_size: usize,
flush_timer: Interval,
expected_vector_dimension: Option<usize>,
}
impl WalBatchManager {
pub fn new(config: BatchConfig, expected_vector_dimension: Option<usize>) -> Self {
let flush_timer = interval(Duration::from_millis(config.batch_write_interval_ms));
Self {
current_batch: Vec::new(),
config,
current_batch_size: 0,
flush_timer,
expected_vector_dimension,
}
}
pub fn add_operation(&mut self, operation: WalOperation) -> Result<bool, ShardexError> {
operation.validate(self.expected_vector_dimension, self.config.max_document_text_size)?;
let operation_size = operation.estimated_serialized_size();
self.current_batch.push(operation);
self.current_batch_size += operation_size;
let should_flush_count = self.current_batch.len() >= self.config.max_operations_per_batch;
let should_flush_size = self.current_batch_size > self.config.max_batch_size_bytes;
Ok(should_flush_count || should_flush_size)
}
pub async fn flush_batch<F>(&mut self, write_fn: F) -> Result<Option<TransactionId>, ShardexError>
where
F: Fn(&WalTransaction) -> Result<(), ShardexError>,
{
if self.current_batch.is_empty() {
return Ok(None);
}
let operations = std::mem::take(&mut self.current_batch);
let transaction = WalTransaction::new(operations)?;
let transaction_id = transaction.id;
transaction.validate(self.expected_vector_dimension, self.config.max_document_text_size)?;
write_fn(&transaction)?;
self.current_batch_size = 0;
Ok(Some(transaction_id))
}
pub fn batch_stats(&self) -> BatchStats {
BatchStats {
operation_count: self.current_batch.len(),
estimated_size_bytes: self.current_batch_size,
is_empty: self.current_batch.is_empty(),
}
}
pub async fn should_flush_due_to_timer(&mut self) -> bool {
self.flush_timer.tick().await;
!self.current_batch.is_empty()
}
pub async fn run_event_loop<F>(
mut self,
mut receiver: mpsc::Receiver<BatchCommand>,
response_sender: mpsc::Sender<BatchResponse>,
write_fn: F,
) where
F: Fn(&WalTransaction) -> Result<(), ShardexError> + Send + 'static,
{
loop {
tokio::select! {
command = receiver.recv() => {
match command {
Some(BatchCommand::AddOperation(operation)) => {
match self.add_operation(operation) {
Ok(should_flush) => {
let _ = response_sender.send(BatchResponse::OperationAdded).await;
if should_flush {
match self.flush_batch(&write_fn).await {
Ok(Some(transaction_id)) => {
let _ = response_sender.send(BatchResponse::BatchFlushed(transaction_id)).await;
}
Ok(None) => {
}
Err(e) => {
let _ = response_sender.send(BatchResponse::Error(e)).await;
}
}
}
}
Err(e) => {
let _ = response_sender.send(BatchResponse::Error(e)).await;
}
}
}
Some(BatchCommand::Flush) => {
match self.flush_batch(&write_fn).await {
Ok(Some(transaction_id)) => {
let _ = response_sender.send(BatchResponse::BatchFlushed(transaction_id)).await;
}
Ok(None) => {
let _ = response_sender.send(BatchResponse::BatchFlushed(TransactionId::new())).await; }
Err(e) => {
let _ = response_sender.send(BatchResponse::Error(e)).await;
}
}
}
Some(BatchCommand::Shutdown) => {
let _ = self.flush_batch(&write_fn).await;
let _ = response_sender.send(BatchResponse::Shutdown).await;
break;
}
None => {
let _ = self.flush_batch(&write_fn).await;
break;
}
}
}
_ = self.should_flush_due_to_timer() => {
match self.flush_batch(&write_fn).await {
Ok(Some(transaction_id)) => {
let _ = response_sender.send(BatchResponse::BatchFlushed(transaction_id)).await;
}
Ok(None) => {
}
Err(e) => {
let _ = response_sender.send(BatchResponse::Error(e)).await;
}
}
}
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct BatchStats {
pub operation_count: usize,
pub estimated_size_bytes: usize,
pub is_empty: bool,
}
pub struct WalBatchHandle {
command_sender: mpsc::Sender<BatchCommand>,
response_receiver: mpsc::Receiver<BatchResponse>,
}
impl WalBatchHandle {
pub fn new(config: BatchConfig, expected_vector_dimension: Option<usize>) -> (Self, WalBatchManager) {
let (command_sender, _command_receiver) = mpsc::channel(1000);
let (_response_sender, response_receiver) = mpsc::channel(1000);
let manager = WalBatchManager::new(config, expected_vector_dimension);
let handle = Self {
command_sender,
response_receiver,
};
(handle, manager)
}
pub async fn add_operation(&mut self, operation: WalOperation) -> Result<(), ShardexError> {
self.command_sender
.send(BatchCommand::AddOperation(operation))
.await
.map_err(|_| ShardexError::Wal("Batch manager channel closed".to_string()))?;
match self.response_receiver.recv().await {
Some(BatchResponse::OperationAdded) => Ok(()),
Some(BatchResponse::Error(e)) => Err(e),
Some(response) => Err(ShardexError::Wal(format!("Unexpected response: {:?}", response))),
None => Err(ShardexError::Wal("Batch manager response channel closed".to_string())),
}
}
pub async fn flush(&mut self) -> Result<TransactionId, ShardexError> {
self.command_sender
.send(BatchCommand::Flush)
.await
.map_err(|_| ShardexError::Wal("Batch manager channel closed".to_string()))?;
match self.response_receiver.recv().await {
Some(BatchResponse::BatchFlushed(transaction_id)) => Ok(transaction_id),
Some(BatchResponse::Error(e)) => Err(e),
Some(response) => Err(ShardexError::Wal(format!("Unexpected response: {:?}", response))),
None => Err(ShardexError::Wal("Batch manager response channel closed".to_string())),
}
}
pub async fn shutdown(&mut self) -> Result<(), ShardexError> {
self.command_sender
.send(BatchCommand::Shutdown)
.await
.map_err(|_| ShardexError::Wal("Batch manager channel closed".to_string()))?;
match self.response_receiver.recv().await {
Some(BatchResponse::Shutdown) => Ok(()),
Some(BatchResponse::Error(e)) => Err(e),
Some(response) => Err(ShardexError::Wal(format!("Unexpected response: {:?}", response))),
None => Err(ShardexError::Wal("Batch manager response channel closed".to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wal_operation_document_id() {
let doc_id = DocumentId::new();
let vector = vec![1.0, 2.0, 3.0];
let add_op = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector,
};
let remove_op = WalOperation::RemoveDocument { document_id: doc_id };
assert_eq!(add_op.document_id(), doc_id);
assert_eq!(remove_op.document_id(), doc_id);
assert!(add_op.is_add_posting());
assert!(!add_op.is_remove_document());
assert!(!remove_op.is_add_posting());
assert!(remove_op.is_remove_document());
}
#[test]
fn test_wal_operation_validation() {
let doc_id = DocumentId::new();
let valid_add = WalOperation::AddPosting {
document_id: doc_id,
start: 100,
length: 50,
vector: vec![1.0, 2.0, 3.0],
};
assert!(valid_add.validate(Some(3), 10 * 1024 * 1024).is_ok());
assert!(valid_add.validate(None, 10 * 1024 * 1024).is_ok());
assert!(valid_add.validate(Some(4), 10 * 1024 * 1024).is_err());
let invalid_nan = WalOperation::AddPosting {
document_id: doc_id,
start: 100,
length: 50,
vector: vec![1.0, f32::NAN, 3.0],
};
assert!(invalid_nan.validate(None, 10 * 1024 * 1024).is_err());
let invalid_inf = WalOperation::AddPosting {
document_id: doc_id,
start: 100,
length: 50,
vector: vec![1.0, f32::INFINITY, 3.0],
};
assert!(invalid_inf.validate(None, 10 * 1024 * 1024).is_err());
let zero_length = WalOperation::AddPosting {
document_id: doc_id,
start: 100,
length: 0,
vector: vec![1.0, 2.0, 3.0],
};
assert!(zero_length.validate(None, 10 * 1024 * 1024).is_err());
let empty_vector = WalOperation::AddPosting {
document_id: doc_id,
start: 100,
length: 50,
vector: vec![],
};
assert!(empty_vector.validate(None, 10 * 1024 * 1024).is_err());
let overflow = WalOperation::AddPosting {
document_id: doc_id,
start: u32::MAX,
length: 1,
vector: vec![1.0, 2.0, 3.0],
};
assert!(overflow.validate(None, 10 * 1024 * 1024).is_err());
let remove_doc = WalOperation::RemoveDocument { document_id: doc_id };
assert!(remove_doc.validate(None, 10 * 1024 * 1024).is_ok());
assert!(remove_doc.validate(Some(128), 10 * 1024 * 1024).is_ok());
}
#[test]
fn test_wal_transaction_creation() {
let doc_id = DocumentId::new();
let operations = vec![
WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
},
WalOperation::RemoveDocument { document_id: doc_id },
];
let transaction = WalTransaction::new(operations.clone()).unwrap();
assert_eq!(transaction.operations, operations);
assert_eq!(transaction.operation_count(), 2);
assert!(transaction.checksum != 0);
}
#[test]
fn test_wal_transaction_empty_operations() {
let result = WalTransaction::new(vec![]);
if let Err(ShardexError::Wal(msg)) = result {
assert!(msg.contains("cannot have zero operations"));
} else {
panic!("Expected Wal error");
}
}
#[test]
fn test_wal_transaction_affected_documents() {
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let operations = vec![
WalOperation::AddPosting {
document_id: doc_id1,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
},
WalOperation::AddPosting {
document_id: doc_id2,
start: 50,
length: 75,
vector: vec![4.0, 5.0, 6.0],
},
WalOperation::RemoveDocument { document_id: doc_id1 },
];
let transaction = WalTransaction::new(operations).unwrap();
let affected = transaction.affected_document_ids();
assert_eq!(affected.len(), 2);
assert!(affected.contains(&doc_id1));
assert!(affected.contains(&doc_id2));
assert!(affected[0] < affected[1]);
}
#[test]
fn test_wal_transaction_checksum_verification() {
let doc_id = DocumentId::new();
let operations = vec![WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
}];
let transaction = WalTransaction::new(operations).unwrap();
assert!(transaction.verify_checksum().is_ok());
let mut bad_transaction = transaction.clone();
bad_transaction.checksum = 12345;
assert!(bad_transaction.verify_checksum().is_err());
}
#[test]
fn test_wal_transaction_serialization() {
let doc_id = DocumentId::new();
let operations = vec![
WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
},
WalOperation::RemoveDocument { document_id: doc_id },
];
let transaction = WalTransaction::new(operations).unwrap();
let serialized = transaction.serialize().unwrap();
assert!(!serialized.is_empty());
let deserialized = WalTransaction::deserialize(&serialized).unwrap();
assert_eq!(transaction.id, deserialized.id);
assert_eq!(transaction.operations, deserialized.operations);
assert_eq!(transaction.checksum, deserialized.checksum);
let time_diff = if transaction.timestamp > deserialized.timestamp {
transaction
.timestamp
.duration_since(deserialized.timestamp)
.unwrap()
} else {
deserialized
.timestamp
.duration_since(transaction.timestamp)
.unwrap()
};
assert!(time_diff.as_micros() < 10); }
#[test]
fn test_wal_transaction_header_operations() {
let doc_id = DocumentId::new();
let operations = vec![WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
}];
let transaction = WalTransaction::new(operations).unwrap();
let header = transaction.to_header().unwrap();
assert_eq!(header.id, transaction.id);
assert_eq!(header.operation_count, 1);
assert!(header.operations_data_size > 0);
assert_eq!(header.checksum, transaction.checksum);
assert!(header.is_valid());
let total_size = header.total_size();
assert_eq!(
total_size,
std::mem::size_of::<WalTransactionHeader>() + header.operations_data_size as usize
);
}
#[test]
fn test_wal_transaction_header_bytemuck() {
let header = WalTransactionHeader {
id: TransactionId::new(),
timestamp_micros: 1640995200000000, operation_count: 5,
operations_data_size: 1024,
checksum: 0x12345678,
reserved: [0; 4],
};
let bytes: &[u8] = bytemuck::bytes_of(&header);
assert_eq!(bytes.len(), std::mem::size_of::<WalTransactionHeader>());
let header_restored: WalTransactionHeader = bytemuck::pod_read_unaligned(bytes);
assert_eq!(header, header_restored);
}
#[tokio::test]
async fn test_invalid_transaction_data() {
let invalid_header = WalTransactionHeader {
id: TransactionId::from_bytes([0xFF; 16]),
timestamp_micros: 0,
operation_count: 1,
operations_data_size: 100,
checksum: 0,
reserved: [0; 4],
};
let bytes = bytemuck::bytes_of(&invalid_header);
let result: Result<WalTransactionHeader, _> = bytemuck::try_pod_read_unaligned(bytes);
assert!(result.is_ok());
let timestamp_header = WalTransactionHeader {
id: TransactionId::new(),
timestamp_micros: u64::MAX,
operation_count: 1,
operations_data_size: 100,
checksum: 0,
reserved: [0; 4],
};
let bytes = bytemuck::bytes_of(×tamp_header);
let header: WalTransactionHeader = bytemuck::pod_read_unaligned(bytes);
assert_eq!(header.timestamp_micros, u64::MAX);
let overflow_header = WalTransactionHeader {
id: TransactionId::new(),
timestamp_micros: 1234567890,
operation_count: u32::MAX,
operations_data_size: 100,
checksum: 0,
reserved: [0; 4],
};
let bytes = bytemuck::bytes_of(&overflow_header);
let header: WalTransactionHeader = bytemuck::pod_read_unaligned(bytes);
assert_eq!(header.operation_count, u32::MAX);
let size_mismatch_header = WalTransactionHeader {
id: TransactionId::new(),
timestamp_micros: 1234567890,
operation_count: 1,
operations_data_size: u32::MAX,
checksum: 0,
reserved: [0; 4],
};
let bytes = bytemuck::bytes_of(&size_mismatch_header);
let header: WalTransactionHeader = bytemuck::pod_read_unaligned(bytes);
assert_eq!(header.operations_data_size, u32::MAX);
}
#[tokio::test]
async fn test_invalid_batch_operations() {
let config = BatchConfig::default();
let mut manager = WalBatchManager::new(config, None);
let operation = WalOperation::AddPosting {
document_id: DocumentId::new(),
start: 0,
length: 100,
vector: vec![], };
let result = manager.add_operation(operation);
assert!(result.is_err());
let large_vector = vec![1.0; 1_000_000]; let operation = WalOperation::AddPosting {
document_id: DocumentId::new(),
start: 0,
length: 100,
vector: large_vector,
};
let result = manager.add_operation(operation);
assert!(result.is_ok());
let operation = WalOperation::AddPosting {
document_id: DocumentId::new(),
start: u32::MAX,
length: u32::MAX,
vector: vec![1.0, 2.0, 3.0],
};
let result = manager.add_operation(operation);
assert!(result.is_err());
let operation = WalOperation::RemoveDocument {
document_id: DocumentId::new(),
};
let result = manager.add_operation(operation);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_batch_manager_basic_operations() {
let config = BatchConfig::default();
let mut manager = WalBatchManager::new(config, Some(3));
let doc_id = DocumentId::new();
let operation = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
};
let should_flush = manager.add_operation(operation).unwrap();
assert!(!should_flush);
let stats = manager.batch_stats();
assert_eq!(stats.operation_count, 1);
assert!(!stats.is_empty);
assert!(stats.estimated_size_bytes > 0);
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let transaction_written = Arc::new(AtomicBool::new(false));
let transaction_written_clone = transaction_written.clone();
let write_fn = move |transaction: &WalTransaction| -> Result<(), ShardexError> {
assert_eq!(transaction.operations.len(), 1);
transaction_written_clone.store(true, Ordering::SeqCst);
Ok(())
};
let result = manager.flush_batch(write_fn).await.unwrap();
assert!(result.is_some());
assert!(transaction_written.load(Ordering::SeqCst));
let stats = manager.batch_stats();
assert_eq!(stats.operation_count, 0);
assert!(stats.is_empty);
}
#[tokio::test]
async fn test_batch_manager_size_limits() {
let config = BatchConfig {
batch_write_interval_ms: 1000, max_operations_per_batch: 2, max_batch_size_bytes: 1024 * 1024,
max_document_text_size: 10 * 1024 * 1024,
};
let mut manager = WalBatchManager::new(config, Some(3));
let doc_id = DocumentId::new();
let operation = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
};
let should_flush = manager.add_operation(operation.clone()).unwrap();
assert!(!should_flush);
let should_flush = manager.add_operation(operation).unwrap();
let stats = manager.batch_stats();
assert!(
should_flush,
"Second operation should trigger flush. Count: {}, should_flush: {}",
stats.operation_count, should_flush
);
}
#[tokio::test]
async fn test_batch_config_defaults() {
let config = BatchConfig::default();
assert_eq!(config.batch_write_interval_ms, 100);
assert_eq!(config.max_operations_per_batch, 1000);
assert_eq!(config.max_batch_size_bytes, 1024 * 1024);
}
#[tokio::test]
async fn test_empty_batch_flush() {
let config = BatchConfig::default();
let mut manager = WalBatchManager::new(config, None);
let write_fn = |_: &WalTransaction| -> Result<(), ShardexError> {
panic!("Write function should not be called for empty batch");
};
let result = manager.flush_batch(write_fn).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn test_batch_manager_validation() {
let config = BatchConfig::default();
let mut manager = WalBatchManager::new(config, Some(3));
let doc_id = DocumentId::new();
let invalid_operation = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0], };
let result = manager.add_operation(invalid_operation);
assert!(result.is_err());
if let Err(ShardexError::InvalidDimension { expected, actual }) = result {
assert_eq!(expected, 3);
assert_eq!(actual, 2);
} else {
panic!("Expected InvalidDimension error");
}
let stats = manager.batch_stats();
assert!(stats.is_empty);
}
#[tokio::test]
async fn test_batch_manager_with_wal_integration() {
use crate::layout::DirectoryLayout;
use crate::test_utils::TestEnvironment;
use crate::wal::WalSegment;
let _test_env = TestEnvironment::new("test_batch_wal_integration");
let layout = DirectoryLayout::new(_test_env.path());
let segment_path = layout.wal_segment_path(1);
let capacity = 8192;
let segment = std::sync::Arc::new(WalSegment::create(1, segment_path, capacity).unwrap());
let config = BatchConfig {
batch_write_interval_ms: 50,
max_operations_per_batch: 3,
max_batch_size_bytes: 1024,
max_document_text_size: 10 * 1024 * 1024,
};
let mut manager = WalBatchManager::new(config, Some(3));
let doc_id = DocumentId::new();
let operations = vec![
WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
},
WalOperation::RemoveDocument { document_id: doc_id },
];
for operation in operations {
let should_flush = manager.add_operation(operation).unwrap();
assert!(!should_flush); }
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let write_count = Arc::new(AtomicUsize::new(0));
let write_count_clone = write_count.clone();
let segment_clone = segment.clone();
let write_fn = move |transaction: &WalTransaction| -> Result<(), ShardexError> {
segment_clone.append_transaction(transaction)?;
segment_clone.sync()?;
write_count_clone.fetch_add(1, Ordering::SeqCst);
Ok(())
};
let result = manager.flush_batch(write_fn).await.unwrap();
assert!(result.is_some());
assert_eq!(write_count.load(Ordering::SeqCst), 1);
assert!(segment.write_pointer() > crate::wal::initial_write_position());
}
#[tokio::test]
async fn test_atomic_batch_commits() {
let config = BatchConfig::default();
let mut manager = WalBatchManager::new(config, Some(3));
let doc_id1 = DocumentId::new();
let doc_id2 = DocumentId::new();
let operations = vec![
WalOperation::AddPosting {
document_id: doc_id1,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
},
WalOperation::AddPosting {
document_id: doc_id2,
start: 50,
length: 75,
vector: vec![4.0, 5.0, 6.0],
},
WalOperation::RemoveDocument { document_id: doc_id1 },
];
for operation in operations {
manager.add_operation(operation).unwrap();
}
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
let commit_called = Arc::new(AtomicBool::new(false));
let operation_count = Arc::new(AtomicUsize::new(0));
let commit_called_clone = commit_called.clone();
let operation_count_clone = operation_count.clone();
let write_fn = move |transaction: &WalTransaction| -> Result<(), ShardexError> {
commit_called_clone.store(true, Ordering::SeqCst);
operation_count_clone.store(transaction.operations.len(), Ordering::SeqCst);
transaction.validate(Some(3), 10 * 1024 * 1024)?;
transaction.verify_checksum()?;
Ok(())
};
let result = manager.flush_batch(write_fn).await.unwrap();
assert!(result.is_some());
assert!(commit_called.load(Ordering::SeqCst));
assert_eq!(operation_count.load(Ordering::SeqCst), 3);
let stats = manager.batch_stats();
assert!(stats.is_empty);
}
#[tokio::test]
async fn test_failed_batch_commit_rollback() {
let config = BatchConfig::default();
let mut manager = WalBatchManager::new(config, Some(3));
let doc_id = DocumentId::new();
let operation = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
};
manager.add_operation(operation).unwrap();
assert_eq!(manager.batch_stats().operation_count, 1);
let write_fn = |_transaction: &WalTransaction| -> Result<(), ShardexError> {
Err(ShardexError::Wal("Simulated write failure".to_string()))
};
let result = manager.flush_batch(write_fn).await;
assert!(result.is_err());
if let Err(ShardexError::Wal(msg)) = result {
assert!(msg.contains("Simulated write failure"));
} else {
panic!("Expected Wal error");
}
let stats = manager.batch_stats();
assert!(stats.is_empty);
}
#[test]
fn test_operation_estimated_size() {
let doc_id = DocumentId::new();
let add_op = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0, 4.0],
};
let remove_op = WalOperation::RemoveDocument { document_id: doc_id };
assert_eq!(add_op.estimated_serialized_size(), 45);
assert_eq!(remove_op.estimated_serialized_size(), 17);
}
#[test]
fn test_transaction_validation() {
let doc_id = DocumentId::new();
let valid_ops = vec![WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 100,
vector: vec![1.0, 2.0, 3.0],
}];
let valid_transaction = WalTransaction::new(valid_ops).unwrap();
assert!(valid_transaction
.validate(Some(3), 10 * 1024 * 1024)
.is_ok());
assert!(valid_transaction
.validate(Some(4), 10 * 1024 * 1024)
.is_err());
let future_time = SystemTime::now() + std::time::Duration::from_secs(3600); let future_ops = vec![WalOperation::RemoveDocument { document_id: doc_id }];
let future_transaction =
WalTransaction::with_id_and_timestamp(TransactionId::new(), future_time, future_ops).unwrap();
assert!(future_transaction.validate(None, 10 * 1024 * 1024).is_err());
}
#[test]
fn test_document_text_wal_operations() {
let doc_id = DocumentId::new();
let text = "The quick brown fox jumps over the lazy dog.".to_string();
let store_op = WalOperation::StoreDocumentText {
document_id: doc_id,
text: text.clone(),
};
assert_eq!(store_op.document_id(), doc_id);
assert!(store_op.is_store_document_text());
assert!(!store_op.is_delete_document_text());
assert!(!store_op.is_add_posting());
assert!(!store_op.is_remove_document());
let delete_op = WalOperation::DeleteDocumentText { document_id: doc_id };
assert_eq!(delete_op.document_id(), doc_id);
assert!(delete_op.is_delete_document_text());
assert!(!delete_op.is_store_document_text());
assert!(!delete_op.is_add_posting());
assert!(!delete_op.is_remove_document());
let store_size = store_op.estimated_serialized_size();
let delete_size = delete_op.estimated_serialized_size();
assert!(store_size > delete_size);
assert!(store_size > text.len()); }
#[test]
fn test_document_text_operation_validation() {
let doc_id = DocumentId::new();
let valid_text = "Hello, world!".to_string();
let valid_op = WalOperation::StoreDocumentText {
document_id: doc_id,
text: valid_text,
};
assert!(valid_op.validate(None, 10 * 1024 * 1024).is_ok());
let large_text = "x".repeat(1000);
let large_op = WalOperation::StoreDocumentText {
document_id: doc_id,
text: large_text,
};
assert!(large_op.validate(None, 10 * 1024 * 1024).is_ok());
let delete_op = WalOperation::DeleteDocumentText { document_id: doc_id };
assert!(delete_op.validate(None, 10 * 1024 * 1024).is_ok());
assert!(delete_op.validate(Some(128), 10 * 1024 * 1024).is_ok());
}
#[test]
fn test_document_text_transaction_atomicity() {
let doc_id = DocumentId::new();
let text = "Updated document content".to_string();
let operations = vec![
WalOperation::StoreDocumentText {
document_id: doc_id,
text,
},
WalOperation::RemoveDocument { document_id: doc_id },
WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 7,
vector: vec![1.0, 2.0, 3.0],
},
];
let transaction = WalTransaction::new(operations).unwrap();
assert_eq!(transaction.operation_count(), 3);
let affected_docs = transaction.affected_document_ids();
assert_eq!(affected_docs.len(), 1);
assert_eq!(affected_docs[0], doc_id);
assert!(transaction.validate(Some(3), 10 * 1024 * 1024).is_ok());
let serialized = transaction.serialize().unwrap();
let deserialized = WalTransaction::deserialize(&serialized).unwrap();
assert_eq!(transaction.operations, deserialized.operations);
}
}