use shardex::concurrent::{ConcurrencyConfig, ConcurrentShardex};
use shardex::cow_index::CowShardexIndex;
use shardex::shardex_index::ShardexIndex;
use shardex::ShardexConfig;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::task::JoinSet;
use tokio::time::timeout;
mod common;
use common::{create_test_concurrent_shardex, TestEnvironment};
fn create_test_concurrent_shardex_with_config(
test_env: &TestEnvironment,
concurrency_config: ConcurrencyConfig,
) -> ConcurrentShardex {
let config = ShardexConfig::new()
.directory_path(test_env.path())
.vector_size(64)
.shard_size(100);
let index = ShardexIndex::create(config).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
ConcurrentShardex::with_config(cow_index, concurrency_config)
}
#[tokio::test]
async fn test_high_contention_reader_performance() {
let _test_env = TestEnvironment::new("test_high_contention_reader_performance");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
const NUM_READERS: usize = 50;
const READS_PER_READER: usize = 20;
let start_time = Instant::now();
let mut tasks = JoinSet::new();
for reader_id in 0..NUM_READERS {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
tasks.spawn(async move {
let mut successful_reads = 0;
let reader_start = Instant::now();
for _ in 0..READS_PER_READER {
let result = concurrent_clone.read_operation(|index| {
let shard_count = index.shard_count();
std::thread::sleep(std::time::Duration::from_micros(10));
Ok(shard_count)
});
if result.is_ok() {
successful_reads += 1;
}
}
let reader_duration = reader_start.elapsed();
(reader_id, successful_reads, reader_duration)
});
}
let mut results = Vec::new();
while let Some(result) = tasks.join_next().await {
let (reader_id, successful_reads, duration) = result.expect("Task should not panic");
results.push((reader_id, successful_reads, duration));
}
let total_duration = start_time.elapsed();
assert_eq!(results.len(), NUM_READERS);
for (reader_id, successful_reads, _duration) in &results {
assert_eq!(
*successful_reads, READS_PER_READER,
"Reader {} failed some read operations",
reader_id
);
}
let max_reader_duration = results
.iter()
.map(|(_, _, duration)| *duration)
.max()
.unwrap();
let base_expected_time = std::time::Duration::from_millis(READS_PER_READER as u64);
let reasonable_overhead_factor = 5; let max_acceptable_duration = base_expected_time * reasonable_overhead_factor;
assert!(
max_reader_duration < max_acceptable_duration,
"Reader took too long, suggesting blocking: max_duration = {:?}, expected < {:?}",
max_reader_duration,
max_acceptable_duration
);
let total_expected_time = base_expected_time; let max_acceptable_total = total_expected_time * reasonable_overhead_factor;
assert!(
total_duration < max_acceptable_total,
"Total test took too long, suggesting serialization: total = {:?}, expected < {:?}",
total_duration,
max_acceptable_total
);
println!(
"High contention test: {} readers, {} reads each, total time: {:?}, max reader time: {:?}",
NUM_READERS, READS_PER_READER, total_duration, max_reader_duration
);
}
#[tokio::test]
async fn test_mixed_read_write_operations_no_deadlock() {
let _test_env = TestEnvironment::new("test_mixed_read_write_operations_no_deadlock");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
const NUM_READERS: usize = 20;
const NUM_WRITERS: usize = 5;
const OPERATIONS_PER_TASK: usize = 10;
let successful_operations = Arc::new(AtomicUsize::new(0));
let mut tasks = JoinSet::new();
for reader_id in 0..NUM_READERS {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let success_counter = Arc::clone(&successful_operations);
tasks.spawn(async move {
for op_id in 0..OPERATIONS_PER_TASK {
let result = concurrent_clone.read_operation(|index| {
let work_duration = (reader_id * op_id % 3) + 1; std::thread::sleep(std::time::Duration::from_millis(work_duration as u64));
Ok(index.shard_count())
});
if result.is_ok() {
success_counter.fetch_add(1, Ordering::SeqCst);
}
}
});
}
for writer_id in 0..NUM_WRITERS {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let success_counter = Arc::clone(&successful_operations);
tasks.spawn(async move {
for op_id in 0..OPERATIONS_PER_TASK {
let result = concurrent_clone
.write_operation(|writer| {
let work_duration = (writer_id * op_id % 5) + 1; std::thread::sleep(std::time::Duration::from_millis(work_duration as u64));
Ok(writer.index().shard_count())
})
.await;
if result.is_ok() {
success_counter.fetch_add(1, Ordering::SeqCst);
}
}
});
}
let test_timeout = Duration::from_secs(30); let test_result = timeout(test_timeout, async {
let mut completed_tasks = 0;
let total_tasks = NUM_READERS + NUM_WRITERS;
while let Some(result) = tasks.join_next().await {
result.expect("Task should not panic");
completed_tasks += 1;
if completed_tasks % 10 == 0 {
println!("Completed {}/{} tasks", completed_tasks, total_tasks);
}
}
println!("All {} tasks completed successfully", completed_tasks);
})
.await;
if test_result.is_err() {
let partial_operations = successful_operations.load(Ordering::SeqCst);
panic!(
"Test timed out after 30s - possible deadlock detected. \
Completed operations: {}/{}, \
Consider debugging concurrency coordination logic",
partial_operations,
(NUM_READERS + NUM_WRITERS) * OPERATIONS_PER_TASK
);
}
let total_successful = successful_operations.load(Ordering::SeqCst);
let expected_operations = (NUM_READERS + NUM_WRITERS) * OPERATIONS_PER_TASK;
assert!(
total_successful >= expected_operations * 8 / 10,
"Too many operations failed: {}/{}",
total_successful,
expected_operations
);
println!(
"Mixed operations test: {}/{} operations succeeded",
total_successful, expected_operations
);
}
#[tokio::test]
async fn test_write_operation_timeout_behavior() {
let _test_env = TestEnvironment::new("test_write_operation_timeout_behavior");
let timeout_config = ConcurrencyConfig {
write_timeout: Duration::from_millis(100),
coordination_lock_timeout: Duration::from_millis(50),
..Default::default()
};
let concurrent = create_test_concurrent_shardex_with_config(&_test_env, timeout_config);
let start_time = Instant::now();
let result = concurrent
.write_operation(|_writer| {
Ok(0)
})
.await;
let duration = start_time.elapsed();
if result.is_ok() {
assert!(
duration < Duration::from_millis(100),
"Operation should complete quickly: {:?}",
duration
);
} else {
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("timed out") || error_msg.contains("contention") || error_msg.contains("coordination"),
"Unexpected error: {}",
error_msg
);
}
}
#[tokio::test]
async fn test_coordination_statistics_accuracy() {
let _test_env = TestEnvironment::new("test_coordination_statistics_accuracy");
let concurrent = create_test_concurrent_shardex(&_test_env);
let initial_stats = concurrent.coordination_stats().await;
assert_eq!(initial_stats.total_writes, 0);
assert_eq!(initial_stats.contended_writes, 0);
assert_eq!(initial_stats.timeout_count, 0);
const NUM_WRITES: u64 = 5;
for _ in 0..NUM_WRITES {
let result = concurrent
.write_operation(|writer| {
Ok(writer.index().shard_count())
})
.await;
assert!(result.is_ok(), "Write operation should succeed");
}
let updated_stats = concurrent.coordination_stats().await;
assert_eq!(updated_stats.total_writes, NUM_WRITES);
assert_eq!(updated_stats.timeout_count, 0);
assert!(
updated_stats.contention_rate() <= 20.0,
"Contention rate too high: {:.2}%",
updated_stats.contention_rate()
);
println!(
"Coordination statistics: total_writes={}, contention_rate={:.2}%, timeout_rate={:.2}%",
updated_stats.total_writes,
updated_stats.contention_rate(),
updated_stats.timeout_rate()
);
}
#[tokio::test]
async fn test_concurrency_metrics_real_time_tracking() {
let _test_env = TestEnvironment::new("test_concurrency_metrics_real_time_tracking");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
let initial_metrics = concurrent.concurrency_metrics().await;
assert_eq!(initial_metrics.active_readers, 0);
assert_eq!(initial_metrics.active_writers, 0);
assert_eq!(initial_metrics.pending_writes, 0);
let result = concurrent.read_operation(|index| {
Ok(index.shard_count())
});
assert!(result.is_ok(), "Read operation should succeed");
let final_metrics = concurrent.concurrency_metrics().await;
assert_eq!(
final_metrics.active_readers, 0,
"Should have 0 active readers after read completes"
);
assert!(
final_metrics.current_epoch >= initial_metrics.current_epoch,
"Epoch should not go backwards: final={}, initial={}",
final_metrics.current_epoch,
initial_metrics.current_epoch
);
}
#[tokio::test]
async fn test_reader_writer_isolation() {
let _test_env = TestEnvironment::new("test_reader_writer_isolation");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let reader_snapshot = Arc::new(std::sync::Mutex::new(None));
let reader_snapshot_clone = Arc::clone(&reader_snapshot);
let reader_task = tokio::spawn(async move {
concurrent_clone.read_operation(|index| {
let shard_count = index.shard_count();
if let Ok(mut snapshot) = reader_snapshot_clone.lock() {
*snapshot = Some(shard_count);
}
std::thread::sleep(Duration::from_millis(100));
Ok(shard_count)
})
});
tokio::time::sleep(Duration::from_millis(25)).await;
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let writer_task = tokio::spawn(async move {
concurrent_clone
.write_operation(|writer| {
let _current_count = writer.index().shard_count();
std::thread::sleep(Duration::from_millis(50));
Ok(())
})
.await
});
let reader_result = reader_task.await.expect("Reader task should complete");
let writer_result = writer_task.await.expect("Writer task should complete");
assert!(reader_result.is_ok(), "Reader should succeed");
assert!(writer_result.is_ok(), "Writer should succeed");
let snapshot_value = {
let snapshot = reader_snapshot.lock().unwrap();
*snapshot
};
assert!(snapshot_value.is_some(), "Reader should have captured a snapshot");
let current_state = concurrent
.read_operation(|index| Ok(index.shard_count()))
.expect("Current state read should succeed");
println!(
"Reader isolation test: snapshot={:?}, current_state={}",
snapshot_value, current_state
);
}
#[tokio::test]
async fn test_epoch_based_coordination() {
let _test_env = TestEnvironment::new("test_epoch_based_coordination");
let concurrent = create_test_concurrent_shardex(&_test_env);
let initial_metrics = concurrent.concurrency_metrics().await;
let initial_epoch = initial_metrics.current_epoch;
let mut previous_epoch = initial_epoch;
for i in 0..5 {
let result = concurrent
.write_operation(|writer| Ok(writer.index().shard_count()))
.await;
assert!(result.is_ok(), "Write operation {} should succeed", i);
let current_metrics = concurrent.concurrency_metrics().await;
assert!(
current_metrics.current_epoch > previous_epoch,
"Epoch should advance after write operation {}: {} <= {}",
i,
current_metrics.current_epoch,
previous_epoch
);
previous_epoch = current_metrics.current_epoch;
let read_result = concurrent.read_operation(|index| Ok(index.shard_count()));
assert!(read_result.is_ok(), "Read operation {} should succeed", i);
let after_read_metrics = concurrent.concurrency_metrics().await;
assert_eq!(
after_read_metrics.current_epoch, previous_epoch,
"Read operation should not advance epoch"
);
}
let final_epoch = concurrent.concurrency_metrics().await.current_epoch;
assert!(
final_epoch > initial_epoch + 4,
"Epoch should have advanced significantly: {} vs {}",
final_epoch,
initial_epoch
);
}
#[tokio::test]
async fn test_graceful_error_handling() {
let _test_env = TestEnvironment::new("test_graceful_error_handling");
let concurrent = create_test_concurrent_shardex(&_test_env);
let read_result: Result<(), _> =
concurrent.read_operation(|_index| Err(shardex::ShardexError::Config("Simulated read error".to_string())));
assert!(read_result.is_err());
assert!(read_result
.unwrap_err()
.to_string()
.contains("Simulated read error"));
let write_result: Result<(), _> = concurrent
.write_operation(|_writer| Err(shardex::ShardexError::Config("Simulated write error".to_string())))
.await;
assert!(write_result.is_err());
assert!(write_result
.unwrap_err()
.to_string()
.contains("Simulated write error"));
let recovery_read = concurrent.read_operation(|index| Ok(index.shard_count()));
assert!(recovery_read.is_ok(), "System should recover from errors");
let recovery_write = concurrent
.write_operation(|writer| Ok(writer.index().shard_count()))
.await;
assert!(recovery_write.is_ok(), "System should recover from errors");
}
#[tokio::test]
async fn test_stress_concurrent_operations() {
let _test_env = TestEnvironment::new("test_stress_concurrent_operations");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
const STRESS_DURATION: Duration = Duration::from_secs(10);
const NUM_READER_TASKS: usize = 50;
const NUM_WRITER_TASKS: usize = 10;
let start_time = Instant::now();
let operation_counter = Arc::new(AtomicUsize::new(0));
let mut tasks = JoinSet::new();
for reader_id in 0..NUM_READER_TASKS {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let counter_clone = Arc::clone(&operation_counter);
tasks.spawn(async move {
let mut operations = 0;
while start_time.elapsed() < STRESS_DURATION {
let result = concurrent_clone.read_operation(|index| {
let work_us = (reader_id % 10) + 1;
std::thread::sleep(std::time::Duration::from_micros(work_us as u64));
Ok(index.shard_count())
});
if result.is_ok() {
operations += 1;
counter_clone.fetch_add(1, Ordering::SeqCst);
}
tokio::time::sleep(Duration::from_micros(100)).await;
}
operations
});
}
for writer_id in 0..NUM_WRITER_TASKS {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let counter_clone = Arc::clone(&operation_counter);
tasks.spawn(async move {
let mut operations = 0;
while start_time.elapsed() < STRESS_DURATION {
let result = concurrent_clone
.write_operation(|writer| {
let work_ms = (writer_id % 5) + 1;
std::thread::sleep(std::time::Duration::from_millis(work_ms as u64));
Ok(writer.index().shard_count())
})
.await;
if result.is_ok() {
operations += 1;
counter_clone.fetch_add(1, Ordering::SeqCst);
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
operations
});
}
let mut total_reader_ops = 0;
let mut total_writer_ops = 0;
let mut task_count = 0;
while let Some(result) = tasks.join_next().await {
let operations = result.expect("Stress task should not panic");
if task_count < NUM_READER_TASKS {
total_reader_ops += operations;
} else {
total_writer_ops += operations;
}
task_count += 1;
}
let actual_duration = start_time.elapsed();
let total_operations = operation_counter.load(Ordering::SeqCst);
println!(
"Stress test completed: {} total operations in {:?} ({:.1} ops/sec)",
total_operations,
actual_duration,
total_operations as f64 / actual_duration.as_secs_f64()
);
println!(
"Operation breakdown: {} reader ops, {} writer ops",
total_reader_ops, total_writer_ops
);
assert!(
total_operations > 100,
"Stress test should achieve reasonable throughput: {} operations",
total_operations
);
assert!(
total_reader_ops > total_writer_ops,
"Readers should perform more operations than writers: {} vs {}",
total_reader_ops,
total_writer_ops
);
let final_metrics = concurrent.concurrency_metrics().await;
assert_eq!(
final_metrics.active_readers, 0,
"No readers should be active after stress test"
);
assert_eq!(
final_metrics.active_writers, 0,
"No writers should be active after stress test"
);
}