rs-auth-postgres 0.1.2

Postgres persistence and migrations for rs-auth.
Documentation
use async_trait::async_trait;
use sqlx::Row;

use rs_auth_core::error::AuthError;
use rs_auth_core::store::OAuthStateStore;
use rs_auth_core::types::{NewOAuthState, OAuthState};

use crate::db::AuthDb;

#[async_trait]
impl OAuthStateStore for AuthDb {
    async fn create_oauth_state(&self, state: NewOAuthState) -> Result<OAuthState, AuthError> {
        sqlx::query(
            r#"
            INSERT INTO oauth_states (provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at)
            VALUES ($1, $2, $3, $4, $5, $6)
            RETURNING id, provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at, created_at
            "#,
        )
        .bind(&state.provider_id)
        .bind(&state.csrf_state)
        .bind(&state.pkce_verifier)
        .bind(match state.intent {
            rs_auth_core::types::OAuthIntent::Login => "login",
            rs_auth_core::types::OAuthIntent::Link => "link",
        })
        .bind(state.link_user_id)
        .bind(state.expires_at)
        .fetch_one(&self.pool)
        .await
        .map(|row| {
            let intent_str: String = row.get("intent");
            OAuthState {
                id: row.get("id"),
                provider_id: row.get("provider_id"),
                csrf_state: row.get("csrf_state"),
                pkce_verifier: row.get("pkce_verifier"),
                intent: match intent_str.as_str() {
                    "link" => rs_auth_core::types::OAuthIntent::Link,
                    _ => rs_auth_core::types::OAuthIntent::Login,
                },
                link_user_id: row.get("link_user_id"),
                expires_at: row.get("expires_at"),
                created_at: row.get("created_at"),
            }
        })
        .map_err(|error| AuthError::Store(error.to_string()))
    }

    async fn find_by_csrf_state(&self, csrf_state: &str) -> Result<Option<OAuthState>, AuthError> {
        sqlx::query(
            r#"
            SELECT id, provider_id, csrf_state, pkce_verifier, intent, link_user_id, expires_at, created_at
            FROM oauth_states
            WHERE csrf_state = $1
            "#,
        )
        .bind(csrf_state)
        .fetch_optional(&self.pool)
        .await
        .map(|row| {
            row.map(|row| {
                let intent_str: String = row.get("intent");
                OAuthState {
                    id: row.get("id"),
                    provider_id: row.get("provider_id"),
                    csrf_state: row.get("csrf_state"),
                    pkce_verifier: row.get("pkce_verifier"),
                    intent: match intent_str.as_str() {
                        "link" => rs_auth_core::types::OAuthIntent::Link,
                        _ => rs_auth_core::types::OAuthIntent::Login,
                    },
                    link_user_id: row.get("link_user_id"),
                    expires_at: row.get("expires_at"),
                    created_at: row.get("created_at"),
                }
            })
        })
        .map_err(|error| AuthError::Store(error.to_string()))
    }

    async fn delete_oauth_state(&self, id: i64) -> Result<(), AuthError> {
        sqlx::query(r#"DELETE FROM oauth_states WHERE id = $1"#)
            .bind(id)
            .execute(&self.pool)
            .await
            .map(|_| ())
            .map_err(|error| AuthError::Store(error.to_string()))
    }

    async fn delete_expired_oauth_states(&self) -> Result<u64, AuthError> {
        sqlx::query(r#"DELETE FROM oauth_states WHERE expires_at < now()"#)
            .execute(&self.pool)
            .await
            .map(|result| result.rows_affected())
            .map_err(|error| AuthError::Store(error.to_string()))
    }
}