Skip to main content

drasi_core/interface/
session_control.rs

1// Copyright 2026 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18
19use super::IndexError;
20
21/// Backend-implemented lifecycle control for session-scoped transactions.
22///
23/// Implementations manage the begin/commit/rollback lifecycle for atomic
24/// writes across all indexes during a single source-change processing.
25/// Session state lives inside backend implementations, shared via `Arc`.
26#[async_trait]
27pub trait SessionControl: Send + Sync {
28    /// Begin a new session-scoped transaction.
29    async fn begin(&self) -> Result<(), IndexError>;
30
31    /// Commit the current session-scoped transaction.
32    async fn commit(&self) -> Result<(), IndexError>;
33
34    /// Roll back the current session-scoped transaction.
35    ///
36    /// This is synchronous to be safe for use in `Drop` implementations.
37    /// Returns an error if the rollback itself fails (e.g., mutex poisoned).
38    fn rollback(&self) -> Result<(), IndexError>;
39}
40
41/// No-op implementation of `SessionControl`.
42///
43/// All methods are no-ops that return `Ok(())`. Used for in-memory backends
44/// and as the default when no session control is configured.
45pub struct NoOpSessionControl;
46
47#[async_trait]
48impl SessionControl for NoOpSessionControl {
49    async fn begin(&self) -> Result<(), IndexError> {
50        Ok(())
51    }
52
53    async fn commit(&self) -> Result<(), IndexError> {
54        Ok(())
55    }
56
57    fn rollback(&self) -> Result<(), IndexError> {
58        Ok(())
59    }
60}
61
62/// RAII guard for session-scoped transactions.
63///
64/// Created via `SessionGuard::begin()`, which calls `control.begin()`.
65/// On drop, automatically calls `control.rollback()` unless `commit()`
66/// was called first.
67pub struct SessionGuard {
68    control: Arc<dyn SessionControl>,
69    committed: bool,
70}
71
72impl SessionGuard {
73    /// Begin a new session, returning an RAII guard.
74    ///
75    /// Calls `control.begin()` and returns a guard that will automatically
76    /// roll back on drop unless `commit()` is called.
77    pub async fn begin(control: Arc<dyn SessionControl>) -> Result<Self, IndexError> {
78        control.begin().await?;
79        Ok(Self {
80            control,
81            committed: false,
82        })
83    }
84
85    /// Commit the session-scoped transaction.
86    ///
87    /// Marks the guard as committed so that `Drop` will not roll back.
88    /// The committed flag is set only after a successful commit — if
89    /// `control.commit()` fails, `Drop` will still trigger rollback.
90    pub async fn commit(mut self) -> Result<(), IndexError> {
91        self.control.commit().await?;
92        self.committed = true;
93        Ok(())
94    }
95}
96
97impl Drop for SessionGuard {
98    fn drop(&mut self) {
99        if !self.committed {
100            if let Err(e) = self.control.rollback() {
101                log::error!("Session rollback failed: {e}");
102            }
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use std::sync::Mutex;
111
112    /// Mock that records every method call for assertion.
113    struct MockSessionControl {
114        calls: Mutex<Vec<&'static str>>,
115        commit_result: Mutex<Option<Result<(), IndexError>>>,
116    }
117
118    impl MockSessionControl {
119        fn new() -> Self {
120            Self {
121                calls: Mutex::new(Vec::new()),
122                commit_result: Mutex::new(None),
123            }
124        }
125
126        fn fail_commit(error: IndexError) -> Self {
127            Self {
128                calls: Mutex::new(Vec::new()),
129                commit_result: Mutex::new(Some(Err(error))),
130            }
131        }
132
133        fn calls(&self) -> Vec<&'static str> {
134            self.calls.lock().expect("lock poisoned").clone()
135        }
136    }
137
138    #[async_trait]
139    impl SessionControl for MockSessionControl {
140        async fn begin(&self) -> Result<(), IndexError> {
141            self.calls.lock().expect("lock poisoned").push("begin");
142            Ok(())
143        }
144
145        async fn commit(&self) -> Result<(), IndexError> {
146            self.calls.lock().expect("lock poisoned").push("commit");
147            match self.commit_result.lock().expect("lock poisoned").take() {
148                Some(result) => result,
149                None => Ok(()),
150            }
151        }
152
153        fn rollback(&self) -> Result<(), IndexError> {
154            self.calls.lock().expect("lock poisoned").push("rollback");
155            Ok(())
156        }
157    }
158
159    #[tokio::test]
160    async fn begin_calls_control_begin() {
161        let mock = Arc::new(MockSessionControl::new());
162        let _guard = SessionGuard::begin(mock.clone()).await.expect("begin");
163        assert_eq!(mock.calls()[0], "begin");
164    }
165
166    #[tokio::test]
167    async fn commit_suppresses_rollback_on_drop() {
168        let mock = Arc::new(MockSessionControl::new());
169        let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
170        guard.commit().await.expect("commit");
171        assert_eq!(mock.calls(), vec!["begin", "commit"]);
172    }
173
174    #[tokio::test]
175    async fn drop_without_commit_triggers_rollback() {
176        let mock = Arc::new(MockSessionControl::new());
177        let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
178        drop(guard);
179        assert_eq!(mock.calls(), vec!["begin", "rollback"]);
180    }
181
182    #[tokio::test]
183    async fn failed_commit_still_triggers_rollback() {
184        let mock = Arc::new(MockSessionControl::fail_commit(IndexError::CorruptedData));
185        let guard = SessionGuard::begin(mock.clone()).await.expect("begin");
186        let result = guard.commit().await;
187        assert!(result.is_err());
188        assert_eq!(mock.calls(), vec!["begin", "commit", "rollback"]);
189    }
190}