use rust_keyvault::{
key::{SecretKey, VersionedKey},
storage::{FileStore, KeyStore, MemoryStore, StorageConfig},
Algorithm, KeyId, KeyMetadata, KeyState,
};
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
use std::time::{Duration, SystemTime};
use tempfile::tempdir;
#[test]
fn test_memory_store_concurrent_reads() {
let store = Arc::new(Mutex::new(MemoryStore::new()));
let key_id = KeyId::generate_base().unwrap();
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
{
let mut store_locked = store.lock().unwrap();
store_locked
.store(VersionedKey {
key: secret_key,
metadata,
})
.unwrap();
}
let num_threads = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let store_clone = Arc::clone(&store);
let id_clone = key_id.clone();
let barrier_clone = Arc::clone(&barrier);
thread::spawn(move || {
barrier_clone.wait();
for _ in 0..100 {
let store_locked = store_clone.lock().unwrap();
let result = store_locked.retrieve(&id_clone);
drop(store_locked);
assert!(result.is_ok(), "Thread {} failed to read", i);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_memory_store_concurrent_writes() {
let store = Arc::new(Mutex::new(MemoryStore::new()));
let num_threads = 10;
let keys_per_thread = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
thread::spawn(move || {
barrier_clone.wait();
for key_num in 0..keys_per_thread {
let mut id_bytes = [0u8; 16];
id_bytes[0] = thread_id as u8;
id_bytes[1] = key_num as u8;
let key_id = KeyId::from_bytes(id_bytes);
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
let mut store_locked = store_clone.lock().unwrap();
let result = store_locked.store(VersionedKey {
key: secret_key,
metadata,
});
drop(store_locked); assert!(
result.is_ok(),
"Thread {} key {} failed to store",
thread_id,
key_num
);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let expected_count = num_threads * keys_per_thread;
let store_locked = store.lock().unwrap();
let actual_count = store_locked.list().unwrap().len();
assert_eq!(
actual_count, expected_count,
"Expected {} keys, found {}",
expected_count, actual_count
);
}
#[test]
fn test_memory_store_mixed_operations() {
let store = Arc::new(Mutex::new(MemoryStore::new()));
{
let mut store_locked = store.lock().unwrap();
for i in 0..10 {
let mut id_bytes = [0u8; 16];
id_bytes[0] = i;
let key_id = KeyId::from_bytes(id_bytes);
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
store_locked
.store(VersionedKey {
key: secret_key,
metadata,
})
.unwrap();
}
}
let num_readers = 5;
let num_writers = 5;
let barrier = Arc::new(Barrier::new(num_readers + num_writers));
let mut handles = vec![];
for i in 0..num_readers {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
let mut id_bytes = [0u8; 16];
id_bytes[0] = (i % 10) as u8;
let key_id = KeyId::from_bytes(id_bytes);
for _ in 0..50 {
let store_locked = store_clone.lock().unwrap();
let _ = store_locked.retrieve(&key_id);
drop(store_locked);
thread::sleep(Duration::from_micros(10));
}
});
handles.push(handle);
}
for i in 0..num_writers {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
let handle = thread::spawn(move || {
barrier_clone.wait();
for j in 0..50 {
let mut id_bytes = [0u8; 16];
id_bytes[0] = 100 + i as u8;
id_bytes[1] = j as u8;
let key_id = KeyId::from_bytes(id_bytes);
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
let mut store_locked = store_clone.lock().unwrap();
let _ = store_locked.store(VersionedKey {
key: secret_key,
metadata,
});
drop(store_locked);
thread::sleep(Duration::from_micros(10));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_file_store_concurrent_operations() {
let temp_dir = tempdir().unwrap();
let config = StorageConfig::default();
let store = Arc::new(std::sync::Mutex::new(
FileStore::new(temp_dir.path(), config).unwrap(),
));
let num_threads = 5;
let keys_per_thread = 5;
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
thread::spawn(move || {
barrier_clone.wait();
for key_num in 0..keys_per_thread {
let mut id_bytes = [0u8; 16];
id_bytes[0] = thread_id as u8;
id_bytes[1] = key_num as u8;
let key_id = KeyId::from_bytes(id_bytes);
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
{
let mut store = store_clone.lock().unwrap();
store
.store(VersionedKey {
key: secret_key,
metadata,
})
.unwrap();
}
{
let store = store_clone.lock().unwrap();
let retrieved = store.retrieve(&key_id).unwrap();
assert_eq!(retrieved.metadata.id, key_id);
}
thread::sleep(Duration::from_micros(100));
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let store = store.lock().unwrap();
let keys = store.list().unwrap();
assert_eq!(keys.len(), num_threads * keys_per_thread);
}
#[test]
fn test_memory_store_stress() {
let store = Arc::new(Mutex::new(MemoryStore::new()));
let num_threads = 20;
let operations_per_thread = 100;
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
thread::spawn(move || {
barrier_clone.wait();
for op in 0..operations_per_thread {
let mut id_bytes = [0u8; 16];
id_bytes[0] = thread_id as u8;
id_bytes[1] = (op % 256) as u8;
let key_id = KeyId::from_bytes(id_bytes);
match op % 4 {
0 => {
let secret_key =
SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
let mut store_locked = store_clone.lock().unwrap();
let _ = store_locked.store(VersionedKey {
key: secret_key,
metadata,
});
}
1 => {
let store_locked = store_clone.lock().unwrap();
let _ = store_locked.retrieve(&key_id);
}
2 => {
let store_locked = store_clone.lock().unwrap();
let _ = store_locked.list();
}
3 => {
let mut store_locked = store_clone.lock().unwrap();
let _ = store_locked.delete(&key_id);
}
_ => unreachable!(),
}
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let store_locked = store.lock().unwrap();
let keys = store_locked.list().unwrap();
println!("Stress test completed. Final key count: {}", keys.len());
}
#[test]
fn test_no_deadlocks() {
let store = Arc::new(Mutex::new(MemoryStore::new()));
let num_threads = 10;
let timeout = Duration::from_secs(5);
let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|thread_id| {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
thread::spawn(move || {
barrier_clone.wait();
for i in 0..20 {
let mut id_bytes = [0u8; 16];
id_bytes[0] = ((thread_id + i) % 10) as u8;
let key_id = KeyId::from_bytes(id_bytes);
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
{
let mut store_locked = store_clone.lock().unwrap();
let _ = store_locked.store(VersionedKey {
key: secret_key,
metadata,
});
}
{
let store_locked = store_clone.lock().unwrap();
let _ = store_locked.retrieve(&key_id);
}
{
let store_locked = store_clone.lock().unwrap();
let _ = store_locked.list();
}
thread::sleep(Duration::from_micros(50));
}
})
})
.collect();
let start = std::time::Instant::now();
for handle in handles {
handle.join().unwrap();
}
let elapsed = start.elapsed();
assert!(
elapsed < timeout,
"Test took too long - possible deadlock detected"
);
}
#[test]
fn test_concurrent_export_import() {
let temp_dir = tempdir().unwrap();
let config = StorageConfig::default();
let store = Arc::new(std::sync::Mutex::new(
FileStore::new(temp_dir.path(), config).unwrap(),
));
{
let mut s = store.lock().unwrap();
for i in 0..3 {
let mut id_bytes = [0u8; 16];
id_bytes[0] = i;
let key_id = KeyId::from_bytes(id_bytes);
let secret_key = SecretKey::generate(Algorithm::ChaCha20Poly1305).unwrap();
let metadata = KeyMetadata {
id: key_id.clone(),
base_id: key_id.clone(),
algorithm: Algorithm::ChaCha20Poly1305,
created_at: SystemTime::now(),
expires_at: None,
state: KeyState::Active,
version: 1,
};
s.store(VersionedKey {
key: secret_key,
metadata,
})
.unwrap();
}
}
let num_threads = 3; let barrier = Arc::new(Barrier::new(num_threads));
let handles: Vec<_> = (0..num_threads)
.map(|i| {
let store_clone = Arc::clone(&store);
let barrier_clone = Arc::clone(&barrier);
thread::spawn(move || {
barrier_clone.wait();
let mut id_bytes = [0u8; 16];
id_bytes[0] = i as u8;
let key_id = KeyId::from_bytes(id_bytes);
let mut store = store_clone.lock().unwrap();
let exported = store.export_key(&key_id, b"test-password").unwrap();
drop(store);
let json = exported.to_json().unwrap();
let _ = rust_keyvault::export::ExportedKey::from_json(&json).unwrap();
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
}