use crate::cow_index::{CowShardexIndex, IndexWriter};
use crate::error::ShardexError;
use crate::shardex_index::ShardexIndex;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tokio::time::timeout;
use tracing::{debug, info, warn};
use uuid::Uuid;
const DEFAULT_WRITE_TIMEOUT: Duration = Duration::from_secs(30);
const COORDINATION_LOCK_TIMEOUT: Duration = Duration::from_secs(5);
const MAX_PENDING_WRITES: usize = 100;
pub struct ConcurrentShardex {
index: Arc<CowShardexIndex>,
write_coordinator: Arc<Mutex<WriteCoordinator>>,
active_readers: Arc<AtomicUsize>,
epoch: Arc<AtomicU64>,
config: ConcurrencyConfig,
}
#[derive(Debug, Clone)]
pub struct ConcurrencyConfig {
pub write_timeout: Duration,
pub coordination_lock_timeout: Duration,
pub max_pending_writes: usize,
pub enable_detailed_logging: bool,
}
impl Default for ConcurrencyConfig {
fn default() -> Self {
Self {
write_timeout: DEFAULT_WRITE_TIMEOUT,
coordination_lock_timeout: COORDINATION_LOCK_TIMEOUT,
max_pending_writes: MAX_PENDING_WRITES,
enable_detailed_logging: false,
}
}
}
#[derive(Debug)]
struct WriteCoordinator {
active_writer: Option<WriterHandle>,
pending_writes: VecDeque<PendingWrite>,
stats: CoordinationStats,
}
#[derive(Debug)]
struct WriterHandle {
writer_id: Uuid,
}
#[derive(Debug)]
struct PendingWrite {
_operation_id: Uuid,
notify: tokio::sync::oneshot::Sender<()>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WriteOperationType {
AddPostings,
RemoveDocuments,
Flush,
Maintenance,
}
#[derive(Debug, Clone, Default)]
pub struct CoordinationStats {
pub total_writes: u64,
pub contended_writes: u64,
pub total_coordination_wait_time: Duration,
pub max_coordination_wait_time: Duration,
pub timeout_count: u64,
}
impl ConcurrentShardex {
pub fn new(index: CowShardexIndex) -> Self {
Self::with_config(index, ConcurrencyConfig::default())
}
pub fn with_config(index: CowShardexIndex, config: ConcurrencyConfig) -> Self {
if config.enable_detailed_logging {
info!("Initializing ConcurrentShardex with detailed logging enabled");
}
Self {
index: Arc::new(index),
write_coordinator: Arc::new(Mutex::new(WriteCoordinator {
active_writer: None,
pending_writes: VecDeque::new(),
stats: CoordinationStats::default(),
})),
active_readers: Arc::new(AtomicUsize::new(0)),
epoch: Arc::new(AtomicU64::new(1)),
config,
}
}
pub fn read_operation<F, R>(&self, operation: F) -> Result<R, ShardexError>
where
F: FnOnce(&ShardexIndex) -> Result<R, ShardexError> + Send,
R: Send,
{
let previous_readers = self.active_readers.fetch_add(1, Ordering::SeqCst);
let current_epoch = self.epoch.load(Ordering::Acquire);
if self.config.enable_detailed_logging {
debug!(
"Starting read operation: active_readers={}, epoch={}",
previous_readers + 1,
current_epoch
);
}
let index_snapshot = self.index.read();
let result = operation(&index_snapshot);
let remaining_readers = self.active_readers.fetch_sub(1, Ordering::SeqCst) - 1;
if self.config.enable_detailed_logging {
debug!(
"Completed read operation: remaining_readers={}, result={:?}",
remaining_readers,
result.is_ok()
);
}
result
}
pub async fn write_operation<F, R>(&self, operation: F) -> Result<R, ShardexError>
where
F: FnOnce(&mut IndexWriter) -> Result<R, ShardexError> + Send,
R: Send,
{
let operation_id = Uuid::new_v4();
let start_time = Instant::now();
if self.config.enable_detailed_logging {
debug!("Starting write operation: id={}", operation_id);
}
let write_result = timeout(
self.config.write_timeout,
self.perform_coordinated_write(operation_id, operation),
)
.await;
let total_duration = start_time.elapsed();
match write_result {
Ok(result) => {
if self.config.enable_detailed_logging {
debug!(
"Write operation completed: id={}, duration={:?}, success={}",
operation_id,
total_duration,
result.is_ok()
);
}
result
}
Err(_) => {
let mut coordinator = self.write_coordinator.lock().await;
coordinator.stats.timeout_count += 1;
warn!(
"Write operation timed out: id={}, duration={:?}",
operation_id, total_duration
);
Err(ShardexError::Config(format!(
"Write operation timed out after {:?}",
self.config.write_timeout
)))
}
}
}
async fn perform_coordinated_write<F, R>(&self, operation_id: Uuid, operation: F) -> Result<R, ShardexError>
where
F: FnOnce(&mut IndexWriter) -> Result<R, ShardexError> + Send,
R: Send,
{
let coordination_start = Instant::now();
let coordination_result = timeout(
self.config.coordination_lock_timeout,
self.acquire_write_coordination(operation_id),
)
.await;
let coordinator_acquired = match coordination_result {
Ok(result) => result?,
Err(_) => {
return Err(ShardexError::Config(format!(
"Failed to acquire write coordination within {:?}",
self.config.coordination_lock_timeout
)));
}
};
let coordination_duration = coordination_start.elapsed();
if self.config.enable_detailed_logging {
debug!(
"Acquired write coordination: id={}, wait_time={:?}",
operation_id, coordination_duration
);
}
let write_start = Instant::now();
let current_epoch = self.epoch.fetch_add(1, Ordering::SeqCst) + 1;
let mut writer = self.index.clone_for_write()?;
let operation_result = operation(&mut writer);
match operation_result {
Ok(result) => {
writer.commit_changes()?;
self.update_coordination_stats(coordination_duration, false)
.await;
drop(coordinator_acquired);
if self.config.enable_detailed_logging {
debug!(
"Write operation committed: id={}, epoch={}, write_time={:?}",
operation_id,
current_epoch,
write_start.elapsed()
);
}
Ok(result)
}
Err(error) => {
writer.discard();
self.update_coordination_stats(coordination_duration, false)
.await;
drop(coordinator_acquired);
if self.config.enable_detailed_logging {
debug!("Write operation failed: id={}, error={}", operation_id, error);
}
Err(error)
}
}
}
async fn acquire_write_coordination(&self, operation_id: Uuid) -> Result<WriteCoordinationGuard, ShardexError> {
let notify_receiver = {
let mut coordinator = self.write_coordinator.lock().await;
if coordinator.pending_writes.len() >= self.config.max_pending_writes {
return Err(ShardexError::Config(format!(
"Too many pending write operations: {} >= {}",
coordinator.pending_writes.len(),
self.config.max_pending_writes
)));
}
if coordinator.active_writer.is_none() {
let writer_handle = WriterHandle {
writer_id: operation_id,
};
coordinator.active_writer = Some(writer_handle);
coordinator.stats.total_writes += 1;
return Ok(WriteCoordinationGuard {
operation_id,
coordinator: Arc::clone(&self.write_coordinator),
});
}
let (notify_sender, notify_receiver) = tokio::sync::oneshot::channel();
let pending_write = PendingWrite {
_operation_id: operation_id,
notify: notify_sender,
};
coordinator.pending_writes.push_back(pending_write);
coordinator.stats.contended_writes += 1;
notify_receiver
};
match notify_receiver.await {
Ok(()) => {
let mut coordinator = self.write_coordinator.lock().await;
let writer_handle = WriterHandle {
writer_id: operation_id,
};
coordinator.active_writer = Some(writer_handle);
coordinator.stats.total_writes += 1;
Ok(WriteCoordinationGuard {
operation_id,
coordinator: Arc::clone(&self.write_coordinator),
})
}
Err(_) => Err(ShardexError::Config(
"Write coordination channel closed while waiting".to_string(),
)),
}
}
async fn update_coordination_stats(&self, wait_duration: Duration, contended: bool) {
let mut coordinator = self.write_coordinator.lock().await;
coordinator.stats.total_coordination_wait_time += wait_duration;
if wait_duration > coordinator.stats.max_coordination_wait_time {
coordinator.stats.max_coordination_wait_time = wait_duration;
}
if contended {
coordinator.stats.contended_writes += 1;
}
}
pub async fn coordination_stats(&self) -> CoordinationStats {
let coordinator = self.write_coordinator.lock().await;
CoordinationStats {
total_writes: coordinator.stats.total_writes,
contended_writes: coordinator.stats.contended_writes,
total_coordination_wait_time: coordinator.stats.total_coordination_wait_time,
max_coordination_wait_time: coordinator.stats.max_coordination_wait_time,
timeout_count: coordinator.stats.timeout_count,
}
}
pub async fn concurrency_metrics(&self) -> ConcurrencyMetrics {
let active_readers = self.active_readers.load(Ordering::Acquire);
let current_epoch = self.epoch.load(Ordering::Acquire);
let coordinator = self.write_coordinator.lock().await;
let (active_writers, pending_writes) = (
if coordinator.active_writer.is_some() {
1
} else {
0
},
coordinator.pending_writes.len(),
);
ConcurrencyMetrics {
active_readers,
active_writers,
pending_writes,
current_epoch,
}
}
}
struct WriteCoordinationGuard {
operation_id: Uuid,
coordinator: Arc<Mutex<WriteCoordinator>>,
}
impl Drop for WriteCoordinationGuard {
fn drop(&mut self) {
if let Ok(mut coordinator) = self.coordinator.try_lock() {
if let Some(ref active_writer) = coordinator.active_writer {
if active_writer.writer_id == self.operation_id {
coordinator.active_writer = None;
}
}
if let Some(pending) = coordinator.pending_writes.pop_front() {
let _ = pending.notify.send(()); }
}
}
}
#[derive(Debug, Clone)]
pub struct ConcurrencyMetrics {
pub active_readers: usize,
pub active_writers: usize,
pub pending_writes: usize,
pub current_epoch: u64,
}
impl CoordinationStats {
pub fn average_coordination_wait_time(&self) -> Duration {
if self.total_writes > 0 {
self.total_coordination_wait_time / self.total_writes as u32
} else {
Duration::ZERO
}
}
pub fn contention_rate(&self) -> f64 {
if self.total_writes > 0 {
(self.contended_writes as f64 / self.total_writes as f64) * 100.0
} else {
0.0
}
}
pub fn timeout_rate(&self) -> f64 {
if self.total_writes > 0 {
(self.timeout_count as f64 / self.total_writes as f64) * 100.0
} else {
0.0
}
}
}
unsafe impl Send for ConcurrentShardex {}
unsafe impl Sync for ConcurrentShardex {}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ShardexConfig;
use crate::test_utils::TestEnvironment;
use std::sync::Arc;
use tokio::task::JoinSet;
#[tokio::test]
async fn test_concurrent_shardex_creation() {
let _test_env = TestEnvironment::new("test_concurrent_shardex_creation");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = ConcurrentShardex::new(cow_index);
assert_eq!(concurrent.active_readers.load(Ordering::Acquire), 0);
assert_eq!(concurrent.epoch.load(Ordering::Acquire), 1);
}
#[tokio::test]
async fn test_non_blocking_read_operations() {
let _test_env = TestEnvironment::new("test_non_blocking_read_operations");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = Arc::new(ConcurrentShardex::new(cow_index));
let mut tasks = JoinSet::new();
for _i in 0..10 {
let concurrent_clone = Arc::clone(&concurrent);
tasks.spawn(async move {
concurrent_clone.read_operation(|index| {
let shard_count = index.shard_count();
Ok(shard_count)
})
});
}
let mut results = Vec::new();
while let Some(result) = tasks.join_next().await {
let result = result
.expect("Task should not panic")
.expect("Read operation should succeed");
results.push(result);
}
assert_eq!(results.len(), 10);
assert!(results.iter().all(|&count| count == results[0]));
}
#[tokio::test]
async fn test_write_operation_coordination() {
let _test_env = TestEnvironment::new("test_write_operation_coordination");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = ConcurrentShardex::new(cow_index);
let result = concurrent
.write_operation(|writer| {
let shard_count = writer.index().shard_count();
Ok(shard_count)
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 0); }
#[tokio::test]
async fn test_concurrent_readers_during_write() {
let _test_env = TestEnvironment::new("test_concurrent_readers_during_write");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = Arc::new(ConcurrentShardex::new(cow_index));
let mut tasks = JoinSet::new();
for i in 0..5 {
let concurrent_clone = Arc::clone(&concurrent);
tasks.spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(i * 10)).await;
concurrent_clone.read_operation(|index| {
std::thread::sleep(std::time::Duration::from_millis(50));
Ok(index.shard_count())
})
});
}
let concurrent_clone = Arc::clone(&concurrent);
tasks.spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(25)).await;
concurrent_clone
.write_operation(|writer| {
std::thread::sleep(std::time::Duration::from_millis(50));
Ok(writer.index().shard_count())
})
.await
});
let mut success_count = 0;
while let Some(result) = tasks.join_next().await {
if result.expect("Task should not panic").is_ok() {
success_count += 1;
}
}
assert_eq!(success_count, 6); }
#[tokio::test]
async fn test_coordination_statistics() {
let _test_env = TestEnvironment::new("test_coordination_statistics");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = ConcurrentShardex::new(cow_index);
let initial_stats = concurrent.coordination_stats().await;
assert_eq!(initial_stats.total_writes, 0);
assert_eq!(initial_stats.contended_writes, 0);
assert_eq!(initial_stats.timeout_count, 0);
let _result = concurrent
.write_operation(|writer| Ok(writer.index().shard_count()))
.await
.expect("Write operation should succeed");
let updated_stats = concurrent.coordination_stats().await;
assert_eq!(updated_stats.total_writes, 1);
}
#[tokio::test]
async fn test_concurrency_metrics() {
let _test_env = TestEnvironment::new("test_concurrency_metrics");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = Arc::new(ConcurrentShardex::new(cow_index));
let initial_metrics = concurrent.concurrency_metrics().await;
assert_eq!(initial_metrics.active_readers, 0);
assert_eq!(initial_metrics.active_writers, 0);
assert_eq!(initial_metrics.current_epoch, 1);
let result = concurrent.read_operation(|index| {
Ok(index.shard_count())
});
assert!(result.is_ok());
let final_metrics = concurrent.concurrency_metrics().await;
assert_eq!(final_metrics.active_readers, 0);
}
#[tokio::test]
async fn test_write_operation_timeout() {
let _test_env = TestEnvironment::new("test_write_operation_timeout");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let timeout_config = ConcurrencyConfig {
write_timeout: Duration::from_millis(30), ..Default::default()
};
let concurrent = ConcurrentShardex::with_config(cow_index, timeout_config);
let start_time = std::time::Instant::now();
let result = concurrent
.write_operation(|_writer| {
Ok(1)
})
.await;
let duration = start_time.elapsed();
assert!(
result.is_ok(),
"Fast operation should not timeout, result: {:?}",
result
);
assert!(
duration < Duration::from_millis(100),
"Fast operation took too long: {:?}",
duration
);
let stats = concurrent.coordination_stats().await;
assert_eq!(stats.total_writes, 1);
assert_eq!(stats.timeout_count, 0);
}
#[tokio::test]
async fn test_configuration_options() {
let _test_env = TestEnvironment::new("test_configuration_options");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(128);
let index = crate::shardex_index::ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let custom_config = ConcurrencyConfig {
write_timeout: Duration::from_secs(60),
coordination_lock_timeout: Duration::from_secs(10),
max_pending_writes: 50,
enable_detailed_logging: true,
};
let concurrent = ConcurrentShardex::with_config(cow_index, custom_config.clone());
assert_eq!(concurrent.config.write_timeout, custom_config.write_timeout);
assert_eq!(
concurrent.config.coordination_lock_timeout,
custom_config.coordination_lock_timeout
);
assert_eq!(concurrent.config.max_pending_writes, custom_config.max_pending_writes);
assert_eq!(
concurrent.config.enable_detailed_logging,
custom_config.enable_detailed_logging
);
}
}