1use async_trait::async_trait;
3use sqlx::{Row, postgres::PgPool};
4
5use crate::auth::{
6 error::{AuthError, Result},
7 session::{SessionData, SessionStore, TokenPair, generate_refresh_token, hash_token},
8};
9
10pub struct PostgresSessionStore {
12 db: PgPool,
13 signing_key: Option<Vec<u8>>,
15}
16
17impl PostgresSessionStore {
18 pub fn new(db: PgPool) -> Self {
23 Self {
24 db,
25 signing_key: None,
26 }
27 }
28
29 pub fn with_rs256_key(db: PgPool, private_key_pem: Vec<u8>) -> Self {
35 Self {
36 db,
37 signing_key: Some(private_key_pem),
38 }
39 }
40
41 pub async fn init(&self) -> Result<()> {
48 sqlx::query(
49 r"
50 CREATE TABLE IF NOT EXISTS _system.sessions (
51 id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
52 user_id TEXT NOT NULL,
53 refresh_token_hash TEXT NOT NULL UNIQUE,
54 issued_at BIGINT NOT NULL,
55 expires_at BIGINT NOT NULL,
56 created_at TIMESTAMPTZ DEFAULT NOW(),
57 revoked_at TIMESTAMPTZ
58 );
59
60 CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON _system.sessions(user_id);
61 CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON _system.sessions(expires_at);
62 CREATE INDEX IF NOT EXISTS idx_sessions_revoked_at ON _system.sessions(revoked_at);
63 ",
64 )
65 .execute(&self.db)
66 .await
67 .map_err(|e| AuthError::DatabaseError {
68 message: format!("Failed to initialize sessions table: {}", e),
69 })?;
70
71 Ok(())
72 }
73
74 fn generate_access_token(&self, user_id: &str, expires_in: u64) -> Result<String> {
79 let now = std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs();
83
84 let exp = now + expires_in;
85
86 let mut claims = crate::auth::Claims {
87 sub: user_id.to_string(),
88 iat: now,
89 exp,
90 iss: "fraiseql".to_string(),
91 aud: vec!["fraiseql-api".to_string()],
92 extra: std::collections::HashMap::new(),
93 };
94
95 claims
97 .extra
98 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
99
100 match &self.signing_key {
101 Some(private_key) => crate::auth::jwt::generate_rs256_token(&claims, private_key),
102 None => {
103 let secret = format!("fraiseql_session_{}", user_id).into_bytes();
105 crate::auth::jwt::generate_hs256_token(&claims, &secret)
106 },
107 }
108 }
109}
110
111#[async_trait]
112impl SessionStore for PostgresSessionStore {
113 async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
114 let refresh_token = generate_refresh_token();
115 let refresh_token_hash = hash_token(&refresh_token);
116
117 let now = std::time::SystemTime::now()
118 .duration_since(std::time::UNIX_EPOCH)
119 .unwrap_or_default()
120 .as_secs();
121
122 sqlx::query(
123 r"
124 INSERT INTO _system.sessions
125 (user_id, refresh_token_hash, issued_at, expires_at)
126 VALUES ($1, $2, $3, $4)
127 ",
128 )
129 .bind(user_id)
130 .bind(&refresh_token_hash)
131 .bind(now as i64)
132 .bind(expires_at as i64)
133 .execute(&self.db)
134 .await
135 .map_err(|e| {
136 if e.to_string().contains("duplicate key") {
137 AuthError::SessionError {
138 message: "Refresh token already exists".to_string(),
139 }
140 } else {
141 AuthError::DatabaseError {
142 message: format!("Failed to create session: {}", e),
143 }
144 }
145 })?;
146
147 let expires_in = expires_at.saturating_sub(now);
148 let access_token = self.generate_access_token(user_id, expires_in)?;
149
150 Ok(TokenPair {
151 access_token,
152 refresh_token,
153 expires_in,
154 })
155 }
156
157 async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
158 let row = sqlx::query(
159 r"
160 SELECT user_id, issued_at, expires_at, refresh_token_hash
161 FROM _system.sessions
162 WHERE refresh_token_hash = $1 AND revoked_at IS NULL
163 ",
164 )
165 .bind(refresh_token_hash)
166 .fetch_optional(&self.db)
167 .await
168 .map_err(|e| AuthError::DatabaseError {
169 message: format!("Failed to get session: {}", e),
170 })?
171 .ok_or(AuthError::TokenNotFound)?;
172
173 let user_id: String = row.get("user_id");
174 let issued_at: i64 = row.get("issued_at");
175 let expires_at: i64 = row.get("expires_at");
176 let refresh_token_hash: String = row.get("refresh_token_hash");
177
178 Ok(SessionData {
179 user_id,
180 issued_at: issued_at as u64,
181 expires_at: expires_at as u64,
182 refresh_token_hash,
183 })
184 }
185
186 async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
187 let result = sqlx::query(
188 r"
189 UPDATE _system.sessions
190 SET revoked_at = NOW()
191 WHERE refresh_token_hash = $1 AND revoked_at IS NULL
192 ",
193 )
194 .bind(refresh_token_hash)
195 .execute(&self.db)
196 .await
197 .map_err(|e| AuthError::DatabaseError {
198 message: format!("Failed to revoke session: {}", e),
199 })?;
200
201 if result.rows_affected() == 0 {
202 return Err(AuthError::SessionError {
203 message: "Session not found or already revoked".to_string(),
204 });
205 }
206
207 Ok(())
208 }
209
210 async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
211 sqlx::query(
212 r"
213 UPDATE _system.sessions
214 SET revoked_at = NOW()
215 WHERE user_id = $1 AND revoked_at IS NULL
216 ",
217 )
218 .bind(user_id)
219 .execute(&self.db)
220 .await
221 .map_err(|e| AuthError::DatabaseError {
222 message: format!("Failed to revoke all sessions: {}", e),
223 })?;
224
225 Ok(())
226 }
227}
228
229#[cfg(test)]
230mod tests {
231
232 #[test]
233 fn test_generate_access_token_creates_valid_jwt() {
234 let test_pool = std::sync::Arc::new(std::sync::Mutex::new(()));
237 let _ = test_pool; let now = std::time::SystemTime::now()
241 .duration_since(std::time::UNIX_EPOCH)
242 .unwrap_or_default()
243 .as_secs();
244
245 let mut claims = crate::auth::Claims {
246 sub: "user123".to_string(),
247 iat: now,
248 exp: now + 3600,
249 iss: "fraiseql".to_string(),
250 aud: vec!["fraiseql-api".to_string()],
251 extra: std::collections::HashMap::new(),
252 };
253
254 claims
255 .extra
256 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
257
258 let secret = b"fraiseql_session_user123";
259 let token1 = crate::auth::jwt::generate_hs256_token(&claims, secret)
260 .expect("Failed to generate token");
261
262 claims
264 .extra
265 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
266
267 let token2 = crate::auth::jwt::generate_hs256_token(&claims, secret)
268 .expect("Failed to generate token");
269
270 assert_ne!(token1, token2);
272 assert_eq!(token1.matches('.').count(), 2);
274 assert_eq!(token2.matches('.').count(), 2);
275 }
276
277 #[test]
278 fn test_generate_access_token_with_rs256_key() {
279 let test_key = include_bytes!("../../test_data/test_rsa_key.pem");
280
281 let now = std::time::SystemTime::now()
282 .duration_since(std::time::UNIX_EPOCH)
283 .unwrap_or_default()
284 .as_secs();
285
286 let mut claims = crate::auth::Claims {
287 sub: "user123".to_string(),
288 iat: now,
289 exp: now + 3600,
290 iss: "fraiseql".to_string(),
291 aud: vec!["fraiseql-api".to_string()],
292 extra: std::collections::HashMap::new(),
293 };
294
295 claims
296 .extra
297 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
298
299 let token = crate::auth::jwt::generate_rs256_token(&claims, test_key)
300 .expect("Failed to generate RS256 token");
301
302 assert_eq!(token.matches('.').count(), 2);
304 }
305}