1use sha2::{Digest, Sha256};
32use uuid::Uuid;
33
34use crate::error::{ForgeError, Result};
35
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub struct TokenPair {
39 pub access_token: String,
40 pub refresh_token: String,
41}
42
43pub fn hash_token(token: &str) -> String {
45 let mut hasher = Sha256::new();
46 hasher.update(token.as_bytes());
47 format!("{:x}", hasher.finalize())
48}
49
50fn generate_refresh_token_for_family(family: Uuid) -> String {
55 let random = Uuid::new_v4();
56 format!("{}.{}", family.simple(), random.simple())
57}
58
59fn extract_family(raw_token: &str) -> Option<Uuid> {
63 let (family_hex, _) = raw_token.split_once('.')?;
64 Uuid::parse_str(family_hex).ok()
65}
66
67pub async fn issue_token_pair(
72 pool: &sqlx::PgPool,
73 user_id: Uuid,
74 roles: &[&str],
75 access_token_ttl_secs: i64,
76 refresh_token_ttl_days: i64,
77 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
78) -> Result<TokenPair> {
79 issue_token_pair_with_client(
80 pool,
81 user_id,
82 roles,
83 access_token_ttl_secs,
84 refresh_token_ttl_days,
85 None,
86 issue_access_fn,
87 )
88 .await
89}
90
91pub async fn issue_token_pair_with_client(
96 pool: &sqlx::PgPool,
97 user_id: Uuid,
98 roles: &[&str],
99 access_token_ttl_secs: i64,
100 refresh_token_ttl_days: i64,
101 client_id: Option<&str>,
102 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
103) -> Result<TokenPair> {
104 let family = Uuid::new_v4();
105 issue_token_in_family(
106 pool,
107 user_id,
108 roles,
109 access_token_ttl_secs,
110 refresh_token_ttl_days,
111 client_id,
112 family,
113 issue_access_fn,
114 )
115 .await
116}
117
118#[allow(clippy::too_many_arguments)]
123async fn issue_token_in_family(
124 pool: &sqlx::PgPool,
125 user_id: Uuid,
126 roles: &[&str],
127 access_token_ttl_secs: i64,
128 refresh_token_ttl_days: i64,
129 client_id: Option<&str>,
130 family: Uuid,
131 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
132) -> Result<TokenPair> {
133 let access_token = issue_access_fn(user_id, roles, access_token_ttl_secs)?;
134
135 let refresh_raw = generate_refresh_token_for_family(family);
136 let refresh_hash = hash_token(&refresh_raw);
137 let expires_at = chrono::Utc::now() + chrono::Duration::days(refresh_token_ttl_days);
138
139 let roles_owned: Vec<String> = roles.iter().map(|s| s.to_string()).collect();
140 sqlx::query!(
141 "INSERT INTO forge_refresh_tokens (user_id, token_hash, client_id, expires_at, token_family, roles) \
142 VALUES ($1, $2, $3, $4, $5, $6)",
143 user_id,
144 &refresh_hash,
145 client_id,
146 expires_at,
147 family,
148 &roles_owned,
149 )
150 .execute(pool)
151 .await
152 .map_err(|e| ForgeError::internal_with("Failed to store refresh token", e))?;
153
154 Ok(TokenPair {
155 access_token,
156 refresh_token: refresh_raw,
157 })
158}
159
160pub async fn rotate_refresh_token(
166 pool: &sqlx::PgPool,
167 old_refresh_token: &str,
168 access_token_ttl_secs: i64,
169 refresh_token_ttl_days: i64,
170 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
171) -> Result<TokenPair> {
172 rotate_refresh_token_with_client(
173 pool,
174 old_refresh_token,
175 access_token_ttl_secs,
176 refresh_token_ttl_days,
177 None,
178 issue_access_fn,
179 )
180 .await
181}
182
183pub async fn rotate_refresh_token_with_client(
197 pool: &sqlx::PgPool,
198 old_refresh_token: &str,
199 access_token_ttl_secs: i64,
200 refresh_token_ttl_days: i64,
201 client_id: Option<&str>,
202 issue_access_fn: impl FnOnce(Uuid, &[&str], i64) -> Result<String>,
203) -> Result<TokenPair> {
204 let hash = hash_token(old_refresh_token);
205
206 struct TokenRow {
213 user_id: Uuid,
214 token_family: Uuid,
215 roles: Vec<String>,
216 }
217
218 let row = if let Some(cid) = client_id {
219 sqlx::query!(
220 "DELETE FROM forge_refresh_tokens \
221 WHERE token_hash = $1 AND expires_at > now() AND client_id = $2 \
222 RETURNING user_id, token_family, roles",
223 hash,
224 cid
225 )
226 .fetch_optional(pool)
227 .await
228 .map(|r| {
229 r.map(|r| TokenRow {
230 user_id: r.user_id,
231 token_family: r.token_family,
232 roles: r.roles,
233 })
234 })
235 } else {
236 sqlx::query!(
237 "DELETE FROM forge_refresh_tokens \
238 WHERE token_hash = $1 AND expires_at > now() AND client_id IS NULL \
239 RETURNING user_id, token_family, roles",
240 hash
241 )
242 .fetch_optional(pool)
243 .await
244 .map(|r| {
245 r.map(|r| TokenRow {
246 user_id: r.user_id,
247 token_family: r.token_family,
248 roles: r.roles,
249 })
250 })
251 }
252 .map_err(|e| ForgeError::internal_with("Failed to rotate refresh token", e))?;
253
254 match row {
255 Some(token) => {
256 let roles_refs: Vec<&str> = token.roles.iter().map(String::as_str).collect();
257 issue_token_in_family(
258 pool,
259 token.user_id,
260 &roles_refs,
261 access_token_ttl_secs,
262 refresh_token_ttl_days,
263 client_id,
264 token.token_family,
265 issue_access_fn,
266 )
267 .await
268 }
269 None => {
270 if let Some(family_id) = extract_family(old_refresh_token) {
275 let deleted = sqlx::query!(
276 "DELETE FROM forge_refresh_tokens WHERE token_family = $1",
277 family_id
278 )
279 .execute(pool)
280 .await
281 .map(|r| r.rows_affected())
282 .unwrap_or(0);
283
284 if deleted > 0 {
285 tracing::warn!(
286 %family_id,
287 revoked = deleted,
288 "Refresh token reuse detected — entire family revoked"
289 );
290 }
291 }
292
293 Err(ForgeError::Unauthorized(
294 "Invalid or expired refresh token".into(),
295 ))
296 }
297 }
298}
299
300pub async fn revoke_refresh_token(pool: &sqlx::PgPool, refresh_token: &str) -> Result<()> {
302 let hash = hash_token(refresh_token);
303 sqlx::query!(
304 "DELETE FROM forge_refresh_tokens WHERE token_hash = $1",
305 &hash
306 )
307 .execute(pool)
308 .await
309 .map_err(|e| ForgeError::internal_with("Failed to revoke refresh token", e))?;
310 Ok(())
311}
312
313pub async fn revoke_all_refresh_tokens(pool: &sqlx::PgPool, user_id: Uuid) -> Result<()> {
315 sqlx::query!(
316 "DELETE FROM forge_refresh_tokens WHERE user_id = $1",
317 user_id
318 )
319 .execute(pool)
320 .await
321 .map_err(|e| ForgeError::internal_with("Failed to revoke refresh tokens", e))?;
322 Ok(())
323}
324
325#[cfg(test)]
326#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_generate_refresh_token_for_family_encodes_family() {
332 let family = Uuid::new_v4();
333 let token = generate_refresh_token_for_family(family);
334
335 assert!(token.contains('.'), "token must contain the dot separator");
336 let recovered = extract_family(&token);
337 assert_eq!(recovered, Some(family));
338 }
339
340 #[test]
341 fn test_extract_family_returns_none_for_legacy_format() {
342 let legacy = format!("{}{}", Uuid::new_v4().simple(), Uuid::new_v4().simple());
343 assert_eq!(extract_family(&legacy), None);
344 }
345
346 #[test]
347 fn test_extract_family_returns_none_for_garbage() {
348 assert_eq!(extract_family("not-a-token"), None);
349 assert_eq!(extract_family(""), None);
350 }
351
352 #[test]
353 fn test_hash_token_is_deterministic() {
354 let token = "some-raw-token-value";
355 assert_eq!(hash_token(token), hash_token(token));
356 }
357
358 #[test]
359 fn test_hash_token_differs_for_different_inputs() {
360 assert_ne!(hash_token("token-a"), hash_token("token-b"));
361 }
362
363 #[test]
364 fn hash_token_returns_64_char_lowercase_hex() {
365 let hash = hash_token("anything");
366 assert_eq!(hash.len(), 64, "SHA-256 hex is exactly 64 chars");
367 assert!(
368 hash.chars()
369 .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase()),
370 "expected lowercase hex digits, got {hash}"
371 );
372 }
373
374 #[test]
375 fn hash_token_matches_known_sha256_for_empty_string() {
376 let expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
379 assert_eq!(hash_token(""), expected);
380 }
381
382 #[test]
383 fn generate_refresh_token_for_family_returns_unique_random_part() {
384 let family = Uuid::new_v4();
385 let a = generate_refresh_token_for_family(family);
386 let b = generate_refresh_token_for_family(family);
387 assert_ne!(a, b);
388 assert_eq!(extract_family(&a), Some(family));
389 assert_eq!(extract_family(&b), Some(family));
390 }
391
392 #[test]
393 fn generate_refresh_token_for_family_has_expected_shape() {
394 let family = Uuid::new_v4();
395 let token = generate_refresh_token_for_family(family);
396 let parts: Vec<&str> = token.split('.').collect();
397 assert_eq!(parts.len(), 2, "exactly one dot separator");
398 assert_eq!(parts[0].len(), 32);
399 assert_eq!(parts[1].len(), 32);
400 assert!(parts[0].chars().all(|c| c.is_ascii_hexdigit()));
401 assert!(parts[1].chars().all(|c| c.is_ascii_hexdigit()));
402 }
403
404 #[test]
405 fn extract_family_returns_none_when_prefix_is_not_a_uuid() {
406 assert_eq!(extract_family("notauuid.suffix"), None);
409 }
410
411 #[test]
412 fn extract_family_returns_first_segment_uuid_for_multi_dot_tokens() {
413 let family = Uuid::new_v4();
416 let weird = format!("{}.a.b.c", family.simple());
417 assert_eq!(extract_family(&weird), Some(family));
418 }
419
420 #[test]
421 fn token_pair_round_trips_through_json() {
422 let pair = TokenPair {
423 access_token: "header.payload.sig".into(),
424 refresh_token: "fam.rand".into(),
425 };
426 let s = serde_json::to_string(&pair).unwrap();
427 let back: TokenPair = serde_json::from_str(&s).unwrap();
428 assert_eq!(back.access_token, pair.access_token);
429 assert_eq!(back.refresh_token, pair.refresh_token);
430 }
431
432 #[test]
433 fn hash_token_is_independent_of_token_length() {
434 let huge = "x".repeat(10_000);
436 assert_eq!(hash_token(&huge).len(), 64);
437 assert_eq!(hash_token("a").len(), 64);
438 }
439}