drasi_core/interface/
session_control.rs1use std::sync::Arc;
16
17use async_trait::async_trait;
18
19use super::IndexError;
20
21#[async_trait]
27pub trait SessionControl: Send + Sync {
28 async fn begin(&self) -> Result<(), IndexError>;
30
31 async fn commit(&self) -> Result<(), IndexError>;
33
34 fn rollback(&self) -> Result<(), IndexError>;
39}
40
41pub struct NoOpSessionControl;
46
47#[async_trait]
48impl SessionControl for NoOpSessionControl {
49 async fn begin(&self) -> Result<(), IndexError> {
50 Ok(())
51 }
52
53 async fn commit(&self) -> Result<(), IndexError> {
54 Ok(())
55 }
56
57 fn rollback(&self) -> Result<(), IndexError> {
58 Ok(())
59 }
60}
61
62pub struct SessionGuard {
68 control: Arc<dyn SessionControl>,
69 committed: bool,
70}
71
72impl SessionGuard {
73 pub async fn begin(control: Arc<dyn SessionControl>) -> Result<Self, IndexError> {
78 control.begin().await?;
79 Ok(Self {
80 control,
81 committed: false,
82 })
83 }
84
85 pub async fn commit(mut self) -> Result<(), IndexError> {
91 self.control.commit().await?;
92 self.committed = true;
93 Ok(())
94 }
95}
96
97impl Drop for SessionGuard {
98 fn drop(&mut self) {
99 if !self.committed {
100 if let Err(e) = self.control.rollback() {
101 log::error!("Session rollback failed: {e}");
102 }
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use std::sync::Mutex;
111
112 struct MockSessionControl {
114 calls: Mutex<Vec<&'static str>>,
115 commit_result: Mutex<Option<Result<(), IndexError>>>,
116 }
117
118 impl MockSessionControl {
119 fn new() -> Self {
120 Self {
121 calls: Mutex::new(Vec::new()),
122 commit_result: Mutex::new(None),
123 }
124 }
125
126 fn fail_commit(error: IndexError) -> Self {
127 Self {
128 calls: Mutex::new(Vec::new()),
129 commit_result: Mutex::new(Some(Err(error))),
130 }
131 }
132
133 fn calls(&self) -> Vec<&'static str> {
134 self.calls.lock().expect("lock poisoned").clone()
135 }
136 }
137
138 #[async_trait]
139 impl SessionControl for MockSessionControl {
140 async fn begin(&self) -> Result<(), IndexError> {
141 self.calls.lock().expect("lock poisoned").push("begin");
142 Ok(())
143 }
144
145 async fn commit(&self) -> Result<(), IndexError> {
146 self.calls.lock().expect("lock poisoned").push("commit");
147 match self.commit_result.lock().expect("lock poisoned").take() {
148 Some(result) => result,
149 None => Ok(()),
150 }
151 }
152
153 fn rollback(&self) -> Result<(), IndexError> {
154 self.calls.lock().expect("lock poisoned").push("rollback");
155 Ok(())
156 }
157 }
158
159 #[tokio::test]
160 async fn begin_calls_control_begin() {
161 let mock = Arc::new(MockSessionControl::new());
162 let _guard = SessionGuard::begin(mock.clone()).await.expect("begin");
163 assert_eq!(mock.calls()[0], "begin");
164 }
165
166 #[tokio::test]
167 async fn commit_suppresses_rollback_on_drop() {
168 let mock = Arc::new(MockSessionControl::new());
169 let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
170 guard.commit().await.expect("commit");
171 assert_eq!(mock.calls(), vec!["begin", "commit"]);
172 }
173
174 #[tokio::test]
175 async fn drop_without_commit_triggers_rollback() {
176 let mock = Arc::new(MockSessionControl::new());
177 let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
178 drop(guard);
179 assert_eq!(mock.calls(), vec!["begin", "rollback"]);
180 }
181
182 #[tokio::test]
183 async fn failed_commit_still_triggers_rollback() {
184 let mock = Arc::new(MockSessionControl::fail_commit(IndexError::CorruptedData));
185 let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
186 let result = guard.commit().await;
187 assert!(result.is_err());
188 assert_eq!(mock.calls(), vec!["begin", "commit", "rollback"]);
189 }
190}