Skip to main content

rs_auth_postgres/
oauth_state.rs

1use async_trait::async_trait;
2use sqlx::Row;
3
4use rs_auth_core::error::AuthError;
5use rs_auth_core::store::OAuthStateStore;
6use rs_auth_core::types::{NewOAuthState, OAuthState};
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl OAuthStateStore for AuthDb {
12    async fn create_oauth_state(&self, state: NewOAuthState) -> Result<OAuthState, AuthError> {
13        sqlx::query(
14            r#"
15            INSERT INTO oauth_states (provider_id, csrf_state, pkce_verifier, expires_at)
16            VALUES ($1, $2, $3, $4)
17            RETURNING id, provider_id, csrf_state, pkce_verifier, expires_at, created_at
18            "#,
19        )
20        .bind(&state.provider_id)
21        .bind(&state.csrf_state)
22        .bind(&state.pkce_verifier)
23        .bind(state.expires_at)
24        .fetch_one(&self.pool)
25        .await
26        .map(|row| OAuthState {
27            id: row.get("id"),
28            provider_id: row.get("provider_id"),
29            csrf_state: row.get("csrf_state"),
30            pkce_verifier: row.get("pkce_verifier"),
31            expires_at: row.get("expires_at"),
32            created_at: row.get("created_at"),
33        })
34        .map_err(|error| AuthError::Store(error.to_string()))
35    }
36
37    async fn find_by_csrf_state(&self, csrf_state: &str) -> Result<Option<OAuthState>, AuthError> {
38        sqlx::query(
39            r#"
40            SELECT id, provider_id, csrf_state, pkce_verifier, expires_at, created_at
41            FROM oauth_states
42            WHERE csrf_state = $1
43            "#,
44        )
45        .bind(csrf_state)
46        .fetch_optional(&self.pool)
47        .await
48        .map(|row| {
49            row.map(|row| OAuthState {
50                id: row.get("id"),
51                provider_id: row.get("provider_id"),
52                csrf_state: row.get("csrf_state"),
53                pkce_verifier: row.get("pkce_verifier"),
54                expires_at: row.get("expires_at"),
55                created_at: row.get("created_at"),
56            })
57        })
58        .map_err(|error| AuthError::Store(error.to_string()))
59    }
60
61    async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
62        sqlx::query(r#"DELETE FROM oauth_states WHERE id = $1"#)
63            .bind(id)
64            .execute(&self.pool)
65            .await
66            .map(|_| ())
67            .map_err(|error| AuthError::Store(error.to_string()))
68    }
69
70    async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
71        sqlx::query(r#"DELETE FROM oauth_states WHERE expires_at < now()"#)
72            .execute(&self.pool)
73            .await
74            .map(|result| result.rows_affected())
75            .map_err(|error| AuthError::Store(error.to_string()))
76    }
77}