Skip to main content

allowthem_core/
sessions.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use serde::Serialize;
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::event_sink::AuthEvent;
12use crate::handle::AllowThem;
13use crate::types::{Email, Session, SessionId, SessionToken, TokenHash, UserId};
14
15/// Configuration for session lifecycle and cookie generation.
16pub struct SessionConfig {
17    /// How long a session lives. Default: 24 hours.
18    pub ttl: Duration,
19    /// Name of the session cookie. Default: `"allowthem_session"`.
20    pub cookie_name: &'static str,
21    /// Whether to set the `Secure` attribute on the cookie.
22    /// Should be `true` in production (HTTPS), `false` in local dev.
23    pub secure: bool,
24}
25
26impl Default for SessionConfig {
27    fn default() -> Self {
28        Self {
29            ttl: Duration::hours(24),
30            cookie_name: "allowthem_session",
31            secure: true,
32        }
33    }
34}
35
36/// A session with joined user email, for admin list display.
37/// Omits `token_hash` — the admin UI must never expose session tokens.
38#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
39pub struct SessionListEntry {
40    pub id: SessionId,
41    pub user_id: UserId,
42    pub user_email: Email,
43    pub ip_address: Option<String>,
44    pub user_agent: Option<String>,
45    pub expires_at: DateTime<Utc>,
46    pub created_at: DateTime<Utc>,
47}
48
49/// Parameters for listing sessions in the admin session viewer.
50pub struct ListSessionsParams {
51    pub user_id: Option<UserId>,
52    pub limit: u32,
53    pub offset: u32,
54}
55
56/// Result of a paginated session list.
57pub struct ListSessionsResult {
58    pub sessions: Vec<SessionListEntry>,
59    pub total: u32,
60}
61
62/// Generate a cryptographically random session token.
63///
64/// Fills 32 bytes from the OS random source and encodes them as base64url
65/// without padding, producing a 43-character string.
66pub fn generate_token() -> SessionToken {
67    let mut bytes = [0u8; 32];
68    OsRng
69        .try_fill_bytes(&mut bytes)
70        .expect("OS RNG unavailable");
71    SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
72}
73
74/// Hash a session token with SHA-256.
75///
76/// The resulting hex string (64 chars) is what is stored in the database.
77/// The raw token is never persisted.
78pub fn hash_token(token: &SessionToken) -> TokenHash {
79    let digest = Sha256::digest(token.as_str().as_bytes());
80    TokenHash::new_unchecked(format!("{digest:x}"))
81}
82
83impl Db {
84    /// Insert a new session record and return it.
85    ///
86    /// The caller is responsible for hashing the token before calling this function
87    /// via `hash_token()`. The raw token must never be passed here.
88    pub async fn create_session(
89        &self,
90        user_id: UserId,
91        token_hash: TokenHash,
92        ip_address: Option<&str>,
93        user_agent: Option<&str>,
94        expires_at: DateTime<Utc>,
95    ) -> Result<Session, AuthError> {
96        let id = SessionId::new();
97        let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
98        sqlx::query_as::<_, Session>(
99            "INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
100             VALUES (?, ?, ?, ?, ?, ?)
101             RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
102        )
103        .bind(id)
104        .bind(token_hash)
105        .bind(user_id)
106        .bind(ip_address)
107        .bind(user_agent)
108        .bind(expires_at_str)
109        .fetch_one(self.pool())
110        .await
111        .map_err(AuthError::Database)
112    }
113
114    /// Look up a session by raw token.
115    ///
116    /// Hashes the token internally and queries by hash. Expired sessions
117    /// (where `expires_at` is in the past) are excluded. Returns `None`
118    /// when no matching active session is found.
119    pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
120        let hash = hash_token(token);
121        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
122        sqlx::query_as::<_, Session>(
123            "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
124             FROM allowthem_sessions
125             WHERE token_hash = ? AND expires_at > ?",
126        )
127        .bind(hash)
128        .bind(now)
129        .fetch_optional(self.pool())
130        .await
131        .map_err(AuthError::Database)
132    }
133
134    /// Validate a session token and optionally extend it.
135    ///
136    /// Fetches the active session by token hash. If the session is past the
137    /// halfway point of its TTL (`now > expires_at - ttl/2`), it is renewed
138    /// by setting `expires_at = now + ttl`. Returns the session with the
139    /// updated expiry, or `None` if no active session was found.
140    pub async fn validate_session(
141        &self,
142        token: &SessionToken,
143        ttl: Duration,
144    ) -> Result<Option<Session>, AuthError> {
145        let session = match self.lookup_session(token).await? {
146            Some(s) => s,
147            None => return Ok(None),
148        };
149
150        let now = Utc::now();
151        let halfway = session.expires_at - ttl / 2;
152
153        if now > halfway {
154            let new_expires_at = now + ttl;
155            let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
156            let hash = hash_token(token);
157            sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
158                .bind(&new_expires_str)
159                .bind(hash)
160                .execute(self.pool())
161                .await
162                .map_err(AuthError::Database)?;
163
164            return Ok(Some(Session {
165                expires_at: new_expires_at,
166                ..session
167            }));
168        }
169
170        Ok(Some(session))
171    }
172
173    /// Delete a single session by raw token.
174    ///
175    /// Returns `true` if a session was found and deleted, `false` if no
176    /// matching session existed.
177    pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
178        let hash = hash_token(token);
179        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
180            .bind(hash)
181            .execute(self.pool())
182            .await
183            .map_err(AuthError::Database)?;
184        Ok(result.rows_affected() > 0)
185    }
186
187    /// Delete all sessions for a user.
188    ///
189    /// Returns the number of sessions that were deleted.
190    pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
191        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
192            .bind(*user_id)
193            .execute(self.pool())
194            .await
195            .map_err(AuthError::Database)?;
196        Ok(result.rows_affected())
197    }
198
199    /// List all active (non-expired) sessions for a user.
200    ///
201    /// Returns sessions ordered by `created_at` descending (newest first).
202    pub async fn list_user_sessions(&self, user_id: UserId) -> Result<Vec<Session>, AuthError> {
203        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
204        sqlx::query_as::<_, Session>(
205            "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at \
206             FROM allowthem_sessions \
207             WHERE user_id = ? AND expires_at > ? \
208             ORDER BY created_at DESC",
209        )
210        .bind(user_id)
211        .bind(now)
212        .fetch_all(self.pool())
213        .await
214        .map_err(AuthError::Database)
215    }
216
217    /// List all active sessions with user email, for admin session viewer.
218    ///
219    /// Joins sessions with users. Filters to non-expired sessions only.
220    /// Optional user_id filter. Two static query variants (no dynamic SQL).
221    pub async fn list_all_sessions(
222        &self,
223        params: ListSessionsParams,
224    ) -> Result<ListSessionsResult, AuthError> {
225        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
226
227        if let Some(user_id) = params.user_id {
228            let total = sqlx::query_scalar::<_, i64>(
229                "SELECT COUNT(*) FROM allowthem_sessions s \
230                 JOIN allowthem_users u ON s.user_id = u.id \
231                 WHERE s.expires_at > ? AND s.user_id = ?",
232            )
233            .bind(&now)
234            .bind(user_id)
235            .fetch_one(self.pool())
236            .await
237            .map_err(AuthError::Database)? as u32;
238
239            let sessions = sqlx::query_as::<_, SessionListEntry>(
240                "SELECT s.id, s.user_id, u.email AS user_email, \
241                 s.ip_address, s.user_agent, s.expires_at, s.created_at \
242                 FROM allowthem_sessions s \
243                 JOIN allowthem_users u ON s.user_id = u.id \
244                 WHERE s.expires_at > ? AND s.user_id = ? \
245                 ORDER BY s.created_at DESC \
246                 LIMIT ? OFFSET ?",
247            )
248            .bind(&now)
249            .bind(user_id)
250            .bind(params.limit)
251            .bind(params.offset)
252            .fetch_all(self.pool())
253            .await
254            .map_err(AuthError::Database)?;
255
256            Ok(ListSessionsResult { sessions, total })
257        } else {
258            let total = sqlx::query_scalar::<_, i64>(
259                "SELECT COUNT(*) FROM allowthem_sessions s \
260                 JOIN allowthem_users u ON s.user_id = u.id \
261                 WHERE s.expires_at > ?",
262            )
263            .bind(&now)
264            .fetch_one(self.pool())
265            .await
266            .map_err(AuthError::Database)? as u32;
267
268            let sessions = sqlx::query_as::<_, SessionListEntry>(
269                "SELECT s.id, s.user_id, u.email AS user_email, \
270                 s.ip_address, s.user_agent, s.expires_at, s.created_at \
271                 FROM allowthem_sessions s \
272                 JOIN allowthem_users u ON s.user_id = u.id \
273                 WHERE s.expires_at > ? \
274                 ORDER BY s.created_at DESC \
275                 LIMIT ? OFFSET ?",
276            )
277            .bind(&now)
278            .bind(params.limit)
279            .bind(params.offset)
280            .fetch_all(self.pool())
281            .await
282            .map_err(AuthError::Database)?;
283
284            Ok(ListSessionsResult { sessions, total })
285        }
286    }
287
288    /// Delete a single session by primary key.
289    ///
290    /// Used by the admin session viewer to revoke individual sessions.
291    /// Returns `true` if a session was found and deleted, `false` if not.
292    pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
293        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE id = ?")
294            .bind(id)
295            .execute(self.pool())
296            .await
297            .map_err(AuthError::Database)?;
298        Ok(result.rows_affected() > 0)
299    }
300}
301
302// ─── AllowThem wrappers ───────────────────────────────────────────────────────
303
304impl AllowThem {
305    /// Delete a session by token and emit `session.destroyed` (scope=single).
306    ///
307    /// Returns `Ok(true)` if a session was found and deleted, `Ok(false)` if
308    /// not. The event is only emitted on `Ok(true)`.
309    pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
310        let deleted = self.db().delete_session(token).await?;
311        if deleted {
312            self.emit_event(AuthEvent::new(
313                "session.destroyed",
314                None,
315                serde_json::json!({ "scope": "single" }),
316            ))
317            .await;
318        }
319        Ok(deleted)
320    }
321
322    /// Delete a session by ID and emit `session.destroyed` (scope=single).
323    ///
324    /// Returns `Ok(true)` if a session was found and deleted, `Ok(false)` if
325    /// not. The event is only emitted on `Ok(true)`.
326    pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
327        let deleted = self.db().delete_session_by_id(id).await?;
328        if deleted {
329            self.emit_event(AuthEvent::new(
330                "session.destroyed",
331                None,
332                serde_json::json!({ "scope": "single" }),
333            ))
334            .await;
335        }
336        Ok(deleted)
337    }
338
339    /// Delete all sessions for a user and emit one `session.destroyed` event.
340    ///
341    /// The event carries `{ user_id, count, scope: "all" }`. Returns the count.
342    pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
343        let count = self.db().delete_user_sessions(user_id).await?;
344        self.emit_event(AuthEvent::new(
345            "session.destroyed",
346            Some(*user_id),
347            serde_json::json!({ "user_id": user_id, "count": count, "scope": "all" }),
348        ))
349        .await;
350        Ok(count)
351    }
352}
353
354/// Build a `Set-Cookie` header value for the given session token.
355///
356/// Attributes: `HttpOnly`, `SameSite=Lax`, `Path=/`, `Max-Age` derived from
357/// `config.ttl`. The `Secure` attribute is added only when `config.secure` is
358/// `true`. The `Domain` attribute is omitted when `domain` is empty.
359pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
360    let max_age = config.ttl.num_seconds();
361    let mut cookie = format!(
362        "{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
363        config.cookie_name,
364        token.as_str(),
365        max_age,
366    );
367    if !domain.is_empty() {
368        cookie.push_str("; Domain=");
369        cookie.push_str(domain);
370    }
371    if config.secure {
372        cookie.push_str("; Secure");
373    }
374    cookie
375}
376
377/// Build a `Set-Cookie` header value that expires the session cookie.
378///
379/// Returns `Max-Age=0` with the same cookie name, path, domain, and flags
380/// as `session_cookie()` so the browser matches and removes the stored cookie.
381pub fn clear_session_cookie(config: &SessionConfig, domain: &str) -> String {
382    let mut cookie = format!(
383        "{}=; HttpOnly; SameSite=Lax; Path=/; Max-Age=0",
384        config.cookie_name,
385    );
386    if !domain.is_empty() {
387        cookie.push_str("; Domain=");
388        cookie.push_str(domain);
389    }
390    if config.secure {
391        cookie.push_str("; Secure");
392    }
393    cookie
394}
395
396/// Extract the session token from a `Cookie` header value.
397///
398/// Searches the semicolon-separated list of `name=value` pairs for
399/// `cookie_name`. Returns `None` if the cookie is absent.
400pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
401    for pair in cookie_header.split("; ") {
402        if let Some((name, value)) = pair.split_once('=')
403            && name.trim() == cookie_name
404        {
405            return Some(SessionToken::from_encoded(value.trim().to_string()));
406        }
407    }
408    None
409}