Skip to main content

alopex_server/
session.rs

1use std::sync::Arc;
2use std::time::{Duration, SystemTime};
3
4use alopex_core::async_runtime::{BoxFuture, BoxStream};
5use alopex_sql::executor::{ExecutionResult, ExecutorError, Row};
6use alopex_sql::storage::erased::ErasedAsyncSqlTransaction;
7use dashmap::DashMap;
8use futures::StreamExt;
9use tokio::sync::mpsc;
10use tokio_stream::wrappers::ReceiverStream;
11use uuid::Uuid;
12
13use crate::error::{Result, ServerError};
14
15/// Session identifier.
16#[derive(Clone, Debug, Eq, Hash, PartialEq, serde::Serialize, serde::Deserialize)]
17pub struct SessionId(Uuid);
18
19impl SessionId {
20    pub fn new() -> Self {
21        Self(Uuid::new_v4())
22    }
23}
24
25impl Default for SessionId {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl std::fmt::Display for SessionId {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "{}", self.0)
34    }
35}
36
37impl std::str::FromStr for SessionId {
38    type Err = uuid::Error;
39
40    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
41        Ok(Self(Uuid::parse_str(s)?))
42    }
43}
44
45/// Session lifecycle state.
46#[derive(Clone, Copy, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
47pub enum SessionState {
48    Idle,
49    InTransaction,
50    Committing,
51    RollingBack,
52}
53
54/// Snapshot of a session for safe sharing.
55#[derive(Clone, Debug, serde::Serialize)]
56pub struct SessionSnapshot {
57    pub id: SessionId,
58    pub has_transaction: bool,
59    pub created_at: SystemTime,
60    pub last_active: SystemTime,
61    pub expires_at: SystemTime,
62    pub state: SessionState,
63}
64
65/// Transaction handle for a session.
66#[derive(Clone)]
67pub struct TxnHandle {
68    inner: Arc<TxnHandleInner>,
69}
70
71struct TxnHandleInner {
72    txn: tokio::sync::Mutex<Option<Box<dyn ErasedAsyncSqlTransaction>>>,
73    created_at: SystemTime,
74}
75
76impl TxnHandle {
77    pub fn new(txn: Box<dyn ErasedAsyncSqlTransaction>) -> Self {
78        Self {
79            inner: Arc::new(TxnHandleInner {
80                txn: tokio::sync::Mutex::new(Some(txn)),
81                created_at: SystemTime::now(),
82            }),
83        }
84    }
85
86    pub fn created_at(&self) -> SystemTime {
87        self.inner.created_at
88    }
89
90    pub fn execute<'a>(
91        &'a self,
92        sql: &'a str,
93    ) -> BoxFuture<'a, alopex_sql::executor::Result<ExecutionResult>> {
94        Box::pin(async move {
95            let mut guard = self.inner.txn.lock().await;
96            let txn = guard
97                .as_mut()
98                .ok_or_else(|| ExecutorError::InvalidOperation {
99                    operation: "execute".into(),
100                    reason: "transaction is closed".into(),
101                })?;
102            txn.execute(sql).await
103        })
104    }
105
106    pub fn query<'a>(&'a self, sql: &'a str) -> BoxStream<'a, alopex_sql::executor::Result<Row>> {
107        let (sender, receiver) = mpsc::channel(32);
108        let sql = sql.to_string();
109        let inner = Arc::clone(&self.inner);
110
111        tokio::spawn(async move {
112            let guard = inner.txn.lock().await;
113            let Some(txn) = guard.as_ref() else {
114                let _ = sender
115                    .send(Err(ExecutorError::InvalidOperation {
116                        operation: "query".into(),
117                        reason: "transaction is closed".into(),
118                    }))
119                    .await;
120                return;
121            };
122            let mut stream = txn.query(&sql);
123            while let Some(item) = stream.next().await {
124                if sender.send(item).await.is_err() {
125                    break;
126                }
127            }
128        });
129
130        Box::pin(ReceiverStream::new(receiver))
131    }
132
133    pub async fn commit(self) -> alopex_sql::executor::Result<()> {
134        let mut guard = self.inner.txn.lock().await;
135        let txn = guard
136            .take()
137            .ok_or_else(|| ExecutorError::InvalidOperation {
138                operation: "commit".into(),
139                reason: "transaction is closed".into(),
140            })?;
141        txn.commit_boxed().await
142    }
143
144    pub async fn rollback(self) -> alopex_sql::executor::Result<()> {
145        let mut guard = self.inner.txn.lock().await;
146        let txn = guard
147            .take()
148            .ok_or_else(|| ExecutorError::InvalidOperation {
149                operation: "rollback".into(),
150                reason: "transaction is closed".into(),
151            })?;
152        txn.rollback_boxed().await
153    }
154}
155
156/// Session configuration.
157#[derive(Clone, Copy, Debug)]
158pub struct SessionConfig {
159    pub ttl: Duration,
160}
161
162/// Transaction factory for session manager.
163pub type TransactionFactory =
164    Arc<dyn Fn() -> BoxFuture<'static, Result<Box<dyn ErasedAsyncSqlTransaction>>> + Send + Sync>;
165
166/// Session manager for server.
167pub struct SessionManager {
168    sessions: DashMap<SessionId, Session>,
169    config: SessionConfig,
170    txn_factory: TransactionFactory,
171}
172
173struct Session {
174    id: SessionId,
175    txn_handle: Option<TxnHandle>,
176    created_at: SystemTime,
177    last_active: SystemTime,
178    expires_at: SystemTime,
179    state: SessionState,
180}
181
182impl SessionManager {
183    pub fn new(config: SessionConfig, txn_factory: TransactionFactory) -> Self {
184        Self {
185            sessions: DashMap::new(),
186            config,
187            txn_factory,
188        }
189    }
190
191    pub async fn create_session(&self) -> Result<SessionId> {
192        let now = SystemTime::now();
193        let id = SessionId::new();
194        let session = Session {
195            id: id.clone(),
196            txn_handle: None,
197            created_at: now,
198            last_active: now,
199            expires_at: now + self.config.ttl,
200            state: SessionState::Idle,
201        };
202        self.sessions.insert(id.clone(), session);
203        Ok(id)
204    }
205
206    pub async fn get_session(&self, id: &SessionId) -> Result<SessionSnapshot> {
207        let entry = self
208            .sessions
209            .get(id)
210            .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
211        if entry.expires_at <= SystemTime::now() {
212            drop(entry);
213            self.sessions.remove(id);
214            return Err(ServerError::SessionExpired("session expired".into()));
215        }
216        Ok(SessionSnapshot {
217            id: entry.id.clone(),
218            has_transaction: entry.txn_handle.is_some(),
219            created_at: entry.created_at,
220            last_active: entry.last_active,
221            expires_at: entry.expires_at,
222            state: entry.state,
223        })
224    }
225
226    pub async fn begin_transaction(&self, id: &SessionId) -> Result<TxnHandle> {
227        let mut entry = self
228            .sessions
229            .get_mut(id)
230            .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
231        if entry.expires_at <= SystemTime::now() {
232            drop(entry);
233            self.sessions.remove(id);
234            return Err(ServerError::SessionExpired("session expired".into()));
235        }
236        if entry.txn_handle.is_some() {
237            return Err(ServerError::Conflict("transaction already active".into()));
238        }
239        let txn = (self.txn_factory)().await?;
240        let handle = TxnHandle::new(txn);
241        entry.txn_handle = Some(handle.clone());
242        entry.last_active = SystemTime::now();
243        entry.state = SessionState::InTransaction;
244        Ok(handle)
245    }
246
247    pub async fn get_transaction(&self, id: &SessionId) -> Result<TxnHandle> {
248        let mut entry = self
249            .sessions
250            .get_mut(id)
251            .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
252        if entry.expires_at <= SystemTime::now() {
253            drop(entry);
254            self.sessions.remove(id);
255            return Err(ServerError::SessionExpired("session expired".into()));
256        }
257        let handle = entry
258            .txn_handle
259            .clone()
260            .ok_or_else(|| ServerError::BadRequest("transaction not started".into()))?;
261        entry.last_active = SystemTime::now();
262        entry.state = SessionState::InTransaction;
263        Ok(handle)
264    }
265
266    pub async fn execute_in_session(&self, id: &SessionId, sql: &str) -> Result<ExecutionResult> {
267        let handle = {
268            let mut entry = self
269                .sessions
270                .get_mut(id)
271                .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
272            if entry.expires_at <= SystemTime::now() {
273                drop(entry);
274                self.sessions.remove(id);
275                return Err(ServerError::SessionExpired("session expired".into()));
276            }
277            let handle = entry
278                .txn_handle
279                .clone()
280                .ok_or_else(|| ServerError::BadRequest("transaction not started".into()))?;
281            entry.last_active = SystemTime::now();
282            handle
283        };
284
285        handle
286            .execute(sql)
287            .await
288            .map_err(|err| ServerError::Sql(err.into()))
289    }
290
291    pub async fn commit(&self, id: &SessionId) -> Result<()> {
292        let handle = self.take_handle(id, SessionState::Committing)?;
293        handle
294            .commit()
295            .await
296            .map_err(|err| ServerError::Sql(err.into()))?;
297        Ok(())
298    }
299
300    pub async fn rollback(&self, id: &SessionId) -> Result<()> {
301        let handle = self.take_handle(id, SessionState::RollingBack)?;
302        handle
303            .rollback()
304            .await
305            .map_err(|err| ServerError::Sql(err.into()))?;
306        Ok(())
307    }
308
309    pub fn cleanup_expired(&self) {
310        let now = SystemTime::now();
311        let expired: Vec<SessionId> = self
312            .sessions
313            .iter()
314            .filter(|entry| entry.expires_at <= now)
315            .map(|entry| entry.id.clone())
316            .collect();
317        for id in expired {
318            self.sessions.remove(&id);
319        }
320    }
321
322    fn take_handle(&self, id: &SessionId, state: SessionState) -> Result<TxnHandle> {
323        let mut entry = self
324            .sessions
325            .get_mut(id)
326            .ok_or_else(|| ServerError::NotFound("session not found".into()))?;
327        if entry.expires_at <= SystemTime::now() {
328            drop(entry);
329            self.sessions.remove(id);
330            return Err(ServerError::SessionExpired("session expired".into()));
331        }
332        let handle = entry
333            .txn_handle
334            .take()
335            .ok_or_else(|| ServerError::BadRequest("transaction not started".into()))?;
336        entry.state = state;
337        entry.last_active = SystemTime::now();
338        Ok(handle)
339    }
340}