torii-storage-postgres 0.5.2

Postgres storage backend for the torii authentication ecosystem
Documentation
use crate::{PostgresStorage, PostgresUser};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use std::time::Duration;
use torii_core::error::{StorageError, ValidationError};
use torii_core::storage::OAuthStorage;
use torii_core::{OAuthAccount, User, UserId};

#[derive(Default)]
pub struct PostgresOAuthAccountBuilder {
    user_id: Option<UserId>,
    provider: Option<String>,
    subject: Option<String>,
    created_at: Option<DateTime<Utc>>,
    updated_at: Option<DateTime<Utc>>,
}

impl PostgresOAuthAccountBuilder {
    pub fn user_id(mut self, user_id: UserId) -> Self {
        self.user_id = Some(user_id);
        self
    }

    pub fn provider(mut self, provider: String) -> Self {
        self.provider = Some(provider);
        self
    }

    pub fn subject(mut self, subject: String) -> Self {
        self.subject = Some(subject);
        self
    }

    pub fn created_at(mut self, created_at: DateTime<Utc>) -> Self {
        self.created_at = Some(created_at);
        self
    }

    pub fn updated_at(mut self, updated_at: DateTime<Utc>) -> Self {
        self.updated_at = Some(updated_at);
        self
    }

    pub fn build(self) -> Result<PostgresOAuthAccount, torii_core::Error> {
        let now = Utc::now();
        Ok(PostgresOAuthAccount {
            id: None,
            user_id: self
                .user_id
                .ok_or(ValidationError::MissingField(
                    "User ID is required".to_string(),
                ))?
                .to_string(),
            provider: self.provider.ok_or(ValidationError::MissingField(
                "Provider is required".to_string(),
            ))?,
            subject: self.subject.ok_or(ValidationError::MissingField(
                "Subject is required".to_string(),
            ))?,
            created_at: self.created_at.unwrap_or(now),
            updated_at: self.updated_at.unwrap_or(now),
        })
    }
}

#[derive(Debug, Clone, sqlx::FromRow)]
pub struct PostgresOAuthAccount {
    pub id: Option<i64>,
    pub user_id: String,
    pub provider: String,
    pub subject: String,
    pub created_at: DateTime<Utc>,
    pub updated_at: DateTime<Utc>,
}

impl PostgresOAuthAccount {
    pub fn builder() -> PostgresOAuthAccountBuilder {
        PostgresOAuthAccountBuilder::default()
    }

    pub fn new(user_id: UserId, provider: impl Into<String>, subject: impl Into<String>) -> Self {
        PostgresOAuthAccountBuilder::default()
            .user_id(user_id)
            .provider(provider.into())
            .subject(subject.into())
            .build()
            .expect("Default builder should never fail")
    }

    pub fn is_expired(&self, ttl: Duration) -> bool {
        Utc::now() > self.created_at + ttl
    }
}

impl From<PostgresOAuthAccount> for OAuthAccount {
    fn from(oauth_account: PostgresOAuthAccount) -> Self {
        OAuthAccount::builder()
            .user_id(UserId::new(&oauth_account.user_id))
            .provider(oauth_account.provider)
            .subject(oauth_account.subject)
            .created_at(oauth_account.created_at)
            .updated_at(oauth_account.updated_at)
            .build()
            .expect("Default builder should never fail")
    }
}

impl From<OAuthAccount> for PostgresOAuthAccount {
    fn from(oauth_account: OAuthAccount) -> Self {
        PostgresOAuthAccount::builder()
            .user_id(oauth_account.user_id)
            .provider(oauth_account.provider)
            .subject(oauth_account.subject)
            .created_at(oauth_account.created_at)
            .updated_at(oauth_account.updated_at)
            .build()
            .expect("Default builder should never fail")
    }
}

