Skip to main content

fraiseql_server/auth/
session_postgres.rs

1// PostgreSQL SessionStore implementation
2use async_trait::async_trait;
3use sqlx::{Row, postgres::PgPool};
4
5use crate::auth::{
6    error::{AuthError, Result},
7    session::{SessionData, SessionStore, TokenPair, generate_refresh_token, hash_token},
8};
9
10/// PostgreSQL-backed session store
11pub struct PostgresSessionStore {
12    db:          PgPool,
13    /// Optional RSA private key for JWT signing (None falls back to HMAC)
14    signing_key: Option<Vec<u8>>,
15}
16
17impl PostgresSessionStore {
18    /// Create a new PostgreSQL session store
19    ///
20    /// # Errors
21    /// Returns error if database connection fails
22    pub fn new(db: PgPool) -> Self {
23        Self {
24            db,
25            signing_key: None,
26        }
27    }
28
29    /// Create a new PostgreSQL session store with RS256 JWT signing
30    ///
31    /// # Arguments
32    /// * `db` - PostgreSQL connection pool
33    /// * `private_key_pem` - RSA private key in PEM format
34    pub fn with_rs256_key(db: PgPool, private_key_pem: Vec<u8>) -> Self {
35        Self {
36            db,
37            signing_key: Some(private_key_pem),
38        }
39    }
40
41    /// Initialize the sessions table
42    ///
43    /// This should be called once during server startup to ensure the table exists.
44    ///
45    /// # Errors
46    /// Returns error if table creation fails
47    pub async fn init(&self) -> Result<()> {
48        sqlx::query(
49            r"
50            CREATE TABLE IF NOT EXISTS _system.sessions (
51                id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
52                user_id TEXT NOT NULL,
53                refresh_token_hash TEXT NOT NULL UNIQUE,
54                issued_at BIGINT NOT NULL,
55                expires_at BIGINT NOT NULL,
56                created_at TIMESTAMPTZ DEFAULT NOW(),
57                revoked_at TIMESTAMPTZ
58            );
59
60            CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON _system.sessions(user_id);
61            CREATE INDEX IF NOT EXISTS idx_sessions_expires_at ON _system.sessions(expires_at);
62            CREATE INDEX IF NOT EXISTS idx_sessions_revoked_at ON _system.sessions(revoked_at);
63            ",
64        )
65        .execute(&self.db)
66        .await
67        .map_err(|e| AuthError::DatabaseError {
68            message: format!("Failed to initialize sessions table: {}", e),
69        })?;
70
71        Ok(())
72    }
73
74    /// Generate a JWT access token with RS256 or HMAC signing
75    ///
76    /// Uses RS256 if a signing key is configured, otherwise falls back to HMAC with a
77    /// deterministic secret derived from the user ID.
78    fn generate_access_token(&self, user_id: &str, expires_in: u64) -> Result<String> {
79        let now = std::time::SystemTime::now()
80            .duration_since(std::time::UNIX_EPOCH)
81            .unwrap_or_default()
82            .as_secs();
83
84        let exp = now + expires_in;
85
86        let mut claims = crate::auth::Claims {
87            sub: user_id.to_string(),
88            iat: now,
89            exp,
90            iss: "fraiseql".to_string(),
91            aud: vec!["fraiseql-api".to_string()],
92            extra: std::collections::HashMap::new(),
93        };
94
95        // Add JTI (JWT ID) for uniqueness
96        claims
97            .extra
98            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
99
100        match &self.signing_key {
101            Some(private_key) => crate::auth::jwt::generate_rs256_token(&claims, private_key),
102            None => {
103                // Fallback: use deterministic HMAC secret (for testing/dev environments)
104                let secret = format!("fraiseql_session_{}", user_id).into_bytes();
105                crate::auth::jwt::generate_hs256_token(&claims, &secret)
106            },
107        }
108    }
109}
110
111#[async_trait]
112impl SessionStore for PostgresSessionStore {
113    async fn create_session(&self, user_id: &str, expires_at: u64) -> Result<TokenPair> {
114        let refresh_token = generate_refresh_token();
115        let refresh_token_hash = hash_token(&refresh_token);
116
117        let now = std::time::SystemTime::now()
118            .duration_since(std::time::UNIX_EPOCH)
119            .unwrap_or_default()
120            .as_secs();
121
122        sqlx::query(
123            r"
124            INSERT INTO _system.sessions
125            (user_id, refresh_token_hash, issued_at, expires_at)
126            VALUES ($1, $2, $3, $4)
127            ",
128        )
129        .bind(user_id)
130        .bind(&refresh_token_hash)
131        .bind(now as i64)
132        .bind(expires_at as i64)
133        .execute(&self.db)
134        .await
135        .map_err(|e| {
136            if e.to_string().contains("duplicate key") {
137                AuthError::SessionError {
138                    message: "Refresh token already exists".to_string(),
139                }
140            } else {
141                AuthError::DatabaseError {
142                    message: format!("Failed to create session: {}", e),
143                }
144            }
145        })?;
146
147        let expires_in = expires_at.saturating_sub(now);
148        let access_token = self.generate_access_token(user_id, expires_in)?;
149
150        Ok(TokenPair {
151            access_token,
152            refresh_token,
153            expires_in,
154        })
155    }
156
157    async fn get_session(&self, refresh_token_hash: &str) -> Result<SessionData> {
158        let row = sqlx::query(
159            r"
160            SELECT user_id, issued_at, expires_at, refresh_token_hash
161            FROM _system.sessions
162            WHERE refresh_token_hash = $1 AND revoked_at IS NULL
163            ",
164        )
165        .bind(refresh_token_hash)
166        .fetch_optional(&self.db)
167        .await
168        .map_err(|e| AuthError::DatabaseError {
169            message: format!("Failed to get session: {}", e),
170        })?
171        .ok_or(AuthError::TokenNotFound)?;
172
173        let user_id: String = row.get("user_id");
174        let issued_at: i64 = row.get("issued_at");
175        let expires_at: i64 = row.get("expires_at");
176        let refresh_token_hash: String = row.get("refresh_token_hash");
177
178        Ok(SessionData {
179            user_id,
180            issued_at: issued_at as u64,
181            expires_at: expires_at as u64,
182            refresh_token_hash,
183        })
184    }
185
186    async fn revoke_session(&self, refresh_token_hash: &str) -> Result<()> {
187        let result = sqlx::query(
188            r"
189            UPDATE _system.sessions
190            SET revoked_at = NOW()
191            WHERE refresh_token_hash = $1 AND revoked_at IS NULL
192            ",
193        )
194        .bind(refresh_token_hash)
195        .execute(&self.db)
196        .await
197        .map_err(|e| AuthError::DatabaseError {
198            message: format!("Failed to revoke session: {}", e),
199        })?;
200
201        if result.rows_affected() == 0 {
202            return Err(AuthError::SessionError {
203                message: "Session not found or already revoked".to_string(),
204            });
205        }
206
207        Ok(())
208    }
209
210    async fn revoke_all_sessions(&self, user_id: &str) -> Result<()> {
211        sqlx::query(
212            r"
213            UPDATE _system.sessions
214            SET revoked_at = NOW()
215            WHERE user_id = $1 AND revoked_at IS NULL
216            ",
217        )
218        .bind(user_id)
219        .execute(&self.db)
220        .await
221        .map_err(|e| AuthError::DatabaseError {
222            message: format!("Failed to revoke all sessions: {}", e),
223        })?;
224
225        Ok(())
226    }
227}
228
229#[cfg(test)]
230mod tests {
231
232    #[test]
233    fn test_generate_access_token_creates_valid_jwt() {
234        // Create a minimal test store - we don't need a real pool since we're just testing token
235        // generation
236        let test_pool = std::sync::Arc::new(std::sync::Mutex::new(()));
237        let _ = test_pool; // Use to avoid unused variable warning
238
239        // Test JWT generation using Claims directly instead of through the store
240        let now = std::time::SystemTime::now()
241            .duration_since(std::time::UNIX_EPOCH)
242            .unwrap_or_default()
243            .as_secs();
244
245        let mut claims = crate::auth::Claims {
246            sub:   "user123".to_string(),
247            iat:   now,
248            exp:   now + 3600,
249            iss:   "fraiseql".to_string(),
250            aud:   vec!["fraiseql-api".to_string()],
251            extra: std::collections::HashMap::new(),
252        };
253
254        claims
255            .extra
256            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
257
258        let secret = b"fraiseql_session_user123";
259        let token1 = crate::auth::jwt::generate_hs256_token(&claims, secret)
260            .expect("Failed to generate token");
261
262        // Update JTI for second token
263        claims
264            .extra
265            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
266
267        let token2 = crate::auth::jwt::generate_hs256_token(&claims, secret)
268            .expect("Failed to generate token");
269
270        // Tokens should be different (different JTI)
271        assert_ne!(token1, token2);
272        // Both should be valid JWT format (three dot-separated parts)
273        assert_eq!(token1.matches('.').count(), 2);
274        assert_eq!(token2.matches('.').count(), 2);
275    }
276
277    #[test]
278    fn test_generate_access_token_with_rs256_key() {
279        let test_key = include_bytes!("../../test_data/test_rsa_key.pem");
280
281        let now = std::time::SystemTime::now()
282            .duration_since(std::time::UNIX_EPOCH)
283            .unwrap_or_default()
284            .as_secs();
285
286        let mut claims = crate::auth::Claims {
287            sub:   "user123".to_string(),
288            iat:   now,
289            exp:   now + 3600,
290            iss:   "fraiseql".to_string(),
291            aud:   vec!["fraiseql-api".to_string()],
292            extra: std::collections::HashMap::new(),
293        };
294
295        claims
296            .extra
297            .insert("jti".to_string(), serde_json::json!(uuid::Uuid::new_v4().to_string()));
298
299        let token = crate::auth::jwt::generate_rs256_token(&claims, test_key)
300            .expect("Failed to generate RS256 token");
301
302        // Valid JWT should have three parts separated by dots
303        assert_eq!(token.matches('.').count(), 2);
304    }
305}