use serde_json::Value;
use crate::connection::client::DatabaseClient;
use crate::error::{Result, SurqlError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransactionState {
Active,
Committed,
RolledBack,
Failed,
}
#[derive(Debug)]
pub struct Transaction<'a> {
client: &'a DatabaseClient,
statements: Vec<String>,
state: TransactionState,
}
impl<'a> Transaction<'a> {
#[allow(clippy::unused_async)]
pub async fn begin(client: &'a DatabaseClient) -> Result<Transaction<'a>> {
if !client.is_connected() {
return Err(SurqlError::Transaction {
reason: "cannot begin transaction: client is not connected".into(),
});
}
Ok(Self {
client,
statements: Vec::new(),
state: TransactionState::Active,
})
}
pub fn state(&self) -> TransactionState {
self.state
}
pub fn is_active(&self) -> bool {
self.state == TransactionState::Active
}
#[allow(clippy::unused_async)]
pub async fn execute(&mut self, surql: &str) -> Result<Value> {
if !self.is_active() {
return Err(SurqlError::Transaction {
reason: format!("transaction is not active (state = {:?})", self.state),
});
}
let trimmed = surql.trim().trim_end_matches(';').to_owned();
self.statements.push(trimmed);
Ok(Value::Null)
}
pub async fn commit(mut self) -> Result<Value> {
if !self.is_active() {
return Err(SurqlError::Transaction {
reason: format!("cannot commit in state {:?}", self.state),
});
}
let mut surql = String::from("BEGIN TRANSACTION;\n");
for stmt in &self.statements {
surql.push_str(stmt);
surql.push_str(";\n");
}
surql.push_str("COMMIT TRANSACTION;\n");
match self.client.query(&surql).await {
Ok(results) => {
self.state = TransactionState::Committed;
Ok(results)
}
Err(err) => {
self.state = TransactionState::Failed;
Err(SurqlError::Transaction {
reason: format!("commit failed: {err}"),
})
}
}
}
#[allow(clippy::unused_async)]
pub async fn rollback(mut self) -> Result<()> {
if !self.is_active() {
return Err(SurqlError::Transaction {
reason: format!("cannot rollback in state {:?}", self.state),
});
}
self.statements.clear();
self.state = TransactionState::RolledBack;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::connection::config::ConnectionConfig;
#[tokio::test]
async fn begin_requires_connected_client() {
let client = DatabaseClient::new(ConnectionConfig::default()).unwrap();
let err = Transaction::begin(&client).await.unwrap_err();
assert!(matches!(err, SurqlError::Transaction { .. }));
}
#[test]
fn state_variants_are_distinct() {
assert_ne!(TransactionState::Active, TransactionState::Committed);
assert_ne!(TransactionState::RolledBack, TransactionState::Failed);
}
}