Skip to main content

forge_core/auth/
tokens.rs

1//! Refresh token management.
2//!
3//! Provides token pair issuance (access + refresh), rotation, and revocation.
4//! Refresh tokens are stored as SHA-256 hashes in `forge_refresh_tokens`.
5
6use sha2::{Digest, Sha256};
7use uuid::Uuid;
8
9use crate::error::{ForgeError, Result};
10
11/// An access token + refresh token pair.
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct TokenPair {
14    pub access_token: String,
15    pub refresh_token: String,
16}
17
18/// SHA-256 hash a raw token string for storage.
19pub fn hash_token(token: &str) -> String {
20    let mut hasher = Sha256::new();
21    hasher.update(token.as_bytes());
22    format!("{:x}", hasher.finalize())
23}
24
25/// Generate a cryptographically random refresh token string.
26pub fn generate_refresh_token() -> String {
27    let a = Uuid::new_v4();
28    let b = Uuid::new_v4();
29    format!("{}{}", a.simple(), b.simple())
30}
31
32/// Issue a token pair: sign an access JWT and store a refresh token.
33///
34/// `issue_access_fn` is called to sign the access token (wraps `ctx.issue_token`).
35/// `client_id` binds the refresh token to an OAuth client (pass `None` for non-OAuth usage).
36pub async fn issue_token_pair(
37    pool: &sqlx::PgPool,
38    user_id: Uuid,
39    roles: &[&str],
40    access_token_ttl_secs: i64,
41    refresh_token_ttl_days: i64,
42    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
43) -> Result<TokenPair> {
44    issue_token_pair_with_client(
45        pool,
46        user_id,
47        roles,
48        access_token_ttl_secs,
49        refresh_token_ttl_days,
50        None,
51        issue_access_fn,
52    )
53    .await
54}
55
56/// Issue a token pair with optional OAuth client binding.
57///
58/// When `client_id` is `Some`, the refresh token is bound to that client
59/// and can only be rotated by presenting the same client_id.
60pub async fn issue_token_pair_with_client(
61    pool: &sqlx::PgPool,
62    user_id: Uuid,
63    roles: &[&str],
64    access_token_ttl_secs: i64,
65    refresh_token_ttl_days: i64,
66    client_id: Option<&str>,
67    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
68) -> Result<TokenPair> {
69    let access_token = issue_access_fn(user_id, roles, access_token_ttl_secs)?;
70
71    let refresh_raw = generate_refresh_token();
72    let refresh_hash = hash_token(&refresh_raw);
73    let expires_at = chrono::Utc::now() + chrono::Duration::days(refresh_token_ttl_days);
74
75    sqlx::query!(
76        "INSERT INTO forge_refresh_tokens (user_id, token_hash, client_id, expires_at) \
77         VALUES ($1, $2, $3, $4)",
78        user_id,
79        &refresh_hash,
80        client_id,
81        expires_at,
82    )
83    .execute(pool)
84    .await
85    .map_err(|e| ForgeError::Internal(format!("Failed to store refresh token: {e}")))?;
86
87    Ok(TokenPair {
88        access_token,
89        refresh_token: refresh_raw,
90    })
91}
92
93/// Rotate a refresh token: validate expiry, delete the old one, issue a new pair.
94pub async fn rotate_refresh_token(
95    pool: &sqlx::PgPool,
96    old_refresh_token: &str,
97    roles: &[&str],
98    access_token_ttl_secs: i64,
99    refresh_token_ttl_days: i64,
100    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
101) -> Result<TokenPair> {
102    rotate_refresh_token_with_client(
103        pool,
104        old_refresh_token,
105        roles,
106        access_token_ttl_secs,
107        refresh_token_ttl_days,
108        None,
109        issue_access_fn,
110    )
111    .await
112}
113
114/// Rotate a refresh token with OAuth client binding validation.
115///
116/// When `client_id` is `Some`, the token must be bound to that client.
117/// The new token is also bound to the same client.
118pub async fn rotate_refresh_token_with_client(
119    pool: &sqlx::PgPool,
120    old_refresh_token: &str,
121    roles: &[&str],
122    access_token_ttl_secs: i64,
123    refresh_token_ttl_days: i64,
124    client_id: Option<&str>,
125    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
126) -> Result<TokenPair> {
127    let hash = hash_token(old_refresh_token);
128
129    // Atomically delete only non-expired tokens, matching client_id binding.
130    // When client_id is provided, require exact match. When omitted, only
131    // allow rotation of tokens that were NOT bound to any client (prevents
132    // an attacker from bypassing client binding by omitting client_id).
133    let row = if let Some(cid) = client_id {
134        sqlx::query_scalar!(
135            "DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now() AND client_id = $2 RETURNING user_id",
136            hash,
137            cid
138        )
139        .fetch_optional(pool)
140        .await
141    } else {
142        sqlx::query_scalar!(
143            "DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now() AND client_id IS NULL RETURNING user_id",
144            hash
145        )
146        .fetch_optional(pool)
147        .await
148    }
149    .map_err(|e| ForgeError::Internal(format!("Failed to rotate refresh token: {e}")))?;
150
151    let user_id = match row {
152        Some(r) => r,
153        None => {
154            return Err(ForgeError::Unauthorized(
155                "Invalid or expired refresh token".into(),
156            ));
157        }
158    };
159
160    issue_token_pair_with_client(
161        pool,
162        user_id,
163        roles,
164        access_token_ttl_secs,
165        refresh_token_ttl_days,
166        client_id,
167        issue_access_fn,
168    )
169    .await
170}
171
172/// Revoke a specific refresh token.
173pub async fn revoke_refresh_token(pool: &sqlx::PgPool, refresh_token: &str) -> Result<()> {
174    let hash = hash_token(refresh_token);
175    sqlx::query!(
176        "DELETE FROM forge_refresh_tokens WHERE token_hash = $1",
177        &hash
178    )
179    .execute(pool)
180    .await
181    .map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh token: {e}")))?;
182    Ok(())
183}
184
185/// Revoke all refresh tokens for a user.
186pub async fn revoke_all_refresh_tokens(pool: &sqlx::PgPool, user_id: Uuid) -> Result<()> {
187    sqlx::query!(
188        "DELETE FROM forge_refresh_tokens WHERE user_id = $1",
189        user_id
190    )
191    .execute(pool)
192    .await
193    .map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh tokens: {e}")))?;
194    Ok(())
195}