openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use http::{header, Request, Response, StatusCode};
use serde_json::Value;

use crate::context::AuthContext;
use crate::error::OpenAuthError;

use super::body::parse_request_body;
use super::error::ApiErrorResponse;
use super::openapi::OpenApiOperation;
use super::schema::BodySchema;

pub type Body = Vec<u8>;
pub type ApiRequest = Request<Body>;
pub type ApiResponse = Response<Body>;
pub type EndpointHandler = fn(&AuthContext, ApiRequest) -> Result<ApiResponse, OpenAuthError>;
pub type EndpointFuture<'a> =
    Pin<Box<dyn Future<Output = Result<ApiResponse, OpenAuthError>> + Send + 'a>>;
pub type AsyncEndpointHandler =
    Arc<dyn for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync>;
pub type EndpointMiddlewareFuture<'a> =
    Pin<Box<dyn Future<Output = Result<Option<ApiResponse>, OpenAuthError>> + Send + 'a>>;
pub type EndpointMiddlewareHandler = Arc<
    dyn for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a> + Send + Sync,
>;

#[derive(Clone)]
pub struct EndpointMiddleware {
    pub handler: EndpointMiddlewareHandler,
}

impl EndpointMiddleware {
    pub fn new<F>(handler: F) -> Self
    where
        F: for<'a> Fn(&'a AuthContext, &'a ApiRequest) -> EndpointMiddlewareFuture<'a>
            + Send
            + Sync
            + 'static,
    {
        Self {
            handler: Arc::new(handler),
        }
    }
}

#[derive(Clone, Default)]
pub struct AuthEndpointOptions {
    pub operation_id: Option<String>,
    pub allowed_media_types: Vec<String>,
    pub body_schema: Option<BodySchema>,
    pub middlewares: Vec<EndpointMiddleware>,
    pub openapi: Option<OpenApiOperation>,
    pub server_only: bool,
    pub hide_from_openapi: bool,
}

impl AuthEndpointOptions {
    pub fn new() -> Self {
        Self::default()
    }

    #[must_use]
    pub fn operation_id(mut self, operation_id: impl Into<String>) -> Self {
        self.operation_id = Some(operation_id.into());
        self
    }

    #[must_use]
    pub fn allowed_media_types<I, S>(mut self, media_types: I) -> Self
    where
        I: IntoIterator<Item = S>,
        S: Into<String>,
    {
        self.allowed_media_types = media_types.into_iter().map(Into::into).collect();
        self
    }

    #[must_use]
    pub fn body_schema(mut self, schema: BodySchema) -> Self {
        self.body_schema = Some(schema);
        self
    }

    #[must_use]
    pub fn middleware(mut self, middleware: EndpointMiddleware) -> Self {
        self.middlewares.push(middleware);
        self
    }

    #[must_use]
    pub fn openapi(mut self, operation: OpenApiOperation) -> Self {
        self.openapi = Some(operation);
        self
    }

    #[must_use]
    pub fn server_only(mut self) -> Self {
        self.server_only = true;
        self
    }

    #[must_use]
    pub fn hide_from_openapi(mut self) -> Self {
        self.hide_from_openapi = true;
        self
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EndpointKind {
    Sync,
    Async,
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EndpointInfo {
    pub path: String,
    pub method: http::Method,
    pub kind: EndpointKind,
    pub operation_id: Option<String>,
    pub allowed_media_types: Vec<String>,
}

#[derive(Clone)]
pub struct AuthEndpoint {
    pub path: String,
    pub method: http::Method,
    pub handler: EndpointHandler,
}

#[derive(Clone)]
pub struct AsyncAuthEndpoint {
    pub path: String,
    pub method: http::Method,
    pub handler: AsyncEndpointHandler,
    pub options: AuthEndpointOptions,
}

impl AsyncAuthEndpoint {
    pub fn new<F>(path: impl Into<String>, method: http::Method, handler: F) -> Self
    where
        F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
    {
        Self {
            path: path.into(),
            method,
            handler: Arc::new(handler),
            options: AuthEndpointOptions::default(),
        }
    }
}

pub fn create_auth_endpoint<F>(
    path: impl Into<String>,
    method: http::Method,
    options: AuthEndpointOptions,
    handler: F,
) -> AsyncAuthEndpoint
where
    F: for<'a> Fn(&'a AuthContext, ApiRequest) -> EndpointFuture<'a> + Send + Sync + 'static,
{
    AsyncAuthEndpoint {
        path: path.into(),
        method,
        handler: Arc::new(handler),
        options,
    }
}

pub(super) fn validate_async_endpoint_request(
    endpoint: &AsyncAuthEndpoint,
    request: &ApiRequest,
) -> Result<Option<ApiResponse>, OpenAuthError> {
    if endpoint.options.allowed_media_types.is_empty() && endpoint.options.body_schema.is_none() {
        return Ok(None);
    }

    let content_type = request
        .headers()
        .get(header::CONTENT_TYPE)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.split(';').next())
        .map(str::trim)
        .filter(|value| !value.is_empty());

    if !endpoint.options.allowed_media_types.is_empty() {
        let Some(content_type) = content_type else {
            return invalid_request_response(
                StatusCode::UNSUPPORTED_MEDIA_TYPE,
                "UNSUPPORTED_MEDIA_TYPE",
                "Missing Content-Type",
            )
            .map(Some);
        };
        if !endpoint
            .options
            .allowed_media_types
            .iter()
            .any(|allowed| allowed.eq_ignore_ascii_case(content_type))
        {
            return invalid_request_response(
                StatusCode::UNSUPPORTED_MEDIA_TYPE,
                "UNSUPPORTED_MEDIA_TYPE",
                "Unsupported Content-Type",
            )
            .map(Some);
        }
    }

    if let Some(schema) = &endpoint.options.body_schema {
        let body = match parse_request_body::<Value>(request) {
            Ok(body) => body,
            Err(error) => {
                return invalid_request_response(
                    StatusCode::BAD_REQUEST,
                    "INVALID_REQUEST_BODY",
                    &error.to_string(),
                )
                .map(Some);
            }
        };
        if let Err(message) = schema.validate(&body) {
            return invalid_request_response(
                StatusCode::BAD_REQUEST,
                "INVALID_REQUEST_BODY",
                &message,
            )
            .map(Some);
        }
    }

    Ok(None)
}

pub(super) async fn run_endpoint_middlewares(
    context: &AuthContext,
    endpoint: &AsyncAuthEndpoint,
    request: &ApiRequest,
) -> Result<Option<ApiResponse>, OpenAuthError> {
    for middleware in &endpoint.options.middlewares {
        if let Some(response) = (middleware.handler)(context, request).await? {
            return Ok(Some(response));
        }
    }
    Ok(None)
}

fn invalid_request_response(
    status: StatusCode,
    code: &str,
    message: &str,
) -> Result<ApiResponse, OpenAuthError> {
    let body = serde_json::to_vec(&ApiErrorResponse {
        code: code.to_owned(),
        message: message.to_owned(),
        original_message: None,
    })
    .map_err(|error| OpenAuthError::Api(error.to_string()))?;

    Response::builder()
        .status(status)
        .header(header::CONTENT_TYPE, "application/json")
        .body(body)
        .map_err(|error| OpenAuthError::Api(error.to_string()))
}