openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
//! Database-backed verification token/value helpers.

use time::OffsetDateTime;

use crate::crypto::random::generate_random_string;
use crate::db::{
    Create, DbAdapter, DbRecord, DbValue, Delete, DeleteMany, FindMany, Sort, SortDirection,
    Update, Verification, Where, WhereOperator,
};
use crate::error::OpenAuthError;

const VERIFICATION_MODEL: &str = "verification";
const DEFAULT_ID_LENGTH: usize = 32;
const VERIFICATION_FIELDS: [&str; 6] = [
    "id",
    "identifier",
    "value",
    "expires_at",
    "created_at",
    "updated_at",
];

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CreateVerificationInput {
    pub id: Option<String>,
    pub identifier: String,
    pub value: String,
    pub expires_at: OffsetDateTime,
}

impl CreateVerificationInput {
    pub fn new(
        identifier: impl Into<String>,
        value: impl Into<String>,
        expires_at: OffsetDateTime,
    ) -> Self {
        Self {
            id: None,
            identifier: identifier.into(),
            value: value.into(),
            expires_at,
        }
    }

    #[must_use]
    pub fn id(mut self, id: impl Into<String>) -> Self {
        self.id = Some(id.into());
        self
    }
}

#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct UpdateVerificationInput {
    pub value: Option<String>,
    pub expires_at: Option<OffsetDateTime>,
}

impl UpdateVerificationInput {
    pub fn new() -> Self {
        Self::default()
    }

    #[must_use]
    pub fn value(mut self, value: impl Into<String>) -> Self {
        self.value = Some(value.into());
        self
    }

    #[must_use]
    pub fn expires_at(mut self, expires_at: OffsetDateTime) -> Self {
        self.expires_at = Some(expires_at);
        self
    }
}

#[derive(Clone, Copy)]
pub struct DbVerificationStore<'a> {
    adapter: &'a dyn DbAdapter,
}

impl<'a> DbVerificationStore<'a> {
    pub fn new(adapter: &'a dyn DbAdapter) -> Self {
        Self { adapter }
    }

    pub async fn create_verification(
        &self,
        input: CreateVerificationInput,
    ) -> Result<Verification, OpenAuthError> {
        let now = OffsetDateTime::now_utc();
        let id = input
            .id
            .unwrap_or_else(|| generate_random_string(DEFAULT_ID_LENGTH));

        let record = self
            .adapter
            .create(
                Create::new(VERIFICATION_MODEL)
                    .data("id", DbValue::String(id))
                    .data("identifier", DbValue::String(input.identifier))
                    .data("value", DbValue::String(input.value))
                    .data("expires_at", DbValue::Timestamp(input.expires_at))
                    .data("created_at", DbValue::Timestamp(now))
                    .data("updated_at", DbValue::Timestamp(now))
                    .select(VERIFICATION_FIELDS)
                    .force_allow_id(),
            )
            .await?;

        verification_from_record(record)
    }

    pub async fn find_verification(
        &self,
        identifier: &str,
    ) -> Result<Option<Verification>, OpenAuthError> {
        self.delete_expired_verifications().await?;

        let Some(record) = self
            .adapter
            .find_many(
                FindMany::new(VERIFICATION_MODEL)
                    .where_clause(identifier_where(identifier))
                    .sort_by(Sort::new("created_at", SortDirection::Desc))
                    .limit(1)
                    .select(VERIFICATION_FIELDS),
            )
            .await?
            .into_iter()
            .next()
        else {
            return Ok(None);
        };

        let verification = verification_from_record(record)?;
        if verification.expires_at <= OffsetDateTime::now_utc() {
            self.delete_expired_verifications().await?;
            return Ok(None);
        }

        Ok(Some(verification))
    }

    pub async fn find_verification_including_expired(
        &self,
        identifier: &str,
    ) -> Result<Option<Verification>, OpenAuthError> {
        self.adapter
            .find_many(
                FindMany::new(VERIFICATION_MODEL)
                    .where_clause(identifier_where(identifier))
                    .sort_by(Sort::new("created_at", SortDirection::Desc))
                    .limit(1)
                    .select(VERIFICATION_FIELDS),
            )
            .await?
            .into_iter()
            .next()
            .map(verification_from_record)
            .transpose()
    }

    pub async fn update_verification(
        &self,
        identifier: &str,
        input: UpdateVerificationInput,
    ) -> Result<Option<Verification>, OpenAuthError> {
        let mut query = Update::new(VERIFICATION_MODEL).where_clause(identifier_where(identifier));

        if let Some(value) = input.value {
            query = query.data("value", DbValue::String(value));
        }
        if let Some(expires_at) = input.expires_at {
            query = query.data("expires_at", DbValue::Timestamp(expires_at));
        }
        query = query.data("updated_at", DbValue::Timestamp(OffsetDateTime::now_utc()));

        self.adapter
            .update(query)
            .await?
            .map(verification_from_record)
            .transpose()
    }

    pub async fn delete_verification(&self, identifier: &str) -> Result<(), OpenAuthError> {
        self.adapter
            .delete(Delete::new(VERIFICATION_MODEL).where_clause(identifier_where(identifier)))
            .await
    }

    pub async fn delete_expired_verifications(&self) -> Result<u64, OpenAuthError> {
        self.adapter
            .delete_many(
                DeleteMany::new(VERIFICATION_MODEL).where_clause(
                    Where::new("expires_at", DbValue::Timestamp(OffsetDateTime::now_utc()))
                        .operator(WhereOperator::Lt),
                ),
            )
            .await
    }
}

fn identifier_where(identifier: &str) -> Where {
    Where::new("identifier", DbValue::String(identifier.to_owned()))
}

fn verification_from_record(record: DbRecord) -> Result<Verification, OpenAuthError> {
    Ok(Verification {
        id: required_string(&record, "id")?.to_owned(),
        identifier: required_string(&record, "identifier")?.to_owned(),
        value: required_string(&record, "value")?.to_owned(),
        expires_at: required_timestamp(&record, "expires_at")?,
        created_at: required_timestamp(&record, "created_at")?,
        updated_at: required_timestamp(&record, "updated_at")?,
    })
}

fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::String(value)) => Ok(value),
        Some(_) => Err(invalid_field(field, "string")),
        None => Err(missing_field(field)),
    }
}

fn required_timestamp(record: &DbRecord, field: &str) -> Result<OffsetDateTime, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::Timestamp(value)) => Ok(*value),
        Some(_) => Err(invalid_field(field, "timestamp")),
        None => Err(missing_field(field)),
    }
}

fn missing_field(field: &str) -> OpenAuthError {
    OpenAuthError::Adapter(format!("verification record is missing `{field}`"))
}

fn invalid_field(field: &str, expected: &str) -> OpenAuthError {
    OpenAuthError::Adapter(format!(
        "verification record field `{field}` must be {expected}"
    ))
}