Skip to main content

modo_session/
store.rs

1use crate::config::SessionConfig;
2use crate::entity::session::{self, ActiveModel, Column, Entity};
3use crate::meta::SessionMeta;
4use crate::types::{SessionData, SessionId, SessionToken};
5use chrono::{DateTime, Utc};
6use modo::Error;
7use modo::cookies::CookieConfig;
8use modo_db::DbPool;
9use modo_db::sea_orm::{
10    ActiveModelTrait, ColumnTrait, EntityTrait, PaginatorTrait, QueryFilter, QueryOrder,
11    QuerySelect, Set, TransactionTrait,
12};
13
14/// Low-level database-backed session store.
15///
16/// Handles all CRUD operations on the `modo_sessions` table.  Application code
17/// should rarely interact with `SessionStore` directly; use [`crate::SessionManager`]
18/// (the axum extractor) for request-scoped session operations instead.
19///
20/// `SessionStore` is cheaply cloneable and is intended to be registered as a
21/// managed service so it can be injected into background jobs.
22#[derive(Clone)]
23pub struct SessionStore {
24    db: DbPool,
25    config: SessionConfig,
26    cookie_config: CookieConfig,
27}
28
29impl SessionStore {
30    /// Create a new store backed by `db` with the given session and cookie config.
31    pub fn new(db: &DbPool, config: SessionConfig, cookie_config: CookieConfig) -> Self {
32        Self {
33            db: db.clone(),
34            config,
35            cookie_config,
36        }
37    }
38
39    /// Return a reference to the session configuration.
40    pub fn config(&self) -> &SessionConfig {
41        &self.config
42    }
43
44    /// Return a reference to the cookie configuration.
45    pub fn cookie_config(&self) -> &CookieConfig {
46        &self.cookie_config
47    }
48
49    /// Insert a new session for `user_id` and return the persisted [`SessionData`]
50    /// together with the plaintext [`SessionToken`] (to be set in the cookie).
51    ///
52    /// The insert and LRU eviction run inside a single transaction so that a
53    /// failed eviction automatically rolls back the insert.
54    pub async fn create(
55        &self,
56        meta: &SessionMeta,
57        user_id: &str,
58        data: Option<serde_json::Value>,
59    ) -> Result<(SessionData, SessionToken), Error> {
60        let token = SessionToken::generate();
61        let token_hash = token.hash();
62        let now = Utc::now();
63        let expires_at = now + chrono::Duration::seconds(self.config.session_ttl_secs as i64);
64        let data_json = data.unwrap_or(serde_json::json!({}));
65
66        let model = ActiveModel {
67            id: Set(SessionId::new().to_string()),
68            token_hash: Set(token_hash),
69            user_id: Set(user_id.to_string()),
70            ip_address: Set(meta.ip_address.clone()),
71            user_agent: Set(meta.user_agent.clone()),
72            device_name: Set(meta.device_name.clone()),
73            device_type: Set(meta.device_type.clone()),
74            fingerprint: Set(meta.fingerprint.clone()),
75            data: Set(serde_json::to_string(&data_json)
76                .map_err(|e| Error::internal(format!("serialize session data: {e}")))?),
77            created_at: Set(now),
78            last_active_at: Set(now),
79            expires_at: Set(expires_at),
80        };
81
82        // Wrap insert + enforce in a transaction so that on error the insert
83        // is rolled back automatically (DatabaseTransaction::Drop).
84        let txn = self
85            .db
86            .connection()
87            .begin()
88            .await
89            .map_err(|e| Error::internal(format!("begin transaction: {e}")))?;
90
91        let result = model
92            .insert(&txn)
93            .await
94            .map_err(|e| Error::internal(format!("insert session: {e}")))?;
95
96        self.enforce_session_limit_txn(user_id, &txn).await?;
97
98        txn.commit()
99            .await
100            .map_err(|e| Error::internal(format!("commit transaction: {e}")))?;
101
102        Ok((model_to_session_data(&result)?, token))
103    }
104
105    /// Load a session by its ID.  Returns `None` if not found (does not check
106    /// expiry — call [`read_by_token`][Self::read_by_token] for expiry-aware
107    /// lookup).
108    pub async fn read(&self, id: &SessionId) -> Result<Option<SessionData>, Error> {
109        let model = Entity::find_by_id(id.as_str())
110            .one(self.db.connection())
111            .await
112            .map_err(|e| Error::internal(format!("read session: {e}")))?;
113
114        match model {
115            Some(m) => Ok(Some(model_to_session_data(&m)?)),
116            None => Ok(None),
117        }
118    }
119
120    /// Load a non-expired session by plaintext token (hashes it internally).
121    ///
122    /// Returns `None` if no matching, non-expired session is found.
123    pub async fn read_by_token(&self, token: &SessionToken) -> Result<Option<SessionData>, Error> {
124        let hash = token.hash();
125        let model = Entity::find()
126            .filter(Column::TokenHash.eq(&hash))
127            .filter(Column::ExpiresAt.gt(Utc::now()))
128            .one(self.db.connection())
129            .await
130            .map_err(|e| Error::internal(format!("read session by token: {e}")))?;
131
132        match model {
133            Some(m) => Ok(Some(model_to_session_data(&m)?)),
134            None => Ok(None),
135        }
136    }
137
138    /// Delete a session by ID.
139    pub async fn destroy(&self, id: &SessionId) -> Result<(), Error> {
140        Entity::delete_by_id(id.as_str())
141            .exec(self.db.connection())
142            .await
143            .map_err(|e| Error::internal(format!("destroy session: {e}")))?;
144        Ok(())
145    }
146
147    /// Replace the token for a session with a newly generated one and return the
148    /// new plaintext token.  The session ID and all other fields are unchanged.
149    pub async fn rotate_token(&self, id: &SessionId) -> Result<SessionToken, Error> {
150        let new_token = SessionToken::generate();
151        let new_hash = new_token.hash();
152
153        let model = ActiveModel {
154            id: Set(id.as_str().to_string()),
155            token_hash: Set(new_hash),
156            ..Default::default()
157        };
158
159        model
160            .update(self.db.connection())
161            .await
162            .map_err(|e| Error::internal(format!("rotate token: {e}")))?;
163
164        Ok(new_token)
165    }
166
167    /// Update `last_active_at` to now and set a new `expires_at` for a session.
168    pub async fn touch(&self, id: &SessionId, new_expires_at: DateTime<Utc>) -> Result<(), Error> {
169        let model = ActiveModel {
170            id: Set(id.as_str().to_string()),
171            last_active_at: Set(Utc::now()),
172            expires_at: Set(new_expires_at),
173            ..Default::default()
174        };
175
176        model
177            .update(self.db.connection())
178            .await
179            .map_err(|e| Error::internal(format!("touch session: {e}")))?;
180
181        Ok(())
182    }
183
184    /// Replace the JSON payload stored in a session.
185    pub async fn update_data(&self, id: &SessionId, data: serde_json::Value) -> Result<(), Error> {
186        let model = ActiveModel {
187            id: Set(id.as_str().to_string()),
188            data: Set(serde_json::to_string(&data)
189                .map_err(|e| Error::internal(format!("serialize session data: {e}")))?),
190            ..Default::default()
191        };
192
193        model
194            .update(self.db.connection())
195            .await
196            .map_err(|e| Error::internal(format!("update session data: {e}")))?;
197
198        Ok(())
199    }
200
201    /// Delete all sessions belonging to `user_id`.
202    pub async fn destroy_all_for_user(&self, user_id: &str) -> Result<(), Error> {
203        Entity::delete_many()
204            .filter(Column::UserId.eq(user_id))
205            .exec(self.db.connection())
206            .await
207            .map_err(|e| Error::internal(format!("destroy all sessions for user: {e}")))?;
208        Ok(())
209    }
210
211    /// Delete all sessions belonging to `user_id` except the one identified by
212    /// `keep`.
213    pub async fn destroy_all_except(&self, user_id: &str, keep: &SessionId) -> Result<(), Error> {
214        Entity::delete_many()
215            .filter(Column::UserId.eq(user_id))
216            .filter(Column::Id.ne(keep.as_str()))
217            .exec(self.db.connection())
218            .await
219            .map_err(|e| Error::internal(format!("destroy all except: {e}")))?;
220        Ok(())
221    }
222
223    /// Return all non-expired sessions for `user_id`, ordered by most-recently-active
224    /// first.
225    pub async fn list_for_user(&self, user_id: &str) -> Result<Vec<SessionData>, Error> {
226        let models = Entity::find()
227            .filter(Column::UserId.eq(user_id))
228            .filter(Column::ExpiresAt.gt(Utc::now()))
229            .order_by_desc(Column::LastActiveAt)
230            .all(self.db.connection())
231            .await
232            .map_err(|e| Error::internal(format!("list sessions: {e}")))?;
233
234        models.iter().map(model_to_session_data).collect()
235    }
236
237    /// Delete all sessions whose `expires_at` is in the past.
238    ///
239    /// Returns the number of rows deleted.  Called automatically by the
240    /// `cleanup-job` feature's cron job.
241    pub async fn cleanup_expired(&self) -> Result<u64, Error> {
242        let result = Entity::delete_many()
243            .filter(Column::ExpiresAt.lt(Utc::now()))
244            .exec(self.db.connection())
245            .await
246            .map_err(|e| Error::internal(format!("cleanup expired sessions: {e}")))?;
247        Ok(result.rows_affected)
248    }
249
250    /// Enforce session limit within an existing transaction.
251    async fn enforce_session_limit_txn(
252        &self,
253        user_id: &str,
254        txn: &modo_db::sea_orm::DatabaseTransaction,
255    ) -> Result<(), Error> {
256        let now = Utc::now();
257
258        let count = Entity::find()
259            .filter(Column::UserId.eq(user_id))
260            .filter(Column::ExpiresAt.gt(now))
261            .count(txn)
262            .await
263            .map_err(|e| Error::internal(format!("count sessions: {e}")))?;
264
265        if count as usize <= self.config.max_sessions_per_user {
266            return Ok(());
267        }
268
269        let excess = count as usize - self.config.max_sessions_per_user;
270
271        // Find least-recently-used sessions (LRU eviction)
272        let oldest = Entity::find()
273            .filter(Column::UserId.eq(user_id))
274            .filter(Column::ExpiresAt.gt(now))
275            .order_by_asc(Column::LastActiveAt)
276            .limit(excess as u64)
277            .all(txn)
278            .await
279            .map_err(|e| Error::internal(format!("find oldest sessions: {e}")))?;
280
281        let ids: Vec<String> = oldest.into_iter().map(|m| m.id).collect();
282        if !ids.is_empty() {
283            Entity::delete_many()
284                .filter(Column::Id.is_in(ids))
285                .exec(txn)
286                .await
287                .map_err(|e| Error::internal(format!("evict sessions: {e}")))?;
288        }
289
290        Ok(())
291    }
292}
293
294fn model_to_session_data(model: &session::Model) -> Result<SessionData, Error> {
295    let data: serde_json::Value = serde_json::from_str(&model.data)
296        .map_err(|e| Error::internal(format!("deserialize session data: {e}")))?;
297
298    Ok(SessionData {
299        id: SessionId::from_raw(&model.id),
300        token_hash: model.token_hash.clone(),
301        user_id: model.user_id.clone(),
302        ip_address: model.ip_address.clone(),
303        user_agent: model.user_agent.clone(),
304        device_name: model.device_name.clone(),
305        device_type: model.device_type.clone(),
306        fingerprint: model.fingerprint.clone(),
307        data,
308        created_at: model.created_at,
309        last_active_at: model.last_active_at,
310        expires_at: model.expires_at,
311    })
312}