use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::GqlError;
use crate::proto;
#[derive(Debug, Clone)]
pub struct TransactionState {
pub session_id: String,
pub mode: proto::TransactionMode,
}
#[derive(Debug, Clone)]
pub struct TransactionManager {
transactions: Arc<RwLock<HashMap<String, TransactionState>>>,
}
impl TransactionManager {
#[must_use]
pub fn new() -> Self {
Self {
transactions: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(
&self,
transaction_id: &str,
session_id: &str,
mode: proto::TransactionMode,
) -> Result<(), GqlError> {
let mut txns = self.transactions.write().await;
let has_active = txns.values().any(|t| t.session_id == session_id);
if has_active {
return Err(GqlError::Transaction(
"session already has an active transaction".to_owned(),
));
}
txns.insert(
transaction_id.to_owned(),
TransactionState {
session_id: session_id.to_owned(),
mode,
},
);
Ok(())
}
pub async fn remove(&self, transaction_id: &str) -> Result<TransactionState, GqlError> {
let mut txns = self.transactions.write().await;
txns.remove(transaction_id)
.ok_or_else(|| GqlError::Transaction(format!("transaction {transaction_id} not found")))
}
pub async fn validate(&self, transaction_id: &str, session_id: &str) -> Result<(), GqlError> {
let txns = self.transactions.read().await;
match txns.get(transaction_id) {
Some(state) if state.session_id == session_id => Ok(()),
Some(_) => Err(GqlError::Transaction(
"transaction does not belong to this session".to_owned(),
)),
None => Err(GqlError::Transaction(format!(
"transaction {transaction_id} not found"
))),
}
}
pub async fn remove_for_session(&self, session_id: &str) -> Vec<String> {
let mut txns = self.transactions.write().await;
let to_remove: Vec<String> = txns
.iter()
.filter(|(_, state)| state.session_id == session_id)
.map(|(id, _)| id.clone())
.collect();
for id in &to_remove {
txns.remove(id);
}
to_remove
}
}
impl Default for TransactionManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn register_and_remove() {
let tm = TransactionManager::new();
tm.register("tx1", "sess1", proto::TransactionMode::ReadWrite)
.await
.unwrap();
let state = tm.remove("tx1").await.unwrap();
assert_eq!(state.session_id, "sess1");
}
#[tokio::test]
async fn double_begin_fails() {
let tm = TransactionManager::new();
tm.register("tx1", "sess1", proto::TransactionMode::ReadWrite)
.await
.unwrap();
let result = tm
.register("tx2", "sess1", proto::TransactionMode::ReadOnly)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn validate_wrong_session() {
let tm = TransactionManager::new();
tm.register("tx1", "sess1", proto::TransactionMode::ReadWrite)
.await
.unwrap();
let result = tm.validate("tx1", "sess2").await;
assert!(result.is_err());
}
#[tokio::test]
async fn remove_for_session() {
let tm = TransactionManager::new();
tm.register("tx1", "sess1", proto::TransactionMode::ReadWrite)
.await
.unwrap();
let removed = tm.remove_for_session("sess1").await;
assert_eq!(removed, vec!["tx1"]);
let result = tm.validate("tx1", "sess1").await;
assert!(result.is_err());
}
}