use proc_macro2::{Ident, TokenStream};
use quote::{ToTokens, format_ident, quote};
use syn::{
Error, FnArg, GenericArgument, Generics, ItemFn, PathArguments, Result, ReturnType, Type,
};
const RESULT_RETURN_ERROR: &str = "`handler` functions must return `Result<TResponse, TError>`";
pub(crate) fn expand(input: TokenStream) -> Result<TokenStream> {
let mut function = syn::parse2::<ItemFn>(input)?;
let handler = HandlerExpansion::new(&function)?;
function.sig.ident = handler.names.user_fn.clone();
function.vis = syn::Visibility::Inherited;
Ok(handler.into_tokens(function))
}
struct HandlerExpansion {
attrs: Vec<syn::Attribute>,
generics: Generics,
input: HandlerInput,
is_async: bool,
names: HandlerNames,
output: HandlerOutput,
vis: syn::Visibility,
}
impl HandlerExpansion {
fn new(function: &ItemFn) -> Result<Self> {
validate_function_shape(function)?;
let input = HandlerInput::parse(function)?;
let generics = function.sig.generics.clone();
input.validate_generics(&generics)?;
Ok(Self {
attrs: function.attrs.clone(),
generics,
input,
is_async: function.sig.asyncness.is_some(),
names: HandlerNames::new(&function.sig.ident),
output: HandlerOutput::parse(&function.sig.output)?,
vis: function.vis.clone(),
})
}
fn into_tokens(self, function: ItemFn) -> TokenStream {
let wrapper_body = self.wrapper_body();
let Self {
attrs,
generics,
input,
names,
output,
vis,
..
} = self;
let HandlerNames {
public_fn,
user_fn: _,
wrapper_fn,
} = names;
let HandlerOutput { response, error } = output;
let (_, _, where_clause) = generics.split_for_impl();
let turbofish = generic_turbofish(&generics);
match input {
HandlerInput::WithRequest { request } => {
let request = *request;
quote! {
#function
#[allow(non_snake_case)]
fn #wrapper_fn #generics(
__const_router_request: #request,
) -> ::const_router::BoxFuture<#response, #error>
#where_clause
{
#wrapper_body
}
#(#attrs)*
#vis const fn #public_fn #generics() -> ::const_router::__internal::Handler<#request, #response, #error>
#where_clause
{
::const_router::__internal::new_handler(#wrapper_fn #turbofish)
}
}
}
HandlerInput::WithoutRequest => {
quote! {
#function
#[allow(non_snake_case)]
fn #wrapper_fn<__ConstRouterRequest>(
_request: __ConstRouterRequest,
) -> ::const_router::BoxFuture<#response, #error>
{
#wrapper_body
}
#(#attrs)*
#vis const fn #public_fn<__ConstRouterRequest>() -> ::const_router::__internal::Handler<__ConstRouterRequest, #response, #error>
{
::const_router::__internal::new_handler(#wrapper_fn::<__ConstRouterRequest>)
}
}
}
}
}
fn wrapper_body(&self) -> TokenStream {
let wrapper_call = self.wrapper_call();
if self.is_async {
quote! {
::std::boxed::Box::pin(#wrapper_call)
}
} else {
quote! {
::std::boxed::Box::pin(::std::future::ready(#wrapper_call))
}
}
}
fn wrapper_call(&self) -> TokenStream {
let user_fn = &self.names.user_fn;
match &self.input {
HandlerInput::WithRequest { .. } => {
let turbofish = generic_turbofish(&self.generics);
quote! {
#user_fn #turbofish(__const_router_request)
}
}
HandlerInput::WithoutRequest => {
quote! {
#user_fn()
}
}
}
}
}
struct HandlerNames {
public_fn: Ident,
user_fn: Ident,
wrapper_fn: Ident,
}
impl HandlerNames {
fn new(public_fn: &Ident) -> Self {
Self {
public_fn: public_fn.clone(),
user_fn: format_ident!("__const_router_{}_user", public_fn),
wrapper_fn: format_ident!("__const_router_{}_wrapper", public_fn),
}
}
}
struct HandlerOutput {
response: Type,
error: Type,
}
impl HandlerOutput {
fn parse(output: &ReturnType) -> Result<Self> {
let ReturnType::Type(_, ty) = output else {
return Err(result_return_error(output));
};
let Type::Path(path) = ty.as_ref() else {
return Err(result_return_error(ty));
};
let segment = path
.path
.segments
.last()
.ok_or_else(|| result_return_error(path))?;
if segment.ident != "Result" {
return Err(result_return_error(segment));
}
let PathArguments::AngleBracketed(arguments) = &segment.arguments else {
return Err(result_return_error(segment));
};
let mut args = arguments.args.iter();
match (args.next(), args.next(), args.next()) {
(Some(GenericArgument::Type(response)), Some(GenericArgument::Type(error)), None) => {
Ok(Self {
response: response.clone(),
error: error.clone(),
})
}
_ => Err(result_return_error(arguments)),
}
}
}
enum HandlerInput {
WithRequest { request: Box<Type> },
WithoutRequest,
}
impl HandlerInput {
fn parse(function: &ItemFn) -> Result<Self> {
let mut inputs = function.sig.inputs.iter();
let Some(first) = inputs.next() else {
return Ok(Self::WithoutRequest);
};
if let Some(second) = inputs.next() {
return Err(Error::new_spanned(
second,
"`handler` functions must take zero or one request argument",
));
}
match first {
FnArg::Typed(argument) => Ok(Self::WithRequest {
request: argument.ty.clone(),
}),
FnArg::Receiver(receiver) => Err(Error::new_spanned(
receiver,
"`handler` can only be applied to free functions",
)),
}
}
fn validate_generics(&self, generics: &Generics) -> Result<()> {
if matches!(self, Self::WithoutRequest) && !generics.params.is_empty() {
return Err(Error::new_spanned(
generics,
"`handler` functions without a request argument cannot be generic",
));
}
Ok(())
}
}
fn validate_function_shape(function: &ItemFn) -> Result<()> {
let sig = &function.sig;
reject_if_present(
&sig.constness,
"`handler` cannot be applied to a `const fn`",
)?;
reject_if_present(
&sig.unsafety,
"`handler` cannot be applied to an `unsafe fn`",
)?;
reject_if_present(&sig.abi, "`handler` cannot be applied to an `extern fn`")?;
reject_if_present(
&sig.variadic,
"`handler` cannot be applied to a variadic function",
)
}
fn reject_if_present<T: ToTokens>(value: &Option<T>, message: &'static str) -> Result<()> {
if let Some(value) = value {
return Err(Error::new_spanned(value, message));
}
Ok(())
}
fn result_return_error(tokens: impl ToTokens) -> Error {
Error::new_spanned(tokens, RESULT_RETURN_ERROR)
}
fn generic_turbofish(generics: &Generics) -> TokenStream {
if generics.params.is_empty() {
return TokenStream::new();
}
let args = generics.params.iter().map(generic_argument);
quote! { ::<#(#args),*> }
}
fn generic_argument(param: &syn::GenericParam) -> TokenStream {
match param {
syn::GenericParam::Lifetime(param) => {
let lifetime = ¶m.lifetime;
quote! { #lifetime }
}
syn::GenericParam::Type(param) => {
let ident = ¶m.ident;
quote! { #ident }
}
syn::GenericParam::Const(param) => {
let ident = ¶m.ident;
quote! { #ident }
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
fn expand_to_string(input: TokenStream) -> String {
expand(input)
.expect("handler input should expand")
.to_string()
}
fn assert_handler_error(input: TokenStream, expected: &str) {
let error = expand(input).expect_err("handler input should fail");
assert_eq!(error.to_string(), expected);
}
#[test]
fn expands_sync_handler_with_request() {
let expanded = expand_to_string(quote! {
pub fn handler(req: Request) -> Result<Response, Error> {
let _ = req;
todo!()
}
});
assert!(expanded.contains("__const_router_handler_user"));
assert!(expanded.contains("std :: future :: ready"));
assert!(expanded.contains("pub const fn handler"));
assert!(expanded.contains("new_handler"));
}
#[test]
fn expands_async_handler_without_request() {
let expanded = expand_to_string(quote! {
pub(crate) async fn handler() -> std::result::Result<Response, Error> {
todo!()
}
});
assert!(expanded.contains("async fn __const_router_handler_user"));
assert!(expanded.contains("Box :: pin"));
assert!(expanded.contains("pub (crate) const fn handler"));
assert!(expanded.contains("__ConstRouterRequest"));
}
#[test]
fn expands_generic_handler_with_request() {
let expanded = expand_to_string(quote! {
fn handler<T, const N: usize>(req: Request) -> Result<Response, Error>
where
T: Default,
{
let _ = (req, T::default(), N);
todo!()
}
});
assert!(expanded.contains("fn __const_router_handler_wrapper < T , const N : usize >"));
assert!(expanded.contains("where T : Default"));
assert!(expanded.contains("__const_router_handler_wrapper :: < T , N >"));
}
#[test]
fn rejects_missing_or_non_result_return_types() {
for input in [
quote! {
fn handler() {}
},
quote! {
fn handler() -> Response {
todo!()
}
},
quote! {
fn handler() -> Result<Response> {
todo!()
}
},
] {
assert_handler_error(
input,
"`handler` functions must return `Result<TResponse, TError>`",
);
}
}
#[test]
fn rejects_invalid_function_inputs() {
assert_handler_error(
quote! {
fn handler(first: Request, second: Request) -> Result<Response, Error> {
todo!()
}
},
"`handler` functions must take zero or one request argument",
);
assert_handler_error(
quote! {
fn handler(&self) -> Result<Response, Error> {
todo!()
}
},
"`handler` can only be applied to free functions",
);
assert_handler_error(
quote! {
fn handler<T>() -> Result<Response, Error> {
todo!()
}
},
"`handler` functions without a request argument cannot be generic",
);
}
#[test]
fn rejects_unsupported_function_qualifiers() {
for (input, expected) in [
(
quote! {
const fn handler() -> Result<Response, Error> {
todo!()
}
},
"`handler` cannot be applied to a `const fn`",
),
(
quote! {
unsafe fn handler() -> Result<Response, Error> {
todo!()
}
},
"`handler` cannot be applied to an `unsafe fn`",
),
(
quote! {
extern "C" fn handler() -> Result<Response, Error> {
todo!()
}
},
"`handler` cannot be applied to an `extern fn`",
),
] {
assert_handler_error(input, expected);
}
}
}