#[async_trait]
impl OAuthStorage for PostgresStorage {
    async fn create_oauth_account(
        &self,
        provider: &str,
        subject: &str,
        user_id: &UserId,
    ) -> Result<OAuthAccount, torii_core::Error> {
        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
            .bind(user_id.as_str())
            .bind(provider)
            .bind(subject)
            .bind(Utc::now())
            .bind(Utc::now())
            .execute(&self.pool)
            .await
            .map_err(|e| {
                tracing::error!(error = %e, "Failed to create oauth account");
                StorageError::Database("Failed to create oauth account".to_string())
            })?;

        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
            r#"
            SELECT id, user_id, provider, subject, created_at, updated_at
            FROM oauth_accounts
            WHERE user_id = $1
            "#,
        )
        .bind(user_id.as_str())
        .fetch_one(&self.pool)
        .await
        .map_err(|e| {
            tracing::error!(error = %e, "Failed to get oauth account");
            StorageError::Database("Failed to get oauth account".to_string())
        })?;

        Ok(oauth_account.into())
    }

    async fn store_pkce_verifier(
        &self,
        csrf_state: &str,
        pkce_verifier: &str,
        expires_in: chrono::Duration,
    ) -> Result<(), torii_core::Error> {
        sqlx::query(
            "INSERT INTO oauth_state (csrf_state, pkce_verifier, expires_at) VALUES ($1, $2, $3) RETURNING value",
        )
        .bind(csrf_state)
        .bind(pkce_verifier)
        .bind(Utc::now() + expires_in)
        .fetch_optional(&self.pool)
        .await
        .map_err(|e| {
            tracing::error!(error = %e, "Failed to store pkce verifier");
            StorageError::Database("Failed to store pkce verifier".to_string())
        })?;

        Ok(())
    }

    async fn get_pkce_verifier(
        &self,
        csrf_state: &str,
    ) -> Result<Option<String>, torii_core::Error> {
        let pkce_verifier =
            sqlx::query_scalar("SELECT pkce_verifier FROM oauth_state WHERE csrf_state = $1")
                .bind(csrf_state)
                .fetch_optional(&self.pool)
                .await
                .map_err(|e| {
                    tracing::error!(error = %e, "Failed to get pkce verifier");
                    StorageError::Database("Failed to get pkce verifier".to_string())
                })?;

        Ok(pkce_verifier)
    }

    async fn get_oauth_account_by_provider_and_subject(
        &self,
        provider: &str,
        subject: &str,
    ) -> Result<Option<OAuthAccount>, torii_core::Error> {
        let oauth_account = sqlx::query_as::<_, PostgresOAuthAccount>(
            r#"
            SELECT id, user_id, provider, subject, created_at, updated_at
            FROM oauth_accounts
            WHERE provider = $1 AND subject = $2
            "#,
        )
        .bind(provider)
        .bind(subject)
        .fetch_optional(&self.pool)
        .await
        .map_err(|e| {
            tracing::error!(error = %e, "Failed to get oauth account");
            StorageError::Database("Failed to get oauth account".to_string())
        })?;

        if let Some(oauth_account) = oauth_account {
            Ok(Some(oauth_account.into()))
        } else {
            Ok(None)
        }
    }

    async fn get_user_by_provider_and_subject(
        &self,
        provider: &str,
        subject: &str,
    ) -> Result<Option<User>, torii_core::Error> {
        let user = sqlx::query_as::<_, PostgresUser>(
            r#"
            SELECT id, email, name, email_verified_at, created_at, updated_at
            FROM users
            WHERE provider = $1 AND subject = $2
            "#,
        )
        .bind(provider)
        .bind(subject)
        .fetch_optional(&self.pool)
        .await
        .map_err(|e| {
            tracing::error!(error = %e, "Failed to get user by provider and subject");
            StorageError::Database("Failed to get user by provider and subject".to_string())
        })?;

        if let Some(user) = user {
            Ok(Some(user.into()))
        } else {
            Ok(None)
        }
    }

    async fn link_oauth_account(
        &self,
        user_id: &UserId,
        provider: &str,
        subject: &str,
    ) -> Result<(), torii_core::Error> {
        sqlx::query("INSERT INTO oauth_accounts (user_id, provider, subject, created_at, updated_at) VALUES ($1, $2, $3, $4, $5)")
            .bind(user_id.as_str())
            .bind(provider)
            .bind(subject)
            .bind(Utc::now())
            .bind(Utc::now())
            .execute(&self.pool)
            .await
            .map_err(|e| {
                tracing::error!(error = %e, "Failed to link oauth account");
                StorageError::Database("Failed to link oauth account".to_string())
            })?;

        Ok(())
    }
}