use bytes::Bytes;
use seerdb::{DBError, DB};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
use tempfile::TempDir;
#[test]
fn test_concurrent_transaction_conflicts() {
let temp_dir = TempDir::new().unwrap();
let db = Arc::new(DB::open(temp_dir.path()).unwrap());
db.put(b"counter", b"0").unwrap();
let num_threads = 10;
let attempts_per_thread = 50;
let barrier = Arc::new(Barrier::new(num_threads));
let successful_commits = Arc::new(AtomicUsize::new(0));
let conflict_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for thread_id in 0..num_threads {
let db = Arc::clone(&db);
let barrier = Arc::clone(&barrier);
let successful_commits = Arc::clone(&successful_commits);
let conflict_count = Arc::clone(&conflict_count);
let handle = thread::spawn(move || {
barrier.wait();
for _ in 0..attempts_per_thread {
let mut txn = db.begin_transaction();
let current = txn.get(b"counter").unwrap();
let value: i32 = current
.map(|b| String::from_utf8_lossy(&b).parse().unwrap_or(0))
.unwrap_or(0);
let new_value = (value + 1).to_string();
txn.put(b"counter", new_value.as_bytes()).unwrap();
match txn.commit() {
Ok(()) => {
successful_commits.fetch_add(1, Ordering::SeqCst);
}
Err(DBError::TransactionConflict(_)) => {
conflict_count.fetch_add(1, Ordering::SeqCst);
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
println!(
"Thread {} finished: {} successful, {} conflicts",
thread_id,
successful_commits.load(Ordering::SeqCst),
conflict_count.load(Ordering::SeqCst)
);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let total_successful = successful_commits.load(Ordering::SeqCst);
let total_conflicts = conflict_count.load(Ordering::SeqCst);
let total_attempts = num_threads * attempts_per_thread;
println!(
"Total: {} successful, {} conflicts out of {} attempts",
total_successful, total_conflicts, total_attempts
);
assert_eq!(total_successful + total_conflicts, total_attempts);
let final_value: i32 = db
.get(b"counter")
.unwrap()
.map(|b| String::from_utf8_lossy(&b).parse().unwrap())
.unwrap();
assert_eq!(
final_value, total_successful as i32,
"Counter should equal successful commits"
);
assert!(
total_conflicts > 0,
"Expected conflicts with concurrent transactions"
);
println!(
"PASS: Counter={}, Successful={}, Conflicts={}",
final_value, total_successful, total_conflicts
);
}
#[test]
fn test_concurrent_transactions_no_false_conflicts() {
let temp_dir = TempDir::new().unwrap();
let db = Arc::new(DB::open(temp_dir.path()).unwrap());
let num_threads = 10;
let ops_per_thread = 100;
let barrier = Arc::new(Barrier::new(num_threads));
let conflict_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for thread_id in 0..num_threads {
let db = Arc::clone(&db);
let barrier = Arc::clone(&barrier);
let conflict_count = Arc::clone(&conflict_count);
let handle = thread::spawn(move || {
barrier.wait();
for i in 0..ops_per_thread {
let key = format!("thread_{}_key_{}", thread_id, i);
let mut txn = db.begin_transaction();
txn.put(key.as_bytes(), b"value").unwrap();
match txn.commit() {
Ok(()) => {}
Err(DBError::TransactionConflict(_)) => {
conflict_count.fetch_add(1, Ordering::SeqCst);
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let conflicts = conflict_count.load(Ordering::SeqCst);
assert_eq!(
conflicts, 0,
"No conflicts expected when threads use different keys"
);
for thread_id in 0..num_threads {
for i in 0..ops_per_thread {
let key = format!("thread_{}_key_{}", thread_id, i);
assert!(db.get(key.as_bytes()).unwrap().is_some());
}
}
println!(
"PASS: {} threads x {} ops = {} total writes, 0 conflicts",
num_threads,
ops_per_thread,
num_threads * ops_per_thread
);
}
#[test]
fn test_transaction_crash_recovery() {
let temp_dir = TempDir::new().unwrap();
let data_dir = temp_dir.path().to_path_buf();
{
let db = DB::open(&data_dir).unwrap();
let mut txn = db.begin_transaction();
txn.put(b"txn_key_1", b"txn_value_1").unwrap();
txn.put(b"txn_key_2", b"txn_value_2").unwrap();
txn.put(b"txn_key_3", b"txn_value_3").unwrap();
txn.commit().unwrap();
drop(db);
}
{
let db = DB::open(&data_dir).unwrap();
assert_eq!(
db.get(b"txn_key_1").unwrap(),
Some(Bytes::from("txn_value_1")),
"txn_key_1 should survive crash"
);
assert_eq!(
db.get(b"txn_key_2").unwrap(),
Some(Bytes::from("txn_value_2")),
"txn_key_2 should survive crash"
);
assert_eq!(
db.get(b"txn_key_3").unwrap(),
Some(Bytes::from("txn_value_3")),
"txn_key_3 should survive crash"
);
println!("PASS: Transaction data recovered after crash");
}
}
#[test]
fn test_uncommitted_transaction_not_recovered() {
let temp_dir = TempDir::new().unwrap();
let data_dir = temp_dir.path().to_path_buf();
{
let db = DB::open(&data_dir).unwrap();
db.put(b"committed_key", b"committed_value").unwrap();
let mut txn = db.begin_transaction();
txn.put(b"uncommitted_key", b"uncommitted_value").unwrap();
drop(txn);
drop(db);
}
{
let db = DB::open(&data_dir).unwrap();
assert_eq!(
db.get(b"committed_key").unwrap(),
Some(Bytes::from("committed_value")),
"Committed data should survive"
);
assert_eq!(
db.get(b"uncommitted_key").unwrap(),
None,
"Uncommitted transaction data should NOT survive crash"
);
println!("PASS: Uncommitted transaction data correctly lost after crash");
}
}
#[test]
fn test_transaction_snapshot_isolation() {
let temp_dir = TempDir::new().unwrap();
let db = Arc::new(DB::open(temp_dir.path()).unwrap());
db.put(b"key1", b"initial1").unwrap();
db.put(b"key2", b"initial2").unwrap();
let mut txn = db.begin_transaction();
assert_eq!(txn.get(b"key1").unwrap(), Some(Bytes::from("initial1")));
assert_eq!(txn.get(b"key2").unwrap(), Some(Bytes::from("initial2")));
db.put(b"key1", b"modified1").unwrap();
db.put(b"key3", b"new_key").unwrap();
txn.put(b"key2", b"txn_modified").unwrap();
let result = txn.commit();
assert!(
matches!(result, Err(DBError::TransactionConflict(_))),
"Expected conflict on key1 which was read then modified externally"
);
assert_eq!(db.get(b"key1").unwrap(), Some(Bytes::from("modified1")));
assert_eq!(db.get(b"key3").unwrap(), Some(Bytes::from("new_key")));
assert_eq!(db.get(b"key2").unwrap(), Some(Bytes::from("initial2")));
println!("PASS: Transaction correctly detected conflict from concurrent write");
}
#[test]
fn test_transaction_and_snapshot_coexist() {
let temp_dir = TempDir::new().unwrap();
let db = DB::open(temp_dir.path()).unwrap();
db.put(b"shared_key", b"v1").unwrap();
let snapshot = db.snapshot().unwrap();
let mut txn = db.begin_transaction();
db.put(b"shared_key", b"v2").unwrap();
assert_eq!(
snapshot.get(b"shared_key").unwrap(),
Some(Bytes::from("v1"))
);
let txn_value = txn.get(b"shared_key").unwrap();
println!("Transaction sees: {:?}", txn_value);
txn.put(b"other_key", b"other_value").unwrap();
let result = txn.commit();
assert!(
matches!(result, Err(DBError::TransactionConflict(_))),
"Expected conflict because shared_key was modified after txn start"
);
println!("PASS: Transaction and snapshot coexist correctly");
}
#[test]
fn test_write_only_transactions_no_conflict() {
let temp_dir = TempDir::new().unwrap();
let db = Arc::new(DB::open(temp_dir.path()).unwrap());
db.put(b"key", b"initial").unwrap();
let num_threads = 10;
let barrier = Arc::new(Barrier::new(num_threads));
let successful = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for thread_id in 0..num_threads {
let db = Arc::clone(&db);
let barrier = Arc::clone(&barrier);
let successful = Arc::clone(&successful);
let handle = thread::spawn(move || {
barrier.wait();
let mut txn = db.begin_transaction();
let value = format!("thread_{}", thread_id);
txn.put(b"key", value.as_bytes()).unwrap();
match txn.commit() {
Ok(()) => {
successful.fetch_add(1, Ordering::SeqCst);
}
Err(e) => panic!("Write-only txn should not fail: {:?}", e),
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
assert_eq!(successful.load(Ordering::SeqCst), num_threads);
println!(
"PASS: {} write-only transactions all committed (last-writer-wins)",
num_threads
);
}
#[test]
fn test_large_transaction_many_keys() {
let temp_dir = TempDir::new().unwrap();
let db = DB::open(temp_dir.path()).unwrap();
let num_keys = 10_000;
for i in 0..num_keys {
let key = format!("key_{:06}", i);
let value = format!("value_{:06}", i);
db.put(key.as_bytes(), value.as_bytes()).unwrap();
}
let mut txn = db.begin_transaction();
for i in 0..num_keys {
let key = format!("key_{:06}", i);
let value = txn.get(key.as_bytes()).unwrap();
assert!(value.is_some(), "Key {} should exist", key);
}
assert_eq!(txn.read_count(), num_keys);
for i in 0..100 {
let key = format!("key_{:06}", i);
let value = format!("modified_{:06}", i);
txn.put(key.as_bytes(), value.as_bytes()).unwrap();
}
assert_eq!(txn.write_count(), 100);
txn.commit().unwrap();
for i in 0..100 {
let key = format!("key_{:06}", i);
let expected = format!("modified_{:06}", i);
assert_eq!(db.get(key.as_bytes()).unwrap(), Some(Bytes::from(expected)));
}
println!(
"PASS: Large transaction with {} reads and 100 writes committed successfully",
num_keys
);
}
#[test]
fn test_partial_key_overlap_conflict() {
let temp_dir = TempDir::new().unwrap();
let db = DB::open(temp_dir.path()).unwrap();
db.put(b"key_a", b"a").unwrap();
db.put(b"key_b", b"b").unwrap();
db.put(b"key_c", b"c").unwrap();
let mut txn = db.begin_transaction();
txn.get(b"key_a").unwrap();
txn.get(b"key_b").unwrap();
db.put(b"key_b", b"b_modified").unwrap();
txn.put(b"key_c", b"c_from_txn").unwrap();
let result = txn.commit();
assert!(
matches!(result, Err(DBError::TransactionConflict(ref c)) if c.conflicting_keys.len() == 1),
"Expected exactly one conflict on key_b"
);
if let Err(DBError::TransactionConflict(c)) = result {
assert_eq!(c.conflicting_keys[0], Bytes::from("key_b"));
println!("PASS: Detected conflict on key_b as expected");
}
}