openauth-plugins 0.0.4

Official OpenAuth plugin modules.
Documentation
use std::sync::Arc;

use http::{header, Method, StatusCode};
use openauth_core::api::{
    create_auth_endpoint, parse_request_body, ApiRequest, ApiResponse, AuthEndpointOptions,
    BodyField, BodySchema, JsonSchemaType,
};
use openauth_core::context::{request_state, AuthContext};
use openauth_core::db::{DbAdapter, DbRecord, DbValue};
use openauth_core::error::OpenAuthError;
use openauth_core::plugin::PluginEndpoint;
use openauth_core::session::{CreateSessionInput, DbSessionStore};
use openauth_core::user::DbUserStore;
use serde::{Deserialize, Serialize};
use time::{Duration, OffsetDateTime};

use crate::device_authorization::errors::{oauth_error_response, OAuthDeviceError};
use crate::device_authorization::options::DeviceAuthorizationOptions;
use crate::device_authorization::store::{
    DeviceAuthorizationStatus, DeviceCodeRecord, DeviceCodeStore,
};

const DEVICE_CODE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code";

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct DeviceTokenRequest {
    pub grant_type: String,
    pub device_code: String,
    pub client_id: String,
}

#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct DeviceTokenResponse {
    pub access_token: String,
    pub token_type: String,
    pub expires_in: i64,
    pub scope: String,
}

pub fn device_token(options: Arc<DeviceAuthorizationOptions>) -> PluginEndpoint {
    create_auth_endpoint(
        "/device/token",
        Method::POST,
        AuthEndpointOptions::new()
            .operation_id("deviceToken")
            .allowed_media_types(["application/json", "application/x-www-form-urlencoded"])
            .openapi(super::openapi::device_token_operation())
            .body_schema(BodySchema::object([
                BodyField::new("grant_type", JsonSchemaType::String),
                BodyField::new("device_code", JsonSchemaType::String),
                BodyField::new("client_id", JsonSchemaType::String),
            ])),
        move |context, request| {
            let options = Arc::clone(&options);
            Box::pin(async move {
                let body = parse_request_body::<DeviceTokenRequest>(&request)?;
                if body.grant_type != DEVICE_CODE_GRANT_TYPE {
                    return token_oauth_error_response(
                        StatusCode::BAD_REQUEST,
                        OAuthDeviceError::InvalidRequest,
                        "Invalid grant type",
                    );
                }
                if let Some(validate_client) = &options.validate_client {
                    if !(validate_client)(body.client_id.clone()).await? {
                        return token_oauth_error_response(
                            StatusCode::BAD_REQUEST,
                            OAuthDeviceError::InvalidGrant,
                            "Invalid client ID",
                        );
                    }
                }

                let adapter = required_adapter(context)?;
                let store = DeviceCodeStore::new(adapter.as_ref());
                let Some(record) = store.find_by_device_code(&body.device_code).await? else {
                    return token_oauth_error_response(
                        StatusCode::BAD_REQUEST,
                        OAuthDeviceError::InvalidGrant,
                        "Invalid device code",
                    );
                };
                if record
                    .client_id
                    .as_ref()
                    .is_some_and(|client_id| client_id != &body.client_id)
                {
                    return token_oauth_error_response(
                        StatusCode::BAD_REQUEST,
                        OAuthDeviceError::InvalidGrant,
                        "Client ID mismatch",
                    );
                }
                if polling_too_fast(&record) {
                    return token_oauth_error_response(
                        StatusCode::BAD_REQUEST,
                        OAuthDeviceError::SlowDown,
                        "Polling too frequently",
                    );
                }
                store.mark_polled(&record.id).await?;
                if record.expires_at < OffsetDateTime::now_utc() {
                    store.delete(&record.id).await?;
                    return token_oauth_error_response(
                        StatusCode::BAD_REQUEST,
                        OAuthDeviceError::ExpiredToken,
                        "Device code has expired",
                    );
                }

                match record.status {
                    DeviceAuthorizationStatus::Pending => token_oauth_error_response(
                        StatusCode::BAD_REQUEST,
                        OAuthDeviceError::AuthorizationPending,
                        "Authorization pending",
                    ),
                    DeviceAuthorizationStatus::Denied => {
                        store.delete(&record.id).await?;
                        token_oauth_error_response(
                            StatusCode::BAD_REQUEST,
                            OAuthDeviceError::AccessDenied,
                            "Access denied",
                        )
                    }
                    DeviceAuthorizationStatus::Approved => {
                        approved_response(context, adapter.as_ref(), &store, &record, &request)
                            .await
                    }
                }
            })
        },
    )
}

