rs_auth_postgres/
oauth_state.rs1use async_trait::async_trait;
2use sqlx::Row;
3
4use rs_auth_core::error::AuthError;
5use rs_auth_core::store::OAuthStateStore;
6use rs_auth_core::types::{NewOAuthState, OAuthState};
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl OAuthStateStore for AuthDb {
12 async fn create_oauth_state(&self, state: NewOAuthState) -> Result<OAuthState, AuthError> {
13 sqlx::query(
14 r#"
15 INSERT INTO oauth_states (provider_id, csrf_state, pkce_verifier, expires_at)
16 VALUES ($1, $2, $3, $4)
17 RETURNING id, provider_id, csrf_state, pkce_verifier, expires_at, created_at
18 "#,
19 )
20 .bind(&state.provider_id)
21 .bind(&state.csrf_state)
22 .bind(&state.pkce_verifier)
23 .bind(state.expires_at)
24 .fetch_one(&self.pool)
25 .await
26 .map(|row| OAuthState {
27 id: row.get("id"),
28 provider_id: row.get("provider_id"),
29 csrf_state: row.get("csrf_state"),
30 pkce_verifier: row.get("pkce_verifier"),
31 expires_at: row.get("expires_at"),
32 created_at: row.get("created_at"),
33 })
34 .map_err(|error| AuthError::Store(error.to_string()))
35 }
36
37 async fn find_by_csrf_state(&self, csrf_state: &str) -> Result<Option<OAuthState>, AuthError> {
38 sqlx::query(
39 r#"
40 SELECT id, provider_id, csrf_state, pkce_verifier, expires_at, created_at
41 FROM oauth_states
42 WHERE csrf_state = $1
43 "#,
44 )
45 .bind(csrf_state)
46 .fetch_optional(&self.pool)
47 .await
48 .map(|row| {
49 row.map(|row| OAuthState {
50 id: row.get("id"),
51 provider_id: row.get("provider_id"),
52 csrf_state: row.get("csrf_state"),
53 pkce_verifier: row.get("pkce_verifier"),
54 expires_at: row.get("expires_at"),
55 created_at: row.get("created_at"),
56 })
57 })
58 .map_err(|error| AuthError::Store(error.to_string()))
59 }
60
61 async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
62 sqlx::query(r#"DELETE FROM oauth_states WHERE id = $1"#)
63 .bind(id)
64 .execute(&self.pool)
65 .await
66 .map(|_| ())
67 .map_err(|error| AuthError::Store(error.to_string()))
68 }
69
70 async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
71 sqlx::query(r#"DELETE FROM oauth_states WHERE expires_at < now()"#)
72 .execute(&self.pool)
73 .await
74 .map(|result| result.rows_affected())
75 .map_err(|error| AuthError::Store(error.to_string()))
76 }
77}