1#![deny(unsafe_code)]
2
3use proc_macro::TokenStream;
4use quote::{quote, ToTokens};
5use syn::{
6 parse::{Parse, ParseStream},
7 parse_macro_input, AngleBracketedGenericArguments, FnArg, GenericArgument, Ident, ItemFn, Path,
8 PathArguments, ReturnType, Token, Type,
9};
10
11#[proc_macro_attribute]
31pub fn handler(attr: TokenStream, item: TokenStream) -> TokenStream {
32 let args = parse_macro_input!(attr as HandlerArgs);
33 let function = parse_macro_input!(item as ItemFn);
34
35 expand_handler(args, function)
36 .unwrap_or_else(syn::Error::into_compile_error)
37 .into()
38}
39
40struct HandlerArgs {
41 handler_ty: Path,
42 write_ty: Option<Type>,
43}
44
45impl Parse for HandlerArgs {
46 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
47 let handler_ty = input.parse::<Path>()?;
48 let mut write_ty = None;
49
50 if input.peek(Token![,]) {
51 input.parse::<Token![,]>()?;
52 let ident = input.parse::<Ident>()?;
53 if ident != "write" {
54 return Err(syn::Error::new_spanned(ident, "expected `write = Type`"));
55 }
56 input.parse::<Token![=]>()?;
57 write_ty = Some(input.parse::<Type>()?);
58 }
59
60 if !input.is_empty() {
61 return Err(input.error("unexpected tokens in `#[handler]`"));
62 }
63
64 Ok(Self {
65 handler_ty,
66 write_ty,
67 })
68 }
69}
70
71fn expand_handler(args: HandlerArgs, function: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
72 if function.sig.asyncness.is_none() {
73 return Err(syn::Error::new_spanned(
74 function.sig.fn_token,
75 "`#[handler]` can only be used on async functions",
76 ));
77 }
78
79 let handler_ty = args.handler_ty;
80 let signature = handler_signature(&function)?;
81 let input_ty = signature.message_ty;
82 let ok_ty = result_ok_type(&function.sig.output)?;
83 let writes_response = !is_unit_type(&ok_ty);
84 let write_ty = match (args.write_ty, writes_response) {
85 (Some(write_ty), false) => write_ty,
86 (Some(write_ty), true) => {
87 return Err(syn::Error::new_spanned(
88 write_ty,
89 "`write = Type` is only supported for handlers that return Result<()>",
90 ));
91 }
92 (None, true) => ok_ty.clone(),
93 (None, false) => {
94 return Err(syn::Error::new_spanned(
95 &function.sig.output,
96 "`#[handler]` functions that return Result<()> must specify `write = Type`",
97 ));
98 }
99 };
100 let fn_name = &function.sig.ident;
101 let call = if signature.takes_state {
102 quote! { #fn_name(self, msg).await? }
103 } else {
104 quote! { #fn_name(msg).await? }
105 };
106 let tcp_body = if writes_response {
107 quote! {
108 let msg = #call;
109 ctx.write(msg).await
110 }
111 } else {
112 quote! {
113 #call;
114 let _ = ctx;
115 Ok(())
116 }
117 };
118 let datagram_body = if writes_response {
119 quote! {
120 let msg = #call;
121 ctx.write(msg).await
122 }
123 } else {
124 quote! {
125 #call;
126 let _ = ctx;
127 Ok(())
128 }
129 };
130
131 Ok(quote! {
132 #function
133
134 impl ::rs_netty::Handler<#input_ty> for #handler_ty {
135 type Write = #write_ty;
136
137 async fn read(
138 &mut self,
139 ctx: &mut ::rs_netty::Context<Self::Write>,
140 msg: #input_ty,
141 ) -> ::rs_netty::Result<()> {
142 #tcp_body
143 }
144 }
145
146 impl ::rs_netty::DatagramHandler<#input_ty> for #handler_ty {
147 type Write = #write_ty;
148
149 async fn read(
150 &mut self,
151 ctx: &mut ::rs_netty::DatagramContext<Self::Write>,
152 msg: #input_ty,
153 ) -> ::rs_netty::Result<()> {
154 #datagram_body
155 }
156 }
157 })
158}
159
160struct HandlerSignature {
161 takes_state: bool,
162 message_ty: Type,
163}
164
165fn handler_signature(function: &ItemFn) -> syn::Result<HandlerSignature> {
166 let mut inputs = function.sig.inputs.iter();
167 let Some(input) = inputs.next() else {
168 return Err(syn::Error::new_spanned(
169 &function.sig.ident,
170 "`#[handler]` functions must accept a message argument",
171 ));
172 };
173
174 let first = typed_input(input)?;
175 if is_mut_ref_type(first.ty.as_ref()) {
176 let Some(input) = inputs.next() else {
177 return Err(syn::Error::new_spanned(
178 first,
179 "`#[handler]` functions with a state argument must also accept a message argument",
180 ));
181 };
182 let message = typed_input(input)?;
183 if let Some(extra) = inputs.next() {
184 return Err(syn::Error::new_spanned(
185 extra,
186 "`#[handler]` functions can accept at most a state argument and a message argument",
187 ));
188 }
189
190 return Ok(HandlerSignature {
191 takes_state: true,
192 message_ty: (*message.ty).clone(),
193 });
194 }
195
196 if let Some(extra) = inputs.next() {
197 return Err(syn::Error::new_spanned(
198 extra,
199 "`#[handler]` functions can accept at most a state argument and a message argument",
200 ));
201 }
202
203 Ok(HandlerSignature {
204 takes_state: false,
205 message_ty: (*first.ty).clone(),
206 })
207}
208
209fn typed_input(input: &FnArg) -> syn::Result<&syn::PatType> {
210 match input {
211 FnArg::Typed(input) => Ok(input),
212 FnArg::Receiver(receiver) => Err(syn::Error::new_spanned(
213 receiver,
214 "`#[handler]` functions cannot take a self receiver",
215 )),
216 }
217}
218
219fn result_ok_type(output: &ReturnType) -> syn::Result<Type> {
220 let ReturnType::Type(_, ty) = output else {
221 return Err(syn::Error::new_spanned(
222 output,
223 "`#[handler]` functions must return Result<Write>",
224 ));
225 };
226
227 let Type::Path(type_path) = ty.as_ref() else {
228 return Err(syn::Error::new_spanned(
229 ty,
230 "`#[handler]` functions must return Result<Write>",
231 ));
232 };
233
234 let Some(segment) = type_path.path.segments.last() else {
235 return Err(syn::Error::new_spanned(
236 ty,
237 "`#[handler]` functions must return Result<Write>",
238 ));
239 };
240
241 if segment.ident != "Result" {
242 return Err(syn::Error::new_spanned(
243 ty,
244 "`#[handler]` functions must return Result<Write>",
245 ));
246 }
247
248 let PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) =
249 &segment.arguments
250 else {
251 return Err(syn::Error::new_spanned(
252 ty,
253 "`#[handler]` functions must return Result<Write>",
254 ));
255 };
256
257 match args.first() {
258 Some(GenericArgument::Type(ok_ty)) => Ok(ok_ty.clone()),
259 _ => Err(syn::Error::new_spanned(
260 args.to_token_stream(),
261 "`#[handler]` functions must return Result<Write>",
262 )),
263 }
264}
265
266fn is_mut_ref_type(ty: &Type) -> bool {
267 matches!(ty, Type::Reference(reference) if reference.mutability.is_some())
268}
269
270fn is_unit_type(ty: &Type) -> bool {
271 matches!(ty, Type::Tuple(tuple) if tuple.elems.is_empty())
272}