#![allow(dead_code)]
use crate::error::ShardexError;
use crate::identifiers::{DocumentId, TransactionId};
use crate::shardex_index::ShardexIndex;
use crate::transactions::{WalOperation, WalTransaction};
use crate::wal::WalSegment;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, SystemTime};
pub struct DocumentTransactionCoordinator {
wal_segment: Arc<WalSegment>,
active_transactions: HashMap<TransactionId, ActiveTransaction>,
transaction_counter: AtomicU64,
transaction_timeout: Duration,
max_active_transactions: usize,
}
#[derive(Debug)]
struct ActiveTransaction {
id: TransactionId,
start_time: SystemTime,
operations: Vec<WalOperation>,
state: TransactionState,
}
#[derive(Debug, Clone, PartialEq)]
enum TransactionState {
Active,
Committing,
Committed,
Aborted,
}
#[derive(Debug, Clone)]
pub struct TransactionStatistics {
pub active_transactions: usize,
pub oldest_transaction_age: Option<Duration>,
pub total_operations_pending: usize,
}
impl DocumentTransactionCoordinator {
pub fn new(wal_segment: Arc<WalSegment>, transaction_timeout: Duration) -> Self {
Self::with_max_transactions(wal_segment, transaction_timeout, 1000)
}
pub fn with_max_transactions(
wal_segment: Arc<WalSegment>,
transaction_timeout: Duration,
max_active_transactions: usize,
) -> Self {
Self {
wal_segment,
active_transactions: HashMap::new(),
transaction_counter: AtomicU64::new(1),
transaction_timeout,
max_active_transactions,
}
}
pub async fn begin_transaction(&mut self) -> Result<TransactionId, ShardexError> {
if self.active_transactions.len() >= self.max_active_transactions {
return Err(ShardexError::InvalidInput {
field: "transaction_count".to_string(),
reason: format!(
"Maximum active transaction limit reached: {} (max: {})",
self.active_transactions.len(),
self.max_active_transactions
),
suggestion: "Wait for existing transactions to complete or increase the limit".to_string(),
});
}
let _ = self.transaction_counter.fetch_add(1, Ordering::SeqCst);
let transaction_id = TransactionId::new();
let transaction = ActiveTransaction {
id: transaction_id,
start_time: SystemTime::now(),
operations: Vec::new(),
state: TransactionState::Active,
};
self.active_transactions.insert(transaction_id, transaction);
tracing::debug!("Transaction {} started", transaction_id);
Ok(transaction_id)
}
pub async fn add_operation(
&mut self,
transaction_id: TransactionId,
operation: WalOperation,
) -> Result<(), ShardexError> {
self.validate_operation(&operation)?;
let transaction = self
.active_transactions
.get_mut(&transaction_id)
.ok_or_else(|| ShardexError::InvalidInput {
field: "transaction_id".to_string(),
reason: format!("Transaction {} not found or expired", transaction_id),
suggestion: "Begin a new transaction".to_string(),
})?;
if transaction.state != TransactionState::Active {
return Err(ShardexError::InvalidInput {
field: "transaction_state".to_string(),
reason: format!(
"Transaction {} is not active (state: {:?})",
transaction_id, transaction.state
),
suggestion: "Begin a new transaction".to_string(),
});
}
transaction.operations.push(operation);
Ok(())
}
pub async fn commit_transaction(
&mut self,
transaction_id: TransactionId,
index: &mut ShardexIndex,
) -> Result<(), ShardexError> {
let mut transaction = self
.active_transactions
.remove(&transaction_id)
.ok_or_else(|| ShardexError::InvalidInput {
field: "transaction_id".to_string(),
reason: format!("Transaction {} not found", transaction_id),
suggestion: "Check transaction ID".to_string(),
})?;
transaction.state = TransactionState::Committing;
for operation in &transaction.operations {
self.validate_operation_for_commit(operation, index)?;
}
let wal_transaction = WalTransaction::with_id_and_timestamp(
transaction_id,
transaction.start_time,
transaction.operations.clone(),
)?;
self.wal_segment.append_transaction(&wal_transaction)?;
self.wal_segment.sync()?;
for operation in &transaction.operations {
self.apply_operation_to_index(operation, index).await?;
}
transaction.state = TransactionState::Committed;
tracing::debug!(
"Transaction {} (id: {}) committed with {} operations",
transaction_id,
transaction.id,
transaction.operations.len()
);
Ok(())
}
pub async fn abort_transaction(&mut self, transaction_id: TransactionId) -> Result<(), ShardexError> {
if let Some(mut transaction) = self.active_transactions.remove(&transaction_id) {
transaction.state = TransactionState::Aborted;
tracing::debug!(
"Transaction {} aborted with {} operations",
transaction_id,
transaction.operations.len()
);
}
Ok(())
}
fn validate_operation(&self, operation: &WalOperation) -> Result<(), ShardexError> {
match operation {
WalOperation::StoreDocumentText { document_id, text } => {
self.validate_document_id(*document_id)?;
self.validate_text_content(text)?;
}
WalOperation::DeleteDocumentText { document_id } => {
self.validate_document_id(*document_id)?;
}
WalOperation::AddPosting {
document_id,
start,
length,
vector,
} => {
self.validate_document_id(*document_id)?;
self.validate_posting_coordinates(*start, *length)?;
self.validate_vector(vector)?;
}
WalOperation::RemoveDocument { document_id } => {
self.validate_document_id(*document_id)?;
}
}
Ok(())
}
fn validate_operation_for_commit(
&self,
operation: &WalOperation,
index: &ShardexIndex,
) -> Result<(), ShardexError> {
match operation {
WalOperation::StoreDocumentText { text, .. } => {
if !index.has_text_storage() {
return Err(ShardexError::InvalidInput {
field: "text_storage".to_string(),
reason: "Text storage not enabled for this index".to_string(),
suggestion: "Enable text storage in configuration".to_string(),
});
}
let config = index.get_config();
if text.len() > config.max_document_text_size {
return Err(ShardexError::DocumentTooLarge {
size: text.len(),
max_size: config.max_document_text_size,
});
}
}
WalOperation::AddPosting { vector, .. } => {
let config = index.get_config();
if vector.len() != config.vector_size {
return Err(ShardexError::InvalidDimension {
expected: config.vector_size,
actual: vector.len(),
});
}
}
_ => {} }
Ok(())
}
async fn apply_operation_to_index(
&self,
operation: &WalOperation,
index: &mut ShardexIndex,
) -> Result<(), ShardexError> {
tracing::trace!("Applying operation to index: {:?}", operation);
match operation {
WalOperation::StoreDocumentText { document_id, text } => {
index.store_document_text(*document_id, text)?;
}
WalOperation::DeleteDocumentText { document_id } => {
index.delete_document_text(*document_id)?;
}
WalOperation::AddPosting {
document_id,
start,
length,
vector,
} => {
index.add_posting(*document_id, *start, *length, vector.clone())?;
tracing::debug!("Added posting for document: {} at {}:{}", document_id, start, length);
}
WalOperation::RemoveDocument { document_id } => {
index.remove_document(*document_id)?;
tracing::debug!("Removed all postings for document: {}", document_id);
}
}
Ok(())
}
pub async fn cleanup_expired_transactions(&mut self) -> Result<(), ShardexError> {
let now = SystemTime::now();
let mut expired_transactions = Vec::new();
for (id, transaction) in &self.active_transactions {
if let Ok(elapsed) = now.duration_since(transaction.start_time) {
if elapsed > self.transaction_timeout {
expired_transactions.push(*id);
}
}
}
for transaction_id in expired_transactions {
tracing::warn!("Cleaning up expired transaction: {}", transaction_id);
self.abort_transaction(transaction_id).await?;
}
Ok(())
}
pub fn get_transaction_statistics(&self) -> TransactionStatistics {
let oldest_age = self
.active_transactions
.values()
.filter_map(|t| SystemTime::now().duration_since(t.start_time).ok())
.max();
TransactionStatistics {
active_transactions: self.active_transactions.len(),
oldest_transaction_age: oldest_age,
total_operations_pending: self
.active_transactions
.values()
.map(|t| t.operations.len())
.sum(),
}
}
fn validate_document_id(&self, _document_id: DocumentId) -> Result<(), ShardexError> {
Ok(())
}
fn validate_text_content(&self, text: &str) -> Result<(), ShardexError> {
if text.is_empty() {
return Err(ShardexError::InvalidInput {
field: "text".to_string(),
reason: "Document text cannot be empty".to_string(),
suggestion: "Provide non-empty text content".to_string(),
});
}
Ok(())
}
fn validate_posting_coordinates(&self, start: u32, length: u32) -> Result<(), ShardexError> {
if length == 0 {
return Err(ShardexError::InvalidPostingData {
reason: "Posting length cannot be zero".to_string(),
suggestion: "Provide a positive length value".to_string(),
});
}
if start > u32::MAX - length {
return Err(ShardexError::InvalidPostingData {
reason: "Start + length coordinates would overflow".to_string(),
suggestion: "Reduce start position or length".to_string(),
});
}
Ok(())
}
fn validate_vector(&self, vector: &[f32]) -> Result<(), ShardexError> {
if vector.is_empty() {
return Err(ShardexError::InvalidPostingData {
reason: "Vector cannot be empty".to_string(),
suggestion: "Provide a non-empty vector".to_string(),
});
}
for (i, &value) in vector.iter().enumerate() {
if !value.is_finite() {
return Err(ShardexError::InvalidPostingData {
reason: format!("Invalid vector value at index {}: {} (must be finite)", i, value),
suggestion: "Remove NaN or infinite values from vector".to_string(),
});
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ShardexConfig;
use crate::identifiers::ShardId;
use crate::layout::DirectoryLayout;
use crate::shardex_index::ShardexIndex;
use crate::test_utils::TestEnvironment;
use std::sync::Arc;
use std::time::Duration;
async fn setup_coordinator() -> (DocumentTransactionCoordinator, ShardexIndex, TestEnvironment) {
let test_env = TestEnvironment::new("doc_tx_coord_test");
let layout = DirectoryLayout::new(test_env.path());
let wal_segment_path = layout.wal_segment_path(1);
let wal_segment = Arc::new(WalSegment::create(1, wal_segment_path, 8192).unwrap());
let config = ShardexConfig::new()
.directory_path(test_env.path())
.vector_size(3)
.max_document_text_size(1024);
let mut index = ShardexIndex::create_new(layout.clone(), config.clone()).unwrap();
index.enable_text_storage().unwrap();
let shard_id = ShardId::new();
let shard_path = layout.shard_vectors_path(&shard_id);
std::fs::create_dir_all(shard_path.parent().unwrap()).unwrap();
let shard = crate::shard::Shard::create(
shard_id,
1000, config.vector_size,
layout.shards_dir().to_path_buf(),
)
.unwrap();
index.add_shard(shard).unwrap();
println!("Shard count after adding: {}", index.shard_count());
assert!(index.shard_count() > 0, "No shards were added to the index");
let coordinator = DocumentTransactionCoordinator::new(wal_segment, Duration::from_secs(30));
(coordinator, index, test_env)
}
#[tokio::test]
async fn test_transaction_lifecycle_basic() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test document".to_string(),
};
coordinator.add_operation(tx_id, operation).await.unwrap();
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 1);
assert_eq!(stats.total_operations_pending, 1);
coordinator
.commit_transaction(tx_id, &mut index)
.await
.unwrap();
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 0);
assert_eq!(stats.total_operations_pending, 0);
}
#[tokio::test]
async fn test_transaction_abort() {
let (mut coordinator, _index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test document".to_string(),
};
coordinator.add_operation(tx_id, operation).await.unwrap();
coordinator.abort_transaction(tx_id).await.unwrap();
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 0);
assert_eq!(stats.total_operations_pending, 0);
}
#[tokio::test]
async fn test_multiple_operations_in_transaction() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let doc_id2 = DocumentId::new();
let operations = vec![
WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test document".to_string(),
},
WalOperation::StoreDocumentText {
document_id: doc_id2,
text: "Another document".to_string(),
},
WalOperation::DeleteDocumentText { document_id: doc_id },
];
for operation in operations {
coordinator.add_operation(tx_id, operation).await.unwrap();
}
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 1);
assert_eq!(stats.total_operations_pending, 3);
coordinator
.commit_transaction(tx_id, &mut index)
.await
.unwrap();
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 0);
}
#[tokio::test]
async fn test_operation_validation() {
let (mut coordinator, _index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let invalid_text_op = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "".to_string(),
};
let result = coordinator.add_operation(tx_id, invalid_text_op).await;
assert!(result.is_err());
let invalid_vector_op = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 10,
vector: vec![],
};
let result = coordinator.add_operation(tx_id, invalid_vector_op).await;
assert!(result.is_err());
let nan_vector_op = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 10,
vector: vec![1.0, f32::NAN, 3.0],
};
let result = coordinator.add_operation(tx_id, nan_vector_op).await;
assert!(result.is_err());
let overflow_op = WalOperation::AddPosting {
document_id: doc_id,
start: u32::MAX,
length: 1,
vector: vec![1.0, 2.0, 3.0],
};
let result = coordinator.add_operation(tx_id, overflow_op).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_commit_validation() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let wrong_dim_op = WalOperation::AddPosting {
document_id: doc_id,
start: 0,
length: 10,
vector: vec![1.0, 2.0, 3.0, 4.0], };
coordinator
.add_operation(tx_id, wrong_dim_op)
.await
.unwrap();
let result = coordinator.commit_transaction(tx_id, &mut index).await;
assert!(result.is_err());
if let Err(ShardexError::InvalidDimension { expected, actual }) = result {
assert_eq!(expected, 3);
assert_eq!(actual, 4);
} else {
panic!("Expected InvalidDimension error");
}
}
#[tokio::test]
async fn test_transaction_not_found_errors() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let fake_tx_id = TransactionId::new();
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test".to_string(),
};
let result = coordinator.add_operation(fake_tx_id, operation).await;
assert!(result.is_err());
let result = coordinator.commit_transaction(fake_tx_id, &mut index).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_inactive_transaction_error() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test".to_string(),
};
coordinator
.add_operation(tx_id, operation.clone())
.await
.unwrap();
coordinator
.commit_transaction(tx_id, &mut index)
.await
.unwrap();
let result = coordinator.add_operation(tx_id, operation).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_transaction_timeout_cleanup() {
let test_env = TestEnvironment::new("timeout_test");
let layout = DirectoryLayout::new(test_env.path());
let wal_segment_path = layout.wal_segment_path(1);
let wal_segment = Arc::new(WalSegment::create(1, wal_segment_path, 8192).unwrap());
let mut coordinator = DocumentTransactionCoordinator::new(
wal_segment,
Duration::from_millis(10), );
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test".to_string(),
};
coordinator.add_operation(tx_id, operation).await.unwrap();
tokio::time::sleep(Duration::from_millis(20)).await;
coordinator.cleanup_expired_transactions().await.unwrap();
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 0);
}
#[tokio::test]
async fn test_transaction_statistics() {
let (mut coordinator, _index, _test_env) = setup_coordinator().await;
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 0);
assert_eq!(stats.total_operations_pending, 0);
assert!(stats.oldest_transaction_age.is_none());
let tx1 = coordinator.begin_transaction().await.unwrap();
let tx2 = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: "Test".to_string(),
};
coordinator
.add_operation(tx1, operation.clone())
.await
.unwrap();
coordinator
.add_operation(tx2, operation.clone())
.await
.unwrap();
coordinator.add_operation(tx2, operation).await.unwrap();
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 2);
assert_eq!(stats.total_operations_pending, 3);
assert!(stats.oldest_transaction_age.is_some());
}
#[tokio::test]
async fn test_large_document_validation() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let max_size = 1024; let large_text = "a".repeat(max_size);
let large_doc_op = WalOperation::StoreDocumentText {
document_id: doc_id,
text: large_text,
};
coordinator
.add_operation(tx_id, large_doc_op)
.await
.unwrap();
let result = coordinator.commit_transaction(tx_id, &mut index).await;
match result {
Ok(()) => {
}
Err(ShardexError::DocumentTooLarge { size, max_size: limit }) => {
assert_eq!(size, max_size);
assert!(limit < max_size);
}
Err(e) => {
panic!("Unexpected error for large document: {:?}", e);
}
}
}
#[tokio::test]
async fn test_concurrent_transaction_cleanup() {
use std::sync::Arc;
use tokio::sync::Barrier;
let test_env = TestEnvironment::new("concurrent_cleanup_test");
let layout = DirectoryLayout::new(test_env.path());
let wal_segment_path = layout.wal_segment_path(1);
let wal_segment = Arc::new(WalSegment::create(1, wal_segment_path, 8192).unwrap());
let coordinator = Arc::new(tokio::sync::Mutex::new(DocumentTransactionCoordinator::new(
wal_segment,
Duration::from_millis(50), )));
let barrier = Arc::new(Barrier::new(3));
let mut handles = vec![];
for i in 0..3 {
let coordinator = coordinator.clone();
let barrier = barrier.clone();
let handle = tokio::spawn(async move {
let tx_id = {
let mut coord = coordinator.lock().await;
coord.begin_transaction().await.unwrap()
};
let doc_id = DocumentId::new();
let operation = WalOperation::StoreDocumentText {
document_id: doc_id,
text: format!("Test document {}", i),
};
{
let mut coord = coordinator.lock().await;
coord.add_operation(tx_id, operation).await.unwrap();
}
barrier.wait().await;
tokio::time::sleep(Duration::from_millis(100)).await;
let stats = {
let mut coord = coordinator.lock().await;
coord.cleanup_expired_transactions().await.unwrap();
coord.get_transaction_statistics()
};
stats
});
handles.push(handle);
}
let mut results = vec![];
for handle in handles {
results.push(handle.await);
}
let final_stats = {
let coord = coordinator.lock().await;
coord.get_transaction_statistics()
};
assert_eq!(final_stats.active_transactions, 0);
for result in results {
assert!(result.is_ok());
}
}
#[tokio::test]
async fn test_transaction_boundary_conditions() {
let (mut coordinator, mut index, _test_env) = setup_coordinator().await;
let tx_id = coordinator.begin_transaction().await.unwrap();
let doc_id = DocumentId::new();
let max_safe_start = u32::MAX - 1000;
let max_safe_length = 1000;
let boundary_op = WalOperation::AddPosting {
document_id: doc_id,
start: max_safe_start,
length: max_safe_length,
vector: vec![1.0, 2.0, 3.0],
};
coordinator.add_operation(tx_id, boundary_op).await.unwrap();
coordinator
.commit_transaction(tx_id, &mut index)
.await
.unwrap();
let tx_id2 = coordinator.begin_transaction().await.unwrap();
let overflow_op = WalOperation::AddPosting {
document_id: doc_id,
start: u32::MAX - 1,
length: 2, vector: vec![1.0, 2.0, 3.0],
};
let result = coordinator.add_operation(tx_id2, overflow_op).await;
assert!(result.is_err()); }
#[tokio::test]
async fn test_transaction_count_limit() {
let test_env = TestEnvironment::new("transaction_limit_test");
let layout = DirectoryLayout::new(test_env.path());
let wal_segment_path = layout.wal_segment_path(1);
let wal_segment = Arc::new(WalSegment::create(1, wal_segment_path, 8192).unwrap());
let mut coordinator = DocumentTransactionCoordinator::with_max_transactions(
wal_segment,
Duration::from_secs(30),
2, );
let tx1 = coordinator.begin_transaction().await;
assert!(tx1.is_ok());
let tx2 = coordinator.begin_transaction().await;
assert!(tx2.is_ok());
let tx3 = coordinator.begin_transaction().await;
assert!(tx3.is_err());
if let Err(ShardexError::InvalidInput { field, reason, .. }) = tx3 {
assert_eq!(field, "transaction_count");
assert!(reason.contains("Maximum active transaction limit reached"));
assert!(reason.contains("2"));
} else {
panic!("Expected InvalidInput error for transaction limit");
}
let stats = coordinator.get_transaction_statistics();
assert_eq!(stats.active_transactions, 2);
}
}