1use sha2::{Digest, Sha256};
7use uuid::Uuid;
8
9use crate::error::{ForgeError, Result};
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
13pub struct TokenPair {
14 pub access_token: String,
15 pub refresh_token: String,
16}
17
18pub fn hash_token(token: &str) -> String {
20 let mut hasher = Sha256::new();
21 hasher.update(token.as_bytes());
22 format!("{:x}", hasher.finalize())
23}
24
25pub fn generate_refresh_token() -> String {
27 let a = Uuid::new_v4();
28 let b = Uuid::new_v4();
29 format!("{}{}", a.simple(), b.simple())
30}
31
32pub async fn issue_token_pair(
37 pool: &sqlx::PgPool,
38 user_id: Uuid,
39 roles: &[&str],
40 access_token_ttl_secs: i64,
41 refresh_token_ttl_days: i64,
42 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
43) -> Result<TokenPair> {
44 issue_token_pair_with_client(
45 pool,
46 user_id,
47 roles,
48 access_token_ttl_secs,
49 refresh_token_ttl_days,
50 None,
51 issue_access_fn,
52 )
53 .await
54}
55
56pub async fn issue_token_pair_with_client(
61 pool: &sqlx::PgPool,
62 user_id: Uuid,
63 roles: &[&str],
64 access_token_ttl_secs: i64,
65 refresh_token_ttl_days: i64,
66 client_id: Option<&str>,
67 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
68) -> Result<TokenPair> {
69 let access_token = issue_access_fn(user_id, roles, access_token_ttl_secs)?;
70
71 let refresh_raw = generate_refresh_token();
72 let refresh_hash = hash_token(&refresh_raw);
73 let expires_at = chrono::Utc::now() + chrono::Duration::days(refresh_token_ttl_days);
74
75 sqlx::query!(
76 "INSERT INTO forge_refresh_tokens (user_id, token_hash, client_id, expires_at) \
77 VALUES ($1, $2, $3, $4)",
78 user_id,
79 &refresh_hash,
80 client_id,
81 expires_at,
82 )
83 .execute(pool)
84 .await
85 .map_err(|e| ForgeError::Internal(format!("Failed to store refresh token: {e}")))?;
86
87 Ok(TokenPair {
88 access_token,
89 refresh_token: refresh_raw,
90 })
91}
92
93pub async fn rotate_refresh_token(
95 pool: &sqlx::PgPool,
96 old_refresh_token: &str,
97 roles: &[&str],
98 access_token_ttl_secs: i64,
99 refresh_token_ttl_days: i64,
100 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
101) -> Result<TokenPair> {
102 rotate_refresh_token_with_client(
103 pool,
104 old_refresh_token,
105 roles,
106 access_token_ttl_secs,
107 refresh_token_ttl_days,
108 None,
109 issue_access_fn,
110 )
111 .await
112}
113
114pub async fn rotate_refresh_token_with_client(
119 pool: &sqlx::PgPool,
120 old_refresh_token: &str,
121 roles: &[&str],
122 access_token_ttl_secs: i64,
123 refresh_token_ttl_days: i64,
124 client_id: Option<&str>,
125 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
126) -> Result<TokenPair> {
127 let hash = hash_token(old_refresh_token);
128
129 let row = if let Some(cid) = client_id {
131 sqlx::query_scalar!(
132 "DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now() AND client_id = $2 RETURNING user_id",
133 hash,
134 cid
135 )
136 .fetch_optional(pool)
137 .await
138 } else {
139 sqlx::query_scalar!(
140 "DELETE FROM forge_refresh_tokens WHERE token_hash = $1 AND expires_at > now()\n RETURNING user_id",
141 hash
142 )
143 .fetch_optional(pool)
144 .await
145 }
146 .map_err(|e| ForgeError::Internal(format!("Failed to rotate refresh token: {e}")))?;
147
148 let user_id = match row {
149 Some(r) => r,
150 None => {
151 return Err(ForgeError::Unauthorized(
152 "Invalid or expired refresh token".into(),
153 ));
154 }
155 };
156
157 issue_token_pair_with_client(
158 pool,
159 user_id,
160 roles,
161 access_token_ttl_secs,
162 refresh_token_ttl_days,
163 client_id,
164 issue_access_fn,
165 )
166 .await
167}
168
169pub async fn revoke_refresh_token(pool: &sqlx::PgPool, refresh_token: &str) -> Result<()> {
171 let hash = hash_token(refresh_token);
172 sqlx::query!(
173 "DELETE FROM forge_refresh_tokens WHERE token_hash = $1",
174 &hash
175 )
176 .execute(pool)
177 .await
178 .map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh token: {e}")))?;
179 Ok(())
180}
181
182pub async fn revoke_all_refresh_tokens(pool: &sqlx::PgPool, user_id: Uuid) -> Result<()> {
184 sqlx::query!(
185 "DELETE FROM forge_refresh_tokens WHERE user_id = $1",
186 user_id
187 )
188 .execute(pool)
189 .await
190 .map_err(|e| ForgeError::Internal(format!("Failed to revoke refresh tokens: {e}")))?;
191 Ok(())
192}