use aegis_common::{TransactionId, Result, AegisError};
use parking_lot::RwLock;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Preparing,
Committed,
Aborted,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
#[default]
RepeatableRead,
Serializable,
}
#[derive(Debug)]
pub struct Transaction {
pub id: TransactionId,
pub state: TransactionState,
pub isolation_level: IsolationLevel,
pub start_timestamp: u64,
pub commit_timestamp: Option<u64>,
pub snapshot: Snapshot,
pub write_set: HashSet<VersionKey>,
pub read_set: HashSet<VersionKey>,
pub locks_held: Vec<LockRequest>,
pub started_at: Instant,
}
impl Transaction {
pub fn new(
id: TransactionId,
isolation_level: IsolationLevel,
start_timestamp: u64,
active_transactions: HashSet<TransactionId>,
) -> Self {
Self {
id,
state: TransactionState::Active,
isolation_level,
start_timestamp,
commit_timestamp: None,
snapshot: Snapshot {
timestamp: start_timestamp,
active_transactions,
},
write_set: HashSet::new(),
read_set: HashSet::new(),
locks_held: Vec::new(),
started_at: Instant::now(),
}
}
pub fn is_active(&self) -> bool {
self.state == TransactionState::Active
}
pub fn duration(&self) -> Duration {
self.started_at.elapsed()
}
pub fn add_to_write_set(&mut self, key: VersionKey) {
self.write_set.insert(key);
}
pub fn add_to_read_set(&mut self, key: VersionKey) {
self.read_set.insert(key);
}
}
#[derive(Debug, Clone)]
pub struct Snapshot {
pub timestamp: u64,
pub active_transactions: HashSet<TransactionId>,
}
impl Snapshot {
pub fn is_visible(&self, version: &Version) -> bool {
match version.state {
VersionState::Committed(commit_ts) => {
commit_ts <= self.timestamp
&& !self.active_transactions.contains(&version.created_by)
}
VersionState::Active => false,
VersionState::Aborted => false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct VersionKey {
pub table_id: u32,
pub row_id: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VersionState {
Active,
Committed(u64),
Aborted,
}
#[derive(Debug, Clone)]
pub struct Version {
pub key: VersionKey,
pub created_by: TransactionId,
pub state: VersionState,
pub data: Vec<u8>,
pub prev_version: Option<Box<Version>>,
}
impl Version {
pub fn new(key: VersionKey, created_by: TransactionId, data: Vec<u8>) -> Self {
Self {
key,
created_by,
state: VersionState::Active,
data,
prev_version: None,
}
}
pub fn commit(&mut self, commit_timestamp: u64) {
self.state = VersionState::Committed(commit_timestamp);
}
pub fn abort(&mut self) {
self.state = VersionState::Aborted;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockMode {
Shared,
Exclusive,
IntentShared,
IntentExclusive,
Update,
}
impl LockMode {
pub fn is_compatible(&self, other: &LockMode) -> bool {
use LockMode::*;
matches!(
(self, other),
(Shared, Shared)
| (Shared, IntentShared)
| (IntentShared, Shared)
| (IntentShared, IntentShared)
| (IntentShared, IntentExclusive)
| (IntentExclusive, IntentShared)
| (IntentExclusive, IntentExclusive)
)
}
}
#[derive(Debug, Clone)]
pub struct LockRequest {
pub tx_id: TransactionId,
pub key: VersionKey,
pub mode: LockMode,
pub granted: bool,
}
#[derive(Debug, Default)]
struct LockEntry {
holders: Vec<LockRequest>,
waiters: Vec<LockRequest>,
}
pub struct LockManager {
locks: RwLock<HashMap<VersionKey, LockEntry>>,
timeout: Duration,
}
impl LockManager {
pub fn new(timeout: Duration) -> Self {
Self {
locks: RwLock::new(HashMap::new()),
timeout,
}
}
pub fn acquire(&self, request: LockRequest) -> Result<()> {
let start = Instant::now();
loop {
{
let mut locks = self.locks.write();
let entry = locks.entry(request.key.clone()).or_default();
let can_grant = entry
.holders
.iter()
.all(|h| h.tx_id == request.tx_id || h.mode.is_compatible(&request.mode));
if can_grant {
entry.holders.push(LockRequest {
granted: true,
..request.clone()
});
return Ok(());
}
if !entry.waiters.iter().any(|w| w.tx_id == request.tx_id) {
entry.waiters.push(request.clone());
}
}
if start.elapsed() > self.timeout {
self.release_waiter(&request);
return Err(AegisError::LockTimeout);
}
std::thread::sleep(Duration::from_millis(1));
}
}
pub fn try_acquire(&self, request: LockRequest) -> Result<bool> {
let mut locks = self.locks.write();
let entry = locks.entry(request.key.clone()).or_default();
let can_grant = entry
.holders
.iter()
.all(|h| h.tx_id == request.tx_id || h.mode.is_compatible(&request.mode));
if can_grant {
entry.holders.push(LockRequest {
granted: true,
..request
});
Ok(true)
} else {
Ok(false)
}
}
pub fn release(&self, tx_id: TransactionId, key: &VersionKey) {
let mut locks = self.locks.write();
if let Some(entry) = locks.get_mut(key) {
entry.holders.retain(|h| h.tx_id != tx_id);
while !entry.waiters.is_empty() {
let waiter = entry.waiters.remove(0);
let can_grant = entry
.holders
.iter()
.all(|h| h.mode.is_compatible(&waiter.mode));
if can_grant {
entry.holders.push(LockRequest {
granted: true,
..waiter
});
} else {
entry.waiters.insert(0, waiter);
break;
}
}
if entry.holders.is_empty() && entry.waiters.is_empty() {
locks.remove(key);
}
}
}
pub fn release_all(&self, tx_id: TransactionId) {
let mut locks = self.locks.write();
let keys: Vec<_> = locks.keys().cloned().collect();
for key in keys {
if let Some(entry) = locks.get_mut(&key) {
entry.holders.retain(|h| h.tx_id != tx_id);
entry.waiters.retain(|w| w.tx_id != tx_id);
if entry.holders.is_empty() && entry.waiters.is_empty() {
locks.remove(&key);
}
}
}
}
fn release_waiter(&self, request: &LockRequest) {
let mut locks = self.locks.write();
if let Some(entry) = locks.get_mut(&request.key) {
entry.waiters.retain(|w| w.tx_id != request.tx_id);
}
}
}
pub struct TransactionManager {
transactions: RwLock<HashMap<TransactionId, Transaction>>,
next_tx_id: AtomicU64,
next_timestamp: AtomicU64,
lock_manager: LockManager,
versions: RwLock<HashMap<VersionKey, Version>>,
}
impl TransactionManager {
pub fn new() -> Self {
Self {
transactions: RwLock::new(HashMap::new()),
next_tx_id: AtomicU64::new(1),
next_timestamp: AtomicU64::new(1),
lock_manager: LockManager::new(Duration::from_secs(30)),
versions: RwLock::new(HashMap::new()),
}
}
pub fn begin(&self, isolation_level: IsolationLevel) -> Result<TransactionId> {
let tx_id = TransactionId(self.next_tx_id.fetch_add(1, Ordering::SeqCst));
let start_ts = self.next_timestamp.fetch_add(1, Ordering::SeqCst);
let active_txs: HashSet<_> = self
.transactions
.read()
.iter()
.filter(|(_, tx)| tx.is_active())
.map(|(id, _)| *id)
.collect();
let transaction = Transaction::new(tx_id, isolation_level, start_ts, active_txs);
self.transactions.write().insert(tx_id, transaction);
Ok(tx_id)
}
pub fn commit(&self, tx_id: TransactionId) -> Result<()> {
let commit_ts = self.next_timestamp.fetch_add(1, Ordering::SeqCst);
{
let mut txs = self.transactions.write();
let tx = txs
.get_mut(&tx_id)
.ok_or_else(|| AegisError::Transaction("Transaction not found".to_string()))?;
if tx.state != TransactionState::Active {
return Err(AegisError::Transaction("Transaction not active".to_string()));
}
if tx.isolation_level == IsolationLevel::Serializable {
self.validate_serializable(tx)?;
}
tx.state = TransactionState::Preparing;
tx.commit_timestamp = Some(commit_ts);
}
{
let mut versions = self.versions.write();
let txs = self.transactions.read();
let tx = txs.get(&tx_id).unwrap();
for key in &tx.write_set {
if let Some(version) = versions.get_mut(key) {
if version.created_by == tx_id {
version.commit(commit_ts);
}
}
}
}
{
let mut txs = self.transactions.write();
if let Some(tx) = txs.get_mut(&tx_id) {
tx.state = TransactionState::Committed;
}
}
self.lock_manager.release_all(tx_id);
Ok(())
}
pub fn abort(&self, tx_id: TransactionId) -> Result<()> {
{
let mut txs = self.transactions.write();
let tx = txs
.get_mut(&tx_id)
.ok_or_else(|| AegisError::Transaction("Transaction not found".to_string()))?;
tx.state = TransactionState::Aborted;
}
{
let mut versions = self.versions.write();
let txs = self.transactions.read();
let tx = txs.get(&tx_id).unwrap();
for key in &tx.write_set {
if let Some(version) = versions.get_mut(key) {
if version.created_by == tx_id {
version.abort();
}
}
}
}
self.lock_manager.release_all(tx_id);
Ok(())
}
pub fn read(&self, tx_id: TransactionId, key: &VersionKey) -> Result<Option<Vec<u8>>> {
let txs = self.transactions.read();
let tx = txs
.get(&tx_id)
.ok_or_else(|| AegisError::Transaction("Transaction not found".to_string()))?;
if !tx.is_active() {
return Err(AegisError::Transaction("Transaction not active".to_string()));
}
let versions = self.versions.read();
if let Some(version) = versions.get(key) {
if tx.snapshot.is_visible(version) {
return Ok(Some(version.data.clone()));
}
let mut current = version.prev_version.as_ref();
while let Some(v) = current {
if tx.snapshot.is_visible(v) {
return Ok(Some(v.data.clone()));
}
current = v.prev_version.as_ref();
}
}
Ok(None)
}
pub fn write(&self, tx_id: TransactionId, key: VersionKey, data: Vec<u8>) -> Result<()> {
{
let mut txs = self.transactions.write();
let tx = txs
.get_mut(&tx_id)
.ok_or_else(|| AegisError::Transaction("Transaction not found".to_string()))?;
if !tx.is_active() {
return Err(AegisError::Transaction("Transaction not active".to_string()));
}
tx.add_to_write_set(key.clone());
}
let lock_request = LockRequest {
tx_id,
key: key.clone(),
mode: LockMode::Exclusive,
granted: false,
};
self.lock_manager.acquire(lock_request)?;
{
let txs = self.transactions.read();
let tx = txs.get(&tx_id).unwrap();
tx.locks_held.len(); }
let mut versions = self.versions.write();
let new_version = Version::new(key.clone(), tx_id, data);
if let Some(existing) = versions.remove(&key) {
let mut new_v = new_version;
new_v.prev_version = Some(Box::new(existing));
versions.insert(key, new_v);
} else {
versions.insert(key, new_version);
}
Ok(())
}
pub fn delete(&self, tx_id: TransactionId, key: &VersionKey) -> Result<()> {
self.write(tx_id, key.clone(), Vec::new())
}
pub fn stats(&self) -> TransactionStats {
let txs = self.transactions.read();
let mut active = 0;
let mut committed = 0;
let mut aborted = 0;
for tx in txs.values() {
match tx.state {
TransactionState::Active | TransactionState::Preparing => active += 1,
TransactionState::Committed => committed += 1,
TransactionState::Aborted => aborted += 1,
}
}
TransactionStats {
active,
committed,
aborted,
total: txs.len(),
}
}
fn validate_serializable(&self, tx: &Transaction) -> Result<()> {
let txs = self.transactions.read();
for other_tx in txs.values() {
if other_tx.id == tx.id {
continue;
}
if other_tx.state != TransactionState::Committed {
continue;
}
if let Some(commit_ts) = other_tx.commit_timestamp {
if commit_ts > tx.start_timestamp {
for read_key in &tx.read_set {
if other_tx.write_set.contains(read_key) {
return Err(AegisError::SerializationFailure);
}
}
}
}
}
Ok(())
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TransactionStats {
pub active: usize,
pub committed: usize,
pub aborted: usize,
pub total: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transaction_lifecycle() {
let tm = TransactionManager::new();
let tx_id = tm.begin(IsolationLevel::RepeatableRead).unwrap();
assert!(tm.transactions.read().get(&tx_id).unwrap().is_active());
tm.commit(tx_id).unwrap();
assert_eq!(
tm.transactions.read().get(&tx_id).unwrap().state,
TransactionState::Committed
);
}
#[test]
fn test_transaction_abort() {
let tm = TransactionManager::new();
let tx_id = tm.begin(IsolationLevel::RepeatableRead).unwrap();
tm.abort(tx_id).unwrap();
assert_eq!(
tm.transactions.read().get(&tx_id).unwrap().state,
TransactionState::Aborted
);
}
#[test]
fn test_mvcc_read_write() {
let tm = TransactionManager::new();
let tx1 = tm.begin(IsolationLevel::RepeatableRead).unwrap();
let key = VersionKey {
table_id: 1,
row_id: 1,
};
tm.write(tx1, key.clone(), b"hello".to_vec()).unwrap();
tm.commit(tx1).unwrap();
let tx2 = tm.begin(IsolationLevel::RepeatableRead).unwrap();
let data = tm.read(tx2, &key).unwrap();
assert_eq!(data, Some(b"hello".to_vec()));
tm.commit(tx2).unwrap();
}
#[test]
fn test_snapshot_isolation() {
let tm = TransactionManager::new();
let key = VersionKey {
table_id: 1,
row_id: 1,
};
let tx1 = tm.begin(IsolationLevel::RepeatableRead).unwrap();
tm.write(tx1, key.clone(), b"v1".to_vec()).unwrap();
tm.commit(tx1).unwrap();
let tx2 = tm.begin(IsolationLevel::RepeatableRead).unwrap();
let tx3 = tm.begin(IsolationLevel::RepeatableRead).unwrap();
tm.write(tx3, key.clone(), b"v2".to_vec()).unwrap();
tm.commit(tx3).unwrap();
let data = tm.read(tx2, &key).unwrap();
assert_eq!(data, Some(b"v1".to_vec()));
tm.commit(tx2).unwrap();
}
#[test]
fn test_lock_compatibility() {
assert!(LockMode::Shared.is_compatible(&LockMode::Shared));
assert!(!LockMode::Shared.is_compatible(&LockMode::Exclusive));
assert!(!LockMode::Exclusive.is_compatible(&LockMode::Exclusive));
assert!(!LockMode::Exclusive.is_compatible(&LockMode::Shared));
}
#[test]
fn test_lock_manager() {
let lm = LockManager::new(Duration::from_secs(1));
let key = VersionKey {
table_id: 1,
row_id: 1,
};
let req1 = LockRequest {
tx_id: TransactionId(1),
key: key.clone(),
mode: LockMode::Shared,
granted: false,
};
assert!(lm.try_acquire(req1).unwrap());
let req2 = LockRequest {
tx_id: TransactionId(2),
key: key.clone(),
mode: LockMode::Shared,
granted: false,
};
assert!(lm.try_acquire(req2).unwrap());
let req3 = LockRequest {
tx_id: TransactionId(3),
key: key.clone(),
mode: LockMode::Exclusive,
granted: false,
};
assert!(!lm.try_acquire(req3).unwrap());
lm.release(TransactionId(1), &key);
lm.release(TransactionId(2), &key);
let req4 = LockRequest {
tx_id: TransactionId(3),
key: key.clone(),
mode: LockMode::Exclusive,
granted: false,
};
assert!(lm.try_acquire(req4).unwrap());
}
}