pub(super) fn required_adapter(context: &AuthContext) -> Result<Arc<dyn DbAdapter>, OpenAuthError> {
    context.adapter().ok_or_else(|| {
        OpenAuthError::Adapter("device authorization requires a database adapter".to_owned())
    })
}

async fn approved_response(
    context: &AuthContext,
    adapter: &dyn DbAdapter,
    store: &DeviceCodeStore<'_>,
    record: &DeviceCodeRecord,
    request: &ApiRequest,
) -> Result<ApiResponse, OpenAuthError> {
    let Some(user_id) = &record.user_id else {
        return token_oauth_error_response(
            StatusCode::INTERNAL_SERVER_ERROR,
            OAuthDeviceError::ServerError,
            "Invalid device code status",
        );
    };
    let Some(user) = DbUserStore::new(adapter).find_user_by_id(user_id).await? else {
        return token_oauth_error_response(
            StatusCode::INTERNAL_SERVER_ERROR,
            OAuthDeviceError::ServerError,
            "User not found",
        );
    };
    let expires_at =
        OffsetDateTime::now_utc() + Duration::seconds(context.session_config.expires_in as i64);
    let mut input = CreateSessionInput::new(user.id.clone(), expires_at)
        .additional_fields(additional_session_create_values(context));
    if let Some(ip_address) = request_ip(request) {
        input = input.ip_address(ip_address);
    }
    if let Some(user_agent) = request_user_agent(request) {
        input = input.user_agent(user_agent);
    }
    let session = match DbSessionStore::new(adapter).create_session(input).await {
        Ok(session) => session,
        Err(_) => {
            return token_oauth_error_response(
                StatusCode::INTERNAL_SERVER_ERROR,
                OAuthDeviceError::ServerError,
                "Failed to create session",
            );
        }
    };
    if request_state::has_request_state() {
        request_state::set_current_new_session(session.clone(), user)?;
    }
    store.delete(&record.id).await?;

    let expires_in = (session.expires_at - OffsetDateTime::now_utc())
        .whole_seconds()
        .max(0);
    let mut response = super::json_response(
        StatusCode::OK,
        &DeviceTokenResponse {
            access_token: session.token,
            token_type: "Bearer".to_owned(),
            expires_in,
            scope: record.scope.clone().unwrap_or_default(),
        },
    )?;
    add_token_cache_headers(&mut response);
    Ok(response)
}

fn polling_too_fast(record: &DeviceCodeRecord) -> bool {
    let (Some(last_polled_at), Some(interval)) = (record.last_polled_at, record.polling_interval)
    else {
        return false;
    };
    let elapsed = OffsetDateTime::now_utc() - last_polled_at;
    elapsed.whole_milliseconds() < i128::from(interval)
}

fn token_oauth_error_response(
    status: StatusCode,
    error: OAuthDeviceError,
    description: &str,
) -> Result<ApiResponse, OpenAuthError> {
    let mut response = oauth_error_response(status, error, description)?;
    add_token_cache_headers(&mut response);
    Ok(response)
}

fn add_token_cache_headers(response: &mut ApiResponse) {
    response.headers_mut().insert(
        header::CACHE_CONTROL,
        http::HeaderValue::from_static("no-store"),
    );
    response
        .headers_mut()
        .insert(header::PRAGMA, http::HeaderValue::from_static("no-cache"));
}

fn additional_session_create_values(context: &AuthContext) -> DbRecord {
    context
        .options
        .session
        .additional_fields
        .iter()
        .map(|(name, field)| {
            (
                name.clone(),
                field.default_value.clone().unwrap_or(DbValue::Null),
            )
        })
        .collect()
}

fn request_user_agent(request: &ApiRequest) -> Option<String> {
    request
        .headers()
        .get(header::USER_AGENT)
        .and_then(|value| value.to_str().ok())
        .map(str::to_owned)
}

fn request_ip(request: &ApiRequest) -> Option<String> {
    request
        .headers()
        .get("x-forwarded-for")
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.split(',').next())
        .map(str::trim)
        .filter(|value| !value.is_empty())
        .map(str::to_owned)
        .or_else(|| {
            request
                .headers()
                .get("x-real-ip")
                .and_then(|value| value.to_str().ok())
                .map(str::to_owned)
        })
}