Skip to main content

fraiseql_auth/
session_postgres.rs

1//! PostgreSQL-backed [`SessionStore`] implementation.
2use async_trait::async_trait;
3use sqlx::{Row, postgres::PgPool};
4
5use crate::{
6    error::{AuthError, Result},
7    session::{SessionData, SessionStore, TokenPair, generate_refresh_token, hash_token},
8};
9
10/// PostgreSQL-backed session store
11pub struct PostgresSessionStore {
12    db:          PgPool,
13    /// Optional RSA private key for JWT signing (None falls back to HMAC)
14    signing_key: Option<Vec<u8>>,
15}
16
17impl PostgresSessionStore {
18    /// Create a new PostgreSQL session store
19    ///
20    /// # Errors
21    /// Returns error if database connection fails
22    pub const fn new(db: PgPool) -> Self {
23        Self {
24            db,
25            signing_key: None,
26        }
27    }
28
29    /// Create a new PostgreSQL session store with RS256 JWT signing
30    ///
31    /// # Arguments
32    /// * `db` - PostgreSQL connection pool
33    /// * `private_key_pem` - RSA private key in PEM format
34    pub const fn with_rs256_key(db: PgPool, private_key_pem: Vec<u8>) -> Self {
35        Self {
36            db,
37            signing_key: Some(private_key_pem),
38        }
39    }
40
41    /// Initialize the sessions table
42    ///
43    /// This should be called once during server startup to ensure the table exists.
44    ///
45    /// # Errors
46    /// Returns error if table creation fails
47    pub async fn init(&self) -> Result<()> {
48        sqlx::query(
49            r"
50            CREATE TABLE IF NOT EXISTS _system.sessions (
51                id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
52                user_id TEXT NOT NULL,
53                refresh_token_hash TEXT NOT NULL UNIQUE,
54                issued_at BIGINT NOT NULL,
55                expires_at BIGINT NOT NULL,
56                created_at TIMESTAMPTZ DEFAULT NOW(),
57                revoked_at TIMESTAMPTZ
58            );
59
60            CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON _system.sessions(user_id);
61            CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON _system.sessions(expires_at);
62            CREATE INDEX IF NOT EXISTS idx_sessions_revoked_at ON _system.sessions(revoked_at);
63            ",
64        )
65        .execute(&self.db)
66        .await
67        .map_err(|e| AuthError::DatabaseError {
68            message: format!("Failed to initialize sessions table: {}", e),
69        })?;
70
71        Ok(())
72    }
73
74    /// Generate a JWT access token with RS256 or HMAC signing
75    ///
76    /// Uses RS256 if a signing key is configured, otherwise falls back to HMAC with a
77    /// deterministic secret derived from the user ID.
78    fn generate_access_token(&self, user_id: &str, expires_in: u64) -> Result<String> {
79        let now = std::time::SystemTime::now()
80            .duration_since(std::time::UNIX_EPOCH)
81            .unwrap_or_default()
82            .as_secs();
83
84        let exp = now + expires_in;
85
86        let mut claims = crate::Claims {
87            sub: user_id.to_string(),
88            iat: now,
89            exp,
90            iss: "fraiseql".to_string(),
91            aud: vec!["fraiseql-api".to_string()],
92            extra: std::collections::HashMap::new(),
93        };
94
95        // Add JTI (JWT ID) for uniqueness
96        claims
97            .extra
98            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
99
100        if let Some(private_key) = &self.signing_key {
101            crate::jwt::generate_rs256_token(&claims, private_key)
102        } else {
103            // Fallback: use deterministic HMAC secret (for testing/dev environments)
104            let secret = format!("fraiseql_session_{}", user_id).into_bytes();
105            crate::jwt::generate_hs256_token(&claims, &secret)
106        }
107    }
108}
109
110// Reason: SessionStore is defined with #[async_trait]; all implementations must match
111// its transformed method signatures to satisfy the trait contract
112// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
113#[async_trait]
114impl SessionStore for PostgresSessionStore {
115    async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
116        let refresh_token = generate_refresh_token();
117        let refresh_token_hash = hash_token(&refresh_token);
118
119        let now = std::time::SystemTime::now()
120            .duration_since(std::time::UNIX_EPOCH)
121            .unwrap_or_default()
122            .as_secs();
123
124        sqlx::query(
125            r"
126            INSERT INTO _system.sessions
127            (user_id, refresh_token_hash, issued_at, expires_at)
128            VALUES ($1, $2, $3, $4)
129            ",
130        )
131        .bind(user_id)
132        .bind(&refresh_token_hash)
133        .bind(now.cast_signed())
134        .bind(expires_at.cast_signed())
135        .execute(&self.db)
136        .await
137        .map_err(|e| {
138            if e.to_string().contains("duplicate key") {
139                AuthError::SessionError {
140                    message: "Refresh token already exists".to_string(),
141                }
142            } else {
143                AuthError::DatabaseError {
144                    message: format!("Failed to create session: {}", e),
145                }
146            }
147        })?;
148
149        let expires_in = expires_at.saturating_sub(now);
150        let access_token = self.generate_access_token(user_id, expires_in)?;
151
152        Ok(TokenPair {
153            access_token,
154            refresh_token,
155            expires_in,
156        })
157    }
158
159    async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
160        let row = sqlx::query(
161            r"
162            SELECT user_id, issued_at, expires_at, refresh_token_hash
163            FROM _system.sessions
164            WHERE refresh_token_hash = $1 AND revoked_at IS NULL
165            ",
166        )
167        .bind(refresh_token_hash)
168        .fetch_optional(&self.db)
169        .await
170        .map_err(|e| AuthError::DatabaseError {
171            message: format!("Failed to get session: {}", e),
172        })?
173        .ok_or(AuthError::TokenNotFound)?;
174
175        let user_id: String = row.get("user_id");
176        let issued_at: i64 = row.get("issued_at");
177        let expires_at: i64 = row.get("expires_at");
178        let refresh_token_hash: String = row.get("refresh_token_hash");
179
180        Ok(SessionData {
181            user_id,
182            issued_at: issued_at.cast_unsigned(),
183            expires_at: expires_at.cast_unsigned(),
184            refresh_token_hash,
185        })
186    }
187
188    async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
189        let result = sqlx::query(
190            r"
191            UPDATE _system.sessions
192            SET revoked_at = NOW()
193            WHERE refresh_token_hash = $1 AND revoked_at IS NULL
194            ",
195        )
196        .bind(refresh_token_hash)
197        .execute(&self.db)
198        .await
199        .map_err(|e| AuthError::DatabaseError {
200            message: format!("Failed to revoke session: {}", e),
201        })?;
202
203        if result.rows_affected() == 0 {
204            return Err(AuthError::SessionError {
205                message: "Session not found or already revoked".to_string(),
206            });
207        }
208
209        Ok(())
210    }
211
212    async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
213        sqlx::query(
214            r"
215            UPDATE _system.sessions
216            SET revoked_at = NOW()
217            WHERE user_id = $1 AND revoked_at IS NULL
218            ",
219        )
220        .bind(user_id)
221        .execute(&self.db)
222        .await
223        .map_err(|e| AuthError::DatabaseError {
224            message: format!("Failed to revoke all sessions: {}", e),
225        })?;
226
227        Ok(())
228    }
229}
230
231#[cfg(test)]
232mod tests {
233
234    #[test]
235    fn test_generate_access_token_creates_valid_jwt() {
236        // Create a minimal test store - we don't need a real pool since we're just testing token
237        // generation
238        let test_pool = std::sync::Arc::new(std::sync::Mutex::new(()));
239        let _ = test_pool; // Use to avoid unused variable warning
240
241        // Test JWT generation using Claims directly instead of through the store
242        let now = std::time::SystemTime::now()
243            .duration_since(std::time::UNIX_EPOCH)
244            .unwrap_or_default()
245            .as_secs();
246
247        let mut claims = crate::Claims {
248            sub:   "user123".to_string(),
249            iat:   now,
250            exp:   now + 3600,
251            iss:   "fraiseql".to_string(),
252            aud:   vec!["fraiseql-api".to_string()],
253            extra: std::collections::HashMap::new(),
254        };
255
256        claims
257            .extra
258            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
259
260        let secret = b"fraiseql_session_user123";
261        let token1 =
262            crate::jwt::generate_hs256_token(&claims, secret).expect("Failed to generate token");
263
264        // Update JTI for second token
265        claims
266            .extra
267            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
268
269        let token2 =
270            crate::jwt::generate_hs256_token(&claims, secret).expect("Failed to generate token");
271
272        // Tokens should be different (different JTI)
273        assert_ne!(token1, token2);
274        // Both should be valid JWT format (three dot-separated parts)
275        assert_eq!(token1.matches('.').count(), 2);
276        assert_eq!(token2.matches('.').count(), 2);
277    }
278
279    #[test]
280    fn test_generate_access_token_with_rs256_key() {
281        let test_key = include_bytes!("../test_data/test_rsa_key.pem");
282
283        let now = std::time::SystemTime::now()
284            .duration_since(std::time::UNIX_EPOCH)
285            .unwrap_or_default()
286            .as_secs();
287
288        let mut claims = crate::Claims {
289            sub:   "user123".to_string(),
290            iat:   now,
291            exp:   now + 3600,
292            iss:   "fraiseql".to_string(),
293            aud:   vec!["fraiseql-api".to_string()],
294            extra: std::collections::HashMap::new(),
295        };
296
297        claims
298            .extra
299            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
300
301        let token = crate::jwt::generate_rs256_token(&claims, test_key)
302            .expect("Failed to generate RS256 token");
303
304        // Valid JWT should have three parts separated by dots
305        assert_eq!(token.matches('.').count(), 2);
306    }
307}