1use proc_macro::TokenStream;
7use quote::{ToTokens, format_ident, quote};
8use syn::{FnArg, ItemFn, ReturnType, Token, parse_macro_input, punctuated::Punctuated};
9
10#[proc_macro_derive(Event, attributes(event))]
33pub fn derive_event(input: TokenStream) -> TokenStream {
34 let input = parse_macro_input!(input as syn::DeriveInput);
35 let name = input.ident;
36 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
37
38 let mut custom_tag = None;
39 for attr in &input.attrs {
40 if !attr.path().is_ident("event") {
41 continue;
42 }
43
44 let meta_list =
45 match attr.parse_args_with(Punctuated::<syn::Meta, Token![,]>::parse_terminated) {
46 Ok(list) => list,
47 Err(e) => return e.to_compile_error().into(),
48 };
49
50 for meta in meta_list {
51 match meta {
52 syn::Meta::NameValue(nv) if nv.path.is_ident("tag") => {
53 let lit_str =
54 match syn::parse2::<syn::LitStr>(nv.value.clone().into_token_stream()) {
55 Ok(lit) => lit,
56 Err(_) => {
57 let msg = "`tag` attribute must be a string literal";
58 return syn::Error::new_spanned(nv.value, msg)
59 .to_compile_error()
60 .into();
61 }
62 };
63
64 if custom_tag.is_some() {
65 let msg = "`tag` specified multiple times";
66 return syn::Error::new_spanned(nv, msg).to_compile_error().into();
67 }
68
69 custom_tag = Some(lit_str);
70 }
71 _ => {
72 let msg = "unknown attribute parameter, expected `tag = \"...\"`";
73 return syn::Error::new_spanned(meta, msg).to_compile_error().into();
74 }
75 }
76 }
77 }
78
79 let tag_expr = if let Some(lit) = custom_tag {
80 quote! { #lit }
81 } else {
82 quote! { concat!(module_path!(), "::", stringify!(#name)) }
83 };
84
85 let expanded = quote! {
86 impl #impl_generics ::ioevent::event::Event for #name #ty_generics #where_clause {
87 const TAG: &'static str = #tag_expr;
88 }
89
90 impl #impl_generics TryFrom<&::ioevent::event::EventData> for #name #ty_generics #where_clause {
91 type Error = ::ioevent::error::TryFromEventError;
92 fn try_from(value: &::ioevent::event::EventData) -> ::core::result::Result<Self, Self::Error> {
93 ::core::result::Result::Ok(value.payload.deserialized()?)
94 }
95 }
96 };
97
98 TokenStream::from(expanded)
99}
100
101#[proc_macro_attribute]
119pub fn subscriber(_attr: TokenStream, item: TokenStream) -> TokenStream {
120 let original_fn = parse_macro_input!(item as ItemFn);
121
122 if original_fn.sig.asyncness.is_none() {
123 return quote! { compile_error!("subscriber macro can only be applied to async functions"); }.into();
124 }
125
126 let params = original_fn.sig.inputs.iter().collect::<Vec<_>>();
127 let (state_param, event_param) = match params.len() {
128 1 => (None, params[0]),
129 2 => (Some(params[0]), params[1]),
130 _ => panic!("Expected 1 or 2 parameters"),
131 };
132
133 let (event_ty, event_name) = match event_param {
134 FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
135 _ => panic!("Event parameter must be a typed parameter"),
136 };
137
138 let state_ty_name = state_param.map(|param| match param {
139 FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
140 _ => panic!("State parameter must be a typed parameter"),
141 });
142
143 let raw_generics = &original_fn.sig.generics.type_params().map(|v|v.clone()).collect::<Vec<_>>();
144
145 let (generics, new_params) = if let Some((state_ty, state_name)) = state_ty_name {
146 let params = quote! {
147 #state_name: &#state_ty,
148 #event_name: &::ioevent::event::EventData
149 };
150 (quote! { <#(#raw_generics),*> }, params)
151 } else {
152 let params = quote! {
153 _state: &::ioevent::state::State<_STATE>,
154 #event_name: &::ioevent::event::EventData
155 };
156 (quote! { <#(#raw_generics),* _STATE> }, params)
157 };
158
159 let event_try_into = quote! {
160 let #event_name: ::core::result::Result<#event_ty, ::ioevent::error::TryFromEventError> = ::std::convert::TryInto::try_into(#event_name);
161 };
162
163 let state_clone = if let Some((_, state_name)) = state_ty_name {
164 quote! {
165 let #state_name = ::std::clone::Clone::clone(#state_name);
166 }
167 } else {
168 quote! {}
169 };
170
171 let return_expr = if matches!(original_fn.sig.output, ReturnType::Default) {
172 Some(quote! { Ok(()) })
173 } else {
174 None
175 };
176
177 let original_stmts = &original_fn.block.stmts;
178
179 let async_block = quote! {
180 async move {
181 let #event_name = #event_name?;
182 #(#original_stmts)*
183 #return_expr
184 }
185 };
186
187 let func_name = &original_fn.sig.ident;
188
189 let mod_name = format_ident!("{}", func_name);
190
191 let vis = &original_fn.vis;
192
193 let mod_block = quote! {
194 #[doc(hidden)]
195 #vis mod #mod_name {
196 use super::*;
197 pub type _Event = #event_ty;
198 }
199 };
200
201 let expanded = quote! {
202 #vis fn #func_name #generics (#new_params) -> ::ioevent::future::SubscribeFutureRet {
203 #event_try_into
204 #state_clone
205 ::std::boxed::Box::pin(#async_block)
206 }
207 #mod_block
208 };
209
210 TokenStream::from(expanded)
211}
212
213#[proc_macro_derive(ProcedureCall, attributes(procedure))]
237pub fn derive_procedure_call(input: TokenStream) -> TokenStream {
238 let input = parse_macro_input!(input as syn::DeriveInput);
239 let name = input.ident;
240 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
241
242 let mut custom_path = None;
243 for attr in &input.attrs {
244 if !attr.path().is_ident("procedure") {
245 continue;
246 }
247
248 let meta_list = match attr.parse_args_with(Punctuated::<syn::Meta, Token![,]>::parse_terminated) {
249 Ok(list) => list,
250 Err(e) => return e.to_compile_error().into(),
251 };
252
253 for meta in meta_list {
254 match meta {
255 syn::Meta::NameValue(nv) if nv.path.is_ident("path") => {
256 let lit_str = match syn::parse2::<syn::LitStr>(nv.value.clone().into_token_stream()) {
257 Ok(lit) => lit,
258 Err(_) => {
259 let msg = "`path` attribute must be a string literal";
260 return syn::Error::new_spanned(nv.value, msg)
261 .to_compile_error()
262 .into();
263 }
264 };
265
266 if custom_path.is_some() {
267 let msg = "`path` specified multiple times";
268 return syn::Error::new_spanned(nv, msg).to_compile_error().into();
269 }
270
271 custom_path = Some(lit_str);
272 }
273 _ => {
274 let msg = "unknown attribute parameter, expected `path = \"...\"`";
275 return syn::Error::new_spanned(meta, msg).to_compile_error().into();
276 }
277 }
278 }
279 }
280
281 let path_expr = if let Some(lit) = custom_path {
282 quote! { #lit }
283 } else {
284 quote! { concat!(module_path!(), "::", stringify!(#name)) }
285 };
286
287 let expanded = quote! {
288 impl #impl_generics ::ioevent::state::ProcedureCall for #name #ty_generics #where_clause {
289 fn path() -> String {
290 #path_expr.to_owned()
291 }
292 }
293
294 impl #impl_generics TryFrom<::ioevent::state::ProcedureCallData> for #name #ty_generics #where_clause {
295 type Error = ::ioevent::error::TryFromEventError;
296 fn try_from(value: ::ioevent::state::ProcedureCallData) -> ::core::result::Result<Self, Self::Error> {
297 ::core::result::Result::Ok(value.payload.deserialized()?)
298 }
299 }
300 };
301
302 TokenStream::from(expanded)
303}
304
305#[proc_macro_attribute]
322pub fn procedure(_attr: TokenStream, item: TokenStream) -> TokenStream {
323 let original_fn = parse_macro_input!(item as ItemFn);
324
325 if original_fn.sig.asyncness.is_none() {
326 return quote! { compile_error!("procedure macro can only be applied to async functions"); }.into();
327 }
328
329 let params = original_fn.sig.inputs.iter().collect::<Vec<_>>();
330 let (state_param, event_param) = match params.len() {
331 1 => (None, params[0]),
332 2 => (Some(params[0]), params[1]),
333 _ => panic!("Expected 1 or 2 parameters"),
334 };
335
336 let (event_ty, event_name) = match event_param {
337 FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
338 _ => panic!("Event parameter must be a typed parameter"),
339 };
340
341 let state_ty_name = state_param.map(|param| match param {
342 FnArg::Typed(pat_type) => (&pat_type.ty, &pat_type.pat),
343 _ => panic!("State parameter must be a typed parameter"),
344 });
345
346 let raw_generics = &original_fn.sig.generics.type_params().map(|v|v.clone()).collect::<Vec<_>>();
347
348 let (generics, new_params) = if let Some((state_ty, state_name)) = state_ty_name {
349 let params = quote! {
350 #state_name: &#state_ty,
351 #event_name: &::ioevent::event::EventData
352 };
353 (quote! { <#(#raw_generics),*> }, params)
354 } else {
355 let params = quote! {
356 _state: &::ioevent::state::State<_STATE>,
357 #event_name: &::ioevent::event::EventData
358 };
359 (quote! { <#(#raw_generics),* _STATE: ::ioevent::state::ProcedureCallWright + ::std::clone::Clone + ::std::marker::Send + ::std::marker::Sync + 'static> }, params)
360 };
361
362 let event_try_into = quote! {
363 let #event_name: ::core::result::Result<::ioevent::state::ProcedureCallData, ::ioevent::error::TryFromEventError> = ::std::convert::TryInto::try_into(#event_name);
364 };
365
366 let state_clone = if let Some((_, state_name)) = state_ty_name {
367 quote! {
368 let #state_name = ::std::clone::Clone::clone(#state_name);
369 }
370 } else {
371 quote! {
372 let _state = ::std::clone::Clone::clone(_state);
373 }
374 };
375
376 let original_stmts = &original_fn.block.stmts;
377
378 let async_block = if let Some((_, state_name)) = state_ty_name {
379 quote! {
380 async move {
381 let #event_name = #event_name?;
382 if <#event_ty as ::ioevent::state::ProcedureCallRequest>::match_self(&#event_name) {
383 let echo = #event_name.echo;
384 let #event_name = <#event_ty as ::std::convert::TryFrom<::ioevent::state::ProcedureCallData>>::try_from(#event_name)?;
385 let response: ::core::result::Result<_, ::ioevent::error::CallSubscribeError> = {
386 #(#original_stmts)*
387 };
388 ::ioevent::state::ProcedureCallExt::resolve::<#event_ty>(&#state_name, echo, &response?).await?;
389 }
390 Ok(())
391 }
392 }
393 } else {
394 quote! {
395 async move {
396 let #event_name = #event_name?;
397 if <#event_ty as ::ioevent::state::ProcedureCallRequest>::match_self(&#event_name) {
398 let echo = #event_name.echo;
399 let #event_name = <#event_ty as ::std::convert::TryFrom<::ioevent::state::ProcedureCallData>>::try_from(#event_name)?;
400 let response: ::core::result::Result<_, ::ioevent::error::CallSubscribeError> = {
401 #(#original_stmts)*
402 };
403 ::ioevent::state::ProcedureCallExt::resolve::<#event_ty>(&_state, echo, &response?).await?;
404 }
405 Ok(())
406 }
407 }
408 };
409
410 let func_name = &original_fn.sig.ident;
411 let mod_name = format_ident!("{}", func_name);
412
413 let vis = &original_fn.vis;
414
415 let mod_block = quote! {
416 #[doc(hidden)]
417 #vis mod #mod_name {
418 use super::*;
419 pub type _Event = ::ioevent::state::ProcedureCallData;
420 }
421 };
422
423 let expanded = quote! {
424 #vis fn #func_name #generics (#new_params) -> ::ioevent::future::SubscribeFutureRet {
425 #event_try_into
426 #state_clone
427 ::std::boxed::Box::pin(#async_block)
428 }
429 #mod_block
430 };
431
432 TokenStream::from(expanded)
433}