tower-conneg 1.1.0

Tower middleware for HTTP content negotiation
use serde::Serialize;
use serde::de::DeserializeOwned;
use utoipa::openapi::{
    Content, HeaderBuilder, Ref, RefOr, Required, Schema,
    request_body::{RequestBody, RequestBodyBuilder},
    response::{Response, ResponseBuilder},
};
use utoipa::{PartialSchema, ToSchema};

use crate::core::{Negotiate, NegotiateResponse, ServerConfig};
use crate::format::ErasedFormat;

/// Media types used to document negotiated request and response bodies.
#[derive(Debug, Clone, Default, Eq, PartialEq)]
pub struct OpenApiFormats {
    media_types: Vec<String>,
}

impl OpenApiFormats {
    /// Reads media types from a server configuration.
    pub fn from_server_config(config: &ServerConfig) -> Self {
        Self::from_formats(config.formats.iter().map(AsRef::as_ref))
    }

    /// Reads media types from format values.
    pub fn from_formats<'a>(formats: impl IntoIterator<Item = &'a dyn ErasedFormat>) -> Self {
        let mut media_types = Vec::new();
        for format in formats {
            for media_type in format.supported_media_types() {
                push_unique(&mut media_types, media_type.to_string());
            }
        }
        Self { media_types }
    }

    /// Uses explicit media type strings.
    pub fn from_media_types<I, S>(media_types: I) -> Self
    where
        I: IntoIterator<Item = S>,
        S: Into<String>,
    {
        let mut unique = Vec::new();
        for media_type in media_types {
            push_unique(&mut unique, media_type.into());
        }
        Self {
            media_types: unique,
        }
    }

    /// Selected media types in documentation order.
    pub fn media_types(&self) -> &[String] {
        &self.media_types
    }

    /// Keeps only selected media types.
    #[must_use]
    pub fn only<I, S>(&self, media_types: I) -> Self
    where
        I: IntoIterator<Item = S>,
        S: AsRef<str>,
    {
        let selected: Vec<String> = media_types
            .into_iter()
            .map(|media_type| media_type.as_ref().to_owned())
            .collect();
        Self::from_media_types(
            self.media_types
                .iter()
                .filter(|media_type| selected.iter().any(|selected| selected == *media_type))
                .cloned(),
        )
    }

    /// Removes selected media types.
    #[must_use]
    pub fn without<I, S>(&self, media_types: I) -> Self
    where
        I: IntoIterator<Item = S>,
        S: AsRef<str>,
    {
        let excluded: Vec<String> = media_types
            .into_iter()
            .map(|media_type| media_type.as_ref().to_owned())
            .collect();
        Self::from_media_types(
            self.media_types
                .iter()
                .filter(|media_type| !excluded.iter().any(|excluded| excluded == *media_type))
                .cloned(),
        )
    }

    /// Request body for all selected media types.
    pub fn request_body<T>(&self) -> RequestBody
    where
        T: ToSchema,
    {
        self.request_body_with_content(&Content::new(Some(schema_ref::<T>())), None)
    }

    /// Request body for all selected media types using a schema component name.
    pub fn request_body_ref<S>(&self, schema_name: S) -> RequestBody
    where
        S: Into<String>,
    {
        self.request_body_with_content(
            &Content::new(Some(Ref::from_schema_name(schema_name))),
            None,
        )
    }

    /// Required request body for all selected media types.
    pub fn required_request_body<T>(&self) -> RequestBody
    where
        T: ToSchema,
    {
        self.request_body_with_content(&Content::new(Some(schema_ref::<T>())), Some(Required::True))
    }

    /// Required request body for all selected media types using a schema component name.
    pub fn required_request_body_ref<S>(&self, schema_name: S) -> RequestBody
    where
        S: Into<String>,
    {
        self.request_body_with_content(
            &Content::new(Some(Ref::from_schema_name(schema_name))),
            Some(Required::True),
        )
    }

    /// Request body preserving Utoipa content metadata.
    pub fn request_body_with_content(
        &self,
        content: &Content,
        required: Option<Required>,
    ) -> RequestBody {
        self.media_types
            .iter()
            .fold(
                RequestBodyBuilder::new().required(required),
                |builder, media_type| builder.content(media_type, content.clone()),
            )
            .build()
    }

    /// Response for all selected media types.
    pub fn response<T, D>(&self, description: D) -> Response
    where
        T: ToSchema,
        D: Into<String>,
    {
        self.response_with_content(description, &Content::new(Some(schema_ref::<T>())))
    }

    /// Response for all selected media types using a schema component name.
    pub fn response_ref<S, D>(&self, description: D, schema_name: S) -> Response
    where
        S: Into<String>,
        D: Into<String>,
    {
        self.response_with_content(
            description,
            &Content::new(Some(Ref::from_schema_name(schema_name))),
        )
    }

    /// Response preserving Utoipa content metadata.
    pub fn response_with_content<D>(&self, description: D, content: &Content) -> Response
    where
        D: Into<String>,
    {
        self.media_types
            .iter()
            .fold(
                ResponseBuilder::new().description(description),
                |builder, media_type| builder.content(media_type, content.clone()),
            )
            .build()
    }

    /// 406 response without a response body schema.
    pub fn not_acceptable_response(&self) -> Response {
        ResponseBuilder::new()
            .description("The Accept header does not match any supported response format.")
            .build()
    }

    /// 406 response with a negotiated response body schema.
    pub fn not_acceptable_response_with_body<T>(&self) -> Response
    where
        T: ToSchema,
    {
        self.response::<T, _>("The Accept header does not match any supported response format.")
    }

    /// 415 response without a response body schema.
    pub fn unsupported_media_type_response(&self) -> Response {
        ResponseBuilder::new()
            .description("The request Content-Type is not supported.")
            .build()
    }

    /// 415 response with a negotiated response body schema.
    pub fn unsupported_media_type_response_with_body<T>(&self) -> Response
    where
        T: ToSchema,
    {
        self.response::<T, _>("The request Content-Type is not supported.")
    }

    /// 415 POST response with `Accept-Post`.
    pub fn unsupported_media_type_post_response(&self) -> Response {
        self.with_supported_header(self.unsupported_media_type_response(), "Accept-Post")
    }

    /// 415 POST response with `Accept-Post` and a negotiated body schema.
    pub fn unsupported_media_type_post_response_with_body<T>(&self) -> Response
    where
        T: ToSchema,
    {
        self.with_supported_header(
            self.unsupported_media_type_response_with_body::<T>(),
            "Accept-Post",
        )
    }

    /// 415 PATCH response with `Accept-Patch`.
    pub fn unsupported_media_type_patch_response(&self) -> Response {
        self.with_supported_header(self.unsupported_media_type_response(), "Accept-Patch")
    }

    /// 415 PATCH response with `Accept-Patch` and a negotiated body schema.
    pub fn unsupported_media_type_patch_response_with_body<T>(&self) -> Response
    where
        T: ToSchema,
    {
        self.with_supported_header(
            self.unsupported_media_type_response_with_body::<T>(),
            "Accept-Patch",
        )
    }

    fn with_supported_header(&self, response: Response, name: &str) -> Response {
        let description = if self.media_types.is_empty() {
            "Supported request media types.".to_owned()
        } else {
            format!(
                "Supported request media types: {}.",
                self.media_types.join(", ")
            )
        };
        ResponseBuilder::from(response)
            .header(
                name,
                HeaderBuilder::new().description(Some(description)).build(),
            )
            .build()
    }
}

