rs_auth_postgres/
oauth_state.rs1use async_trait::async_trait;
2use sqlx::Row;
3
4use rs_auth_core::error::AuthError;
5use rs_auth_core::store::OAuthStateStore;
6use rs_auth_core::types::{NewOAuthState, OAuthState};
7
8use crate::db::AuthDb;
9
10#[async_trait]
11impl OAuthStateStore for AuthDb {
12 async fn create_oauth_state(&self, state: NewOAuthState) -> Result<OAuthState, AuthError> {
13 sqlx::query(
14 r#"
15 INSERT INTO oauth_states (provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at)
16 VALUES ($1, $2, $3, $4, $5, $6)
17 RETURNING id, provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at, created_at
18 "#,
19 )
20 .bind(&state.provider_id)
21 .bind(&state.csrf_state)
22 .bind(&state.pkce_verifier)
23 .bind(match state.intent {
24 rs_auth_core::types::OAuthIntent::Login => "login",
25 rs_auth_core::types::OAuthIntent::Link => "link",
26 })
27 .bind(state.link_user_id)
28 .bind(state.expires_at)
29 .fetch_one(&self.pool)
30 .await
31 .map(|row| {
32 let intent_str: String = row.get("intent");
33 OAuthState {
34 id: row.get("id"),
35 provider_id: row.get("provider_id"),
36 csrf_state: row.get("csrf_state"),
37 pkce_verifier: row.get("pkce_verifier"),
38 intent: match intent_str.as_str() {
39 "link" => rs_auth_core::types::OAuthIntent::Link,
40 _ => rs_auth_core::types::OAuthIntent::Login,
41 },
42 link_user_id: row.get("link_user_id"),
43 expires_at: row.get("expires_at"),
44 created_at: row.get("created_at"),
45 }
46 })
47 .map_err(|error| AuthError::Store(error.to_string()))
48 }
49
50 async fn find_by_csrf_state(&self, csrf_state: &str) -> Result<Option<OAuthState>, AuthError> {
51 sqlx::query(
52 r#"
53 SELECT id, provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at, created_at
54 FROM oauth_states
55 WHERE csrf_state = $1
56 "#,
57 )
58 .bind(csrf_state)
59 .fetch_optional(&self.pool)
60 .await
61 .map(|row| {
62 row.map(|row| {
63 let intent_str: String = row.get("intent");
64 OAuthState {
65 id: row.get("id"),
66 provider_id: row.get("provider_id"),
67 csrf_state: row.get("csrf_state"),
68 pkce_verifier: row.get("pkce_verifier"),
69 intent: match intent_str.as_str() {
70 "link" => rs_auth_core::types::OAuthIntent::Link,
71 _ => rs_auth_core::types::OAuthIntent::Login,
72 },
73 link_user_id: row.get("link_user_id"),
74 expires_at: row.get("expires_at"),
75 created_at: row.get("created_at"),
76 }
77 })
78 })
79 .map_err(|error| AuthError::Store(error.to_string()))
80 }
81
82 async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
83 sqlx::query(r#"DELETE FROM oauth_states WHERE id = $1"#)
84 .bind(id)
85 .execute(&self.pool)
86 .await
87 .map(|_| ())
88 .map_err(|error| AuthError::Store(error.to_string()))
89 }
90
91 async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
92 sqlx::query(r#"DELETE FROM oauth_states WHERE expires_at < now()"#)
93 .execute(&self.pool)
94 .await
95 .map(|result| result.rows_affected())
96 .map_err(|error| AuthError::Store(error.to_string()))
97 }
98}