Skip to main content

forge_core/auth/
tokens.rs

1//! Refresh token management.
2//!
3//! Provides token pair issuance (access + refresh), rotation, and revocation.
4//! Refresh tokens are stored as SHA-256 hashes in `forge_refresh_tokens`.
5
6use sha2::{Digest, Sha256};
7use uuid::Uuid;
8
9use crate::error::{ForgeError, Result};
10
11/// An access token + refresh token pair.
12#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct TokenPair {
14    pub access_token: String,
15    pub refresh_token: String,
16}
17
18/// SHA-256 hash a raw token string for storage.
19pub fn hash_token(token: &str) -> String {
20    let mut hasher = Sha256::new();
21    hasher.update(token.as_bytes());
22    format!("{:x}", hasher.finalize())
23}
24
25/// Generate a cryptographically random refresh token string.
26pub fn generate_refresh_token() -> String {
27    let a = Uuid::new_v4();
28    let b = Uuid::new_v4();
29    format!("{}{}", a.simple(), b.simple())
30}
31
32/// Issue a token pair: sign an access JWT and store a refresh token.
33///
34/// `issue_access_fn` is called to sign the access token (wraps `ctx.issue_token`).
35/// `client_id` binds the refresh token to an OAuth client (pass `None` for non-OAuth usage).
36pub async fn issue_token_pair(
37    pool: &sqlx::PgPool,
38    user_id: Uuid,
39    roles: &[&str],
40    access_token_ttl_secs: i64,
41    refresh_token_ttl_days: i64,
42    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
43) -> Result<TokenPair> {
44    issue_token_pair_with_client(
45        pool,
46        user_id,
47        roles,
48        access_token_ttl_secs,
49        refresh_token_ttl_days,
50        None,
51        issue_access_fn,
52    )
53    .await
54}
55
56/// Issue a token pair with optional OAuth client binding.
57///
58/// When `client_id` is `Some`, the refresh token is bound to that client
59/// and can only be rotated by presenting the same client_id.
60pub async fn issue_token_pair_with_client(
61    pool: &sqlx::PgPool,
62    user_id: Uuid,
63    roles: &[&str],
64    access_token_ttl_secs: i64,
65    refresh_token_ttl_days: i64,
66    client_id: Option<&str>,
67    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
68) -> Result<TokenPair> {
69    let access_token = issue_access_fn(user_id, roles, access_token_ttl_secs)?;
70
71    let refresh_raw = generate_refresh_token();
72    let refresh_hash = hash_token(&refresh_raw);
73    let expires_at = chrono::Utc::now() + chrono::Duration::days(refresh_token_ttl_days);
74
75    sqlx::query(
76        "INSERT INTO forge_refresh_tokens (user_id, token_hash, client_id, expires_at) \
77         VALUES ($1, $2, $3, $4)",
78    )
79    .bind(user_id)
80    .bind(&refresh_hash)
81    .bind(client_id)
82    .bind(expires_at)
83    .execute(pool)
84    .await
85    .map_err(|e| ForgeError::Internal(format!("Failed to store refresh token: {e}")))?;
86
87    Ok(TokenPair {
88        access_token,
89        refresh_token: refresh_raw,
90    })
91}
92
93/// Rotate a refresh token: validate expiry, delete the old one, issue a new pair.
94pub async fn rotate_refresh_token(
95    pool: &sqlx::PgPool,
96    old_refresh_token: &str,
97    roles: &[&str],
98    access_token_ttl_secs: i64,
99    refresh_token_ttl_days: i64,
100    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
101) -> Result<TokenPair> {
102    rotate_refresh_token_with_client(
103        pool,
104        old_refresh_token,
105        roles,
106        access_token_ttl_secs,
107        refresh_token_ttl_days,
108        None,
109        issue_access_fn,
110    )
111    .await
112}
113
114/// Rotate a refresh token with OAuth client binding validation.
115///
116/// When `client_id` is `Some`, the token must be bound to that client.
117/// The new token is also bound to the same client.
118pub async fn rotate_refresh_token_with_client(
119    pool: &sqlx::PgPool,
120    old_refresh_token: &str,
121    roles: &[&str],
122    access_token_ttl_secs: i64,
123    refresh_token_ttl_days: i64,
124    client_id: Option<&str>,
125    issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
126) -> Result<TokenPair> {
127    let hash = hash_token(old_refresh_token);
128
129    // Atomically delete only non-expired tokens, optionally matching client_id.
130    let row = if let Some(cid) = client_id {
131        sqlx::query_scalar!(
132            "DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now() AND client_id = $2 RETURNING user_id",
133            hash,
134            cid
135        )
136        .fetch_optional(pool)
137        .await
138    } else {
139        sqlx::query_scalar!(
140            "DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now()\n         RETURNING user_id",
141            hash
142        )
143        .fetch_optional(pool)
144        .await
145    }
146    .map_err(|e| ForgeError::Internal(format!("Failed to rotate refresh token: {e}")))?;
147
148    let user_id = match row {
149        Some(r) => r,
150        None => {
151            return Err(ForgeError::Unauthorized(
152                "Invalid or expired refresh token".into(),
153            ));
154        }
155    };
156
157    issue_token_pair_with_client(
158        pool,
159        user_id,
160        roles,
161        access_token_ttl_secs,
162        refresh_token_ttl_days,
163        client_id,
164        issue_access_fn,
165    )
166    .await
167}
168
169/// Revoke a specific refresh token.
170pub async fn revoke_refresh_token(pool: &sqlx::PgPool, refresh_token: &str) -> Result<()> {
171    let hash = hash_token(refresh_token);
172    sqlx::query("DELETE FROM forge_refresh_tokens WHERE token_hash = $1")
173        .bind(&hash)
174        .execute(pool)
175        .await
176        .map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh token: {e}")))?;
177    Ok(())
178}
179
180/// Revoke all refresh tokens for a user.
181pub async fn revoke_all_refresh_tokens(pool: &sqlx::PgPool, user_id: Uuid) -> Result<()> {
182    sqlx::query("DELETE FROM forge_refresh_tokens WHERE user_id = $1")
183        .bind(user_id)
184        .execute(pool)
185        .await
186        .map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh tokens: {e}")))?;
187    Ok(())
188}