impl From<&ServerConfig> for OpenApiFormats {
    fn from(config: &ServerConfig) -> Self {
        Self::from_server_config(config)
    }
}

impl ServerConfig {
    /// `OpenAPI` helpers for the formats configured on this server.
    pub fn openapi_formats(&self) -> OpenApiFormats {
        OpenApiFormats::from_server_config(self)
    }
}

impl<T> PartialSchema for Negotiate<T>
where
    T: DeserializeOwned + ToSchema,
{
    fn schema() -> RefOr<Schema> {
        T::schema()
    }
}

impl<T> ToSchema for Negotiate<T>
where
    T: DeserializeOwned + ToSchema,
{
    fn name() -> std::borrow::Cow<'static, str> {
        T::name()
    }
}

impl<T> PartialSchema for NegotiateResponse<T>
where
    T: Serialize + ToSchema,
{
    fn schema() -> RefOr<Schema> {
        T::schema()
    }
}

impl<T> ToSchema for NegotiateResponse<T>
where
    T: Serialize + ToSchema,
{
    fn name() -> std::borrow::Cow<'static, str> {
        T::name()
    }
}

fn schema_ref<T>() -> RefOr<Schema>
where
    T: ToSchema,
{
    RefOr::Ref(Ref::from_schema_name(T::name()))
}

fn push_unique(values: &mut Vec<String>, value: String) {
    if !values.contains(&value) {
        values.push(value);
    }
}