use std::sync::Arc;
use async_trait::async_trait;
use super::IndexError;
#[async_trait]
pub trait SessionControl: Send + Sync {
async fn begin(&self) -> Result<(), IndexError>;
async fn commit(&self) -> Result<(), IndexError>;
fn rollback(&self) -> Result<(), IndexError>;
}
pub struct NoOpSessionControl;
#[async_trait]
impl SessionControl for NoOpSessionControl {
async fn begin(&self) -> Result<(), IndexError> {
Ok(())
}
async fn commit(&self) -> Result<(), IndexError> {
Ok(())
}
fn rollback(&self) -> Result<(), IndexError> {
Ok(())
}
}
pub struct SessionGuard {
control: Arc<dyn SessionControl>,
committed: bool,
}
impl SessionGuard {
pub async fn begin(control: Arc<dyn SessionControl>) -> Result<Self, IndexError> {
control.begin().await?;
Ok(Self {
control,
committed: false,
})
}
pub async fn commit(mut self) -> Result<(), IndexError> {
self.control.commit().await?;
self.committed = true;
Ok(())
}
}
impl Drop for SessionGuard {
fn drop(&mut self) {
if !self.committed {
if let Err(e) = self.control.rollback() {
log::error!("Session rollback failed: {e}");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
struct MockSessionControl {
calls: Mutex<Vec<&'static str>>,
commit_result: Mutex<Option<Result<(), IndexError>>>,
}
impl MockSessionControl {
fn new() -> Self {
Self {
calls: Mutex::new(Vec::new()),
commit_result: Mutex::new(None),
}
}
fn fail_commit(error: IndexError) -> Self {
Self {
calls: Mutex::new(Vec::new()),
commit_result: Mutex::new(Some(Err(error))),
}
}
fn calls(&self) -> Vec<&'static str> {
self.calls.lock().expect("lock poisoned").clone()
}
}
#[async_trait]
impl SessionControl for MockSessionControl {
async fn begin(&self) -> Result<(), IndexError> {
self.calls.lock().expect("lock poisoned").push("begin");
Ok(())
}
async fn commit(&self) -> Result<(), IndexError> {
self.calls.lock().expect("lock poisoned").push("commit");
match self.commit_result.lock().expect("lock poisoned").take() {
Some(result) => result,
None => Ok(()),
}
}
fn rollback(&self) -> Result<(), IndexError> {
self.calls.lock().expect("lock poisoned").push("rollback");
Ok(())
}
}
#[tokio::test]
async fn begin_calls_control_begin() {
let mock = Arc::new(MockSessionControl::new());
let _guard = SessionGuard::begin(mock.clone()).await.expect("begin");
assert_eq!(mock.calls()[0], "begin");
}
#[tokio::test]
async fn commit_suppresses_rollback_on_drop() {
let mock = Arc::new(MockSessionControl::new());
let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
guard.commit().await.expect("commit");
assert_eq!(mock.calls(), vec!["begin", "commit"]);
}
#[tokio::test]
async fn drop_without_commit_triggers_rollback() {
let mock = Arc::new(MockSessionControl::new());
let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
drop(guard);
assert_eq!(mock.calls(), vec!["begin", "rollback"]);
}
#[tokio::test]
async fn failed_commit_still_triggers_rollback() {
let mock = Arc::new(MockSessionControl::fail_commit(IndexError::CorruptedData));
let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
let result = guard.commit().await;
assert!(result.is_err());
assert_eq!(mock.calls(), vec!["begin", "commit", "rollback"]);
}
}