use crate::engine::Database;
use crate::error::DbxResult;
use crate::sql::builder::{Execute, Query, QueryOne, QueryOptional, QueryScalar};
use std::collections::HashMap;
use std::marker::PhantomData;
pub trait TxState {}
pub struct Active;
pub struct Committed;
pub struct RolledBack;
impl TxState for Active {}
impl TxState for Committed {}
impl TxState for RolledBack {}
#[derive(Debug, Clone)]
enum TxOp {
Insert(String, Vec<u8>, Vec<u8>),
Delete(String, Vec<u8>),
Batch(String, Vec<(Vec<u8>, Vec<u8>)>),
}
pub struct Transaction<'a, S: TxState> {
db: &'a Database,
ops: Vec<TxOp>,
local_buffer: HashMap<String, HashMap<Vec<u8>, Option<Vec<u8>>>>,
_state: PhantomData<S>,
}
impl Database {
pub fn begin(&self) -> DbxResult<Transaction<'_, Active>> {
Ok(Transaction {
db: self,
ops: Vec::new(),
local_buffer: HashMap::new(),
_state: PhantomData,
})
}
}
impl<'a> Transaction<'a, Active> {
pub fn query<T: crate::api::FromRow>(&self, sql: impl Into<String>) -> Query<'_, T> {
self.db.query(sql)
}
pub fn query_one<T: crate::api::FromRow>(&self, sql: impl Into<String>) -> QueryOne<'_, T> {
self.db.query_one(sql)
}
pub fn query_optional<T: crate::api::FromRow>(
&self,
sql: impl Into<String>,
) -> QueryOptional<'_, T> {
self.db.query_optional(sql)
}
pub fn query_scalar<T: crate::api::FromScalar>(
&self,
sql: impl Into<String>,
) -> QueryScalar<'_, T> {
self.db.query_scalar(sql)
}
pub fn execute(&self, sql: impl Into<String>) -> Execute<'_> {
self.db.execute(sql)
}
pub fn insert(&mut self, table: &str, key: &[u8], value: &[u8]) -> DbxResult<()> {
self.ops.push(TxOp::Insert(
table.to_string(),
key.to_vec(),
value.to_vec(),
));
self.local_buffer
.entry(table.to_string())
.or_default()
.insert(key.to_vec(), Some(value.to_vec()));
Ok(())
}
pub fn insert_batch(&mut self, table: &str, rows: Vec<(Vec<u8>, Vec<u8>)>) -> DbxResult<()> {
self.ops.push(TxOp::Batch(table.to_string(), rows.clone()));
let table_buf = self.local_buffer.entry(table.to_string()).or_default();
for (key, value) in rows {
table_buf.insert(key, Some(value));
}
Ok(())
}
pub fn delete(&mut self, table: &str, key: &[u8]) -> DbxResult<bool> {
self.ops.push(TxOp::Delete(table.to_string(), key.to_vec()));
self.local_buffer
.entry(table.to_string())
.or_default()
.insert(key.to_vec(), None);
Ok(true)
}
pub fn get(&self, table: &str, key: &[u8]) -> DbxResult<Option<Vec<u8>>> {
if let Some(table_buf) = self.local_buffer.get(table)
&& let Some(value_opt) = table_buf.get(key)
{
return Ok(value_opt.clone()); }
self.db.get(table, key)
}
pub fn pending_ops(&self) -> usize {
self.ops.len()
}
pub fn commit(self) -> DbxResult<Transaction<'a, Committed>> {
let commit_ts = self.db.allocate_commit_ts();
for op in &self.ops {
match op {
TxOp::Insert(table, key, value) => {
self.db
.insert_versioned(table, key, Some(value), commit_ts)?;
self.db.insert(table, key, value)?;
}
TxOp::Delete(table, key) => {
self.db.insert_versioned(table, key, None, commit_ts)?;
self.db.delete(table, key)?;
}
TxOp::Batch(table, rows) => {
for (key, value) in rows {
self.db
.insert_versioned(table, key, Some(value), commit_ts)?;
self.db.insert(table, key, value)?;
}
}
}
}
Ok(Transaction {
db: self.db,
ops: Vec::new(),
local_buffer: HashMap::new(),
_state: PhantomData,
})
}
pub fn rollback(self) -> DbxResult<Transaction<'a, RolledBack>> {
Ok(Transaction {
db: self.db,
ops: Vec::new(),
local_buffer: HashMap::new(),
_state: PhantomData,
})
}
}
impl<'a> Transaction<'a, Committed> {
pub fn is_committed(&self) -> bool {
true
}
}
impl<'a> Transaction<'a, RolledBack> {
pub fn is_rolled_back(&self) -> bool {
true
}
}
impl crate::traits::DatabaseTransaction for Database {
fn begin(&self) -> DbxResult<Transaction<'_, Active>> {
Database::begin(self)
}
}
#[cfg(test)]
mod tests {
use crate::engine::Database;
#[test]
fn test_begin_commit() {
let db = Database::open_in_memory().unwrap();
let mut tx = db.begin().unwrap();
tx.insert("users", b"u1", b"Alice").unwrap();
tx.insert("users", b"u2", b"Bob").unwrap();
assert_eq!(db.get("users", b"u1").unwrap(), None);
assert_eq!(tx.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
let committed = tx.commit().unwrap();
assert!(committed.is_committed());
assert_eq!(db.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
assert_eq!(db.get("users", b"u2").unwrap(), Some(b"Bob".to_vec()));
}
#[test]
fn test_begin_rollback() {
let db = Database::open_in_memory().unwrap();
let mut tx = db.begin().unwrap();
tx.insert("users", b"u1", b"Alice").unwrap();
tx.insert("users", b"u2", b"Bob").unwrap();
let rolled_back = tx.rollback().unwrap();
assert!(rolled_back.is_rolled_back());
assert_eq!(db.get("users", b"u1").unwrap(), None);
assert_eq!(db.get("users", b"u2").unwrap(), None);
}
#[test]
fn test_delete_in_transaction() {
let db = Database::open_in_memory().unwrap();
db.insert("users", b"u1", b"Alice").unwrap();
assert_eq!(db.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
let mut tx = db.begin().unwrap();
tx.delete("users", b"u1").unwrap();
assert_eq!(tx.get("users", b"u1").unwrap(), None);
assert_eq!(db.get("users", b"u1").unwrap(), Some(b"Alice".to_vec()));
tx.commit().unwrap();
assert_eq!(db.get("users", b"u1").unwrap(), None);
}
#[test]
fn test_read_your_writes() {
let db = Database::open_in_memory().unwrap();
db.insert("t", b"k1", b"old").unwrap();
let mut tx = db.begin().unwrap();
tx.insert("t", b"k1", b"new").unwrap();
assert_eq!(tx.get("t", b"k1").unwrap(), Some(b"new".to_vec()));
db.insert("t", b"k2", b"main_data").unwrap();
assert_eq!(tx.get("t", b"k2").unwrap(), Some(b"main_data".to_vec()));
tx.rollback().unwrap();
assert_eq!(db.get("t", b"k1").unwrap(), Some(b"old".to_vec()));
}
#[test]
fn test_pending_ops_count() {
let db = Database::open_in_memory().unwrap();
let mut tx = db.begin().unwrap();
assert_eq!(tx.pending_ops(), 0);
tx.insert("t", b"a", b"1").unwrap();
assert_eq!(tx.pending_ops(), 1);
tx.delete("t", b"b").unwrap();
assert_eq!(tx.pending_ops(), 2);
tx.insert("t", b"c", b"3").unwrap();
assert_eq!(tx.pending_ops(), 3);
}
#[test]
fn test_empty_transaction_commit() {
let db = Database::open_in_memory().unwrap();
let tx = db.begin().unwrap();
let committed = tx.commit().unwrap();
assert!(committed.is_committed());
}
#[test]
fn test_empty_transaction_rollback() {
let db = Database::open_in_memory().unwrap();
let tx = db.begin().unwrap();
let rolled_back = tx.rollback().unwrap();
assert!(rolled_back.is_rolled_back());
}
}