use shardex::concurrent::ConcurrentShardex;
use shardex::cow_index::CowShardexIndex;
use shardex::shardex_index::ShardexIndex;
use shardex::ShardexConfig;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::task::JoinSet;
use tokio::time::timeout;
mod common;
use common::{create_test_concurrent_shardex, TestEnvironment};
#[tokio::test]
async fn test_concurrent_reader_writer_failures() {
let _test_env = TestEnvironment::new("test_concurrent_reader_writer_failures");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
let failure_count = Arc::new(AtomicUsize::new(0));
let success_count = Arc::new(AtomicUsize::new(0));
let mut tasks = JoinSet::new();
for reader_id in 0..5 {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let failure_counter = Arc::clone(&failure_count);
let success_counter = Arc::clone(&success_count);
tasks.spawn(async move {
for _op_id in 0..20 {
let result = concurrent_clone.read_operation(|index| {
if reader_id % 3 == 0 && _op_id % 7 == 0 {
return Err(shardex::error::ShardexError::InvalidInput {
field: "simulated_read_failure".to_string(),
reason: "Test failure injection".to_string(),
suggestion: "This is a simulated failure for testing".to_string(),
});
}
Ok(index.shard_count())
});
match result {
Ok(_) => success_counter.fetch_add(1, Ordering::SeqCst),
Err(_) => failure_counter.fetch_add(1, Ordering::SeqCst),
};
tokio::time::sleep(Duration::from_millis(1)).await;
}
});
}
for writer_id in 0..3 {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let failure_counter = Arc::clone(&failure_count);
let success_counter = Arc::clone(&success_count);
tasks.spawn(async move {
for op_id in 0..10 {
let result = concurrent_clone
.write_operation(|_writer| {
if writer_id == 1 && op_id % 5 == 0 {
return Err(shardex::error::ShardexError::InvalidInput {
field: "simulated_write_failure".to_string(),
reason: "Test failure injection".to_string(),
suggestion: "This is a simulated failure for testing".to_string(),
});
}
Ok(writer_id + op_id)
})
.await;
match result {
Ok(_) => success_counter.fetch_add(1, Ordering::SeqCst),
Err(_) => failure_counter.fetch_add(1, Ordering::SeqCst),
};
tokio::time::sleep(Duration::from_millis(5)).await;
}
});
}
let test_timeout = Duration::from_secs(30);
let test_result = timeout(test_timeout, async {
while let Some(result) = tasks.join_next().await {
result.expect("Task should not panic even with simulated failures");
}
})
.await;
assert!(test_result.is_ok(), "Test timed out");
let total_failures = failure_count.load(Ordering::SeqCst);
let total_successes = success_count.load(Ordering::SeqCst);
assert!(total_successes > 0, "Should have some successful operations");
assert!(total_failures > 0, "Should have some simulated failures");
println!(
"Completed with {} successes and {} expected failures",
total_successes, total_failures
);
}
#[tokio::test]
async fn test_memory_pressure_behavior() {
let _test_env = TestEnvironment::new("test_memory_pressure_behavior");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
let mut tasks = JoinSet::new();
let memory_pressure_operations = Arc::new(AtomicUsize::new(0));
for writer_id in 0..20 {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let ops_counter = Arc::clone(&memory_pressure_operations);
tasks.spawn(async move {
for _op_id in 0..3 {
let start_time = Instant::now();
let result = concurrent_clone
.write_operation(|_writer| {
let _large_vec: Vec<u8> = vec![0; 1024 * 1024]; std::thread::sleep(Duration::from_millis(10));
Ok(writer_id)
})
.await;
if result.is_ok() {
ops_counter.fetch_add(1, Ordering::SeqCst);
}
let duration = start_time.elapsed();
if duration > Duration::from_millis(100) {
println!(
"Writer {} operation took {:?} (possible memory pressure)",
writer_id, duration
);
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
});
}
let concurrent_monitor: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let monitor_task = tokio::spawn(async move {
let mut high_contention_count = 0;
for _ in 0..10 {
tokio::time::sleep(Duration::from_millis(200)).await;
let stats = concurrent_monitor.coordination_stats().await;
if stats.contended_writes > 0 {
high_contention_count += 1;
println!(
"Memory pressure indicator - contended writes: {}",
stats.contended_writes
);
}
}
high_contention_count
});
let test_timeout = Duration::from_secs(120);
let test_result = timeout(test_timeout, async {
while let Some(result) = tasks.join_next().await {
result.expect("Memory pressure test task should not panic");
}
})
.await;
assert!(test_result.is_ok(), "Memory pressure test timed out");
let contention_observations = monitor_task.await.expect("Monitor task failed");
let completed_operations = memory_pressure_operations.load(Ordering::SeqCst);
println!(
"Completed {} operations under memory pressure with {} contention observations",
completed_operations, contention_observations
);
assert!(
completed_operations > 0,
"Should complete some operations even under pressure"
);
}
#[tokio::test]
async fn test_cow_operation_crash_recovery() {
let _test_env = TestEnvironment::new("test_cow_operation_crash_recovery");
let config = ShardexConfig::new()
.directory_path(_test_env.path())
.vector_size(64)
.shard_size(100);
let index = ShardexIndex::create(config.clone()).expect("Failed to create index");
let cow_index = CowShardexIndex::new(index);
let concurrent = ConcurrentShardex::new(cow_index);
for i in 0..5 {
let result = concurrent.write_operation(|_writer| Ok(i)).await;
assert!(result.is_ok(), "Initial setup operations should succeed");
}
let initial_shard_count = concurrent
.read_operation(|index| Ok(index.shard_count()))
.expect("Should be able to read initial state");
drop(concurrent);
let recovered_index = ShardexIndex::open(_test_env.path()).expect("Should be able to recover index");
let recovered_cow = CowShardexIndex::new(recovered_index);
let recovered_concurrent = ConcurrentShardex::new(recovered_cow);
let recovered_shard_count = recovered_concurrent
.read_operation(|index| Ok(index.shard_count()))
.expect("Should be able to read recovered state");
assert_eq!(
initial_shard_count, recovered_shard_count,
"Recovered index should have same shard count as before crash"
);
let post_recovery_result = recovered_concurrent.write_operation(|_writer| Ok(42)).await;
assert!(
post_recovery_result.is_ok(),
"Should be able to perform operations after recovery"
);
println!(
"Successfully recovered from simulated crash. Shard count: initial={}, recovered={}",
initial_shard_count, recovered_shard_count
);
}
#[tokio::test]
async fn test_extreme_read_concurrency() {
let _test_env = TestEnvironment::new("test_extreme_read_concurrency");
let concurrent = Arc::new(create_test_concurrent_shardex(&_test_env));
let read_count = Arc::new(AtomicUsize::new(0));
let mut tasks = JoinSet::new();
for reader_id in 0..100 {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
let counter = Arc::clone(&read_count);
tasks.spawn(async move {
for _op_id in 0..10 {
let result = concurrent_clone.read_operation(|index| {
let shard_count = index.shard_count();
Ok(shard_count + reader_id)
});
if result.is_ok() {
counter.fetch_add(1, Ordering::SeqCst);
}
let delay_ms = (reader_id % 3) + 1;
tokio::time::sleep(Duration::from_millis(delay_ms as u64)).await;
}
});
}
for writer_id in 0..5 {
let concurrent_clone: Arc<ConcurrentShardex> = Arc::clone(&concurrent);
tasks.spawn(async move {
for _op_id in 0..3 {
let _result = concurrent_clone
.write_operation(|_writer| Ok(writer_id))
.await;
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
}
let start_time = Instant::now();
let test_timeout = Duration::from_secs(60);
let test_result = timeout(test_timeout, async {
while let Some(result) = tasks.join_next().await {
result.expect("High concurrency test task should not panic");
}
})
.await;
let duration = start_time.elapsed();
let total_reads = read_count.load(Ordering::SeqCst);
assert!(test_result.is_ok(), "High concurrency test timed out");
assert_eq!(total_reads, 1000, "Should complete all 1000 read operations");
println!(
"Completed {} reads in {:?} with extreme concurrency",
total_reads, duration
);
let reads_per_second = total_reads as f64 / duration.as_secs_f64();
assert!(
reads_per_second > 100.0,
"Should maintain reasonable throughput under high concurrency: {:.1} reads/sec",
reads_per_second
);
}