Skip to main content

modo/auth/session/
store.rs

1use chrono::{DateTime, Utc};
2use serde::{Deserialize, Serialize};
3
4use crate::db::{ColumnMap, ConnExt, ConnQueryExt, Database, FromRow};
5use crate::error::{Error, Result};
6
7use super::config::SessionConfig;
8use super::meta::SessionMeta;
9use super::token::SessionToken;
10
11const SESSION_COLUMNS: &str = "id, user_id, ip_address, user_agent, device_name, device_type, \
12    fingerprint, data, created_at, last_active_at, expires_at";
13
14/// A snapshot of a session row as returned from the database.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SessionData {
17    /// Unique session identifier (ULID).
18    pub id: String,
19    /// The authenticated user's identifier.
20    pub user_id: String,
21    /// IP address recorded at login.
22    pub ip_address: String,
23    /// Raw `User-Agent` header recorded at login.
24    pub user_agent: String,
25    /// Human-readable device name derived from the user agent (e.g. `"Chrome on macOS"`).
26    pub device_name: String,
27    /// Device category: `"desktop"`, `"mobile"`, or `"tablet"`.
28    pub device_type: String,
29    /// SHA-256 fingerprint of the browser environment used to detect session hijacking.
30    pub fingerprint: String,
31    /// Arbitrary JSON data attached to the session.
32    pub data: serde_json::Value,
33    /// When the session was created.
34    pub created_at: DateTime<Utc>,
35    /// When the session was last touched.
36    pub last_active_at: DateTime<Utc>,
37    /// When the session expires.
38    pub expires_at: DateTime<Utc>,
39}
40
41/// Low-level SQLite-backed session store.
42///
43/// Wraps a [`Database`] handle and exposes async methods for all session CRUD
44/// operations. Consumed by [`super::middleware::SessionLayer`] and available
45/// to handlers via [`super::extractor::Session`].
46#[derive(Clone)]
47pub struct Store {
48    db: Database,
49    config: SessionConfig,
50}
51
52impl Store {
53    /// Create a store from a [`Database`] handle and session configuration.
54    pub fn new(db: Database, config: SessionConfig) -> Self {
55        Self { db, config }
56    }
57
58    /// Return the session configuration for this store.
59    pub fn config(&self) -> &SessionConfig {
60        &self.config
61    }
62
63    /// Look up an active (non-expired) session by its token hash.
64    ///
65    /// Returns `None` if no matching session exists or the session has expired.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if the database query fails or the stored data cannot
70    /// be deserialised.
71    pub async fn read_by_token(&self, token: &SessionToken) -> Result<Option<SessionData>> {
72        let hash = token.hash();
73        let now = Utc::now().to_rfc3339();
74        let row: Option<SessionRow> = self
75            .db
76            .conn()
77            .query_optional(
78                &format!(
79                    "SELECT {SESSION_COLUMNS} FROM sessions \
80                     WHERE token_hash = ?1 AND expires_at > ?2"
81                ),
82                libsql::params![hash, now],
83            )
84            .await?;
85
86        row.map(row_to_session_data).transpose()
87    }
88
89    /// Look up a session by its ULID identifier (ignores expiry).
90    ///
91    /// Returns `None` if no session with that ID exists.
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if the database query fails or the stored data cannot
96    /// be deserialised.
97    pub async fn read(&self, id: &str) -> Result<Option<SessionData>> {
98        let row: Option<SessionRow> = self
99            .db
100            .conn()
101            .query_optional(
102                &format!("SELECT {SESSION_COLUMNS} FROM sessions WHERE id = ?1"),
103                libsql::params![id],
104            )
105            .await?;
106
107        row.map(row_to_session_data).transpose()
108    }
109
110    /// List all active (non-expired) sessions for a user, ordered by most recently active.
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if the database query fails or the stored data cannot
115    /// be deserialised.
116    pub async fn list_for_user(&self, user_id: &str) -> Result<Vec<SessionData>> {
117        let now = Utc::now().to_rfc3339();
118        let rows: Vec<SessionRow> = self
119            .db
120            .conn()
121            .query_all(
122                &format!(
123                    "SELECT {SESSION_COLUMNS} FROM sessions \
124                     WHERE user_id = ?1 AND expires_at > ?2 \
125                     ORDER BY last_active_at DESC"
126                ),
127                libsql::params![user_id, now],
128            )
129            .await?;
130
131        rows.into_iter().map(row_to_session_data).collect()
132    }
133
134    /// Create a new session for the given user.
135    ///
136    /// Inserts the session row then trims excess sessions when the
137    /// `max_sessions_per_user` limit is exceeded by evicting the oldest
138    /// session(s).
139    ///
140    /// Returns the newly-created `SessionData` and the raw `SessionToken` that
141    /// must be placed in the cookie.
142    ///
143    /// # Errors
144    ///
145    /// Returns an error if the session data cannot be serialised or the
146    /// database insert/eviction query fails.
147    pub async fn create(
148        &self,
149        meta: &SessionMeta,
150        user_id: &str,
151        data: Option<serde_json::Value>,
152    ) -> Result<(SessionData, SessionToken)> {
153        let id = crate::id::ulid();
154        let token = SessionToken::generate();
155        let token_hash = token.hash();
156        let now = Utc::now();
157        let expires_at = now + chrono::Duration::seconds(self.config.session_ttl_secs as i64);
158        let data_json = data.unwrap_or(serde_json::json!({}));
159        let data_str = serde_json::to_string(&data_json)
160            .map_err(|e| Error::internal(format!("serialize session data: {e}")))?;
161        let now_str = now.to_rfc3339();
162        let expires_str = expires_at.to_rfc3339();
163
164        // Insert session
165        self.db
166            .conn()
167            .execute_raw(
168                "INSERT INTO sessions \
169                 (id, token_hash, user_id, ip_address, user_agent, device_name, device_type, \
170                  fingerprint, data, created_at, last_active_at, expires_at) \
171                 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12)",
172                libsql::params![
173                    id.as_str(),
174                    token_hash.as_str(),
175                    user_id,
176                    meta.ip_address.as_str(),
177                    meta.user_agent.as_str(),
178                    meta.device_name.as_str(),
179                    meta.device_type.as_str(),
180                    meta.fingerprint.as_str(),
181                    data_str.as_str(),
182                    now_str.as_str(),
183                    now_str.as_str(),
184                    expires_str.as_str()
185                ],
186            )
187            .await
188            .map_err(|e| Error::internal(format!("insert session: {e}")))?;
189
190        // Trim excess sessions
191        let max = self.config.max_sessions_per_user as i64;
192        self.db
193            .conn()
194            .execute_raw(
195                "DELETE FROM sessions WHERE id IN (\
196                     SELECT id FROM sessions \
197                     WHERE user_id = ?1 AND expires_at > ?2 \
198                     ORDER BY last_active_at ASC \
199                     LIMIT MAX(0, (SELECT COUNT(*) FROM sessions \
200                                   WHERE user_id = ?3 AND expires_at > ?4) - ?5)\
201                 )",
202                libsql::params![user_id, now_str.as_str(), user_id, now_str.as_str(), max],
203            )
204            .await
205            .map_err(|e| Error::internal(format!("evict excess sessions: {e}")))?;
206
207        let session_data = SessionData {
208            id,
209            user_id: user_id.to_string(),
210            ip_address: meta.ip_address.clone(),
211            user_agent: meta.user_agent.clone(),
212            device_name: meta.device_name.clone(),
213            device_type: meta.device_type.clone(),
214            fingerprint: meta.fingerprint.clone(),
215            data: data_json,
216            created_at: now,
217            last_active_at: now,
218            expires_at,
219        };
220
221        Ok((session_data, token))
222    }
223
224    /// Delete a session by its ULID identifier.
225    ///
226    /// # Errors
227    ///
228    /// Returns an error if the database delete fails.
229    pub async fn destroy(&self, id: &str) -> Result<()> {
230        self.db
231            .conn()
232            .execute_raw("DELETE FROM sessions WHERE id = ?1", libsql::params![id])
233            .await
234            .map_err(|e| Error::internal(format!("destroy session: {e}")))?;
235        Ok(())
236    }
237
238    /// Delete all sessions belonging to a user.
239    ///
240    /// # Errors
241    ///
242    /// Returns an error if the database delete fails.
243    pub async fn destroy_all_for_user(&self, user_id: &str) -> Result<()> {
244        self.db
245            .conn()
246            .execute_raw(
247                "DELETE FROM sessions WHERE user_id = ?1",
248                libsql::params![user_id],
249            )
250            .await
251            .map_err(|e| Error::internal(format!("destroy all sessions for user: {e}")))?;
252        Ok(())
253    }
254
255    /// Delete all sessions for a user except the one with the given ID.
256    ///
257    /// Used to implement "log out other devices".
258    ///
259    /// # Errors
260    ///
261    /// Returns an error if the database delete fails.
262    pub async fn destroy_all_except(&self, user_id: &str, keep_id: &str) -> Result<()> {
263        self.db
264            .conn()
265            .execute_raw(
266                "DELETE FROM sessions WHERE user_id = ?1 AND id != ?2",
267                libsql::params![user_id, keep_id],
268            )
269            .await
270            .map_err(|e| Error::internal(format!("destroy all except: {e}")))?;
271        Ok(())
272    }
273
274    /// Issue a new token for an existing session, invalidating the old one.
275    ///
276    /// Returns the new [`SessionToken`]. The middleware will write this token
277    /// to the session cookie on the response.
278    ///
279    /// # Errors
280    ///
281    /// Returns an error if the database update fails.
282    pub async fn rotate_token(&self, id: &str) -> Result<SessionToken> {
283        let new_token = SessionToken::generate();
284        let new_hash = new_token.hash();
285        self.db
286            .conn()
287            .execute_raw(
288                "UPDATE sessions SET token_hash = ?1 WHERE id = ?2",
289                libsql::params![new_hash, id],
290            )
291            .await
292            .map_err(|e| Error::internal(format!("rotate token: {e}")))?;
293        Ok(new_token)
294    }
295
296    /// Persist the session's JSON data and update `last_active_at` / `expires_at`.
297    ///
298    /// Called by the middleware at the end of a request when the session was
299    /// marked dirty.
300    ///
301    /// # Errors
302    ///
303    /// Returns an error if the session data cannot be serialised or the
304    /// database update fails.
305    pub async fn flush(
306        &self,
307        id: &str,
308        data: &serde_json::Value,
309        now: DateTime<Utc>,
310        expires_at: DateTime<Utc>,
311    ) -> Result<()> {
312        let data_str = serde_json::to_string(data)
313            .map_err(|e| Error::internal(format!("serialize session data: {e}")))?;
314        self.db
315            .conn()
316            .execute_raw(
317                "UPDATE sessions SET data = ?1, last_active_at = ?2, expires_at = ?3 \
318                 WHERE id = ?4",
319                libsql::params![data_str, now.to_rfc3339(), expires_at.to_rfc3339(), id],
320            )
321            .await
322            .map_err(|e| Error::internal(format!("flush session: {e}")))?;
323        Ok(())
324    }
325
326    /// Update `last_active_at` and `expires_at` without changing session data.
327    ///
328    /// Called by the middleware when the touch interval has elapsed but the
329    /// session data is not dirty.
330    ///
331    /// # Errors
332    ///
333    /// Returns an error if the database update fails.
334    pub async fn touch(
335        &self,
336        id: &str,
337        now: DateTime<Utc>,
338        expires_at: DateTime<Utc>,
339    ) -> Result<()> {
340        self.db
341            .conn()
342            .execute_raw(
343                "UPDATE sessions SET last_active_at = ?1, expires_at = ?2 WHERE id = ?3",
344                libsql::params![now.to_rfc3339(), expires_at.to_rfc3339(), id],
345            )
346            .await
347            .map_err(|e| Error::internal(format!("touch session: {e}")))?;
348        Ok(())
349    }
350
351    /// Delete all sessions whose `expires_at` is in the past.
352    ///
353    /// Returns the number of rows deleted. Schedule this periodically (e.g.
354    /// via a cron job) to keep the table small.
355    ///
356    /// # Errors
357    ///
358    /// Returns an error if the database delete fails.
359    pub async fn cleanup_expired(&self) -> Result<u64> {
360        let now = Utc::now().to_rfc3339();
361        let affected = self
362            .db
363            .conn()
364            .execute_raw(
365                "DELETE FROM sessions WHERE expires_at < ?1",
366                libsql::params![now],
367            )
368            .await
369            .map_err(Error::from)?;
370        Ok(affected)
371    }
372}
373
374struct SessionRow {
375    id: String,
376    user_id: String,
377    ip_address: String,
378    user_agent: String,
379    device_name: String,
380    device_type: String,
381    fingerprint: String,
382    data: String,
383    created_at: String,
384    last_active_at: String,
385    expires_at: String,
386}
387
388impl FromRow for SessionRow {
389    fn from_row(row: &libsql::Row) -> Result<Self> {
390        let cols = ColumnMap::from_row(row);
391        Ok(Self {
392            id: cols.get(row, "id")?,
393            user_id: cols.get(row, "user_id")?,
394            ip_address: cols.get(row, "ip_address")?,
395            user_agent: cols.get(row, "user_agent")?,
396            device_name: cols.get(row, "device_name")?,
397            device_type: cols.get(row, "device_type")?,
398            fingerprint: cols.get(row, "fingerprint")?,
399            data: cols.get(row, "data")?,
400            created_at: cols.get(row, "created_at")?,
401            last_active_at: cols.get(row, "last_active_at")?,
402            expires_at: cols.get(row, "expires_at")?,
403        })
404    }
405}
406
407fn row_to_session_data(row: SessionRow) -> Result<SessionData> {
408    let data: serde_json::Value = serde_json::from_str(&row.data)
409        .map_err(|e| Error::internal(format!("deserialize session data: {e}")))?;
410    let created_at = DateTime::parse_from_rfc3339(&row.created_at)
411        .map_err(|e| Error::internal(format!("parse created_at: {e}")))?
412        .with_timezone(&Utc);
413    let last_active_at = DateTime::parse_from_rfc3339(&row.last_active_at)
414        .map_err(|e| Error::internal(format!("parse last_active_at: {e}")))?
415        .with_timezone(&Utc);
416    let expires_at = DateTime::parse_from_rfc3339(&row.expires_at)
417        .map_err(|e| Error::internal(format!("parse expires_at: {e}")))?
418        .with_timezone(&Utc);
419
420    Ok(SessionData {
421        id: row.id,
422        user_id: row.user_id,
423        ip_address: row.ip_address,
424        user_agent: row.user_agent,
425        device_name: row.device_name,
426        device_type: row.device_type,
427        fingerprint: row.fingerprint,
428        data,
429        created_at,
430        last_active_at,
431        expires_at,
432    })
433}