poem-openapi-derive 2.0.23

Macros for poem-openapi
Documentation
use darling::{
    ast::{Data, Style},
    util::{Ignored, SpannedValue},
    FromDeriveInput, FromMeta,
};
use http::header::HeaderName;
use proc_macro2::{Ident, Span, TokenStream};
use quote::quote;
use syn::{Attribute, DeriveInput, Error, Path};

use crate::{
    error::GeneratorResult,
    utils::{get_crate_name, get_description, optional_literal},
};

#[derive(FromMeta, Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) enum AuthType {
    #[darling(rename = "api_key")]
    ApiKey,
    #[darling(rename = "basic")]
    Basic,
    #[darling(rename = "bearer")]
    Bearer,
    #[darling(rename = "oauth2")]
    OAuth2,
    #[darling(rename = "openid_connect")]
    OpenIdConnect,
}

#[derive(FromMeta)]
struct OAuthFlow {
    #[darling(default)]
    authorization_url: Option<String>,
    #[darling(default)]
    token_url: Option<String>,
    #[darling(default)]
    refresh_url: Option<String>,
    #[darling(default)]
    scopes: Option<Path>,
}

impl OAuthFlow {
    fn generate_meta(&self, crate_name: &TokenStream) -> GeneratorResult<TokenStream> {
        let authorization_url = optional_literal(&self.authorization_url);
        let token_url = optional_literal(&self.token_url);
        let refresh_url = optional_literal(&self.refresh_url);
        let scopes = match &self.scopes {
            Some(scopes) => quote!(<#scopes as #crate_name::OAuthScopes>::meta()),
            None => quote!(::std::vec![]),
        };

        Ok(quote! {
            #crate_name::registry::MetaOAuthFlow {
                authorization_url: #authorization_url,
                token_url: #token_url,
                refresh_url: #refresh_url,
                scopes: #scopes,
            }
        })
    }
}

#[derive(FromMeta)]
struct OAuthFlows {
    #[darling(default)]
    implicit: Option<OAuthFlow>,
    #[darling(default)]
    password: Option<OAuthFlow>,
    #[darling(default)]
    client_credentials: Option<OAuthFlow>,
    #[darling(default)]
    authorization_code: Option<OAuthFlow>,
}

