ruma-macros 0.18.0

Procedural macros used by the Ruma crates.
Documentation
use cfg_if::cfg_if;
use proc_macro2::TokenStream;
use quote::quote;

use super::{Body, Headers, MacroKind, StructSuffix, ensure_feature_presence};
use crate::util::{
    PrivateField, RumaCommon, RumaCommonReexport, StructFieldExt, expand_fields_as_list,
};

mod incoming;
mod outgoing;
mod parse;

pub(crate) use self::parse::RequestAttrs;

const KIND: MacroKind = MacroKind::Request;

/// Expand the `#[request]` macro on a struct.
///
/// This uses the `#[derive(Request)]` macro internally.
pub fn expand_request(attrs: RequestAttrs, item: syn::ItemStruct) -> TokenStream {
    let ruma_common = RumaCommon::new();
    let ruma_macros = ruma_common.reexported(RumaCommonReexport::RumaMacros);

    let maybe_feature_error = ensure_feature_presence().map(syn::Error::to_compile_error);

    let error_ty = attrs.error_ty_or_default(&ruma_common);

    cfg_if! {
        // Make the macro expand the internal derives, such that Rust Analyzer's expand macro helper can
        // render their output. Requires a nightly toolchain.
        if #[cfg(feature = "__internal_macro_expand")] {
            use syn::parse_quote;

            let mut derive_input = item.clone();
            derive_input.attrs.push(parse_quote! { #[ruma_api(error = #error_ty)] });
            crate::util::cfg_expand_struct(&mut derive_input);

            let extra_derive = quote! { #ruma_macros::_FakeDeriveRumaApi };
            let ruma_api_attribute = quote! {};
            let request_impls =
                expand_derive_request(derive_input).unwrap_or_else(syn::Error::into_compile_error);
        } else {
            let extra_derive = quote! { #ruma_macros::Request };
            let ruma_api_attribute = quote! { #[ruma_api(error = #error_ty)] };
            let request_impls = quote! {};
        }
    }

    quote! {
        #maybe_feature_error

        #[derive(Clone, Debug, #ruma_common::serde::_FakeDeriveSerde, #extra_derive)]
        #[cfg_attr(not(ruma_unstable_exhaustive_types), non_exhaustive)]
        #ruma_api_attribute
        #item

        #request_impls
    }
}

/// Expand the `#[derive(Request)]` macro.
pub fn expand_derive_request(input: syn::ItemStruct) -> syn::Result<TokenStream> {
    let request = Request::try_from(input)?;

    let ruma_common = RumaCommon::new();
    let impls = request.expand_impls(&ruma_common);
    let tests = request.expand_tests(&ruma_common);

    Ok(quote! {
        #impls

        #[allow(deprecated)]
        #[cfg(test)]
        mod __request {
            #tests
        }
    })
}

/// A parsed struct representing an API request.
struct Request {
    /// The name of the struct.
    ident: syn::Ident,

    /// The generics of the struct.
    generics: syn::Generics,

    /// The HTTP headers.
    headers: Headers,

    /// The path variables.
    path: RequestPath,

    /// The query variables.
    query: RequestQuery,

    /// The body.
    body: Body,

    /// The type used for the `EndpointError` associated type on `OutgoingRequest` and
    /// `IncomingRequest` implementations.
    error_ty: syn::Type,
}

impl Request {
    /// Expand the implementations generated by this macro.
    fn expand_impls(&self, ruma_common: &RumaCommon) -> TokenStream {
        let ruma_macros = ruma_common.reexported(RumaCommonReexport::RumaMacros);
        let serde = ruma_common.reexported(RumaCommonReexport::Serde);

        let request_body_serde_struct = self.body.expand_serde_struct_definition(KIND, ruma_common);
        let request_query_serde_struct =
            self.query.expand_serde_struct_definition(&ruma_macros, &serde);

        let outgoing_request_impl = self.expand_outgoing(ruma_common);
        let incoming_request_impl = self.expand_incoming(ruma_common);

        quote! {
            #request_body_serde_struct
            #request_query_serde_struct

            #[allow(deprecated)]
            mod __request_impls {
                use super::*;
                #outgoing_request_impl
                #incoming_request_impl
            }
        }
    }

    /// Expand the tests generated by this macro.
    fn expand_tests(&self, ruma_common: &RumaCommon) -> TokenStream {
        let ident = &self.ident;

        let mut tests = self.path.expand_tests(ident, ruma_common);

        if !self.body.is_empty() {
            let http = ruma_common.reexported(RumaCommonReexport::Http);

            tests.extend(quote! {
                #[::std::prelude::v1::test]
                fn request_is_not_get() {
                    ::std::assert_ne!(
                        <super::#ident as #ruma_common::api::Metadata>::METHOD, #http::Method::GET,
                        "GET endpoints can't have body fields",
                    );
                }
            });
        }

        tests
    }
}

/// Request path fields.
#[derive(Default)]
pub struct RequestPath(Vec<syn::Field>);

impl RequestPath {
    /// Generate code to test the path parameters for the request with the given ident.
    fn expand_tests(&self, ident: &syn::Ident, ruma_common: &RumaCommon) -> TokenStream {
        let path_fields = self.0.iter().map(|f| f.ident().to_string());

        quote! {
            #[::std::prelude::v1::test]
            fn path_parameters() {
                use #ruma_common::api::path_builder::PathBuilder as _;

                let path_params = <super::#ident as #ruma_common::api::Metadata>::PATH_BUILDER._path_parameters();
                let request_path_fields: &[&::std::primitive::str] = &[#(#path_fields),*];
                ::std::assert_eq!(
                    path_params, request_path_fields,
                    "Path parameters must match the `Request`'s `#[ruma_api(path)]` fields"
                );
            }
        }
    }

    /// Generate code for a comma-separated list of field names.
    ///
    /// No attributes are forwarded.
    fn expand_fields(&self) -> TokenStream {
        expand_fields_as_list(&self.0)
    }
}

/// Request query fields.
#[derive(Default)]
#[allow(clippy::large_enum_variant)]
enum RequestQuery {
    /// The request doesn't contain a query.
    #[default]
    None,

    /// The fields containing the query parameters.
    Fields(Vec<syn::Field>),

    /// The single field containing the whole query.
    All(syn::Field),
}

impl RequestQuery {
    /// Generate code to define a `struct RequestQuery` used for (de)serializing the query of
    /// request.
    fn expand_serde_struct_definition(
        &self,
        ruma_macros: &TokenStream,
        serde: &TokenStream,
    ) -> Option<TokenStream> {
        let (fields, extra_attrs) = match self {
            Self::None => return None,
            Self::Fields(fields) => (fields.as_slice(), None),
            Self::All(field) => {
                let extra_attrs = quote! { #[serde(transparent)] };
                (std::slice::from_ref(field), Some(extra_attrs))
            }
        };

        let fields = fields.iter().map(PrivateField);
        let ident = KIND.as_struct_ident(StructSuffix::Query);

        Some(quote! {
            /// Data in the request's query string.
            #[cfg(any(feature = "client", feature = "server"))]
            #[derive(Debug, #ruma_macros::_FakeDeriveRumaApi, #ruma_macros::_FakeDeriveSerde)]
            #[cfg_attr(feature = "client", derive(#serde::Serialize))]
            #[cfg_attr(feature = "server", derive(#serde::Deserialize))]
            #extra_attrs
            struct #ident { #( #fields ),* }
        })
    }

    /// Generate code for a comma-separated list of field names.
    ///
    /// Only the `#[cfg]` attributes on the fields are forwarded.
    fn expand_fields(&self) -> Option<TokenStream> {
        let fields = match self {
            Self::None => return None,
            Self::Fields(fields) => fields.as_slice(),
            Self::All(field) => std::slice::from_ref(field),
        };

        Some(expand_fields_as_list(fields))
    }
}