1use std::sync::Arc;
2use std::time::{Duration, SystemTime};
3
4use alopex_core::async_runtime::{BoxFuture, BoxStream};
5use alopex_sql::executor::{ExecutionResult, ExecutorError, Row};
6use alopex_sql::storage::erased::ErasedAsyncSqlTransaction;
7use dashmap::DashMap;
8use futures::StreamExt;
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use uuid::Uuid;
12
13use crate::error::{Result, ServerError};
14
15#[derive(Clone, Debug, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
17pub struct SessionId(Uuid);
18
19impl SessionId {
20 pub fn new() -> Self {
21 Self(Uuid::new_v4())
22 }
23}
24
25impl Default for SessionId {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl std::fmt::Display for SessionId {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "{}", self.0)
34 }
35}
36
37impl std::str::FromStr for SessionId {
38 type Err = uuid::Error;
39
40 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
41 Ok(Self(Uuid::parse_str(s)?))
42 }
43}
44
45#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
47pub enum SessionState {
48 Idle,
49 InTransaction,
50 Committing,
51 RollingBack,
52}
53
54#[derive(Clone, Debug, serde::Serialize)]
56pub struct SessionSnapshot {
57 pub id: SessionId,
58 pub has_transaction: bool,
59 pub created_at: SystemTime,
60 pub last_active: SystemTime,
61 pub expires_at: SystemTime,
62 pub state: SessionState,
63}
64
65#[derive(Clone)]
67pub struct TxnHandle {
68 inner: Arc<TxnHandleInner>,
69}
70
71struct TxnHandleInner {
72 txn: tokio::sync::Mutex<Option<Box<dyn ErasedAsyncSqlTransaction>>>,
73 created_at: SystemTime,
74}
75
76impl TxnHandle {
77 pub fn new(txn: Box<dyn ErasedAsyncSqlTransaction>) -> Self {
78 Self {
79 inner: Arc::new(TxnHandleInner {
80 txn: tokio::sync::Mutex::new(Some(txn)),
81 created_at: SystemTime::now(),
82 }),
83 }
84 }
85
86 pub fn created_at(&self) -> SystemTime {
87 self.inner.created_at
88 }
89
90 pub fn execute<'a>(
91 &'a self,
92 sql: &'a str,
93 ) -> BoxFuture<'a, alopex_sql::executor::Result<ExecutionResult>> {
94 Box::pin(async move {
95 let mut guard = self.inner.txn.lock().await;
96 let txn = guard
97 .as_mut()
98 .ok_or_else(|| ExecutorError::InvalidOperation {
99 operation: "execute".into(),
100 reason: "transaction is closed".into(),
101 })?;
102 txn.execute(sql).await
103 })
104 }
105
106 pub fn query<'a>(&'a self, sql: &'a str) -> BoxStream<'a, alopex_sql::executor::Result<Row>> {
107 let (sender, receiver) = mpsc::channel(32);
108 let sql = sql.to_string();
109 let inner = Arc::clone(&self.inner);
110
111 tokio::spawn(async move {
112 let guard = inner.txn.lock().await;
113 let Some(txn) = guard.as_ref() else {
114 let _ = sender
115 .send(Err(ExecutorError::InvalidOperation {
116 operation: "query".into(),
117 reason: "transaction is closed".into(),
118 }))
119 .await;
120 return;
121 };
122 let mut stream = txn.query(&sql);
123 while let Some(item) = stream.next().await {
124 if sender.send(item).await.is_err() {
125 break;
126 }
127 }
128 });
129
130 Box::pin(ReceiverStream::new(receiver))
131 }
132
133 pub async fn commit(self) -> alopex_sql::executor::Result<()> {
134 let mut guard = self.inner.txn.lock().await;
135 let txn = guard
136 .take()
137 .ok_or_else(|| ExecutorError::InvalidOperation {
138 operation: "commit".into(),
139 reason: "transaction is closed".into(),
140 })?;
141 txn.commit_boxed().await
142 }
143
144 pub async fn rollback(self) -> alopex_sql::executor::Result<()> {
145 let mut guard = self.inner.txn.lock().await;
146 let txn = guard
147 .take()
148 .ok_or_else(|| ExecutorError::InvalidOperation {
149 operation: "rollback".into(),
150 reason: "transaction is closed".into(),
151 })?;
152 txn.rollback_boxed().await
153 }
154}
155
156#[derive(Clone, Copy, Debug)]
158pub struct SessionConfig {
159 pub ttl: Duration,
160}
161
162pub type TransactionFactory =
164 Arc<dyn Fn() -> BoxFuture<'static, Result<Box<dyn ErasedAsyncSqlTransaction>>> + Send + Sync>;
165
166pub struct SessionManager {
168 sessions: DashMap<SessionId, Session>,
169 config: SessionConfig,
170 txn_factory: TransactionFactory,
171}
172
173struct Session {
174 id: SessionId,
175 txn_handle: Option<TxnHandle>,
176 created_at: SystemTime,
177 last_active: SystemTime,
178 expires_at: SystemTime,
179 state: SessionState,
180}
181
182impl SessionManager {
183 pub fn new(config: SessionConfig, txn_factory: TransactionFactory) -> Self {
184 Self {
185 sessions: DashMap::new(),
186 config,
187 txn_factory,
188 }
189 }
190
191 pub async fn create_session(&self) -> Result<SessionId> {
192 let now = SystemTime::now();
193 let id = SessionId::new();
194 let session = Session {
195 id: id.clone(),
196 txn_handle: None,
197 created_at: now,
198 last_active: now,
199 expires_at: now + self.config.ttl,
200 state: SessionState::Idle,
201 };
202 self.sessions.insert(id.clone(), session);
203 Ok(id)
204 }
205
206 pub async fn get_session(&self, id: &SessionId) -> Result<SessionSnapshot> {
207 let entry = self
208 .sessions
209 .get(id)
210 .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
211 if entry.expires_at <= SystemTime::now() {
212 drop(entry);
213 self.sessions.remove(id);
214 return Err(ServerError::SessionExpired("session expired".into()));
215 }
216 Ok(SessionSnapshot {
217 id: entry.id.clone(),
218 has_transaction: entry.txn_handle.is_some(),
219 created_at: entry.created_at,
220 last_active: entry.last_active,
221 expires_at: entry.expires_at,
222 state: entry.state,
223 })
224 }
225
226 pub async fn begin_transaction(&self, id: &SessionId) -> Result<TxnHandle> {
227 let mut entry = self
228 .sessions
229 .get_mut(id)
230 .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
231 if entry.expires_at <= SystemTime::now() {
232 drop(entry);
233 self.sessions.remove(id);
234 return Err(ServerError::SessionExpired("session expired".into()));
235 }
236 if entry.txn_handle.is_some() {
237 return Err(ServerError::Conflict("transaction already active".into()));
238 }
239 let txn = (self.txn_factory)().await?;
240 let handle = TxnHandle::new(txn);
241 entry.txn_handle = Some(handle.clone());
242 entry.last_active = SystemTime::now();
243 entry.state = SessionState::InTransaction;
244 Ok(handle)
245 }
246
247 pub async fn get_transaction(&self, id: &SessionId) -> Result<TxnHandle> {
248 let mut entry = self
249 .sessions
250 .get_mut(id)
251 .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
252 if entry.expires_at <= SystemTime::now() {
253 drop(entry);
254 self.sessions.remove(id);
255 return Err(ServerError::SessionExpired("session expired".into()));
256 }
257 let handle = entry
258 .txn_handle
259 .clone()
260 .ok_or_else(|| ServerError::BadRequest("transaction not started".into()))?;
261 entry.last_active = SystemTime::now();
262 entry.state = SessionState::InTransaction;
263 Ok(handle)
264 }
265
266 pub async fn execute_in_session(&self, id: &SessionId, sql: &str) -> Result<ExecutionResult> {
267 let handle = {
268 let mut entry = self
269 .sessions
270 .get_mut(id)
271 .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
272 if entry.expires_at <= SystemTime::now() {
273 drop(entry);
274 self.sessions.remove(id);
275 return Err(ServerError::SessionExpired("session expired".into()));
276 }
277 let handle = entry
278 .txn_handle
279 .clone()
280 .ok_or_else(|| ServerError::BadRequest("transaction not started".into()))?;
281 entry.last_active = SystemTime::now();
282 handle
283 };
284
285 handle
286 .execute(sql)
287 .await
288 .map_err(|err| ServerError::Sql(err.into()))
289 }
290
291 pub async fn commit(&self, id: &SessionId) -> Result<()> {
292 let handle = self.take_handle(id, SessionState::Committing)?;
293 handle
294 .commit()
295 .await
296 .map_err(|err| ServerError::Sql(err.into()))?;
297 Ok(())
298 }
299
300 pub async fn rollback(&self, id: &SessionId) -> Result<()> {
301 let handle = self.take_handle(id, SessionState::RollingBack)?;
302 handle
303 .rollback()
304 .await
305 .map_err(|err| ServerError::Sql(err.into()))?;
306 Ok(())
307 }
308
309 pub fn cleanup_expired(&self) {
310 let now = SystemTime::now();
311 let expired: Vec<SessionId> = self
312 .sessions
313 .iter()
314 .filter(|entry| entry.expires_at <= now)
315 .map(|entry| entry.id.clone())
316 .collect();
317 for id in expired {
318 self.sessions.remove(&id);
319 }
320 }
321
322 fn take_handle(&self, id: &SessionId, state: SessionState) -> Result<TxnHandle> {
323 let mut entry = self
324 .sessions
325 .get_mut(id)
326 .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
327 if entry.expires_at <= SystemTime::now() {
328 drop(entry);
329 self.sessions.remove(id);
330 return Err(ServerError::SessionExpired("session expired".into()));
331 }
332 let handle = entry
333 .txn_handle
334 .take()
335 .ok_or_else(|| ServerError::BadRequest("transaction not started".into()))?;
336 entry.state = state;
337 entry.last_active = SystemTime::now();
338 Ok(handle)
339 }
340}