Skip to main content

rs_netty_macros/
lib.rs

1#![deny(unsafe_code)]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{
6    parse::{Parse, ParseStream},
7    parse_macro_input, AngleBracketedGenericArguments, FnArg, GenericArgument, Ident, ItemFn, Path,
8    PathArguments, ReturnType, Token, Type,
9};
10
11/// Adapts an async function into an `rs_netty::Handler` implementation.
12///
13/// The MVP form expects a user-declared handler type and an async function with
14/// one inbound message argument:
15///
16/// ```ignore
17/// struct Echo;
18///
19/// #[handler(Echo)]
20/// async fn echo(req: Request) -> rs_netty::Result<Response> {
21///     Ok(Response { echoed: req.message })
22/// }
23///
24/// #[handler(PrintResponse, write = Request)]
25/// async fn print_response(res: Response) -> rs_netty::Result<()> {
26///     println!("{}", res.message);
27///     Ok(())
28/// }
29/// ```
30#[proc_macro_attribute]
31pub fn handler(attr: TokenStream, item: TokenStream) -> TokenStream {
32    let args = parse_macro_input!(attr as HandlerArgs);
33    let function = parse_macro_input!(item as ItemFn);
34
35    expand_handler(args, function)
36        .unwrap_or_else(syn::Error::into_compile_error)
37        .into()
38}
39
40struct HandlerArgs {
41    handler_ty: Path,
42    write_ty: Option<Type>,
43}
44
45impl Parse for HandlerArgs {
46    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
47        let handler_ty = input.parse::<Path>()?;
48        let mut write_ty = None;
49
50        if input.peek(Token![,]) {
51            input.parse::<Token![,]>()?;
52            let ident = input.parse::<Ident>()?;
53            if ident != "write" {
54                return Err(syn::Error::new_spanned(ident, "expected `write = Type`"));
55            }
56            input.parse::<Token![=]>()?;
57            write_ty = Some(input.parse::<Type>()?);
58        }
59
60        if !input.is_empty() {
61            return Err(input.error("unexpected tokens in `#[handler]`"));
62        }
63
64        Ok(Self {
65            handler_ty,
66            write_ty,
67        })
68    }
69}
70
71fn expand_handler(args: HandlerArgs, function: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
72    if function.sig.asyncness.is_none() {
73        return Err(syn::Error::new_spanned(
74            function.sig.fn_token,
75            "`#[handler]` can only be used on async functions",
76        ));
77    }
78
79    let handler_ty = args.handler_ty;
80    let signature = handler_signature(&function)?;
81    let input_ty = signature.message_ty;
82    let ok_ty = result_ok_type(&function.sig.output)?;
83    let writes_response = !is_unit_type(&ok_ty);
84    let write_ty = match (args.write_ty, writes_response) {
85        (Some(write_ty), false) => write_ty,
86        (Some(write_ty), true) => {
87            return Err(syn::Error::new_spanned(
88                write_ty,
89                "`write = Type` is only supported for handlers that return Result<()>",
90            ));
91        }
92        (None, true) => ok_ty.clone(),
93        (None, false) => {
94            return Err(syn::Error::new_spanned(
95                &function.sig.output,
96                "`#[handler]` functions that return Result<()> must specify `write = Type`",
97            ));
98        }
99    };
100    let fn_name = &function.sig.ident;
101    let call = if signature.takes_state {
102        quote! { #fn_name(self, msg).await? }
103    } else {
104        quote! { #fn_name(msg).await? }
105    };
106    let tcp_body = if writes_response {
107        quote! {
108            let msg = #call;
109            ctx.write(msg).await
110        }
111    } else {
112        quote! {
113            #call;
114            let _ = ctx;
115            Ok(())
116        }
117    };
118    let datagram_body = if writes_response {
119        quote! {
120            let msg = #call;
121            ctx.write(msg).await
122        }
123    } else {
124        quote! {
125            #call;
126            let _ = ctx;
127            Ok(())
128        }
129    };
130
131    Ok(quote! {
132        #function
133
134        impl ::rs_netty::Handler<#input_ty> for #handler_ty {
135            type Write = #write_ty;
136
137            async fn read(
138                &mut self,
139                ctx: &mut ::rs_netty::Context<Self::Write>,
140                msg: #input_ty,
141            ) -> ::rs_netty::Result<()> {
142                #tcp_body
143            }
144        }
145
146        impl ::rs_netty::DatagramHandler<#input_ty> for #handler_ty {
147            type Write = #write_ty;
148
149            async fn read(
150                &mut self,
151                ctx: &mut ::rs_netty::DatagramContext<Self::Write>,
152                msg: #input_ty,
153            ) -> ::rs_netty::Result<()> {
154                #datagram_body
155            }
156        }
157    })
158}
159
160struct HandlerSignature {
161    takes_state: bool,
162    message_ty: Type,
163}
164
165fn handler_signature(function: &ItemFn) -> syn::Result<HandlerSignature> {
166    let mut inputs = function.sig.inputs.iter();
167    let Some(input) = inputs.next() else {
168        return Err(syn::Error::new_spanned(
169            &function.sig.ident,
170            "`#[handler]` functions must accept a message argument",
171        ));
172    };
173
174    let first = typed_input(input)?;
175    if is_mut_ref_type(first.ty.as_ref()) {
176        let Some(input) = inputs.next() else {
177            return Err(syn::Error::new_spanned(
178                first,
179                "`#[handler]` functions with a state argument must also accept a message argument",
180            ));
181        };
182        let message = typed_input(input)?;
183        if let Some(extra) = inputs.next() {
184            return Err(syn::Error::new_spanned(
185                extra,
186                "`#[handler]` functions can accept at most a state argument and a message argument",
187            ));
188        }
189
190        return Ok(HandlerSignature {
191            takes_state: true,
192            message_ty: (*message.ty).clone(),
193        });
194    }
195
196    if let Some(extra) = inputs.next() {
197        return Err(syn::Error::new_spanned(
198            extra,
199            "`#[handler]` functions can accept at most a state argument and a message argument",
200        ));
201    }
202
203    Ok(HandlerSignature {
204        takes_state: false,
205        message_ty: (*first.ty).clone(),
206    })
207}
208
209fn typed_input(input: &FnArg) -> syn::Result<&syn::PatType> {
210    match input {
211        FnArg::Typed(input) => Ok(input),
212        FnArg::Receiver(receiver) => Err(syn::Error::new_spanned(
213            receiver,
214            "`#[handler]` functions cannot take a self receiver",
215        )),
216    }
217}
218
219fn result_ok_type(output: &ReturnType) -> syn::Result<Type> {
220    let ReturnType::Type(_, ty) = output else {
221        return Err(syn::Error::new_spanned(
222            output,
223            "`#[handler]` functions must return Result<Write>",
224        ));
225    };
226
227    let Type::Path(type_path) = ty.as_ref() else {
228        return Err(syn::Error::new_spanned(
229            ty,
230            "`#[handler]` functions must return Result<Write>",
231        ));
232    };
233
234    let Some(segment) = type_path.path.segments.last() else {
235        return Err(syn::Error::new_spanned(
236            ty,
237            "`#[handler]` functions must return Result<Write>",
238        ));
239    };
240
241    if segment.ident != "Result" {
242        return Err(syn::Error::new_spanned(
243            ty,
244            "`#[handler]` functions must return Result<Write>",
245        ));
246    }
247
248    let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) =
249        &segment.arguments
250    else {
251        return Err(syn::Error::new_spanned(
252            ty,
253            "`#[handler]` functions must return Result<Write>",
254        ));
255    };
256
257    match args.first() {
258        Some(GenericArgument::Type(ok_ty)) => Ok(ok_ty.clone()),
259        _ => Err(syn::Error::new_spanned(
260            args.to_token_stream(),
261            "`#[handler]` functions must return Result<Write>",
262        )),
263    }
264}
265
266fn is_mut_ref_type(ty: &Type) -> bool {
267    matches!(ty, Type::Reference(reference) if reference.mutability.is_some())
268}
269
270fn is_unit_type(ty: &Type) -> bool {
271    matches!(ty, Type::Tuple(tuple) if tuple.elems.is_empty())
272}