Skip to main content

rs_auth_postgres/
oauth_state.rs

1use async_trait::async_trait;
2use sqlx::Row;
3
4use rs_auth_core::error::AuthError;
5use rs_auth_core::store::OAuthStateStore;
6use rs_auth_core::types::{NewOAuthState, OAuthState};
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl OAuthStateStore for AuthDb {
12    async fn create_oauth_state(&self, state: NewOAuthState) -> Result<OAuthState, AuthError> {
13        sqlx::query(
14            r#"
15            INSERT INTO oauth_states (provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at)
16            VALUES ($1, $2, $3, $4, $5, $6)
17            RETURNING id, provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at, created_at
18            "#,
19        )
20        .bind(&state.provider_id)
21        .bind(&state.csrf_state)
22        .bind(&state.pkce_verifier)
23        .bind(match state.intent {
24            rs_auth_core::types::OAuthIntent::Login => "login",
25            rs_auth_core::types::OAuthIntent::Link => "link",
26        })
27        .bind(state.link_user_id)
28        .bind(state.expires_at)
29        .fetch_one(&self.pool)
30        .await
31        .map(|row| {
32            let intent_str: String = row.get("intent");
33            OAuthState {
34                id: row.get("id"),
35                provider_id: row.get("provider_id"),
36                csrf_state: row.get("csrf_state"),
37                pkce_verifier: row.get("pkce_verifier"),
38                intent: match intent_str.as_str() {
39                    "link" => rs_auth_core::types::OAuthIntent::Link,
40                    _ => rs_auth_core::types::OAuthIntent::Login,
41                },
42                link_user_id: row.get("link_user_id"),
43                expires_at: row.get("expires_at"),
44                created_at: row.get("created_at"),
45            }
46        })
47        .map_err(|error| AuthError::Store(error.to_string()))
48    }
49
50    async fn find_by_csrf_state(&self, csrf_state: &str) -> Result<Option<OAuthState>, AuthError> {
51        sqlx::query(
52            r#"
53            SELECT id, provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at, created_at
54            FROM oauth_states
55            WHERE csrf_state = $1
56            "#,
57        )
58        .bind(csrf_state)
59        .fetch_optional(&self.pool)
60        .await
61        .map(|row| {
62            row.map(|row| {
63                let intent_str: String = row.get("intent");
64                OAuthState {
65                    id: row.get("id"),
66                    provider_id: row.get("provider_id"),
67                    csrf_state: row.get("csrf_state"),
68                    pkce_verifier: row.get("pkce_verifier"),
69                    intent: match intent_str.as_str() {
70                        "link" => rs_auth_core::types::OAuthIntent::Link,
71                        _ => rs_auth_core::types::OAuthIntent::Login,
72                    },
73                    link_user_id: row.get("link_user_id"),
74                    expires_at: row.get("expires_at"),
75                    created_at: row.get("created_at"),
76                }
77            })
78        })
79        .map_err(|error| AuthError::Store(error.to_string()))
80    }
81
82    async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
83        sqlx::query(r#"DELETE FROM oauth_states WHERE id = $1"#)
84            .bind(id)
85            .execute(&self.pool)
86            .await
87            .map(|_| ())
88            .map_err(|error| AuthError::Store(error.to_string()))
89    }
90
91    async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
92        sqlx::query(r#"DELETE FROM oauth_states WHERE expires_at < now()"#)
93            .execute(&self.pool)
94            .await
95            .map(|result| result.rows_affected())
96            .map_err(|error| AuthError::Store(error.to_string()))
97    }
98}