1use async_trait::async_trait;
3use sqlx::{Row, postgres::PgPool};
4
5use crate::{
6 error::{AuthError, Result},
7 session::{SessionData, SessionStore, TokenPair, generate_refresh_token, hash_token},
8};
9
10pub struct PostgresSessionStore {
12 db: PgPool,
13 signing_key: Option<Vec<u8>>,
15}
16
17impl PostgresSessionStore {
18 pub const fn new(db: PgPool) -> Self {
23 Self {
24 db,
25 signing_key: None,
26 }
27 }
28
29 pub const fn with_rs256_key(db: PgPool, private_key_pem: Vec<u8>) -> Self {
35 Self {
36 db,
37 signing_key: Some(private_key_pem),
38 }
39 }
40
41 pub async fn init(&self) -> Result<()> {
48 sqlx::query(
49 r"
50 CREATE TABLE IF NOT EXISTS _system.sessions (
51 id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
52 user_id TEXT NOT NULL,
53 refresh_token_hash TEXT NOT NULL UNIQUE,
54 issued_at BIGINT NOT NULL,
55 expires_at BIGINT NOT NULL,
56 created_at TIMESTAMPTZ DEFAULT NOW(),
57 revoked_at TIMESTAMPTZ
58 );
59
60 CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON _system.sessions(user_id);
61 CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON _system.sessions(expires_at);
62 CREATE INDEX IF NOT EXISTS idx_sessions_revoked_at ON _system.sessions(revoked_at);
63 ",
64 )
65 .execute(&self.db)
66 .await
67 .map_err(|e| AuthError::DatabaseError {
68 message: format!("Failed to initialize sessions table: {}", e),
69 })?;
70
71 Ok(())
72 }
73
74 fn generate_access_token(&self, user_id: &str, expires_in: u64) -> Result<String> {
79 let now = std::time::SystemTime::now()
80 .duration_since(std::time::UNIX_EPOCH)
81 .unwrap_or_default()
82 .as_secs();
83
84 let exp = now + expires_in;
85
86 let mut claims = crate::Claims {
87 sub: user_id.to_string(),
88 iat: now,
89 exp,
90 iss: "fraiseql".to_string(),
91 aud: vec!["fraiseql-api".to_string()],
92 extra: std::collections::HashMap::new(),
93 };
94
95 claims
97 .extra
98 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
99
100 if let Some(private_key) = &self.signing_key {
101 crate::jwt::generate_rs256_token(&claims, private_key)
102 } else {
103 let secret = format!("fraiseql_session_{}", user_id).into_bytes();
105 crate::jwt::generate_hs256_token(&claims, &secret)
106 }
107 }
108}
109
110#[async_trait]
114impl SessionStore for PostgresSessionStore {
115 async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
116 let refresh_token = generate_refresh_token();
117 let refresh_token_hash = hash_token(&refresh_token);
118
119 let now = std::time::SystemTime::now()
120 .duration_since(std::time::UNIX_EPOCH)
121 .unwrap_or_default()
122 .as_secs();
123
124 sqlx::query(
125 r"
126 INSERT INTO _system.sessions
127 (user_id, refresh_token_hash, issued_at, expires_at)
128 VALUES ($1, $2, $3, $4)
129 ",
130 )
131 .bind(user_id)
132 .bind(&refresh_token_hash)
133 .bind(now.cast_signed())
134 .bind(expires_at.cast_signed())
135 .execute(&self.db)
136 .await
137 .map_err(|e| {
138 if e.to_string().contains("duplicate key") {
139 AuthError::SessionError {
140 message: "Refresh token already exists".to_string(),
141 }
142 } else {
143 AuthError::DatabaseError {
144 message: format!("Failed to create session: {}", e),
145 }
146 }
147 })?;
148
149 let expires_in = expires_at.saturating_sub(now);
150 let access_token = self.generate_access_token(user_id, expires_in)?;
151
152 Ok(TokenPair {
153 access_token,
154 refresh_token,
155 expires_in,
156 })
157 }
158
159 async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
160 let row = sqlx::query(
161 r"
162 SELECT user_id, issued_at, expires_at, refresh_token_hash
163 FROM _system.sessions
164 WHERE refresh_token_hash = $1 AND revoked_at IS NULL
165 ",
166 )
167 .bind(refresh_token_hash)
168 .fetch_optional(&self.db)
169 .await
170 .map_err(|e| AuthError::DatabaseError {
171 message: format!("Failed to get session: {}", e),
172 })?
173 .ok_or(AuthError::TokenNotFound)?;
174
175 let user_id: String = row.get("user_id");
176 let issued_at: i64 = row.get("issued_at");
177 let expires_at: i64 = row.get("expires_at");
178 let refresh_token_hash: String = row.get("refresh_token_hash");
179
180 Ok(SessionData {
181 user_id,
182 issued_at: issued_at.cast_unsigned(),
183 expires_at: expires_at.cast_unsigned(),
184 refresh_token_hash,
185 })
186 }
187
188 async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
189 let result = sqlx::query(
190 r"
191 UPDATE _system.sessions
192 SET revoked_at = NOW()
193 WHERE refresh_token_hash = $1 AND revoked_at IS NULL
194 ",
195 )
196 .bind(refresh_token_hash)
197 .execute(&self.db)
198 .await
199 .map_err(|e| AuthError::DatabaseError {
200 message: format!("Failed to revoke session: {}", e),
201 })?;
202
203 if result.rows_affected() == 0 {
204 return Err(AuthError::SessionError {
205 message: "Session not found or already revoked".to_string(),
206 });
207 }
208
209 Ok(())
210 }
211
212 async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
213 sqlx::query(
214 r"
215 UPDATE _system.sessions
216 SET revoked_at = NOW()
217 WHERE user_id = $1 AND revoked_at IS NULL
218 ",
219 )
220 .bind(user_id)
221 .execute(&self.db)
222 .await
223 .map_err(|e| AuthError::DatabaseError {
224 message: format!("Failed to revoke all sessions: {}", e),
225 })?;
226
227 Ok(())
228 }
229}
230
231#[cfg(test)]
232mod tests {
233
234 #[test]
235 fn test_generate_access_token_creates_valid_jwt() {
236 let test_pool = std::sync::Arc::new(std::sync::Mutex::new(()));
239 let _ = test_pool; let now = std::time::SystemTime::now()
243 .duration_since(std::time::UNIX_EPOCH)
244 .unwrap_or_default()
245 .as_secs();
246
247 let mut claims = crate::Claims {
248 sub: "user123".to_string(),
249 iat: now,
250 exp: now + 3600,
251 iss: "fraiseql".to_string(),
252 aud: vec!["fraiseql-api".to_string()],
253 extra: std::collections::HashMap::new(),
254 };
255
256 claims
257 .extra
258 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
259
260 let secret = b"fraiseql_session_user123";
261 let token1 =
262 crate::jwt::generate_hs256_token(&claims, secret).expect("Failed to generate token");
263
264 claims
266 .extra
267 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
268
269 let token2 =
270 crate::jwt::generate_hs256_token(&claims, secret).expect("Failed to generate token");
271
272 assert_ne!(token1, token2);
274 assert_eq!(token1.matches('.').count(), 2);
276 assert_eq!(token2.matches('.').count(), 2);
277 }
278
279 #[test]
280 fn test_generate_access_token_with_rs256_key() {
281 let test_key = include_bytes!("../test_data/test_rsa_key.pem");
282
283 let now = std::time::SystemTime::now()
284 .duration_since(std::time::UNIX_EPOCH)
285 .unwrap_or_default()
286 .as_secs();
287
288 let mut claims = crate::Claims {
289 sub: "user123".to_string(),
290 iat: now,
291 exp: now + 3600,
292 iss: "fraiseql".to_string(),
293 aud: vec!["fraiseql-api".to_string()],
294 extra: std::collections::HashMap::new(),
295 };
296
297 claims
298 .extra
299 .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
300
301 let token = crate::jwt::generate_rs256_token(&claims, test_key)
302 .expect("Failed to generate RS256 token");
303
304 assert_eq!(token.matches('.').count(), 2);
306 }
307}