openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, Method, StatusCode};
use openauth_core::api::{
    create_auth_endpoint, parse_request_body, AuthEndpointOptions, BodyField, BodySchema,
    JsonSchemaType,
};
use openauth_core::context::AuthContext;
use openauth_core::crypto::random::generate_random_string;
use openauth_core::db::DbAdapter;
use openauth_core::error::OpenAuthError;
use openauth_core::plugin::PluginEndpoint;
use rand::rngs::OsRng;
use rand::RngCore;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use url::Url;

use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
use crate::device_authorization::options::{AsyncDeviceCodeGenerator, DeviceAuthorizationOptions};
use crate::device_authorization::store::{CreateDeviceCodeInput, DeviceCodeStore};

const DEFAULT_USER_CODE_CHARSET: &[u8] = b"ABCDEFGHJKLMNPQRSTUVWXYZ23456789";

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct DeviceCodeRequest {
    pub client_id: String,
    pub scope: Option<String>,
}

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct DeviceCodeResponse {
    pub device_code: String,
    pub user_code: String,
    pub verification_uri: String,
    pub verification_uri_complete: String,
    pub expires_in: i64,
    pub interval: i64,
}

pub fn device_code(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
    create_auth_endpoint(
        "/device/code",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("deviceCode")
            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
            .openapi(super::openapi::device_code_operation())
            .body_schema(BodySchema::object([
                BodyField::new("client_id", JsonSchemaType::String),
                BodyField::optional("scope", JsonSchemaType::String),
            ])),
        move |context, request| {
            let options = Arc::clone(&options);
            Box::pin(async move {
                let body = parse_request_body::<DeviceCodeRequest>(&request)?;
                if let Some(validate_client) = &options.validate_client {
                    if !(validate_client)(body.client_id.clone()).await? {
                        return oauth_error_response(
                            StatusCode::BAD_REQUEST,
                            OAuthDeviceError::InvalidClient,
                            "Invalid client ID",
                        );
                    }
                }
                if let Some(hook) = &options.on_device_auth_request {
                    (hook)(body.client_id.clone(), body.scope.clone()).await?;
                }

                let adapter = required_adapter(context)?;
                let device_code = generate_code(
                    options.generate_device_code.as_ref(),
                    options.device_code_length,
                    default_device_code,
                )
                .await;
                let user_code = generate_code(
                    options.generate_user_code.as_ref(),
                    options.user_code_length,
                    default_user_code,
                )
                .await;
                let expires_in = options.expires_in.whole_seconds();
                let interval = options.interval.whole_seconds();
                let polling_interval = i64::try_from(options.interval.whole_milliseconds())
                    .map_err(|_| {
                        OpenAuthError::InvalidConfig(
                            "device authorization interval is too large".to_owned(),
                        )
                    })?;

                DeviceCodeStore::new(adapter.as_ref())
                    .create(CreateDeviceCodeInput {
                        device_code: device_code.clone(),
                        user_code: super::clean_user_code(&user_code),
                        expires_at: OffsetDateTime::now_utc() + options.expires_in,
                        polling_interval,
                        client_id: body.client_id,
                        scope: body.scope,
                    })
                    .await?;

                let (verification_uri, verification_uri_complete) = build_verification_uris(
                    &options.verification_uri,
                    &context.base_url,
                    &user_code,
                )?;
                let mut response = super::json_response(
                    StatusCode::OK,
                    &DeviceCodeResponse {
                        device_code,
                        user_code,
                        verification_uri,
                        verification_uri_complete,
                        expires_in,
                        interval,
                    },
                )?;
                response.headers_mut().insert(
                    header::CACHE_CONTROL,
                    http::HeaderValue::from_static("no-store"),
                );
                Ok(response)
            })
        },
    )
}

fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
    context.adapter().ok_or_else(|| {
        OpenAuthError::Adapter("device authorization requires a database adapter".to_owned())
    })
}

async fn generate_code(
    generator: Option<&AsyncDeviceCodeGenerator>,
    length: usize,
    fallback: fn(usize) -> String,
) -> String {
    match generator {
        Some(generator) => generator().await,
        None => fallback(length),
    }
}

fn default_device_code(length: usize) -> String {
    generate_random_string(length)
}

fn default_user_code(length: usize) -> String {
    let mut bytes = vec![0_u8; length];
    OsRng.fill_bytes(&mut bytes);
    bytes
        .into_iter()
        .map(|byte| {
            let index = usize::from(byte) % DEFAULT_USER_CODE_CHARSET.len();
            char::from(DEFAULT_USER_CODE_CHARSET[index])
        })
        .collect()
}

fn build_verification_uris(
    verification_uri: &str,
    base_url: &str,
    user_code: &str,
) -> Result<(String, String), OpenAuthError> {
    let verification_url = Url::parse(verification_uri)
        .or_else(|_| Url::parse(base_url).and_then(|base| base.join(verification_uri)))
        .map_err(|error| OpenAuthError::InvalidConfig(error.to_string()))?;
    let mut complete = verification_url.clone();
    complete
        .query_pairs_mut()
        .append_pair("user_code", user_code);
    Ok((verification_url.to_string(), complete.to_string()))
}