Skip to main content

allowthem_core/
sessions.rs

1use base64ct::{Base64UrlUnpadded, Encoding};
2use chrono::{DateTime, Duration, Utc};
3use rand::TryRngCore;
4use rand::rngs::OsRng;
5use sha2::{Digest, Sha256};
6
7use crate::db::Db;
8use crate::error::AuthError;
9use crate::types::{Session, SessionId, SessionToken, TokenHash, UserId};
10
11/// Configuration for session lifecycle and cookie generation.
12pub struct SessionConfig {
13    /// How long a session lives. Default: 24 hours.
14    pub ttl: Duration,
15    /// Name of the session cookie. Default: `"allowthem_session"`.
16    pub cookie_name: &'static str,
17    /// Whether to set the `Secure` attribute on the cookie.
18    /// Should be `true` in production (HTTPS), `false` in local dev.
19    pub secure: bool,
20}
21
22impl Default for SessionConfig {
23    fn default() -> Self {
24        Self {
25            ttl: Duration::hours(24),
26            cookie_name: "allowthem_session",
27            secure: true,
28        }
29    }
30}
31
32/// Generate a cryptographically random session token.
33///
34/// Fills 32 bytes from the OS random source and encodes them as base64url
35/// without padding, producing a 43-character string.
36pub fn generate_token() -> SessionToken {
37    let mut bytes = [0u8; 32];
38    OsRng
39        .try_fill_bytes(&mut bytes)
40        .expect("OS RNG unavailable");
41    SessionToken::from_encoded(Base64UrlUnpadded::encode_string(&bytes))
42}
43
44/// Hash a session token with SHA-256.
45///
46/// The resulting hex string (64 chars) is what is stored in the database.
47/// The raw token is never persisted.
48pub fn hash_token(token: &SessionToken) -> TokenHash {
49    let digest = Sha256::digest(token.as_str().as_bytes());
50    TokenHash::new_unchecked(format!("{digest:x}"))
51}
52
53impl Db {
54    /// Insert a new session record and return it.
55    ///
56    /// The caller is responsible for hashing the token before calling this function
57    /// via `hash_token()`. The raw token must never be passed here.
58    pub async fn create_session(
59        &self,
60        user_id: UserId,
61        token_hash: TokenHash,
62        ip_address: Option<&str>,
63        user_agent: Option<&str>,
64        expires_at: DateTime<Utc>,
65    ) -> Result<Session, AuthError> {
66        let id = SessionId::new();
67        let expires_at_str = expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
68        sqlx::query_as::<_, Session>(
69            "INSERT INTO allowthem_sessions (id, token_hash, user_id, ip_address, user_agent, expires_at)
70             VALUES (?, ?, ?, ?, ?, ?)
71             RETURNING id, token_hash, user_id, ip_address, user_agent, expires_at, created_at",
72        )
73        .bind(id)
74        .bind(token_hash)
75        .bind(user_id)
76        .bind(ip_address)
77        .bind(user_agent)
78        .bind(expires_at_str)
79        .fetch_one(self.pool())
80        .await
81        .map_err(AuthError::Database)
82    }
83
84    /// Look up a session by raw token.
85    ///
86    /// Hashes the token internally and queries by hash. Expired sessions
87    /// (where `expires_at` is in the past) are excluded. Returns `None`
88    /// when no matching active session is found.
89    pub async fn lookup_session(&self, token: &SessionToken) -> Result<Option<Session>, AuthError> {
90        let hash = hash_token(token);
91        let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
92        sqlx::query_as::<_, Session>(
93            "SELECT id, token_hash, user_id, ip_address, user_agent, expires_at, created_at
94             FROM allowthem_sessions
95             WHERE token_hash = ? AND expires_at > ?",
96        )
97        .bind(hash)
98        .bind(now)
99        .fetch_optional(self.pool())
100        .await
101        .map_err(AuthError::Database)
102    }
103
104    /// Validate a session token and optionally extend it.
105    ///
106    /// Fetches the active session by token hash. If the session is past the
107    /// halfway point of its TTL (`now > expires_at - ttl/2`), it is renewed
108    /// by setting `expires_at = now + ttl`. Returns the session with the
109    /// updated expiry, or `None` if no active session was found.
110    pub async fn validate_session(
111        &self,
112        token: &SessionToken,
113        ttl: Duration,
114    ) -> Result<Option<Session>, AuthError> {
115        let session = match self.lookup_session(token).await? {
116            Some(s) => s,
117            None => return Ok(None),
118        };
119
120        let now = Utc::now();
121        let halfway = session.expires_at - ttl / 2;
122
123        if now > halfway {
124            let new_expires_at = now + ttl;
125            let new_expires_str = new_expires_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
126            let hash = hash_token(token);
127            sqlx::query("UPDATE allowthem_sessions SET expires_at = ? WHERE token_hash = ?")
128                .bind(&new_expires_str)
129                .bind(hash)
130                .execute(self.pool())
131                .await
132                .map_err(AuthError::Database)?;
133
134            return Ok(Some(Session {
135                expires_at: new_expires_at,
136                ..session
137            }));
138        }
139
140        Ok(Some(session))
141    }
142
143    /// Delete a single session by raw token.
144    ///
145    /// Returns `true` if a session was found and deleted, `false` if no
146    /// matching session existed.
147    pub async fn delete_session(&self, token: &SessionToken) -> Result<bool, AuthError> {
148        let hash = hash_token(token);
149        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE token_hash = ?")
150            .bind(hash)
151            .execute(self.pool())
152            .await
153            .map_err(AuthError::Database)?;
154        Ok(result.rows_affected() > 0)
155    }
156
157    /// Delete all sessions for a user.
158    ///
159    /// Returns the number of sessions that were deleted.
160    pub async fn delete_user_sessions(&self, user_id: &UserId) -> Result<u64, AuthError> {
161        let result = sqlx::query("DELETE FROM allowthem_sessions WHERE user_id = ?")
162            .bind(*user_id)
163            .execute(self.pool())
164            .await
165            .map_err(AuthError::Database)?;
166        Ok(result.rows_affected())
167    }
168}
169
170/// Build a `Set-Cookie` header value for the given session token.
171///
172/// Attributes: `HttpOnly`, `SameSite=Lax`, `Path=/`, `Max-Age` derived from
173/// `config.ttl`. The `Secure` attribute is added only when `config.secure` is
174/// `true`. The `Domain` attribute is omitted when `domain` is empty.
175pub fn session_cookie(token: &SessionToken, config: &SessionConfig, domain: &str) -> String {
176    let max_age = config.ttl.num_seconds();
177    let mut cookie = format!(
178        "{}={}; HttpOnly; SameSite=Lax; Path=/; Max-Age={}",
179        config.cookie_name,
180        token.as_str(),
181        max_age,
182    );
183    if !domain.is_empty() {
184        cookie.push_str("; Domain=");
185        cookie.push_str(domain);
186    }
187    if config.secure {
188        cookie.push_str("; Secure");
189    }
190    cookie
191}
192
193/// Extract the session token from a `Cookie` header value.
194///
195/// Searches the semicolon-separated list of `name=value` pairs for
196/// `cookie_name`. Returns `None` if the cookie is absent.
197pub fn parse_session_cookie(cookie_header: &str, cookie_name: &str) -> Option<SessionToken> {
198    for pair in cookie_header.split("; ") {
199        if let Some((name, value)) = pair.split_once('=')
200            && name.trim() == cookie_name
201        {
202            return Some(SessionToken::from_encoded(value.trim().to_string()));
203        }
204    }
205    None
206}