use sha2::{Digest, Sha256};
use uuid::Uuid;
use crate::error::{ForgeError, Result};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
}
pub fn hash_token(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
format!("{:x}", hasher.finalize())
}
pub fn generate_refresh_token() -> String {
let a = Uuid::new_v4();
let b = Uuid::new_v4();
format!("{}{}", a.simple(), b.simple())
}
pub async fn issue_token_pair(
pool: &sqlx::PgPool,
user_id: Uuid,
roles: &[&str],
access_token_ttl_secs: i64,
refresh_token_ttl_days: i64,
issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
) -> Result<TokenPair> {
issue_token_pair_with_client(
pool,
user_id,
roles,
access_token_ttl_secs,
refresh_token_ttl_days,
None,
issue_access_fn,
)
.await
}
pub async fn issue_token_pair_with_client(
pool: &sqlx::PgPool,
user_id: Uuid,
roles: &[&str],
access_token_ttl_secs: i64,
refresh_token_ttl_days: i64,
client_id: Option<&str>,
issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
) -> Result<TokenPair> {
let access_token = issue_access_fn(user_id, roles, access_token_ttl_secs)?;
let refresh_raw = generate_refresh_token();
let refresh_hash = hash_token(&refresh_raw);
let expires_at = chrono::Utc::now() + chrono::Duration::days(refresh_token_ttl_days);
sqlx::query!(
"INSERT INTO forge_refresh_tokens (user_id, token_hash, client_id, expires_at) \
VALUES ($1, $2, $3, $4)",
user_id,
&refresh_hash,
client_id,
expires_at,
)
.execute(pool)
.await
.map_err(|e| ForgeError::Internal(format!("Failed to store refresh token: {e}")))?;
Ok(TokenPair {
access_token,
refresh_token: refresh_raw,
})
}
pub async fn rotate_refresh_token(
pool: &sqlx::PgPool,
old_refresh_token: &str,
roles: &[&str],
access_token_ttl_secs: i64,
refresh_token_ttl_days: i64,
issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
) -> Result<TokenPair> {
rotate_refresh_token_with_client(
pool,
old_refresh_token,
roles,
access_token_ttl_secs,
refresh_token_ttl_days,
None,
issue_access_fn,
)
.await
}
pub async fn rotate_refresh_token_with_client(
pool: &sqlx::PgPool,
old_refresh_token: &str,
roles: &[&str],
access_token_ttl_secs: i64,
refresh_token_ttl_days: i64,
client_id: Option<&str>,
issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
) -> Result<TokenPair> {
let hash = hash_token(old_refresh_token);
let row = if let Some(cid) = client_id {
sqlx::query_scalar!(
"DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now() AND client_id = $2 RETURNING user_id",
hash,
cid
)
.fetch_optional(pool)
.await
} else {
sqlx::query_scalar!(
"DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now() AND client_id IS NULL RETURNING user_id",
hash
)
.fetch_optional(pool)
.await
}
.map_err(|e| ForgeError::Internal(format!("Failed to rotate refresh token: {e}")))?;
let user_id = match row {
Some(r) => r,
None => {
return Err(ForgeError::Unauthorized(
"Invalid or expired refresh token".into(),
));
}
};
issue_token_pair_with_client(
pool,
user_id,
roles,
access_token_ttl_secs,
refresh_token_ttl_days,
client_id,
issue_access_fn,
)
.await
}
pub async fn revoke_refresh_token(pool: &sqlx::PgPool, refresh_token: &str) -> Result<()> {
let hash = hash_token(refresh_token);
sqlx::query!(
"DELETE FROM forge_refresh_tokens WHERE token_hash = $1",
&hash
)
.execute(pool)
.await
.map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh token: {e}")))?;
Ok(())
}
pub async fn revoke_all_refresh_tokens(pool: &sqlx::PgPool, user_id: Uuid) -> Result<()> {
sqlx::query!(
"DELETE FROM forge_refresh_tokens WHERE user_id = $1",
user_id
)
.execute(pool)
.await
.map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh tokens: {e}")))?;
Ok(())
}