1use rand::RngCore;
2use rand::rngs::OsRng;
3use serde::Serialize;
4use serde_json::Value;
5use sha2::{Digest, Sha256};
6use sqlx::{Pool, Postgres};
7use std::env;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10#[derive(Debug, Clone, sqlx::FromRow)]
11pub struct TokenRecord {
12 pub token: String,
13 pub payload: Value,
14 pub modified_at: i64,
15}
16
17#[derive(Debug, Clone)]
18pub struct TokenConfig {
19 pub ttl_seconds: i64,
20 pub renew_threshold_seconds: i64,
21}
22
23impl TokenConfig {
24 pub fn load() -> Self {
25 let ttl_seconds = env::var("TOKEN_TTL_SECONDS")
26 .ok()
27 .and_then(|v| v.parse::<i64>().ok())
28 .unwrap_or(300);
29 let renew_threshold_seconds = env::var("TOKEN_RENEW_THRESHOLD_SECONDS")
30 .ok()
31 .and_then(|v| v.parse::<i64>().ok())
32 .unwrap_or(30);
33 Self {
34 ttl_seconds,
35 renew_threshold_seconds,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
41pub struct TokenManager<'a> {
42 pool: &'a Pool<Postgres>,
43 config: TokenConfig,
44}
45
46#[derive(Debug)]
47pub enum TokenError {
48 NotFound,
49 Expired,
50 Database(sqlx::Error),
51}
52
53impl From<sqlx::Error> for TokenError {
54 fn from(err: sqlx::Error) -> Self {
55 TokenError::Database(err)
56 }
57}
58
59#[derive(Debug, Serialize)]
60pub struct TokenIssue {
61 pub token: String,
62 pub expires_at: i64,
63}
64
65#[derive(Debug)]
66pub struct TokenValidation {
67 pub record: TokenRecord,
68 pub renewed: bool,
69 pub expires_at: i64,
70}
71
72impl<'a> TokenManager<'a> {
73 pub fn new(pool: &'a Pool<Postgres>) -> Self {
74 let config = TokenConfig::load();
75 Self { pool, config }
76 }
77
78 pub fn ttl(&self) -> i64 {
79 self.config.ttl_seconds
80 }
81
82 fn now_epoch() -> i64 {
83 SystemTime::now()
84 .duration_since(UNIX_EPOCH)
85 .unwrap_or_default()
86 .as_secs() as i64
87 }
88
89 fn generate_token_value(secret: &str, now: i64) -> String {
90 let mut random = [0u8; 32];
91 OsRng.fill_bytes(&mut random);
92
93 let mut hasher = Sha256::new();
94 hasher.update(secret.as_bytes());
95 hasher.update(&random);
96 hasher.update(now.to_be_bytes());
97
98 let digest = hasher.finalize();
99 format!("{:x}", digest)
100 }
101
102 fn hash_token_value(token: &str) -> String {
103 let mut hasher = Sha256::new();
104 hasher.update(token.as_bytes());
105 format!("{:x}", hasher.finalize())
106 }
107
108 async fn insert_token(
109 &self,
110 token: &str,
111 payload: &Value,
112 modified_at: i64,
113 ) -> Result<(), sqlx::Error> {
114 let hashed = Self::hash_token_value(token);
115 sqlx::query("INSERT INTO auth.tokens_cache (token, payload, modified_at) VALUES ($1, $2, $3)")
116 .bind(hashed)
117 .bind(payload)
118 .bind(modified_at)
119 .execute(self.pool)
120 .await?;
121 Ok(())
122 }
123
124 async fn fetch_token(&self, token: &str) -> Result<Option<TokenRecord>, sqlx::Error> {
125 let hashed = Self::hash_token_value(token);
126 sqlx::query_as::<_, TokenRecord>(
127 "SELECT token, payload, modified_at FROM auth.tokens_cache WHERE token = $1",
128 )
129 .bind(hashed)
130 .fetch_optional(self.pool)
131 .await
132 }
133
134 async fn touch_token(
135 &self,
136 token: &str,
137 previous_modified_at: i64,
138 new_modified_at: i64,
139 ) -> Result<Option<TokenRecord>, sqlx::Error> {
140 let hashed = Self::hash_token_value(token);
141 let updated = sqlx::query_as::<_, TokenRecord>(
142 "UPDATE auth.tokens_cache SET modified_at = $1 WHERE token = $2 AND modified_at = $3 RETURNING token, payload, modified_at",
143 )
144 .bind(new_modified_at)
145 .bind(&hashed)
146 .bind(previous_modified_at)
147 .fetch_optional(self.pool)
148 .await?;
149 if updated.is_some() {
150 sqlx::query("UPDATE auth.permissions_cache SET modified_at = $1 WHERE token = $2")
151 .bind(new_modified_at)
152 .bind(&hashed)
153 .execute(self.pool)
154 .await?;
155 }
156 Ok(updated)
157 }
158
159 fn compute_expires_at(&self, modified_at: i64) -> i64 {
160 modified_at + self.config.ttl_seconds
161 }
162
163 pub async fn issue_token(&self, payload: Value) -> Result<TokenIssue, sqlx::Error> {
164 let now = Self::now_epoch();
165 let secret = env::var("JWT_SECRET").unwrap_or_else(|_| "local_secret".to_string());
166 let token = Self::generate_token_value(&secret, now);
167 self.insert_token(&token, &payload, now).await?;
168 Ok(TokenIssue {
169 token,
170 expires_at: self.compute_expires_at(now),
171 })
172 }
173
174 pub async fn delete_token(&self, token: &str) -> Result<bool, sqlx::Error> {
175 let hashed = Self::hash_token_value(token);
176 let rows = sqlx::query("DELETE FROM auth.tokens_cache WHERE token = $1")
177 .bind(hashed)
178 .execute(self.pool)
179 .await?
180 .rows_affected();
181 Ok(rows > 0)
182 }
183
184 pub async fn delete_tokens_for_user(&self, user_id: i32) -> Result<u64, sqlx::Error> {
185 let rows = sqlx::query("DELETE FROM auth.tokens_cache WHERE payload ->> 'user_id' = $1")
186 .bind(user_id.to_string())
187 .execute(self.pool)
188 .await?
189 .rows_affected();
190 Ok(rows)
191 }
192
193 pub async fn cleanup_expired(&self) -> Result<u64, sqlx::Error> {
194 let ttl = self.config.ttl_seconds.max(1);
195 let cutoff = Self::now_epoch() - ttl;
196 let rows = sqlx::query("DELETE FROM auth.tokens_cache WHERE modified_at < $1")
197 .bind(cutoff)
198 .execute(self.pool)
199 .await?
200 .rows_affected();
201 Ok(rows)
202 }
203
204 fn has_expired(&self, modified_at: i64, now: i64) -> bool {
205 now - modified_at > self.config.ttl_seconds
206 }
207
208 fn should_renew(&self, modified_at: i64, now: i64) -> bool {
209 if self.config.renew_threshold_seconds <= 0 {
210 return false;
211 }
212 let expires_at = self.compute_expires_at(modified_at);
213 expires_at - now <= self.config.renew_threshold_seconds
214 }
215
216 pub async fn validate_token(
217 &self,
218 token: &str,
219 renew_if_needed: bool,
220 ) -> Result<TokenValidation, TokenError> {
221 let mut record = match self.fetch_token(token).await? {
222 Some(rec) => rec,
223 None => return Err(TokenError::NotFound),
224 };
225 let now = Self::now_epoch();
226 if self.has_expired(record.modified_at, now) {
227 let _ = self.delete_token(token).await;
228 return Err(TokenError::Expired);
229 }
230
231 let mut renewed = false;
232 if renew_if_needed && self.should_renew(record.modified_at, now) {
233 match self.touch_token(token, record.modified_at, now).await? {
234 Some(updated) => {
235 record = updated;
236 renewed = true;
237 }
238 None => {
239 if let Some(updated) = self.fetch_token(token).await? {
240 if self.has_expired(updated.modified_at, now) {
241 let _ = self.delete_token(token).await;
242 return Err(TokenError::Expired);
243 }
244 record = updated;
245 } else {
246 return Err(TokenError::NotFound);
247 }
248 }
249 }
250 }
251
252 let expires_at = self.compute_expires_at(record.modified_at);
253
254 Ok(TokenValidation {
255 record,
256 renewed,
257 expires_at,
258 })
259 }
260}