use std::collections::HashMap;
use wasm_dbms_api::prelude::{DbmsError, DbmsResult, QueryError, TransactionId};
use super::Transaction;
#[derive(Default, Debug)]
pub struct TransactionSession {
transactions: HashMap<TransactionId, Transaction>,
owners: HashMap<TransactionId, Vec<u8>>,
next_transaction_id: TransactionId,
}
impl TransactionSession {
pub fn begin_transaction(&mut self, owner: Vec<u8>) -> TransactionId {
let transaction_id = self.next_transaction_id;
self.next_transaction_id += 1;
self.transactions
.insert(transaction_id, Transaction::default());
self.owners.insert(transaction_id, owner);
transaction_id
}
pub fn has_transaction(&self, transaction_id: &TransactionId, caller: &[u8]) -> bool {
self.owners
.get(transaction_id)
.is_some_and(|owner| owner.as_slice() == caller)
}
pub fn get_transaction(&self, transaction_id: &TransactionId) -> DbmsResult<&Transaction> {
self.transactions
.get(transaction_id)
.ok_or(DbmsError::Query(QueryError::TransactionNotFound))
}
pub fn take_transaction(&mut self, transaction_id: &TransactionId) -> DbmsResult<Transaction> {
let transaction = self
.transactions
.remove(transaction_id)
.ok_or(DbmsError::Query(QueryError::TransactionNotFound))?;
self.owners.remove(transaction_id);
Ok(transaction)
}
pub fn close_transaction(&mut self, transaction_id: &TransactionId) {
self.transactions.remove(transaction_id);
self.owners.remove(transaction_id);
}
pub fn get_transaction_mut(
&mut self,
transaction_id: &TransactionId,
) -> DbmsResult<&mut Transaction> {
self.transactions
.get_mut(transaction_id)
.ok_or(DbmsError::Query(QueryError::TransactionNotFound))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_should_begin_transaction() {
let mut session = TransactionSession::default();
let alice = vec![1, 2, 3];
let bob = vec![4, 5, 6];
let transaction_id = session.begin_transaction(alice.clone());
assert!(session.has_transaction(&transaction_id, &alice));
assert!(!session.has_transaction(&transaction_id, &bob));
let transaction = session.get_transaction_mut(&transaction_id);
assert!(transaction.is_ok());
}
#[test]
fn test_should_close_transaction() {
let mut session = TransactionSession::default();
let alice = vec![1, 2, 3];
let transaction_id = session.begin_transaction(alice.clone());
assert!(session.has_transaction(&transaction_id, &alice));
session.close_transaction(&transaction_id);
assert!(!session.has_transaction(&transaction_id, &alice));
let transaction = session.get_transaction_mut(&transaction_id);
assert!(transaction.is_err());
assert!(!session.owners.contains_key(&transaction_id));
assert!(!session.transactions.contains_key(&transaction_id));
}
#[test]
fn test_should_take_transaction() {
let mut session = TransactionSession::default();
let alice = vec![1, 2, 3];
let transaction_id = session.begin_transaction(alice.clone());
let _transaction = session
.take_transaction(&transaction_id)
.expect("failed to take tx");
assert!(!session.has_transaction(&transaction_id, &alice));
let transaction_after_take = session.get_transaction(&transaction_id);
assert!(transaction_after_take.is_err());
assert!(!session.owners.contains_key(&transaction_id));
assert!(!session.transactions.contains_key(&transaction_id));
}
#[test]
fn test_should_get_transaction() {
let mut session = TransactionSession::default();
let alice = vec![1, 2, 3];
let transaction_id = session.begin_transaction(alice);
let _tx = session
.get_transaction(&transaction_id)
.expect("failed to get tx");
}
}