rs-auth-postgres 0.1.2

Postgres persistence and migrations for rs-auth.
Documentation
use async_trait::async_trait;
use rs_auth_core::error::AuthError;
use rs_auth_core::store::AccountStore;
use rs_auth_core::types::{Account, NewAccount};
use sqlx::Row;
use time::OffsetDateTime;

use crate::db::AuthDb;

#[async_trait]
impl AccountStore for AuthDb {
    async fn create_account(&self, account: NewAccount) -> Result<Account, AuthError> {
        let row = sqlx::query(
            r#"
            INSERT INTO accounts (user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope)
            VALUES ($1, $2, $3, $4, $5, $6, $7)
            RETURNING id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
            "#,
        )
        .bind(account.user_id)
        .bind(&account.provider_id)
        .bind(&account.account_id)
        .bind(&account.access_token)
        .bind(&account.refresh_token)
        .bind(account.access_token_expires_at)
        .bind(&account.scope)
        .fetch_one(&self.pool)
        .await
        .map_err(|e| AuthError::Store(e.to_string()))?;

        Ok(Account {
            id: row.get("id"),
            user_id: row.get("user_id"),
            provider_id: row.get("provider_id"),
            account_id: row.get("account_id"),
            access_token: row.get("access_token"),
            refresh_token: row.get("refresh_token"),
            access_token_expires_at: row.get("access_token_expires_at"),
            scope: row.get("scope"),
            created_at: row.get("created_at"),
            updated_at: row.get("updated_at"),
        })
    }

    async fn find_by_provider(
        &self,
        provider_id: &str,
        account_id: &str,
    ) -> Result<Option<Account>, AuthError> {
        let row = sqlx::query(
            r#"
            SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
            FROM accounts
            WHERE provider_id = $1 AND account_id = $2
            "#,
        )
        .bind(provider_id)
        .bind(account_id)
        .fetch_optional(&self.pool)
        .await
        .map_err(|e| AuthError::Store(e.to_string()))?;

        Ok(row.map(|row| Account {
            id: row.get("id"),
            user_id: row.get("user_id"),
            provider_id: row.get("provider_id"),
            account_id: row.get("account_id"),
            access_token: row.get("access_token"),
            refresh_token: row.get("refresh_token"),
            access_token_expires_at: row.get("access_token_expires_at"),
            scope: row.get("scope"),
            created_at: row.get("created_at"),
            updated_at: row.get("updated_at"),
        }))
    }

    async fn find_by_user_id(&self, user_id: i64) -> Result<Vec<Account>, AuthError> {
        let rows = sqlx::query(
            r#"
            SELECT id, user_id, provider_id, account_id, access_token, refresh_token, access_token_expires_at, scope, created_at, updated_at
            FROM accounts
            WHERE user_id = $1
            ORDER BY created_at DESC
            "#,
        )
        .bind(user_id)
        .fetch_all(&self.pool)
        .await
        .map_err(|e| AuthError::Store(e.to_string()))?;

        Ok(rows
            .into_iter()
            .map(|row| Account {
                id: row.get("id"),
                user_id: row.get("user_id"),
                provider_id: row.get("provider_id"),
                account_id: row.get("account_id"),
                access_token: row.get("access_token"),
                refresh_token: row.get("refresh_token"),
                access_token_expires_at: row.get("access_token_expires_at"),
                scope: row.get("scope"),
                created_at: row.get("created_at"),
                updated_at: row.get("updated_at"),
            })
            .collect())
    }

    async fn delete_account(&self, id: i64) -> Result<(), AuthError> {
        sqlx::query(r#"DELETE FROM accounts WHERE id = $1"#)
            .bind(id)
            .execute(&self.pool)
            .await
            .map_err(|e| AuthError::Store(e.to_string()))?;
        Ok(())
    }

    async fn update_account(
        &self,
        id: i64,
        access_token: Option<String>,
        refresh_token: Option<String>,
        access_token_expires_at: Option<OffsetDateTime>,
        scope: Option<String>,
    ) -> Result<(), AuthError> {
        sqlx::query(
            r#"
            UPDATE accounts
            SET access_token = $2, refresh_token = $3,
                access_token_expires_at = $4, scope = $5, updated_at = now()
            WHERE id = $1
            "#,
        )
        .bind(id)
        .bind(access_token)
        .bind(refresh_token)
        .bind(access_token_expires_at)
        .bind(scope)
        .execute(&self.pool)
        .await
        .map_err(|e| AuthError::Store(e.to_string()))?;
        Ok(())
    }
}