ioevent_macro/
lib.rs

1//! Procedural macros for the I/O event system.
2//!
3//! This crate provides procedural macros for deriving event types and creating event subscribers.
4//! It is part of the `ioevent` ecosystem and should be used as a dependency of `ioevent`.
5
6use proc_macro::TokenStream;
7use quote::{ToTokens, format_ident, quote};
8use syn::{FnArg, ItemFn, ReturnType, Token, parse_macro_input, punctuated::Punctuated};
9
10/// Derives the `Event` trait for a type.
11///
12/// This macro implements the `Event` trait for a type, allowing it to participate in the event system.
13/// It provides serialization and deserialization capabilities for the type.
14///
15/// # Attributes
16///
17/// * `#[event(tag = "custom_tag")]` - Specifies a custom tag for the event type.
18///   If not provided, the tag will be generated from the module path and type name.
19///   
20/// # Requires
21///
22/// * The type must implement the `Serialize` and `Deserialize` traits from the `serde` crate.
23///
24/// # Examples
25///
26/// ```rust
27/// #[derive(Event)]
28/// struct MyEvent {
29///     field: String,
30/// }
31/// ```
32#[proc_macro_derive(Event, attributes(event))]
33pub fn derive_event(input: TokenStream) -> TokenStream {
34    let input = parse_macro_input!(input as syn::DeriveInput);
35    let name = input.ident;
36    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
37
38    let mut custom_tag = None;
39    for attr in &input.attrs {
40        if !attr.path().is_ident("event") {
41            continue;
42        }
43
44        let meta_list =
45            match attr.parse_args_with(Punctuated::<syn::Meta, Token![,]>::parse_terminated) {
46                Ok(list) => list,
47                Err(e) => return e.to_compile_error().into(),
48            };
49
50        for meta in meta_list {
51            match meta {
52                syn::Meta::NameValue(nv) if nv.path.is_ident("tag") => {
53                    let lit_str =
54                        match syn::parse2::<syn::LitStr>(nv.value.clone().into_token_stream()) {
55                            Ok(lit) => lit,
56                            Err(_) => {
57                                let msg = "`tag` attribute must be a string literal";
58                                return syn::Error::new_spanned(nv.value, msg)
59                                    .to_compile_error()
60                                    .into();
61                            }
62                        };
63
64                    if custom_tag.is_some() {
65                        let msg = "`tag` specified multiple times";
66                        return syn::Error::new_spanned(nv, msg).to_compile_error().into();
67                    }
68
69                    custom_tag = Some(lit_str);
70                }
71                _ => {
72                    let msg = "unknown attribute parameter, expected `tag = \"...\"`";
73                    return syn::Error::new_spanned(meta, msg).to_compile_error().into();
74                }
75            }
76        }
77    }
78
79    let tag_expr = if let Some(lit) = custom_tag {
80        quote! { #lit }
81    } else {
82        quote! { concat!(module_path!(), "::", stringify!(#name)) }
83    };
84
85    let expanded = quote! {
86        impl #impl_generics ::ioevent::event::Event for #name #ty_generics #where_clause {
87            const TAG: &'static str = #tag_expr;
88        }
89
90        impl #impl_generics TryFrom<&::ioevent::event::EventData> for #name #ty_generics #where_clause {
91            type Error = ::ioevent::error::TryFromEventError;
92            fn try_from(value: &::ioevent::event::EventData) -> ::core::result::Result<Self, Self::Error> {
93                ::core::result::Result::Ok(value.payload.deserialized()?)
94            }
95        }
96    };
97
98    TokenStream::from(expanded)
99}
100
101/// Creates an event subscriber from an async function.
102///
103/// This macro transforms an async function into an event subscriber that can be registered
104/// with the event system. The function must take either one or two parameters:
105/// * A state parameter (optional)
106/// * An event parameter
107/// * A return value of type `Result` (optional)
108///
109/// # Examples
110///
111/// ```rust
112/// #[subscriber]
113/// async fn handle_event(event: MyEvent) -> Result {
114///     // Handle the event
115///     Ok(())
116/// }
117/// ```
118#[proc_macro_attribute]
119pub fn subscriber(_attr: TokenStream, item: TokenStream) -> TokenStream {
120    let original_fn = parse_macro_input!(item as ItemFn);
121
122    if original_fn.sig.asyncness.is_none() {
123        return quote! { compile_error!("subscriber macro can only be applied to async functions"); }.into();
124    }
125
126    let params = original_fn.sig.inputs.iter().collect::<Vec<_>>();
127    let (state_param, event_param) = match params.len() {
128        1 => (None, params[0]),
129        2 => (Some(params[0]), params[1]),
130        _ => panic!("Expected 1 or 2 parameters"),
131    };
132
133    let (event_ty, event_name) = match event_param {
134        FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
135        _ => panic!("Event parameter must be a typed parameter"),
136    };
137
138    let state_ty_name = state_param.map(|param| match param {
139        FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
140        _ => panic!("State parameter must be a typed parameter"),
141    });
142    
143    let raw_generics = &original_fn.sig.generics.type_params().map(|v|v.clone()).collect::<Vec<_>>();
144
145    let (generics, new_params) = if let Some((state_ty, state_name)) = state_ty_name {
146        let params = quote! {
147            #state_name: &#state_ty,
148            #event_name: &::ioevent::event::EventData
149        };
150        (quote! { <#(#raw_generics),*> }, params)
151    } else {
152        let params = quote! {
153            _state: &::ioevent::state::State<_STATE>,
154            #event_name: &::ioevent::event::EventData
155        };
156        (quote! { <#(#raw_generics),* _STATE> }, params)
157    };
158
159    let event_try_into = quote! {
160        let #event_name: ::core::result::Result<#event_ty, ::ioevent::error::TryFromEventError> = ::std::convert::TryInto::try_into(#event_name);
161    };
162
163    let state_clone = if let Some((_, state_name)) = state_ty_name {
164        quote! {
165            let #state_name = ::std::clone::Clone::clone(#state_name);
166        }
167    } else {
168        quote! {}
169    };
170
171    let return_expr = if matches!(original_fn.sig.output, ReturnType::Default) {
172        Some(quote! { Ok(()) })
173    } else {
174        None
175    };
176
177    let original_stmts = &original_fn.block.stmts;
178
179    let async_block = quote! {
180        async move {
181            let #event_name = #event_name?;
182            #(#original_stmts)*
183            #return_expr
184        }
185    };
186
187    let func_name = &original_fn.sig.ident;
188
189    let mod_name = format_ident!("{}", func_name);
190
191    let vis = &original_fn.vis;
192
193    let mod_block = quote! {
194        #[doc(hidden)]
195        #vis mod #mod_name {
196            use super::*;
197            pub type _Event = #event_ty;
198        }
199    };
200
201    let expanded = quote! {
202        #vis fn #func_name #generics (#new_params) -> ::ioevent::future::SubscribeFutureRet {
203            #event_try_into
204            #state_clone
205            ::std::boxed::Box::pin(#async_block)
206        }
207        #mod_block
208    };
209
210    TokenStream::from(expanded)
211}
212
213/// Derives the `ProcedureCall` trait for a type.
214///
215/// This macro implements the `ProcedureCall` trait for a type, allowing it to be used
216/// in remote procedure calls. It provides serialization and deserialization capabilities
217/// for the type.
218///
219/// # Attributes
220///
221/// * `#[procedure(path = "custom_path")]` - Specifies a custom path for the procedure.
222///   If not provided, the path will be generated from the module path and type name.
223///   
224/// # Requires
225///
226/// * The type must implement the `Serialize` and `Deserialize` traits from the `serde` crate.
227///
228/// # Examples
229///
230/// ```rust
231/// #[derive(ProcedureCall)]
232/// struct MyProcedure {
233///     field: String,
234/// }
235/// ```
236#[proc_macro_derive(ProcedureCall, attributes(procedure))]
237pub fn derive_procedure_call(input: TokenStream) -> TokenStream {
238    let input = parse_macro_input!(input as syn::DeriveInput);
239    let name = input.ident;
240    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
241
242    let mut custom_path = None;
243    for attr in &input.attrs {
244        if !attr.path().is_ident("procedure") {
245            continue;
246        }
247
248        let meta_list = match attr.parse_args_with(Punctuated::<syn::Meta, Token![,]>::parse_terminated) {
249            Ok(list) => list,
250            Err(e) => return e.to_compile_error().into(),
251        };
252
253        for meta in meta_list {
254            match meta {
255                syn::Meta::NameValue(nv) if nv.path.is_ident("path") => {
256                    let lit_str = match syn::parse2::<syn::LitStr>(nv.value.clone().into_token_stream()) {
257                        Ok(lit) => lit,
258                        Err(_) => {
259                            let msg = "`path` attribute must be a string literal";
260                            return syn::Error::new_spanned(nv.value, msg)
261                                .to_compile_error()
262                                .into();
263                        }
264                    };
265
266                    if custom_path.is_some() {
267                        let msg = "`path` specified multiple times";
268                        return syn::Error::new_spanned(nv, msg).to_compile_error().into();
269                    }
270
271                    custom_path = Some(lit_str);
272                }
273                _ => {
274                    let msg = "unknown attribute parameter, expected `path = \"...\"`";
275                    return syn::Error::new_spanned(meta, msg).to_compile_error().into();
276                }
277            }
278        }
279    }
280
281    let path_expr = if let Some(lit) = custom_path {
282        quote! { #lit }
283    } else {
284        quote! { concat!(module_path!(), "::", stringify!(#name)) }
285    };
286
287    let expanded = quote! {
288        impl #impl_generics ::ioevent::state::ProcedureCall for #name #ty_generics #where_clause {
289            fn path() -> String {
290                #path_expr.to_owned()
291            }
292        }
293
294        impl #impl_generics TryFrom<::ioevent::state::ProcedureCallData> for #name #ty_generics #where_clause {
295            type Error = ::ioevent::error::TryFromEventError;
296            fn try_from(value: ::ioevent::state::ProcedureCallData) -> ::core::result::Result<Self, Self::Error> {
297                ::core::result::Result::Ok(value.payload.deserialized()?)
298            }
299        }
300    };
301
302    TokenStream::from(expanded)
303}
304
305/// Creates a procedure handler from an async function.
306///
307/// This macro transforms an async function into a procedure handler that can be registered
308/// with the procedure call system. The function must take either one or two parameters:
309/// * A state parameter (optional)
310/// * A procedure parameter
311///
312/// # Examples
313///
314/// ```rust
315/// #[procedure]
316/// async fn handle_procedure(proc: MyProcedureRequest) -> Result {
317///     // Handle the procedure
318///     Ok(MyProcedureResponse)
319/// }
320/// ```
321#[proc_macro_attribute]
322pub fn procedure(_attr: TokenStream, item: TokenStream) -> TokenStream {
323    let original_fn = parse_macro_input!(item as ItemFn);
324
325    if original_fn.sig.asyncness.is_none() {
326        return quote! { compile_error!("procedure macro can only be applied to async functions"); }.into();
327    }
328    
329    let params = original_fn.sig.inputs.iter().collect::<Vec<_>>();
330    let (state_param, event_param) = match params.len() {
331        1 => (None, params[0]),
332        2 => (Some(params[0]), params[1]),
333        _ => panic!("Expected 1 or 2 parameters"),
334    };
335
336    let (event_ty, event_name) = match event_param {
337        FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
338        _ => panic!("Event parameter must be a typed parameter"),
339    };
340
341    let state_ty_name = state_param.map(|param| match param {
342        FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
343        _ => panic!("State parameter must be a typed parameter"),
344    });
345    
346    let raw_generics = &original_fn.sig.generics.type_params().map(|v|v.clone()).collect::<Vec<_>>();
347
348    let (generics, new_params) = if let Some((state_ty, state_name)) = state_ty_name {
349        let params = quote! {
350            #state_name: &#state_ty,
351            #event_name: &::ioevent::event::EventData
352        };
353        (quote! { <#(#raw_generics),*> }, params)
354    } else {
355        let params = quote! {
356            _state: &::ioevent::state::State<_STATE>,
357            #event_name: &::ioevent::event::EventData
358        };
359        (quote! { <#(#raw_generics),* _STATE: ::ioevent::state::ProcedureCallWright + ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static> }, params)
360    };
361
362    let event_try_into = quote! {
363        let #event_name: ::core::result::Result<::ioevent::state::ProcedureCallData, ::ioevent::error::TryFromEventError> = ::std::convert::TryInto::try_into(#event_name);
364    };
365
366    let state_clone = if let Some((_, state_name)) = state_ty_name {
367        quote! {
368            let #state_name = ::std::clone::Clone::clone(#state_name);
369        }
370    } else {
371        quote! {
372            let _state = ::std::clone::Clone::clone(_state);
373        }
374    };
375
376    let original_stmts = &original_fn.block.stmts;
377
378    let async_block = if let Some((_, state_name)) = state_ty_name {
379        quote! {
380            async move {
381                let #event_name = #event_name?;
382                if <#event_ty as ::ioevent::state::ProcedureCallRequest>::match_self(&#event_name) {
383                    let echo = #event_name.echo;
384                    let #event_name = <#event_ty as ::std::convert::TryFrom<::ioevent::state::ProcedureCallData>>::try_from(#event_name)?;
385                    let response: ::core::result::Result<_, ::ioevent::error::CallSubscribeError> = {
386                        #(#original_stmts)*
387                    };
388                    ::ioevent::state::ProcedureCallExt::resolve::<#event_ty>(&#state_name, echo, &response?).await?;
389                }
390                Ok(())
391            }
392        }
393    } else {
394        quote! {
395            async move {
396                let #event_name = #event_name?;
397                if <#event_ty as ::ioevent::state::ProcedureCallRequest>::match_self(&#event_name) {
398                    let echo = #event_name.echo;
399                    let #event_name = <#event_ty as ::std::convert::TryFrom<::ioevent::state::ProcedureCallData>>::try_from(#event_name)?;
400                    let response: ::core::result::Result<_, ::ioevent::error::CallSubscribeError> = {
401                        #(#original_stmts)*
402                    };
403                    ::ioevent::state::ProcedureCallExt::resolve::<#event_ty>(&_state, echo, &response?).await?;
404                }
405                Ok(())
406            }
407        }
408    };
409
410    let func_name = &original_fn.sig.ident;
411    let mod_name = format_ident!("{}", func_name);
412
413    let vis = &original_fn.vis;
414
415    let mod_block = quote! {
416        #[doc(hidden)]
417        #vis mod #mod_name {
418            use super::*;
419            pub type _Event = ::ioevent::state::ProcedureCallData;
420        }
421    };
422
423    let expanded = quote! {
424        #vis fn #func_name #generics (#new_params) -> ::ioevent::future::SubscribeFutureRet {
425            #event_try_into
426            #state_clone
427            ::std::boxed::Box::pin(#async_block)
428        }
429        #mod_block
430    };
431
432    TokenStream::from(expanded)
433}