Skip to main content

allowthem_core/
sessions.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use serde::Serialize;
8
9use crate::db::Db;
10use crate::error::AuthError;
11use crate::types::{Email, Session, SessionId, SessionToken, TokenHash, UserId};
12
13/// Configuration for session lifecycle and cookie generation.
14pub struct SessionConfig {
15    /// How long a session lives. Default: 24 hours.
16    pub ttl: Duration,
17    /// Name of the session cookie. Default: `"allowthem_session"`.
18    pub cookie_name: &'static str,
19    /// Whether to set the `Secure` attribute on the cookie.
20    /// Should be `true` in production (HTTPS), `false` in local dev.
21    pub secure: bool,
22}
23
24impl Default for SessionConfig {
25    fn default() -> Self {
26        Self {
27            ttl: Duration::hours(24),
28            cookie_name: "allowthem_session",
29            secure: true,
30        }
31    }
32}
33
34/// A session with joined user email, for admin list display.
35/// Omits `token_hash` — the admin UI must never expose session tokens.
36#[derive(Debug, Clone, Serialize, sqlx::FromRow)]
37pub struct SessionListEntry {
38    pub id: SessionId,
39    pub user_id: UserId,
40    pub user_email: Email,
41    pub ip_address: Option<String>,
42    pub user_agent: Option<String>,
43    pub expires_at: DateTime<Utc>,
44    pub created_at: DateTime<Utc>,
45}
46
47/// Parameters for listing sessions in the admin session viewer.
48pub struct ListSessionsParams {
49    pub user_id: Option<UserId>,
50    pub limit: u32,
51    pub offset: u32,
52}
53
54/// Result of a paginated session list.
55pub struct ListSessionsResult {
56    pub sessions: Vec<SessionListEntry>,
57    pub total: u32,
58}
59
60/// Generate a cryptographically random session token.
61///
62/// Fills 32 bytes from the OS random source and encodes them as base64url
63/// without padding, producing a 43-character string.
64pub fn generate_token() -> SessionToken {
65    let mut bytes = [0u8; 32];
66    OsRng
67        .try_fill_bytes(&mut bytes)
68        .expect("OS RNG unavailable");
69    SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
70}
71
72/// Hash a session token with SHA-256.
73///
74/// The resulting hex string (64 chars) is what is stored in the database.
75/// The raw token is never persisted.
76pub fn hash_token(token: &SessionToken) -> TokenHash {
77    let digest = Sha256::digest(token.as_str().as_bytes());
78    TokenHash::new_unchecked(format!("{digest:x}"))
79}
80
81impl Db {
82    /// Insert a new session record and return it.
83    ///
84    /// The caller is responsible for hashing the token before calling this function
85    /// via `hash_token()`. The raw token must never be passed here.
86    pub async fn create_session(
87        &self,
88        user_id: UserId,
89        token_hash: TokenHash,
90        ip_address: Option<&str>,
91        user_agent: Option<&str>,
92        expires_at: DateTime<Utc>,
93    ) -> Result<Session, AuthError> {
94        let id = SessionId::new();
95        let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
96        sqlx::query_as::<_, Session>(
97            "INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
98             VALUES (?, ?, ?, ?, ?, ?)
99             RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
100        )
101        .bind(id)
102        .bind(token_hash)
103        .bind(user_id)
104        .bind(ip_address)
105        .bind(user_agent)
106        .bind(expires_at_str)
107        .fetch_one(self.pool())
108        .await
109        .map_err(AuthError::Database)
110    }
111
112    /// Look up a session by raw token.
113    ///
114    /// Hashes the token internally and queries by hash. Expired sessions
115    /// (where `expires_at` is in the past) are excluded. Returns `None`
116    /// when no matching active session is found.
117    pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
118        let hash = hash_token(token);
119        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
120        sqlx::query_as::<_, Session>(
121            "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
122             FROM allowthem_sessions
123             WHERE token_hash = ? AND expires_at > ?",
124        )
125        .bind(hash)
126        .bind(now)
127        .fetch_optional(self.pool())
128        .await
129        .map_err(AuthError::Database)
130    }
131
132    /// Validate a session token and optionally extend it.
133    ///
134    /// Fetches the active session by token hash. If the session is past the
135    /// halfway point of its TTL (`now > expires_at - ttl/2`), it is renewed
136    /// by setting `expires_at = now + ttl`. Returns the session with the
137    /// updated expiry, or `None` if no active session was found.
138    pub async fn validate_session(
139        &self,
140        token: &SessionToken,
141        ttl: Duration,
142    ) -> Result<Option<Session>, AuthError> {
143        let session = match self.lookup_session(token).await? {
144            Some(s) => s,
145            None => return Ok(None),
146        };
147
148        let now = Utc::now();
149        let halfway = session.expires_at - ttl / 2;
150
151        if now > halfway {
152            let new_expires_at = now + ttl;
153            let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
154            let hash = hash_token(token);
155            sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
156                .bind(&new_expires_str)
157                .bind(hash)
158                .execute(self.pool())
159                .await
160                .map_err(AuthError::Database)?;
161
162            return Ok(Some(Session {
163                expires_at: new_expires_at,
164                ..session
165            }));
166        }
167
168        Ok(Some(session))
169    }
170
171    /// Delete a single session by raw token.
172    ///
173    /// Returns `true` if a session was found and deleted, `false` if no
174    /// matching session existed.
175    pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
176        let hash = hash_token(token);
177        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
178            .bind(hash)
179            .execute(self.pool())
180            .await
181            .map_err(AuthError::Database)?;
182        Ok(result.rows_affected() > 0)
183    }
184
185    /// Delete all sessions for a user.
186    ///
187    /// Returns the number of sessions that were deleted.
188    pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
189        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
190            .bind(*user_id)
191            .execute(self.pool())
192            .await
193            .map_err(AuthError::Database)?;
194        Ok(result.rows_affected())
195    }
196
197    /// List all active (non-expired) sessions for a user.
198    ///
199    /// Returns sessions ordered by `created_at` descending (newest first).
200    pub async fn list_user_sessions(&self, user_id: UserId) -> Result<Vec<Session>, AuthError> {
201        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
202        sqlx::query_as::<_, Session>(
203            "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at \
204             FROM allowthem_sessions \
205             WHERE user_id = ? AND expires_at > ? \
206             ORDER BY created_at DESC",
207        )
208        .bind(user_id)
209        .bind(now)
210        .fetch_all(self.pool())
211        .await
212        .map_err(AuthError::Database)
213    }
214
215    /// List all active sessions with user email, for admin session viewer.
216    ///
217    /// Joins sessions with users. Filters to non-expired sessions only.
218    /// Optional user_id filter. Two static query variants (no dynamic SQL).
219    pub async fn list_all_sessions(
220        &self,
221        params: ListSessionsParams,
222    ) -> Result<ListSessionsResult, AuthError> {
223        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
224
225        if let Some(user_id) = params.user_id {
226            let total = sqlx::query_scalar::<_, i64>(
227                "SELECT COUNT(*) FROM allowthem_sessions s \
228                 JOIN allowthem_users u ON s.user_id = u.id \
229                 WHERE s.expires_at > ? AND s.user_id = ?",
230            )
231            .bind(&now)
232            .bind(user_id)
233            .fetch_one(self.pool())
234            .await
235            .map_err(AuthError::Database)? as u32;
236
237            let sessions = sqlx::query_as::<_, SessionListEntry>(
238                "SELECT s.id, s.user_id, u.email AS user_email, \
239                 s.ip_address, s.user_agent, s.expires_at, s.created_at \
240                 FROM allowthem_sessions s \
241                 JOIN allowthem_users u ON s.user_id = u.id \
242                 WHERE s.expires_at > ? AND s.user_id = ? \
243                 ORDER BY s.created_at DESC \
244                 LIMIT ? OFFSET ?",
245            )
246            .bind(&now)
247            .bind(user_id)
248            .bind(params.limit)
249            .bind(params.offset)
250            .fetch_all(self.pool())
251            .await
252            .map_err(AuthError::Database)?;
253
254            Ok(ListSessionsResult { sessions, total })
255        } else {
256            let total = sqlx::query_scalar::<_, i64>(
257                "SELECT COUNT(*) FROM allowthem_sessions s \
258                 JOIN allowthem_users u ON s.user_id = u.id \
259                 WHERE s.expires_at > ?",
260            )
261            .bind(&now)
262            .fetch_one(self.pool())
263            .await
264            .map_err(AuthError::Database)? as u32;
265
266            let sessions = sqlx::query_as::<_, SessionListEntry>(
267                "SELECT s.id, s.user_id, u.email AS user_email, \
268                 s.ip_address, s.user_agent, s.expires_at, s.created_at \
269                 FROM allowthem_sessions s \
270                 JOIN allowthem_users u ON s.user_id = u.id \
271                 WHERE s.expires_at > ? \
272                 ORDER BY s.created_at DESC \
273                 LIMIT ? OFFSET ?",
274            )
275            .bind(&now)
276            .bind(params.limit)
277            .bind(params.offset)
278            .fetch_all(self.pool())
279            .await
280            .map_err(AuthError::Database)?;
281
282            Ok(ListSessionsResult { sessions, total })
283        }
284    }
285
286    /// Delete a single session by primary key.
287    ///
288    /// Used by the admin session viewer to revoke individual sessions.
289    /// Returns `true` if a session was found and deleted, `false` if not.
290    pub async fn delete_session_by_id(&self, id: SessionId) -> Result<bool, AuthError> {
291        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE id = ?")
292            .bind(id)
293            .execute(self.pool())
294            .await
295            .map_err(AuthError::Database)?;
296        Ok(result.rows_affected() > 0)
297    }
298}
299
300/// Build a `Set-Cookie` header value for the given session token.
301///
302/// Attributes: `HttpOnly`, `SameSite=Lax`, `Path=/`, `Max-Age` derived from
303/// `config.ttl`. The `Secure` attribute is added only when `config.secure` is
304/// `true`. The `Domain` attribute is omitted when `domain` is empty.
305pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
306    let max_age = config.ttl.num_seconds();
307    let mut cookie = format!(
308        "{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
309        config.cookie_name,
310        token.as_str(),
311        max_age,
312    );
313    if !domain.is_empty() {
314        cookie.push_str("; Domain=");
315        cookie.push_str(domain);
316    }
317    if config.secure {
318        cookie.push_str("; Secure");
319    }
320    cookie
321}
322
323/// Build a `Set-Cookie` header value that expires the session cookie.
324///
325/// Returns `Max-Age=0` with the same cookie name, path, domain, and flags
326/// as `session_cookie()` so the browser matches and removes the stored cookie.
327pub fn clear_session_cookie(config: &SessionConfig, domain: &str) -> String {
328    let mut cookie = format!(
329        "{}=; HttpOnly; SameSite=Lax; Path=/; Max-Age=0",
330        config.cookie_name,
331    );
332    if !domain.is_empty() {
333        cookie.push_str("; Domain=");
334        cookie.push_str(domain);
335    }
336    if config.secure {
337        cookie.push_str("; Secure");
338    }
339    cookie
340}
341
342/// Extract the session token from a `Cookie` header value.
343///
344/// Searches the semicolon-separated list of `name=value` pairs for
345/// `cookie_name`. Returns `None` if the cookie is absent.
346pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
347    for pair in cookie_header.split("; ") {
348        if let Some((name, value)) = pair.split_once('=')
349            && name.trim() == cookie_name
350        {
351            return Some(SessionToken::from_encoded(value.trim().to_string()));
352        }
353    }
354    None
355}