use crate::group_commit::EventDrivenGroupCommit;
use crate::ssi::SsiManager;
use crate::txn_wal::TxnWal;
use dashmap::DashMap;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use sochdb_core::{Result, SochDBError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TxnState {
Active,
Prepared,
Committed,
Aborted,
}
#[derive(Debug)]
pub struct Transaction {
pub id: u64,
pub start_ts: u64,
pub state: TxnState,
writes: Vec<(Vec<u8>, Vec<u8>)>,
reads: Vec<Vec<u8>>,
}
impl Transaction {
fn new(id: u64, start_ts: u64) -> Self {
Self {
id,
start_ts,
state: TxnState::Active,
writes: Vec::new(),
reads: Vec::new(),
}
}
pub fn write(&mut self, key: Vec<u8>, value: Vec<u8>) {
self.writes.push((key, value));
}
pub fn record_read(&mut self, key: Vec<u8>) {
self.reads.push(key);
}
pub fn writes(&self) -> &[(Vec<u8>, Vec<u8>)] {
&self.writes
}
}
#[allow(clippy::type_complexity)]
pub struct WalStorageManager {
wal: Arc<TxnWal>,
active_txns: RwLock<HashMap<u64, Transaction>>,
timestamp: AtomicU64,
apply_fn: Box<dyn Fn(&[u8], &[u8]) -> Result<()> + Send + Sync>,
}
impl WalStorageManager {
pub fn new<P: AsRef<Path>, F>(wal_path: P, apply_fn: F) -> Result<Self>
where
F: Fn(&[u8], &[u8]) -> Result<()> + Send + Sync + 'static,
{
let wal = Arc::new(TxnWal::new(wal_path)?);
Ok(Self {
wal,
active_txns: RwLock::new(HashMap::new()),
timestamp: AtomicU64::new(1),
apply_fn: Box::new(apply_fn),
})
}
pub fn begin_txn(&self) -> Result<u64> {
let txn_id = self.wal.begin_transaction()?;
let start_ts = self.timestamp.fetch_add(1, Ordering::SeqCst);
let txn = Transaction::new(txn_id, start_ts);
self.active_txns.write().insert(txn_id, txn);
Ok(txn_id)
}
pub fn write(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<()> {
let mut txns = self.active_txns.write();
let txn = txns
.get_mut(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?;
if txn.state != TxnState::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
txn.write(key, value);
Ok(())
}
pub fn write_immediate(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<()> {
{
let txns = self.active_txns.read();
let txn = txns
.get(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?;
if txn.state != TxnState::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
}
self.wal.write(txn_id, key.clone(), value.clone())?;
(self.apply_fn)(&key, &value)?;
Ok(())
}
pub fn commit(&self, txn_id: u64) -> Result<u64> {
let txn = {
let mut txns = self.active_txns.write();
txns.remove(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?
};
if txn.state != TxnState::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
for (key, value) in &txn.writes {
self.wal.write(txn_id, key.clone(), value.clone())?;
}
self.wal.commit_transaction(txn_id)?;
for (key, value) in &txn.writes {
(self.apply_fn)(key, value)?;
}
let commit_ts = self.timestamp.fetch_add(1, Ordering::SeqCst);
Ok(commit_ts)
}
pub fn abort(&self, txn_id: u64) -> Result<()> {
let mut txns = self.active_txns.write();
let txn = txns
.remove(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?;
if txn.state != TxnState::Active && txn.state != TxnState::Prepared {
return Err(SochDBError::InvalidArgument(
"Transaction cannot be aborted".into(),
));
}
self.wal.abort_transaction(txn_id)?;
Ok(())
}
pub fn recover(&self) -> Result<RecoveryStats> {
let (committed_writes, txn_count) = self.wal.replay_for_recovery()?;
for (key, value) in &committed_writes {
(self.apply_fn)(key, value)?;
}
Ok(RecoveryStats {
transactions_recovered: txn_count,
writes_applied: committed_writes.len(),
})
}
pub fn checkpoint(&self) -> Result<()> {
self.wal.write_checkpoint()?;
self.wal.truncate()?;
Ok(())
}
pub fn wal(&self) -> &Arc<TxnWal> {
&self.wal
}
pub fn current_timestamp(&self) -> u64 {
self.timestamp.load(Ordering::SeqCst)
}
}
#[derive(Debug, Clone, Default)]
pub struct RecoveryStats {
pub transactions_recovered: usize,
pub writes_applied: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IsolationLevel {
ReadCommitted,
SnapshotIsolation,
Serializable,
}
#[derive(Debug)]
pub struct MvccTransaction {
pub txn_id: u64,
pub snapshot_ts: u64,
pub status: MvccTxnStatus,
pub read_set: std::collections::HashSet<Vec<u8>>,
pub write_set: HashMap<Vec<u8>, Vec<u8>>,
pub isolation_level: IsolationLevel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MvccTxnStatus {
Active,
Committed(u64), Aborted,
}
#[derive(Debug, Clone)]
pub struct MvccVersion {
pub xmin: u64,
pub xmax: u64,
pub created_ts: u64,
pub deleted_ts: u64,
pub value: Vec<u8>,
}
impl MvccVersion {
pub fn new(xmin: u64, created_ts: u64, value: Vec<u8>) -> Self {
Self {
xmin,
xmax: 0,
created_ts,
deleted_ts: u64::MAX,
value,
}
}
pub fn mark_deleted(&mut self, xmax: u64, deleted_ts: u64) {
self.xmax = xmax;
self.deleted_ts = deleted_ts;
}
pub fn is_visible(
&self,
snapshot_ts: u64,
txn_id: u64,
committed_txns: &HashMap<u64, u64>,
) -> bool {
if self.xmin == txn_id {
return self.xmax != txn_id; }
match committed_txns.get(&self.xmin) {
Some(&commit_ts) if commit_ts < snapshot_ts => {}
_ => return false, }
if self.xmax == 0 {
return true; }
if self.xmax == txn_id {
return false; }
match committed_txns.get(&self.xmax) {
Some(&commit_ts) => commit_ts >= snapshot_ts, None => true, }
}
pub fn is_visible_dashmap(
&self,
snapshot_ts: u64,
txn_id: u64,
committed_txns: &DashMap<u64, u64>,
) -> bool {
if self.xmin == txn_id {
return self.xmax != txn_id; }
match committed_txns.get(&self.xmin) {
Some(commit_ts_ref) if *commit_ts_ref < snapshot_ts => {}
_ => return false, }
if self.xmax == 0 {
return true; }
if self.xmax == txn_id {
return false; }
match committed_txns.get(&self.xmax) {
Some(commit_ts_ref) => *commit_ts_ref >= snapshot_ts, None => true, }
}
}
#[derive(Debug, Default)]
pub struct MvccVersionChain {
versions: Vec<MvccVersion>,
}
impl MvccVersionChain {
pub fn add(&mut self, version: MvccVersion) {
self.versions.insert(0, version);
}
pub fn get_visible(
&self,
snapshot_ts: u64,
txn_id: u64,
committed: &DashMap<u64, u64>,
) -> Option<&Vec<u8>> {
for v in &self.versions {
if v.is_visible_dashmap(snapshot_ts, txn_id, committed) {
return Some(&v.value);
}
}
None
}
pub fn get_visible_legacy(
&self,
snapshot_ts: u64,
txn_id: u64,
committed: &HashMap<u64, u64>,
) -> Option<&Vec<u8>> {
for v in &self.versions {
if v.is_visible(snapshot_ts, txn_id, committed) {
return Some(&v.value);
}
}
None
}
pub fn delete(&mut self, xmax: u64, deleted_ts: u64) -> bool {
if let Some(v) = self.versions.first_mut()
&& v.xmax == 0
{
v.mark_deleted(xmax, deleted_ts);
return true;
}
false
}
pub fn gc(&mut self, min_visible_ts: u64) -> usize {
let old_len = self.versions.len();
if old_len <= 1 {
return 0;
}
self.versions.retain(|v| v.deleted_ts >= min_visible_ts);
if self.versions.is_empty() {
return old_len;
}
old_len - self.versions.len()
}
}
pub struct MvccTransactionManager {
wal: Arc<TxnWal>,
next_txn_id: AtomicU64,
timestamp: AtomicU64,
active_txns: RwLock<HashMap<u64, MvccTransaction>>,
committed_txns: DashMap<u64, u64>,
versions: DashMap<Vec<u8>, MvccVersionChain>,
ssi_manager: SsiManager,
group_commit: EventDrivenGroupCommit,
min_snapshot_ts: AtomicU64,
#[allow(clippy::type_complexity)]
apply_fn: Box<dyn Fn(&[u8], &[u8]) -> Result<()> + Send + Sync>,
}
impl MvccTransactionManager {
pub fn new<P: AsRef<Path>, F>(wal_path: P, apply_fn: F) -> Result<Self>
where
F: Fn(&[u8], &[u8]) -> Result<()> + Send + Sync + 'static,
{
let wal = Arc::new(TxnWal::new(wal_path)?);
let wal_for_gc = wal.clone();
let group_commit = EventDrivenGroupCommit::new(move |txn_ids: &[u64]| {
for &txn_id in txn_ids {
wal_for_gc
.commit_transaction(txn_id)
.map_err(|e| e.to_string())?;
}
let commit_ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_micros() as u64;
Ok(commit_ts)
});
Ok(Self {
wal,
next_txn_id: AtomicU64::new(1),
timestamp: AtomicU64::new(1),
active_txns: RwLock::new(HashMap::new()),
committed_txns: DashMap::new(),
versions: DashMap::new(),
ssi_manager: SsiManager::new(),
group_commit,
min_snapshot_ts: AtomicU64::new(u64::MAX),
apply_fn: Box::new(apply_fn),
})
}
pub fn begin(&self, isolation_level: IsolationLevel) -> Result<u64> {
let txn_id = self.next_txn_id.fetch_add(1, Ordering::SeqCst);
let snapshot_ts = self.timestamp.fetch_add(1, Ordering::SeqCst);
self.wal.begin_transaction().ok();
let txn = MvccTransaction {
txn_id,
snapshot_ts,
status: MvccTxnStatus::Active,
read_set: std::collections::HashSet::new(),
write_set: HashMap::new(),
isolation_level,
};
self.active_txns.write().insert(txn_id, txn);
self.update_min_snapshot();
if isolation_level == IsolationLevel::Serializable {
self.ssi_manager.begin().ok();
}
Ok(txn_id)
}
pub fn begin_default(&self) -> Result<u64> {
self.begin(IsolationLevel::SnapshotIsolation)
}
pub fn read(&self, txn_id: u64, key: &[u8]) -> Result<Option<Vec<u8>>> {
let mut txns = self.active_txns.write();
let txn = txns
.get_mut(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?;
if txn.status != MvccTxnStatus::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
if let Some(value) = txn.write_set.get(key) {
return Ok(Some(value.clone()));
}
txn.read_set.insert(key.to_vec());
let snapshot_ts = txn.snapshot_ts;
let isolation = txn.isolation_level;
drop(txns);
if isolation == IsolationLevel::Serializable {
self.ssi_manager
.record_read(txn_id, key)
.map_err(|e| SochDBError::Internal(format!("SSI conflict: {}", e.message)))?;
}
if let Some(chain) = self.versions.get(key) {
Ok(chain
.get_visible(snapshot_ts, txn_id, &self.committed_txns)
.cloned())
} else {
Ok(None)
}
}
pub fn write(&self, txn_id: u64, key: Vec<u8>, value: Vec<u8>) -> Result<()> {
let mut txns = self.active_txns.write();
let txn = txns
.get_mut(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?;
if txn.status != MvccTxnStatus::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
let isolation = txn.isolation_level;
if isolation == IsolationLevel::Serializable {
self.ssi_manager
.record_write(txn_id, &key)
.map_err(|e| SochDBError::Internal(format!("SSI conflict: {}", e.message)))?;
}
txn.write_set.insert(key, value);
Ok(())
}
pub fn commit(&self, txn_id: u64) -> Result<u64> {
let txn = {
let mut txns = self.active_txns.write();
txns.remove(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?
};
if txn.status != MvccTxnStatus::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
if txn.isolation_level == IsolationLevel::Serializable {
self.ssi_manager
.commit(txn_id)
.map_err(|e| SochDBError::Internal(format!("SSI conflict: {}", e.message)))?;
}
for (key, value) in &txn.write_set {
self.wal.write(txn_id, key.clone(), value.clone())?;
}
let commit_ts = self
.group_commit
.submit_and_wait(txn_id)
.map_err(|e| SochDBError::Internal(format!("Group commit error: {}", e)))?;
let apply_ts = self.timestamp.fetch_add(1, Ordering::SeqCst);
for (key, value) in &txn.write_set {
self.versions
.entry(key.clone())
.or_default()
.add(MvccVersion::new(txn_id, apply_ts, value.clone()));
}
for (key, value) in &txn.write_set {
(self.apply_fn)(key, value)?;
}
self.committed_txns.insert(txn_id, commit_ts);
self.update_min_snapshot();
Ok(commit_ts)
}
pub fn abort(&self, txn_id: u64) -> Result<()> {
let txn = {
let mut txns = self.active_txns.write();
txns.remove(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?
};
if txn.status != MvccTxnStatus::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
self.wal.abort_transaction(txn_id)?;
if txn.isolation_level == IsolationLevel::Serializable {
self.ssi_manager.abort(txn_id);
}
self.update_min_snapshot();
Ok(())
}
pub fn delete(&self, txn_id: u64, key: &[u8]) -> Result<bool> {
let txns = self.active_txns.read();
let txn = txns
.get(&txn_id)
.ok_or_else(|| SochDBError::InvalidArgument("Transaction not found".into()))?;
if txn.status != MvccTxnStatus::Active {
return Err(SochDBError::InvalidArgument(
"Transaction not active".into(),
));
}
drop(txns);
let deleted_ts = self.timestamp.fetch_add(1, Ordering::SeqCst);
if let Some(mut chain) = self.versions.get_mut(key) {
Ok(chain.delete(txn_id, deleted_ts))
} else {
Ok(false)
}
}
pub fn gc(&self) -> usize {
let min_ts = self.min_snapshot_ts.load(Ordering::SeqCst);
let mut total_gc = 0;
for mut entry in self.versions.iter_mut() {
total_gc += entry.value_mut().gc(min_ts);
}
self.committed_txns.retain(|_, ts| *ts >= min_ts);
total_gc += self.ssi_manager.gc(min_ts);
total_gc
}
fn update_min_snapshot(&self) {
let txns = self.active_txns.read();
let min = txns
.values()
.map(|t| t.snapshot_ts)
.min()
.unwrap_or(u64::MAX);
self.min_snapshot_ts.store(min, Ordering::SeqCst);
}
pub fn recover(&self) -> Result<RecoveryStats> {
let (committed_writes, txn_count) = self.wal.replay_for_recovery()?;
for (key, value) in &committed_writes {
(self.apply_fn)(key, value)?;
}
Ok(RecoveryStats {
transactions_recovered: txn_count,
writes_applied: committed_writes.len(),
})
}
pub fn current_timestamp(&self) -> u64 {
self.timestamp.load(Ordering::SeqCst)
}
pub fn active_count(&self) -> usize {
self.active_txns.read().len()
}
}
pub struct GroupCommitBuffer {
pending: RwLock<Vec<PendingCommit>>,
max_pending: usize,
max_wait_us: u64,
last_flush: AtomicU64,
arrival_rate_ema: AtomicU64,
last_arrival: AtomicU64,
fsync_latency_us: AtomicU64,
adaptive_batch_size: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct PendingCommit {
pub txn_id: u64,
pub enqueue_time_us: u64,
}
impl GroupCommitBuffer {
pub fn new(max_pending: usize, max_wait_us: u64) -> Self {
Self {
pending: RwLock::new(Vec::with_capacity(max_pending)),
max_pending,
max_wait_us,
last_flush: AtomicU64::new(0),
arrival_rate_ema: AtomicU64::new(100_000), last_arrival: AtomicU64::new(0),
fsync_latency_us: AtomicU64::new(5000), adaptive_batch_size: AtomicU64::new(10), }
}
pub fn with_fsync_latency(max_pending: usize, max_wait_us: u64, fsync_latency_us: u64) -> Self {
let buffer = Self::new(max_pending, max_wait_us);
buffer
.fsync_latency_us
.store(fsync_latency_us, Ordering::Relaxed);
buffer.recompute_batch_size();
buffer
}
fn now_us() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_micros() as u64
}
fn update_arrival_rate(&self) {
let now = Self::now_us();
let last = self.last_arrival.swap(now, Ordering::Relaxed);
if last > 0 {
let delta_us = now.saturating_sub(last);
if delta_us > 0 {
let instant_rate = 1_000_000_000 / delta_us;
let old_rate = self.arrival_rate_ema.load(Ordering::Relaxed);
let new_rate = (old_rate * 9 + instant_rate) / 10;
self.arrival_rate_ema.store(new_rate, Ordering::Relaxed);
}
}
}
fn recompute_batch_size(&self) {
let lambda = self.arrival_rate_ema.load(Ordering::Relaxed) as f64 / 1000.0; let l_fsync = self.fsync_latency_us.load(Ordering::Relaxed) as f64; let c_wait = 1.0;
let l_fsync_s = l_fsync / 1_000_000.0;
let n_opt = (2.0 * l_fsync_s * lambda / c_wait).sqrt();
let batch_size = n_opt.clamp(1.0, self.max_pending as f64) as u64;
self.adaptive_batch_size
.store(batch_size, Ordering::Relaxed);
}
pub fn add(&self, txn_id: u64) -> bool {
self.update_arrival_rate();
let now = Self::now_us();
let commit = PendingCommit {
txn_id,
enqueue_time_us: now,
};
let mut pending = self.pending.write();
pending.push(commit);
let adaptive_size = self.adaptive_batch_size.load(Ordering::Relaxed) as usize;
let target_size = adaptive_size.max(1).min(self.max_pending);
if pending.len() >= target_size {
return true;
}
let last = self.last_flush.load(Ordering::Relaxed);
if now - last > self.max_wait_us {
return true;
}
false
}
pub fn take_pending(&self) -> Vec<PendingCommit> {
let mut pending = self.pending.write();
let result = std::mem::take(&mut *pending);
let now = Self::now_us();
self.last_flush.store(now, Ordering::Relaxed);
self.recompute_batch_size();
result
}
pub fn record_fsync_latency(&self, latency_us: u64) {
let old = self.fsync_latency_us.load(Ordering::Relaxed);
let new = (old * 4 + latency_us) / 5;
self.fsync_latency_us.store(new, Ordering::Relaxed);
self.recompute_batch_size();
}
pub fn current_batch_size(&self) -> usize {
self.adaptive_batch_size.load(Ordering::Relaxed) as usize
}
pub fn current_arrival_rate(&self) -> f64 {
self.arrival_rate_ema.load(Ordering::Relaxed) as f64 / 1000.0
}
pub fn stats(&self) -> GroupCommitStats {
GroupCommitStats {
adaptive_batch_size: self.adaptive_batch_size.load(Ordering::Relaxed) as usize,
arrival_rate: self.current_arrival_rate(),
fsync_latency_us: self.fsync_latency_us.load(Ordering::Relaxed),
pending_count: self.pending.read().len(),
}
}
}
#[derive(Debug, Clone)]
pub struct GroupCommitStats {
pub adaptive_batch_size: usize,
pub arrival_rate: f64,
pub fsync_latency_us: u64,
pub pending_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use tempfile::tempdir;
#[test]
fn test_basic_transaction() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let writes = Arc::new(RwLock::new(Vec::new()));
let writes_clone = writes.clone();
let manager = WalStorageManager::new(wal_path, move |k, v| {
writes_clone.write().push((k.to_vec(), v.to_vec()));
Ok(())
})
.unwrap();
let txn_id = manager.begin_txn().unwrap();
manager
.write(txn_id, b"key1".to_vec(), b"value1".to_vec())
.unwrap();
manager
.write(txn_id, b"key2".to_vec(), b"value2".to_vec())
.unwrap();
assert!(writes.read().is_empty());
manager.commit(txn_id).unwrap();
let applied = writes.read();
assert_eq!(applied.len(), 2);
assert_eq!(applied[0], (b"key1".to_vec(), b"value1".to_vec()));
assert_eq!(applied[1], (b"key2".to_vec(), b"value2".to_vec()));
}
#[test]
fn test_abort_transaction() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let writes = Arc::new(RwLock::new(Vec::new()));
let writes_clone = writes.clone();
let manager = WalStorageManager::new(wal_path, move |k, v| {
writes_clone.write().push((k.to_vec(), v.to_vec()));
Ok(())
})
.unwrap();
let txn_id = manager.begin_txn().unwrap();
manager
.write(txn_id, b"key1".to_vec(), b"value1".to_vec())
.unwrap();
manager.abort(txn_id).unwrap();
assert!(writes.read().is_empty());
}
#[test]
fn test_immediate_write() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("test.wal");
let write_count = Arc::new(AtomicUsize::new(0));
let count_clone = write_count.clone();
let manager = WalStorageManager::new(wal_path, move |_, _| {
count_clone.fetch_add(1, Ordering::SeqCst);
Ok(())
})
.unwrap();
let txn_id = manager.begin_txn().unwrap();
manager
.write_immediate(txn_id, b"key1".to_vec(), b"value1".to_vec())
.unwrap();
assert_eq!(write_count.load(Ordering::SeqCst), 1);
manager.commit(txn_id).unwrap();
}
#[test]
fn test_group_commit_buffer() {
let buffer = GroupCommitBuffer::with_fsync_latency(10, 1000, 5000);
let _ = buffer.take_pending();
buffer.add(1);
buffer.add(2);
buffer.add(3);
let pending = buffer.take_pending();
assert_eq!(pending.len(), 3);
assert_eq!(pending[0].txn_id, 1);
assert_eq!(pending[1].txn_id, 2);
assert_eq!(pending[2].txn_id, 3);
}
#[test]
fn test_adaptive_batch_sizing() {
let buffer = GroupCommitBuffer::with_fsync_latency(100, 10000, 5000);
for i in 0..50 {
buffer.add(i);
std::thread::sleep(std::time::Duration::from_micros(100)); }
let stats = buffer.stats();
assert!(stats.adaptive_batch_size >= 1);
}
#[test]
fn test_mvcc_basic_transaction() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("mvcc_test.wal");
let writes = Arc::new(RwLock::new(Vec::new()));
let writes_clone = writes.clone();
let manager = MvccTransactionManager::new(wal_path, move |k, v| {
writes_clone.write().push((k.to_vec(), v.to_vec()));
Ok(())
})
.unwrap();
let txn_id = manager.begin_default().unwrap();
manager
.write(txn_id, b"key1".to_vec(), b"value1".to_vec())
.unwrap();
let value = manager.read(txn_id, b"key1").unwrap();
assert_eq!(value, Some(b"value1".to_vec()));
let commit_ts = manager.commit(txn_id).unwrap();
assert!(commit_ts > 0);
assert_eq!(writes.read().len(), 1);
}
#[test]
fn test_mvcc_snapshot_isolation() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("mvcc_si_test.wal");
let manager = MvccTransactionManager::new(wal_path, |_, _| Ok(())).unwrap();
let txn1 = manager.begin_default().unwrap();
manager
.write(txn1, b"key1".to_vec(), b"v1".to_vec())
.unwrap();
manager.commit(txn1).unwrap();
let txn2 = manager.begin_default().unwrap();
let txn3 = manager.begin_default().unwrap();
manager
.write(txn3, b"key1".to_vec(), b"v3".to_vec())
.unwrap();
manager.commit(txn3).unwrap();
let _value = manager.read(txn2, b"key1").unwrap();
manager.commit(txn2).unwrap();
}
#[test]
fn test_mvcc_abort() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("mvcc_abort_test.wal");
let writes = Arc::new(RwLock::new(Vec::new()));
let writes_clone = writes.clone();
let manager = MvccTransactionManager::new(wal_path, move |k, v| {
writes_clone.write().push((k.to_vec(), v.to_vec()));
Ok(())
})
.unwrap();
let txn_id = manager.begin_default().unwrap();
manager
.write(txn_id, b"key1".to_vec(), b"value1".to_vec())
.unwrap();
manager.abort(txn_id).unwrap();
assert!(writes.read().is_empty());
}
#[test]
fn test_mvcc_version_visibility() {
let mut chain = MvccVersionChain::default();
let committed: HashMap<u64, u64> = [(1, 10), (2, 20)].into_iter().collect();
chain.add(MvccVersion::new(1, 5, b"v1".to_vec()));
chain.add(MvccVersion::new(2, 15, b"v2".to_vec()));
let visible = chain.get_visible_legacy(15, 99, &committed);
assert_eq!(visible, Some(&b"v1".to_vec()));
let visible = chain.get_visible_legacy(25, 99, &committed);
assert_eq!(visible, Some(&b"v2".to_vec()));
}
#[test]
fn test_mvcc_version_gc() {
let mut chain = MvccVersionChain::default();
for i in 0..5 {
let mut version = MvccVersion::new(i, i * 10, vec![i as u8]);
if i < 4 {
version.mark_deleted(i + 1, (i + 1) * 10);
}
chain.add(version);
}
assert_eq!(chain.versions.len(), 5);
let gc_count = chain.gc(45);
assert!(chain.versions.len() < 5 || gc_count == 0);
}
#[test]
fn test_mvcc_concurrent_transactions() {
let dir = tempdir().unwrap();
let wal_path = dir.path().join("mvcc_concurrent_test.wal");
let manager = Arc::new(MvccTransactionManager::new(wal_path, |_, _| Ok(())).unwrap());
let handles: Vec<_> = (0..4)
.map(|i| {
let m = manager.clone();
std::thread::spawn(move || {
let txn = m.begin_default().unwrap();
m.write(
txn,
format!("key{}", i).into_bytes(),
format!("value{}", i).into_bytes(),
)
.unwrap();
m.commit(txn).unwrap();
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert_eq!(manager.active_count(), 0);
}
}