impl OAuthFlows {
    fn validate(&self, span: Span) -> GeneratorResult<()> {
        if self.implicit.is_none()
            && self.password.is_none()
            && self.authorization_code.is_none()
            && self.client_credentials.is_none()
        {
            return Err(Error::new(
                span,
                r#"At least one OAuth2 flow configuration is required."#,
            )
            .into());
        }

        if let Some(implicit) = &self.implicit {
            if implicit.authorization_url.is_none() {
                return Err(Error::new(
                    span,
                    r#"Missing authorization url. #[oai(authorization_url="...")]"#,
                )
                .into());
            }
        }

        if let Some(password) = &self.password {
            if password.token_url.is_none() {
                return Err(
                    Error::new(span, r#"Missing token url. #[oai(token_url="...")]"#).into(),
                );
            }
        }

        if let Some(client_credentials) = &self.client_credentials {
            if client_credentials.token_url.is_none() {
                return Err(
                    Error::new(span, r#"Missing token url. #[oai(token_url="...")]"#).into(),
                );
            }
        }

        if let Some(authorization_code) = &self.authorization_code {
            if authorization_code.authorization_url.is_none() {
                return Err(Error::new(
                    span,
                    r#"Missing authorization url. #[oai(authorization_url="...")]"#,
                )
                .into());
            }

            if authorization_code.token_url.is_none() {
                return Err(
                    Error::new(span, r#"Missing token url. #[oai(token_url="...")]"#).into(),
                );
            }
        }

        Ok(())
    }

    fn generate_meta(&self, crate_name: &TokenStream) -> GeneratorResult<TokenStream> {
        let implicit = match &self.implicit {
            Some(implicit) => {
                let meta = implicit.generate_meta(crate_name)?;
                quote!(::std::option::Option::Some(#meta))
            }
            None => quote!(::std::option::Option::None),
        };

        let password = match &self.password {
            Some(password) => {
                let meta = password.generate_meta(crate_name)?;
                quote!(::std::option::Option::Some(#meta))
            }
            None => quote!(::std::option::Option::None),
        };

        let client_credentials = match &self.client_credentials {
            Some(client_credentials) => {
                let meta = client_credentials.generate_meta(crate_name)?;
                quote!(::std::option::Option::Some(#meta))
            }
            None => quote!(::std::option::Option::None),
        };

        let authorization_code = match &self.authorization_code {
            Some(authorization_code) => {
                let meta = authorization_code.generate_meta(crate_name)?;
                quote!(::std::option::Option::Some(#meta))
            }
            None => quote!(::std::option::Option::None),
        };

        Ok(quote! {
            #crate_name::registry::MetaOAuthFlows {
                implicit: #implicit,
                password: #password,
                client_credentials: #client_credentials,
                authorization_code: #authorization_code,
            }
        })
    }
}

#[derive(FromMeta, Debug, Copy, Clone, Eq, PartialEq)]
pub(crate) enum ApiKeyInType {
    #[darling(rename = "query")]
    Query,
    #[darling(rename = "header")]
    Header,
    #[darling(rename = "cookie")]
    Cookie,
}

#[derive(FromDeriveInput)]
#[darling(attributes(oai), forward_attrs(doc))]
struct SecuritySchemeArgs {
    ident: Ident,
    data: Data<Ignored, syn::Type>,
    attrs: Vec<Attribute>,

    #[darling(default)]
    internal: bool,
    #[darling(default)]
    rename: Option<String>,
    #[darling(rename = "type")]
    ty: AuthType,
    #[darling(default, rename = "in")]
    key_in: Option<ApiKeyInType>,
    #[darling(default)]
    key_name: Option<SpannedValue<String>>,
    #[darling(default)]
    bearer_format: Option<String>,
    #[darling(default)]
    flows: Option<SpannedValue<OAuthFlows>>,
    #[darling(default)]
    openid_connect_url: Option<String>,
    #[darling(default)]
    checker: Option<Path>,
}

impl SecuritySchemeArgs {
    fn validate(&self) -> GeneratorResult<()> {
        match self.ty {
            AuthType::ApiKey => self.validate_api_key(),
            AuthType::OAuth2 => self.validate_oauth2(),
            AuthType::OpenIdConnect => self.validate_openid_connect(),
            _ => Ok(()),
        }
    }

    fn validate_api_key(&self) -> GeneratorResult<()> {
        match &self.key_name {
            Some(name) => {
                HeaderName::try_from(&**name).map_err(|_| {
                    Error::new(
                        name.span(),
                        format!("`{}` is not a valid header name.", &**name),
                    )
                })?;
            }
            None => {
                return Err(Error::new_spanned(
                    &self.ident,
                    r#"Missing a key name. #[oai(key_name = "...")]"#,
                )
                .into())
            }
        }

        if self.key_in.is_none() {
            return Err(Error::new_spanned(
                &self.ident,
                r#"Missing a input type. #[oai(in = "...")]"#,
            )
            .into());
        }

        Ok(())
    }

    fn validate_oauth2(&self) -> GeneratorResult<()> {
        match &self.flows {
            Some(flows) => flows.validate(flows.span())?,
            None => {
                return Err(Error::new_spanned(
                    &self.ident,
                    r#"Missing an oauth2 flows. #[oai(flows = "...")]"#,
                )
                .into());
            }
        }

        Ok(())
    }

    fn validate_openid_connect(&self) -> GeneratorResult<()> {
        if self.openid_connect_url.is_none() {
            return Err(Error::new_spanned(
                &self.ident,
                r#"Missing open id connect url. #[oai(openid_connect_url = "...")]"#,
            )
            .into());
        }

        Ok(())
    }

    fn generate_register_security_scheme(
        &self,
        crate_name: &TokenStream,
        name: &str,
    ) -> GeneratorResult<TokenStream> {
        let description = get_description(&self.attrs)?;
        let description = optional_literal(&description);

        let key_name = match &self.key_name {
            Some(key_name) => {
                let name = &**key_name;
                quote!(::std::option::Option::Some(#name))
            }
            None => quote!(::std::option::Option::None),
        };
        let key_in = match &self.key_in {
            Some(ApiKeyInType::Query) => quote!(::std::option::Option::Some("query")),
            Some(ApiKeyInType::Header) => quote!(::std::option::Option::Some("header")),
            Some(ApiKeyInType::Cookie) => quote!(::std::option::Option::Some("cookie")),
            None => quote!(::std::option::Option::None),
        };
        let bearer_format = match &self.bearer_format {
            Some(bearer_format) => quote!(::std::option::Option::Some(#bearer_format)),
            None => quote!(::std::option::Option::None),
        };
        let openid_connect_url = match &self.openid_connect_url {
            Some(openid_connect_url) => quote!(::std::option::Option::Some(#openid_connect_url)),
            None => quote!(::std::option::Option::None),
        };

        let ts = match self.ty {
            AuthType::ApiKey => {
                quote! {
                    registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme {
                        ty: "apiKey",
                        description: #description,
                        name: #key_name,
                        key_in: #key_in,
                        scheme: ::std::option::Option::None,
                        bearer_format: ::std::option::Option::None,
                        flows: ::std::option::Option::None,
                        openid_connect_url: ::std::option::Option::None,
                    });
                }
            }
            AuthType::Basic => {
                quote! {
                    registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme {
                        ty: "http",
                        description: #description,
                        name: ::std::option::Option::None,
                        key_in: ::std::option::Option::None,
                        scheme: ::std::option::Option::Some("basic"),
                        bearer_format: #bearer_format,
                        flows: ::std::option::Option::None,
                        openid_connect_url: ::std::option::Option::None,
                    });
                }
            }
            AuthType::Bearer => {
                quote! {
                    registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme {
                        ty: "http",
                        description: #description,
                        name: ::std::option::Option::None,
                        key_in: ::std::option::Option::None,
                        scheme: ::std::option::Option::Some("bearer"),
                        bearer_format: #bearer_format,
                        flows: ::std::option::Option::None,
                        openid_connect_url: ::std::option::Option::None,
                    });
                }
            }
            AuthType::OAuth2 => {
                let flows = self.flows.as_ref().unwrap().generate_meta(crate_name)?;
                quote! {
                    registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme {
                        ty: "oauth2",
                        description: #description,
                        name: ::std::option::Option::None,
                        key_in: ::std::option::Option::None,
                        scheme: ::std::option::Option::None,
                        bearer_format: ::std::option::Option::None,
                        flows: ::std::option::Option::Some(#flows),
                        openid_connect_url: ::std::option::Option::None,
                    });
                }
            }
            AuthType::OpenIdConnect => {
                quote! {
                    registry.create_security_scheme(#name, #crate_name::registry::MetaSecurityScheme {
                        ty: "openIdConnect",
                        description: #description,
                        name: ::std::option::Option::None,
                        key_in: ::std::option::Option::None,
                        scheme: ::std::option::Option::None,
                        bearer_format: ::std::option::Option::None,
                        flows: ::std::option::Option::None,
                        openid_connect_url: #openid_connect_url,
                    });
                }
            }
        };
        Ok(ts)
    }

    fn generate_from_request(&self, crate_name: &TokenStream) -> TokenStream {
        match self.ty {
            AuthType::ApiKey => {
                let key_name = self.key_name.as_ref().unwrap().as_str();
                let param_in = match self.key_in.as_ref().unwrap() {
                    ApiKeyInType::Query => quote!(#crate_name::registry::MetaParamIn::Query),
                    ApiKeyInType::Header => quote!(#crate_name::registry::MetaParamIn::Header),
                    ApiKeyInType::Cookie => quote!(#crate_name::registry::MetaParamIn::Cookie),
                };
                quote!(<#crate_name::auth::ApiKey as #crate_name::auth::ApiKeyAuthorization>::from_request(req, query, #key_name, #param_in))
            }
            AuthType::Basic => {
                quote!(<#crate_name::auth::Basic as #crate_name::auth::BasicAuthorization>::from_request(req))
            }
            AuthType::Bearer => {
                quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req))
            }
            AuthType::OAuth2 => {
                quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req))
            }
            AuthType::OpenIdConnect => {
                quote!(<#crate_name::auth::Bearer as #crate_name::auth::BearerAuthorization>::from_request(req))
            }
        }
    }
}

pub(crate) fn generate(args: DeriveInput) -> GeneratorResult<TokenStream> {
    let args: SecuritySchemeArgs = SecuritySchemeArgs::from_derive_input(&args)?;
    let crate_name = get_crate_name(args.internal);
    let ident = &args.ident;
    let oai_typename = args.rename.clone().unwrap_or_else(|| ident.to_string());
    args.validate()?;

    let fields = match &args.data {
        Data::Struct(e) => e,
        _ => {
            return Err(Error::new_spanned(
                ident,
                "SecurityScheme can only be applied to an struct.",
            )
            .into())
        }
    };

    if fields.style == Style::Tuple && fields.fields.len() != 1 {
        return Err(Error::new_spanned(
            ident,
            "Only one unnamed field is allowed in the SecurityScheme struct.",
        )
        .into());
    }

    let register_security_scheme =
        args.generate_register_security_scheme(&crate_name, &oai_typename)?;
    let from_request = args.generate_from_request(&crate_name);
    let checker = args.checker.as_ref().map(|path| {
        quote! {
            let output = ::std::option::Option::ok_or(#path(&req, output).await, #crate_name::error::AuthorizationError)?;
        }
    });

    let expanded = quote! {
        #[#crate_name::__private::poem::async_trait]
        impl<'a> #crate_name::ApiExtractor<'a> for #ident {
            const TYPE: #crate_name::ApiExtractorType = #crate_name::ApiExtractorType::SecurityScheme;

            type ParamType = ();
            type ParamRawType = ();

            fn register(registry: &mut #crate_name::registry::Registry) {
                #register_security_scheme
            }

            fn security_scheme() -> ::std::option::Option<&'static str> {
                ::std::option::Option::Some(#oai_typename)
            }

            async fn from_request(
                req: &'a #crate_name::__private::poem::Request,
                body: &mut #crate_name::__private::poem::RequestBody,
                _param_opts: #crate_name::ExtractParamOptions<Self::ParamType>,
            ) -> #crate_name::__private::poem::Result<Self> {
                let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap();
                let output = #from_request?;
                #checker
                ::std::result::Result::Ok(Self(output))
            }
        }
    };

    Ok(expanded)
}