use bytes::Bytes;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::Ordering;
use crate::db::{DBError, Result, DB};
use crate::types::SnapshotHandle;
use crate::wal::{BatchOp, Record};
#[derive(Debug, Clone)]
pub struct TransactionConflict {
pub conflicting_keys: Vec<Bytes>,
}
impl std::fmt::Display for TransactionConflict {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"transaction conflict on {} key(s)",
self.conflicting_keys.len()
)
}
}
impl std::error::Error for TransactionConflict {}
#[derive(Clone, Debug)]
enum WriteOp {
Put(Bytes),
Delete,
}
pub struct Transaction<'db> {
db: &'db DB,
start_seq: u64,
write_buffer: HashMap<Bytes, WriteOp>,
read_set: HashSet<Bytes>,
active: bool,
#[allow(dead_code)]
gc_handle: SnapshotHandle,
}
impl<'db> Transaction<'db> {
pub(crate) fn new(db: &'db DB, start_seq: u64, gc_handle: SnapshotHandle) -> Self {
Self {
db,
start_seq,
write_buffer: HashMap::new(),
read_set: HashSet::new(),
active: true,
gc_handle,
}
}
pub fn get(&mut self, key: impl AsRef<[u8]>) -> Result<Option<Bytes>> {
if !self.active {
return Err(DBError::TransactionAborted);
}
let key_bytes = Bytes::copy_from_slice(key.as_ref());
if let Some(op) = self.write_buffer.get(&key_bytes) {
return Ok(match op {
WriteOp::Put(v) => Some(v.clone()),
WriteOp::Delete => None,
});
}
self.read_set.insert(key_bytes);
self.db.get_at_seq(key.as_ref(), self.start_seq)
}
pub fn put(&mut self, key: impl AsRef<[u8]>, value: impl AsRef<[u8]>) -> Result<()> {
if !self.active {
return Err(DBError::TransactionAborted);
}
let key_bytes = Bytes::copy_from_slice(key.as_ref());
let value_bytes = Bytes::copy_from_slice(value.as_ref());
self.write_buffer
.insert(key_bytes, WriteOp::Put(value_bytes));
Ok(())
}
pub fn delete(&mut self, key: impl AsRef<[u8]>) -> Result<()> {
if !self.active {
return Err(DBError::TransactionAborted);
}
let key_bytes = Bytes::copy_from_slice(key.as_ref());
self.write_buffer.insert(key_bytes, WriteOp::Delete);
Ok(())
}
pub fn commit(mut self) -> Result<()> {
if !self.active {
return Err(DBError::TransactionAborted);
}
self.active = false;
if self.write_buffer.is_empty() && self.read_set.is_empty() {
return Ok(());
}
let wal_ops: Vec<BatchOp> = self
.write_buffer
.iter()
.map(|(k, op)| match op {
WriteOp::Put(v) => BatchOp::Put {
key: k.clone(),
value: v.clone(),
},
WriteOp::Delete => BatchOp::Delete { key: k.clone() },
})
.collect();
let _commit_guard = self.db.commit_lock.lock().expect("commit lock poisoned");
let conflicts = self.validate_read_set()?;
if !conflicts.is_empty() {
return Err(DBError::TransactionConflict(TransactionConflict {
conflicting_keys: conflicts,
}));
}
if wal_ops.is_empty() {
return Ok(());
}
let op_count = wal_ops.len() as u64;
let base_seq = self.db.next_seq.fetch_add(op_count, Ordering::SeqCst);
let batch_record = Record::Batch {
base_seq,
operations: wal_ops,
};
self.db
.pipelined_wal
.put(batch_record, |records| {
self.db.apply_wal_records(records);
})
.map_err(DBError::Wal)?;
Ok(())
}
pub fn abort(mut self) {
self.active = false;
self.write_buffer.clear();
self.read_set.clear();
}
#[must_use]
pub const fn is_active(&self) -> bool {
self.active
}
#[must_use]
pub fn write_count(&self) -> usize {
self.write_buffer.len()
}
#[must_use]
pub fn read_count(&self) -> usize {
self.read_set.len()
}
fn validate_read_set(&self) -> Result<Vec<Bytes>> {
let mut conflicts = Vec::new();
for key in &self.read_set {
if let Some(latest_seq) = self.db.get_latest_seq(key)? {
if latest_seq >= self.start_seq {
conflicts.push(key.clone());
}
}
}
Ok(conflicts)
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
self.active = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_transaction_read_write() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
db.put(b"key1", b"value1").unwrap();
let mut txn = db.begin_transaction();
assert_eq!(txn.get(b"key1").unwrap(), Some(Bytes::from("value1")));
txn.put(b"key2", b"value2").unwrap();
assert_eq!(txn.get(b"key2").unwrap(), Some(Bytes::from("value2")));
assert_eq!(db.get(b"key2").unwrap(), None);
txn.commit().unwrap();
assert_eq!(db.get(b"key2").unwrap(), Some(Bytes::from("value2")));
}
#[test]
fn test_transaction_delete() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
db.put(b"key1", b"value1").unwrap();
let mut txn = db.begin_transaction();
txn.delete(b"key1").unwrap();
assert_eq!(txn.get(b"key1").unwrap(), None);
assert_eq!(db.get(b"key1").unwrap(), Some(Bytes::from("value1")));
txn.commit().unwrap();
assert_eq!(db.get(b"key1").unwrap(), None);
}
#[test]
fn test_transaction_abort() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
let mut txn = db.begin_transaction();
txn.put(b"key1", b"value1").unwrap();
txn.abort();
assert_eq!(db.get(b"key1").unwrap(), None);
}
#[test]
fn test_transaction_conflict() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
db.put(b"balance", b"100").unwrap();
let mut txn = db.begin_transaction();
let _balance = txn.get(b"balance").unwrap();
db.put(b"balance", b"50").unwrap();
txn.put(b"balance", b"200").unwrap();
let result = txn.commit();
assert!(result.is_err());
match result {
Err(DBError::TransactionConflict(c)) => {
assert_eq!(c.conflicting_keys.len(), 1);
assert_eq!(c.conflicting_keys[0], Bytes::from("balance"));
}
_ => panic!("Expected TransactionConflict"),
}
}
#[test]
fn test_transaction_no_conflict_on_unread_keys() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
db.put(b"key1", b"value1").unwrap();
db.put(b"key2", b"value2").unwrap();
let mut txn = db.begin_transaction();
let _v = txn.get(b"key1").unwrap();
db.put(b"key2", b"new_value").unwrap();
txn.put(b"key1", b"updated").unwrap();
txn.commit().unwrap();
assert_eq!(db.get(b"key1").unwrap(), Some(Bytes::from("updated")));
}
#[test]
fn test_transaction_empty_commit() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
let txn = db.begin_transaction();
txn.commit().unwrap();
}
#[test]
fn test_transaction_write_only_no_conflict() {
let dir = tempdir().unwrap();
let db = DB::open(dir.path()).unwrap();
db.put(b"key1", b"value1").unwrap();
let mut txn = db.begin_transaction();
txn.put(b"key1", b"new_value").unwrap();
db.put(b"key1", b"concurrent").unwrap();
txn.commit().unwrap();
assert_eq!(db.get(b"key1").unwrap(), Some(Bytes::from("new_value")));
}
}