openauth-plugins 0.0.3

Official OpenAuth plugin modules.
Documentation
use openauth_core::context::AuthContext;
use openauth_core::db::{Create, DbAdapter, DbRecord, DbValue, FindMany, Sort, SortDirection};
use openauth_core::error::OpenAuthError;

use super::{Jwk, JwkAlgorithm, JwtOptions};

const MODEL: &str = "jwks";
const FIELDS: [&str; 7] = [
    "id",
    "public_key",
    "private_key",
    "created_at",
    "expires_at",
    "alg",
    "crv",
];

pub(crate) async fn get_all_keys(
    context: &AuthContext,
    options: &JwtOptions,
) -> Result<Vec<Jwk>, OpenAuthError> {
    if let Some(get_jwks) = &options.adapter.get_jwks {
        return get_jwks(context).await;
    }
    let adapter = adapter(context)?;
    adapter
        .find_many(FindMany::new(MODEL).select(FIELDS))
        .await?
        .into_iter()
        .map(jwk_from_record)
        .collect()
}

pub(crate) async fn get_latest_key(
    context: &AuthContext,
    options: &JwtOptions,
) -> Result<Option<Jwk>, OpenAuthError> {
    if options.adapter.get_jwks.is_some() {
        return Ok(get_all_keys(context, options)
            .await?
            .into_iter()
            .max_by_key(|key| key.created_at));
    }
    let adapter = adapter(context)?;
    adapter
        .find_many(
            FindMany::new(MODEL)
                .sort_by(Sort::new("created_at", SortDirection::Desc))
                .limit(1)
                .select(FIELDS),
        )
        .await?
        .into_iter()
        .next()
        .map(jwk_from_record)
        .transpose()
}

pub(crate) async fn create_jwk(
    context: &AuthContext,
    options: &JwtOptions,
    jwk: Jwk,
) -> Result<Jwk, OpenAuthError> {
    if let Some(create_jwk) = &options.adapter.create_jwk {
        return create_jwk(context, jwk).await;
    }
    let adapter = adapter(context)?;
    let record = adapter
        .create(
            Create::new(MODEL)
                .data("id", DbValue::String(jwk.id))
                .data("public_key", DbValue::String(jwk.public_key))
                .data("private_key", DbValue::String(jwk.private_key))
                .data("created_at", DbValue::Timestamp(jwk.created_at))
                .data(
                    "expires_at",
                    jwk.expires_at
                        .map(DbValue::Timestamp)
                        .unwrap_or(DbValue::Null),
                )
                .data(
                    "alg",
                    jwk.alg
                        .map(|alg| DbValue::String(alg.as_str().to_owned()))
                        .unwrap_or(DbValue::Null),
                )
                .data("crv", jwk.crv.map(DbValue::String).unwrap_or(DbValue::Null))
                .select(FIELDS)
                .force_allow_id(),
        )
        .await?;
    jwk_from_record(record)
}

fn adapter(context: &AuthContext) -> Result<std::sync::Arc<dyn DbAdapter>, OpenAuthError> {
    context.adapter().ok_or_else(|| {
        OpenAuthError::InvalidConfig("JWT plugin requires a database adapter".to_owned())
    })
}

fn jwk_from_record(record: DbRecord) -> Result<Jwk, OpenAuthError> {
    Ok(Jwk {
        id: required_string(&record, "id")?.to_owned(),
        public_key: required_string(&record, "public_key")?.to_owned(),
        private_key: required_string(&record, "private_key")?.to_owned(),
        created_at: required_timestamp(&record, "created_at")?,
        expires_at: optional_timestamp(&record, "expires_at")?,
        alg: optional_string(&record, "alg")?
            .map(parse_algorithm)
            .transpose()?,
        crv: optional_string(&record, "crv")?.map(str::to_owned),
    })
}

fn parse_algorithm(value: &str) -> Result<JwkAlgorithm, OpenAuthError> {
    match value {
        "EdDSA" => Ok(JwkAlgorithm::EdDsa),
        "ES256" => Ok(JwkAlgorithm::Es256),
        "ES512" => Ok(JwkAlgorithm::Es512),
        "RS256" => Ok(JwkAlgorithm::Rs256),
        "PS256" => Ok(JwkAlgorithm::Ps256),
        other => Err(OpenAuthError::Adapter(format!(
            "unsupported JWK alg `{other}`"
        ))),
    }
}

fn required_string<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::String(value)) => Ok(value),
        Some(_) => Err(invalid_field(field, "string")),
        None => Err(missing_field(field)),
    }
}

fn optional_string<'a>(
    record: &'a DbRecord,
    field: &str,
) -> Result<Option<&'a str>, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::String(value)) => Ok(Some(value)),
        Some(DbValue::Null) | None => Ok(None),
        Some(_) => Err(invalid_field(field, "string or null")),
    }
}

fn required_timestamp(
    record: &DbRecord,
    field: &str,
) -> Result<time::OffsetDateTime, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::Timestamp(value)) => Ok(*value),
        Some(_) => Err(invalid_field(field, "timestamp")),
        None => Err(missing_field(field)),
    }
}

fn optional_timestamp(
    record: &DbRecord,
    field: &str,
) -> Result<Option<time::OffsetDateTime>, OpenAuthError> {
    match record.get(field) {
        Some(DbValue::Timestamp(value)) => Ok(Some(*value)),
        Some(DbValue::Null) | None => Ok(None),
        Some(_) => Err(invalid_field(field, "timestamp or null")),
    }
}

fn missing_field(field: &str) -> OpenAuthError {
    OpenAuthError::Adapter(format!("jwks record is missing `{field}`"))
}

fn invalid_field(field: &str, expected: &str) -> OpenAuthError {
    OpenAuthError::Adapter(format!("jwks record field `{field}` must be {expected}"))
}