openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use openauth_core::crypto::random::generate_random_string;
use openauth_core::db::{
    Create, DbAdapter, DbRecord, DbValue, Delete, FindOne, Update, User, Where,
};
use openauth_core::error::OpenAuthError;
use openauth_core::user::DbUserStore;

const USER_MODEL: &str = "user";

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TwoFactorRecord {
    pub id: String,
    pub user_id: String,
    pub secret: String,
    pub backup_codes: String,
    pub verified: Option<bool>,
}

#[derive(Clone, Copy)]
pub struct TwoFactorStore<'a> {
    adapter: &'a dyn DbAdapter,
    table: &'a str,
}

impl<'a> TwoFactorStore<'a> {
    pub fn new(adapter: &'a dyn DbAdapter, table: &'a str) -> Self {
        Self { adapter, table }
    }

    pub async fn find_by_user(
        &self,
        user_id: &str,
    ) -> Result<Option<TwoFactorRecord>, OpenAuthError> {
        self.adapter
            .find_one(
                FindOne::new(self.table)
                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
            )
            .await?
            .map(record_from_db)
            .transpose()
    }

    pub async fn upsert_for_user(
        &self,
        user_id: &str,
        secret: String,
        backup_codes: String,
        verified: bool,
    ) -> Result<TwoFactorRecord, OpenAuthError> {
        if let Some(existing) = self.find_by_user(user_id).await? {
            let Some(record) = self
                .adapter
                .update(
                    Update::new(self.table)
                        .where_clause(Where::new("id", DbValue::String(existing.id)))
                        .data("secret", DbValue::String(secret))
                        .data("backup_codes", DbValue::String(backup_codes))
                        .data("verified", DbValue::Boolean(verified)),
                )
                .await?
            else {
                return Err(OpenAuthError::Adapter(
                    "two factor update failed".to_owned(),
                ));
            };
            return record_from_db(record);
        }

        let record = self
            .adapter
            .create(
                Create::new(self.table)
                    .data("id", DbValue::String(generate_random_string(32)))
                    .data("user_id", DbValue::String(user_id.to_owned()))
                    .data("secret", DbValue::String(secret))
                    .data("backup_codes", DbValue::String(backup_codes))
                    .data("verified", DbValue::Boolean(verified))
                    .force_allow_id(),
            )
            .await?;
        record_from_db(record)
    }

    pub async fn mark_verified(&self, id: &str) -> Result<(), OpenAuthError> {
        self.adapter
            .update(
                Update::new(self.table)
                    .where_clause(Where::new("id", DbValue::String(id.to_owned())))
                    .data("verified", DbValue::Boolean(true)),
            )
            .await?;
        Ok(())
    }

    pub async fn update_backup_codes_if_current(
        &self,
        id: &str,
        current: &str,
        next: String,
    ) -> Result<bool, OpenAuthError> {
        let updated = self
            .adapter
            .update(
                Update::new(self.table)
                    .where_clause(Where::new("id", DbValue::String(id.to_owned())))
                    .where_clause(Where::new(
                        "backup_codes",
                        DbValue::String(current.to_owned()),
                    ))
                    .data("backup_codes", DbValue::String(next)),
            )
            .await?;
        Ok(updated.is_some())
    }

    pub async fn delete_for_user(&self, user_id: &str) -> Result<(), OpenAuthError> {
        self.adapter
            .delete(
                Delete::new(self.table)
                    .where_clause(Where::new("user_id", DbValue::String(user_id.to_owned()))),
            )
            .await
    }
}

pub async fn find_user_raw(
    adapter: &dyn DbAdapter,
    user_id: &str,
) -> Result<Option<DbRecord>, OpenAuthError> {
    adapter
        .find_one(
            FindOne::new(USER_MODEL)
                .where_clause(Where::new("id", DbValue::String(user_id.to_owned()))),
        )
        .await
}

pub async fn user_two_factor_enabled(
    adapter: &dyn DbAdapter,
    user_id: &str,
) -> Result<bool, OpenAuthError> {
    Ok(find_user_raw(adapter, user_id)
        .await?
        .and_then(|record| match record.get("two_factor_enabled") {
            Some(DbValue::Boolean(value)) => Some(*value),
            _ => None,
        })
        .unwrap_or(false))
}

pub async fn update_user_two_factor_enabled(
    adapter: &dyn DbAdapter,
    user_id: &str,
    enabled: bool,
) -> Result<User, OpenAuthError> {
    adapter
        .update(
            Update::new(USER_MODEL)
                .where_clause(Where::new("id", DbValue::String(user_id.to_owned())))
                .data("two_factor_enabled", DbValue::Boolean(enabled)),
        )
        .await?;
    DbUserStore::new(adapter)
        .find_user_by_id(user_id)
        .await?
        .ok_or_else(|| OpenAuthError::Adapter("user not found".to_owned()))
}

pub async fn credential_password_hash(
    adapter: &dyn DbAdapter,
    user_id: &str,
) -> Result<Option<String>, OpenAuthError> {
    Ok(DbUserStore::new(adapter)
        .find_credential_account(user_id)
        .await?
        .and_then(|account| account.password))
}

fn record_from_db(record: DbRecord) -> Result<TwoFactorRecord, OpenAuthError> {
    Ok(TwoFactorRecord {
        id: required_string(&record, "id")?.to_owned(),
        user_id: required_string(&record, "user_id")?.to_owned(),
        secret: required_string(&record, "secret")?.to_owned(),
        backup_codes: required_string(&record, "backup_codes")?.to_owned(),
        verified: optional_bool(&record, "verified")?,
    })
}

fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::String(value)) => Ok(value),
        Some(_) => Err(OpenAuthError::Adapter(format!(
            "two factor field `{field}` must be string"
        ))),
        None => Err(OpenAuthError::Adapter(format!(
            "two factor record is missing `{field}`"
        ))),
    }
}

fn optional_bool(record: &DbRecord, field: &str) -> Result<Option<bool>, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::Boolean(value)) => Ok(Some(*value)),
        Some(DbValue::Null) | None => Ok(None),
        Some(_) => Err(OpenAuthError::Adapter(format!(
            "two factor field `{field}` must be boolean"
        ))),
    }
}