Skip to main content

forge_core/auth/
tokens.rs

1//! Refresh token management.
2//!
3//! Provides token pair issuance (access + refresh), rotation, and revocation.
4//! Refresh tokens are stored as SHA-256 hashes in `forge_refresh_tokens`.
5//!
6//! ## Token format
7//!
8//! Raw refresh tokens encode their family UUID so chain reuse detection can
9//! revoke the whole family even after the specific token row is deleted:
10//!
11//! ```text
12//! <family_uuid_hex>.<random_uuid_hex>
13//! ```
14//!
15//! Only the SHA-256 hash of the full string is stored in the database.
16//!
17//! ## Chain reuse detection
18//!
19//! Each token family represents a single login session. During rotation the old
20//! token is deleted and a new one with the same `token_family` is inserted
21//! atomically. If a previously-rotated (deleted) token is presented:
22//!
23//! 1. The DELETE returns 0 rows.
24//! 2. The family UUID is decoded from the raw token value.
25//! 3. All tokens sharing that family are revoked immediately.
26//! 4. `Unauthorized` is returned to the caller.
27//!
28//! This terminates the session for both the legitimate user (who holds the
29//! current token) and the attacker who replayed the old one.
30
31use sha2::{Digest, Sha256};
32use uuid::Uuid;
33
34use crate::error::{ForgeError, Result};
35
36/// An access token + refresh token pair.
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct TokenPair {
39    pub access_token: String,
40    pub refresh_token: String,
41}
42
43/// SHA-256 hash a raw token string for storage.
44pub fn hash_token(token: &str) -> String {
45    let mut hasher = Sha256::new();
46    hasher.update(token.as_bytes());
47    format!("{:x}", hasher.finalize())
48}
49
50/// Generate a raw refresh token that encodes its family UUID.
51///
52/// Format: `<family_hex>.<random_hex>` — the dot separator lets
53/// `extract_family` recover the family without a DB lookup.
54fn generate_refresh_token_for_family(family: Uuid) -> String {
55    let random = Uuid::new_v4();
56    format!("{}.{}", family.simple(), random.simple())
57}
58
59/// Extract the family UUID from a raw refresh token.
60///
61/// Returns `None` for legacy tokens that pre-date the family format.
62fn extract_family(raw_token: &str) -> Option<Uuid> {
63    let (family_hex, _) = raw_token.split_once('.')?;
64    Uuid::parse_str(family_hex).ok()
65}
66
67/// Issue a token pair: sign an access JWT and store a refresh token.
68///
69/// `issue_access_fn` is called to sign the access token (wraps `ctx.issue_token`).
70/// `client_id` binds the refresh token to an OAuth client (pass `None` for non-OAuth usage).
71pub async fn issue_token_pair(
72    pool: &sqlx::PgPool,
73    user_id: Uuid,
74    roles: &[&str],
75    access_token_ttl_secs: i64,
76    refresh_token_ttl_days: i64,
77    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
78) -> Result<TokenPair> {
79    issue_token_pair_with_client(
80        pool,
81        user_id,
82        roles,
83        access_token_ttl_secs,
84        refresh_token_ttl_days,
85        None,
86        issue_access_fn,
87    )
88    .await
89}
90
91/// Issue a token pair with optional OAuth client binding.
92///
93/// When `client_id` is `Some`, the refresh token is bound to that client
94/// and can only be rotated by presenting the same client_id.
95pub async fn issue_token_pair_with_client(
96    pool: &sqlx::PgPool,
97    user_id: Uuid,
98    roles: &[&str],
99    access_token_ttl_secs: i64,
100    refresh_token_ttl_days: i64,
101    client_id: Option<&str>,
102    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
103) -> Result<TokenPair> {
104    let family = Uuid::new_v4();
105    issue_token_in_family(
106        pool,
107        user_id,
108        roles,
109        access_token_ttl_secs,
110        refresh_token_ttl_days,
111        client_id,
112        family,
113        issue_access_fn,
114    )
115    .await
116}
117
118/// Internal: insert a new refresh token carrying an existing family ID.
119///
120/// Used both by `issue_token_pair_with_client` (new family) and
121/// `rotate_refresh_token_with_client` (carry family forward).
122#[allow(clippy::too_many_arguments)]
123async fn issue_token_in_family(
124    pool: &sqlx::PgPool,
125    user_id: Uuid,
126    roles: &[&str],
127    access_token_ttl_secs: i64,
128    refresh_token_ttl_days: i64,
129    client_id: Option<&str>,
130    family: Uuid,
131    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
132) -> Result<TokenPair> {
133    let access_token = issue_access_fn(user_id, roles, access_token_ttl_secs)?;
134
135    let refresh_raw = generate_refresh_token_for_family(family);
136    let refresh_hash = hash_token(&refresh_raw);
137    let expires_at = chrono::Utc::now() + chrono::Duration::days(refresh_token_ttl_days);
138
139    let roles_owned: Vec<String> = roles.iter().map(|s| s.to_string()).collect();
140    sqlx::query!(
141        "INSERT INTO forge_refresh_tokens (user_id, token_hash, client_id, expires_at, token_family, roles) \
142         VALUES ($1, $2, $3, $4, $5, $6)",
143        user_id,
144        &refresh_hash,
145        client_id,
146        expires_at,
147        family,
148        &roles_owned,
149    )
150    .execute(pool)
151    .await
152    .map_err(|e| ForgeError::internal_with("Failed to store refresh token", e))?;
153
154    Ok(TokenPair {
155        access_token,
156        refresh_token: refresh_raw,
157    })
158}
159
160/// Rotate a refresh token: validate expiry, delete the old one, issue a new pair.
161///
162/// Roles are carried forward from the old token row — the rotated token has
163/// the exact same role set as the original sign-in. New role grants only take
164/// effect at next sign-in.
165pub async fn rotate_refresh_token(
166    pool: &sqlx::PgPool,
167    old_refresh_token: &str,
168    access_token_ttl_secs: i64,
169    refresh_token_ttl_days: i64,
170    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
171) -> Result<TokenPair> {
172    rotate_refresh_token_with_client(
173        pool,
174        old_refresh_token,
175        access_token_ttl_secs,
176        refresh_token_ttl_days,
177        None,
178        issue_access_fn,
179    )
180    .await
181}
182
183/// Rotate a refresh token with OAuth client binding validation.
184///
185/// When `client_id` is `Some`, the token must be bound to that client.
186/// The new token is issued in the same family as the old one and inherits
187/// its roles.
188///
189/// ## Chain reuse detection
190///
191/// If the DELETE returns 0 rows the token is either invalid, expired, or
192/// already rotated. The family UUID is decoded from the raw token value
193/// (no extra DB read needed). If it parses, all tokens in that family are
194/// revoked immediately — the session is terminated for everyone holding a
195/// token in the chain, cutting off both the attacker and the legitimate user.
196pub async fn rotate_refresh_token_with_client(
197    pool: &sqlx::PgPool,
198    old_refresh_token: &str,
199    access_token_ttl_secs: i64,
200    refresh_token_ttl_days: i64,
201    client_id: Option<&str>,
202    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
203) -> Result<TokenPair> {
204    let hash = hash_token(old_refresh_token);
205
206    // Atomically delete the token if valid, returning the family + roles so
207    // the new token is issued in the same chain with the same role set.
208    //
209    // When client_id is provided, require exact match. When omitted, only
210    // allow rotation of tokens that were NOT bound to any client (prevents
211    // an attacker from bypassing client binding by omitting client_id).
212    struct TokenRow {
213        user_id: Uuid,
214        token_family: Uuid,
215        roles: Vec<String>,
216    }
217
218    let row = if let Some(cid) = client_id {
219        sqlx::query!(
220            "DELETE FROM forge_refresh_tokens \
221             WHERE token_hash = $1 AND expires_at > now() AND client_id = $2 \
222             RETURNING user_id, token_family, roles",
223            hash,
224            cid
225        )
226        .fetch_optional(pool)
227        .await
228        .map(|r| {
229            r.map(|r| TokenRow {
230                user_id: r.user_id,
231                token_family: r.token_family,
232                roles: r.roles,
233            })
234        })
235    } else {
236        sqlx::query!(
237            "DELETE FROM forge_refresh_tokens \
238             WHERE token_hash = $1 AND expires_at > now() AND client_id IS NULL \
239             RETURNING user_id, token_family, roles",
240            hash
241        )
242        .fetch_optional(pool)
243        .await
244        .map(|r| {
245            r.map(|r| TokenRow {
246                user_id: r.user_id,
247                token_family: r.token_family,
248                roles: r.roles,
249            })
250        })
251    }
252    .map_err(|e| ForgeError::internal_with("Failed to rotate refresh token", e))?;
253
254    match row {
255        Some(token) => {
256            let roles_refs: Vec<&str> = token.roles.iter().map(String::as_str).collect();
257            issue_token_in_family(
258                pool,
259                token.user_id,
260                &roles_refs,
261                access_token_ttl_secs,
262                refresh_token_ttl_days,
263                client_id,
264                token.token_family,
265                issue_access_fn,
266            )
267            .await
268        }
269        None => {
270            // Token not found. Decode the family from the raw token value —
271            // if the format matches, this is a previously-rotated token being
272            // replayed (reuse attack). Nuke the whole family to terminate the
273            // session for everyone, then return Unauthorized.
274            if let Some(family_id) = extract_family(old_refresh_token) {
275                let deleted = sqlx::query!(
276                    "DELETE FROM forge_refresh_tokens WHERE token_family = $1",
277                    family_id
278                )
279                .execute(pool)
280                .await
281                .map(|r| r.rows_affected())
282                .unwrap_or(0);
283
284                if deleted > 0 {
285                    tracing::warn!(
286                        %family_id,
287                        revoked = deleted,
288                        "Refresh token reuse detected — entire family revoked"
289                    );
290                }
291            }
292
293            Err(ForgeError::Unauthorized(
294                "Invalid or expired refresh token".into(),
295            ))
296        }
297    }
298}
299
300/// Revoke a specific refresh token.
301pub async fn revoke_refresh_token(pool: &sqlx::PgPool, refresh_token: &str) -> Result<()> {
302    let hash = hash_token(refresh_token);
303    sqlx::query!(
304        "DELETE FROM forge_refresh_tokens WHERE token_hash = $1",
305        &hash
306    )
307    .execute(pool)
308    .await
309    .map_err(|e| ForgeError::internal_with("Failed to revoke refresh token", e))?;
310    Ok(())
311}
312
313/// Revoke all refresh tokens for a user.
314pub async fn revoke_all_refresh_tokens(pool: &sqlx::PgPool, user_id: Uuid) -> Result<()> {
315    sqlx::query!(
316        "DELETE FROM forge_refresh_tokens WHERE user_id = $1",
317        user_id
318    )
319    .execute(pool)
320    .await
321    .map_err(|e| ForgeError::internal_with("Failed to revoke refresh tokens", e))?;
322    Ok(())
323}
324
325#[cfg(test)]
326#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_generate_refresh_token_for_family_encodes_family() {
332        let family = Uuid::new_v4();
333        let token = generate_refresh_token_for_family(family);
334
335        assert!(token.contains('.'), "token must contain the dot separator");
336        let recovered = extract_family(&token);
337        assert_eq!(recovered, Some(family));
338    }
339
340    #[test]
341    fn test_extract_family_returns_none_for_legacy_format() {
342        let legacy = format!("{}{}", Uuid::new_v4().simple(), Uuid::new_v4().simple());
343        assert_eq!(extract_family(&legacy), None);
344    }
345
346    #[test]
347    fn test_extract_family_returns_none_for_garbage() {
348        assert_eq!(extract_family("not-a-token"), None);
349        assert_eq!(extract_family(""), None);
350    }
351
352    #[test]
353    fn test_hash_token_is_deterministic() {
354        let token = "some-raw-token-value";
355        assert_eq!(hash_token(token), hash_token(token));
356    }
357
358    #[test]
359    fn test_hash_token_differs_for_different_inputs() {
360        assert_ne!(hash_token("token-a"), hash_token("token-b"));
361    }
362
363    #[test]
364    fn hash_token_returns_64_char_lowercase_hex() {
365        let hash = hash_token("anything");
366        assert_eq!(hash.len(), 64, "SHA-256 hex is exactly 64 chars");
367        assert!(
368            hash.chars()
369                .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()),
370            "expected lowercase hex digits, got {hash}"
371        );
372    }
373
374    #[test]
375    fn hash_token_matches_known_sha256_for_empty_string() {
376        // Pinning the hash for "" guards against accidental algorithm change
377        // (would break every refresh token after the swap).
378        let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
379        assert_eq!(hash_token(""), expected);
380    }
381
382    #[test]
383    fn generate_refresh_token_for_family_returns_unique_random_part() {
384        let family = Uuid::new_v4();
385        let a = generate_refresh_token_for_family(family);
386        let b = generate_refresh_token_for_family(family);
387        assert_ne!(a, b);
388        assert_eq!(extract_family(&a), Some(family));
389        assert_eq!(extract_family(&b), Some(family));
390    }
391
392    #[test]
393    fn generate_refresh_token_for_family_has_expected_shape() {
394        let family = Uuid::new_v4();
395        let token = generate_refresh_token_for_family(family);
396        let parts: Vec<&str> = token.split('.').collect();
397        assert_eq!(parts.len(), 2, "exactly one dot separator");
398        assert_eq!(parts[0].len(), 32);
399        assert_eq!(parts[1].len(), 32);
400        assert!(parts[0].chars().all(|c| c.is_ascii_hexdigit()));
401        assert!(parts[1].chars().all(|c| c.is_ascii_hexdigit()));
402    }
403
404    #[test]
405    fn extract_family_returns_none_when_prefix_is_not_a_uuid() {
406        // Has a dot but the prefix isn't a valid UUID — should not falsely
407        // identify it as a family-format token.
408        assert_eq!(extract_family("notauuid.suffix"), None);
409    }
410
411    #[test]
412    fn extract_family_returns_first_segment_uuid_for_multi_dot_tokens() {
413        // `split_once('.')` only splits on the first dot. As long as the first
414        // segment parses as a UUID, additional dots after it don't matter.
415        let family = Uuid::new_v4();
416        let weird = format!("{}.a.b.c", family.simple());
417        assert_eq!(extract_family(&weird), Some(family));
418    }
419
420    #[test]
421    fn token_pair_round_trips_through_json() {
422        let pair = TokenPair {
423            access_token: "header.payload.sig".into(),
424            refresh_token: "fam.rand".into(),
425        };
426        let s = serde_json::to_string(&pair).unwrap();
427        let back: TokenPair = serde_json::from_str(&s).unwrap();
428        assert_eq!(back.access_token, pair.access_token);
429        assert_eq!(back.refresh_token, pair.refresh_token);
430    }
431
432    #[test]
433    fn hash_token_is_independent_of_token_length() {
434        // Tiny and very long inputs both yield 64-char hashes — bounded output.
435        let huge = "x".repeat(10_000);
436        assert_eq!(hash_token(&huge).len(), 64);
437        assert_eq!(hash_token("a").len(), 64);
438    }
439}