archy_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, DeriveInput, Fields, ItemImpl, ImplItem, FnArg, ReturnType, Pat};
4
5/// Derive macro for Service structs - generates ServiceFactory implementation
6///
7/// ```ignore
8/// #[derive(Service)]
9/// struct PaymentService {
10///     config: Res<Config>,
11///     orders: Client<OrderService>,
12/// }
13/// ```
14///
15/// Generates:
16/// ```ignore
17/// impl ::archy::ServiceFactory for PaymentService {
18///     fn create(app: &::archy::App) -> Self {
19///         PaymentService {
20///             config: app.extract(),
21///             orders: app.extract(),
22///         }
23///     }
24/// }
25/// ```
26#[proc_macro_derive(Service)]
27pub fn derive_service(input: TokenStream) -> TokenStream {
28    let input = parse_macro_input!(input as DeriveInput);
29    let name = &input.ident;
30
31    let fields = match &input.data {
32        syn::Data::Struct(data) => match &data.fields {
33            Fields::Named(fields) => &fields.named,
34            _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs with named fields")
35                .to_compile_error()
36                .into(),
37        },
38        _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs")
39            .to_compile_error()
40            .into(),
41    };
42
43    let field_inits = fields.iter().map(|f| {
44        let field_name = f.ident.as_ref().unwrap();
45        quote! { #field_name: app.extract() }
46    });
47
48    let expanded = quote! {
49        impl ::archy::ServiceFactory for #name {
50            fn create(app: &::archy::App) -> Self {
51                #name {
52                    #(#field_inits),*
53                }
54            }
55        }
56    };
57
58    TokenStream::from(expanded)
59}
60
61/// Attribute macro for Service impl blocks - generates message enum, Service impl, and Client methods
62///
63/// ```ignore
64/// #[service]
65/// impl PaymentService {
66///     pub async fn process(&self, amount: u32) -> String {
67///         // ...
68///     }
69/// }
70/// ```
71///
72/// Generates:
73/// - Message enum with variants for each public async method
74/// - impl Service for PaymentService
75/// - Client extension trait + impl
76#[proc_macro_attribute]
77pub fn service(_attr: TokenStream, item: TokenStream) -> TokenStream {
78    let input = parse_macro_input!(item as ItemImpl);
79    let service_name = match &*input.self_ty {
80        syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
81        _ => return syn::Error::new_spanned(&input.self_ty, "#[service] must be applied to an impl block for a named type")
82            .to_compile_error()
83            .into(),
84    };
85
86    let msg_enum_name = format_ident!("{}Msg", service_name);
87    let client_trait_name = format_ident!("{}Client", service_name);
88
89    // Collect public async methods
90    let mut methods = Vec::new();
91    for item in &input.items {
92        if let ImplItem::Fn(method) = item {
93            let is_pub = matches!(method.vis, syn::Visibility::Public(_));
94            let is_async = method.sig.asyncness.is_some();
95            let has_self = method.sig.inputs.first().map_or(false, |arg| matches!(arg, FnArg::Receiver(_)));
96
97            if is_pub && is_async && has_self {
98                methods.push(method);
99            }
100        }
101    }
102
103    // Generate message enum variants
104    let msg_variants = methods.iter().map(|method| {
105        let method_name = &method.sig.ident;
106        let variant_name = to_pascal_case(&method_name.to_string());
107        let variant_ident = format_ident!("{}", variant_name);
108
109        // Get parameters (skip &self)
110        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
111            if let FnArg::Typed(pat_type) = arg {
112                if let Pat::Ident(pat_ident) = &*pat_type.pat {
113                    let name = &pat_ident.ident;
114                    let ty = &pat_type.ty;
115                    return Some(quote! { #name: #ty });
116                }
117            }
118            None
119        }).collect();
120
121        // Fire-and-forget optimization: skip respond field for unit returns
122        if is_unit_return(&method.sig.output) {
123            quote! {
124                #variant_ident { #(#params),* }
125            }
126        } else {
127            let return_type = match &method.sig.output {
128                ReturnType::Default => quote! { () },
129                ReturnType::Type(_, ty) => quote! { #ty },
130            };
131            quote! {
132                #variant_ident { #(#params,)* respond: ::archy::tokio::sync::oneshot::Sender<#return_type> }
133            }
134        }
135    });
136
137    // Generate handle match arms
138    let match_arms = methods.iter().map(|method| {
139        let method_name = &method.sig.ident;
140        let variant_name = to_pascal_case(&method_name.to_string());
141        let variant_ident = format_ident!("{}", variant_name);
142
143        // Get parameter names (skip &self)
144        let param_names: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
145            if let FnArg::Typed(pat_type) = arg {
146                if let Pat::Ident(pat_ident) = &*pat_type.pat {
147                    return Some(&pat_ident.ident);
148                }
149            }
150            None
151        }).collect();
152
153        let method_call = if param_names.is_empty() {
154            quote! { self.#method_name().await }
155        } else {
156            quote! { self.#method_name(#(#param_names),*).await }
157        };
158
159        // Fire-and-forget optimization: no respond for unit returns
160        if is_unit_return(&method.sig.output) {
161            let param_pattern = if param_names.is_empty() {
162                quote! {}
163            } else {
164                quote! { #(#param_names),* }
165            };
166            quote! {
167                #msg_enum_name::#variant_ident { #param_pattern } => {
168                    #method_call;
169                }
170            }
171        } else {
172            let param_pattern = if param_names.is_empty() {
173                quote! { respond }
174            } else {
175                quote! { #(#param_names,)* respond }
176            };
177            quote! {
178                #msg_enum_name::#variant_ident { #param_pattern } => {
179                    let result = #method_call;
180                    let _ = respond.send(result);
181                }
182            }
183        }
184    });
185
186    // Generate client trait methods (using async fn for cleaner syntax)
187    let client_trait_methods = methods.iter().map(|method| {
188        let method_name = &method.sig.ident;
189
190        // Get parameters with types (skip &self)
191        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
192            if let FnArg::Typed(pat_type) = arg {
193                if let Pat::Ident(pat_ident) = &*pat_type.pat {
194                    let name = &pat_ident.ident;
195                    let ty = &pat_type.ty;
196                    return Some(quote! { #name: #ty });
197                }
198            }
199            None
200        }).collect();
201
202        // Get return type wrapped in Result
203        let return_type = match &method.sig.output {
204            ReturnType::Default => quote! { () },
205            ReturnType::Type(_, ty) => quote! { #ty },
206        };
207
208        quote! {
209            async fn #method_name(&self, #(#params),*) -> ::std::result::Result<#return_type, ::archy::ServiceError>;
210        }
211    });
212
213    // Generate client trait impl methods
214    let client_impl_methods = methods.iter().map(|method| {
215        let method_name = &method.sig.ident;
216        let variant_name = to_pascal_case(&method_name.to_string());
217        let variant_ident = format_ident!("{}", variant_name);
218
219        // Get parameters with types (skip &self)
220        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
221            if let FnArg::Typed(pat_type) = arg {
222                if let Pat::Ident(pat_ident) = &*pat_type.pat {
223                    let name = &pat_ident.ident;
224                    let ty = &pat_type.ty;
225                    return Some((name.clone(), quote! { #ty }));
226                }
227            }
228            None
229        }).collect();
230
231        let param_decls: Vec<_> = params.iter().map(|(name, ty)| quote! { #name: #ty }).collect();
232        let param_names: Vec<_> = params.iter().map(|(name, _)| name).collect();
233
234        // Get return type
235        let return_type = match &method.sig.output {
236            ReturnType::Default => quote! { () },
237            ReturnType::Type(_, ty) => quote! { #ty },
238        };
239
240        // Fire-and-forget optimization for unit returns
241        if is_unit_return(&method.sig.output) {
242            let msg_construction = if param_names.is_empty() {
243                quote! { #msg_enum_name::#variant_ident {} }
244            } else {
245                quote! { #msg_enum_name::#variant_ident { #(#param_names),* } }
246            };
247
248            quote! {
249                async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
250                    self.sender.send(#msg_construction).await
251                        .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
252                    Ok(())
253                }
254            }
255        } else {
256            let msg_construction = if param_names.is_empty() {
257                quote! { #msg_enum_name::#variant_ident { respond: tx } }
258            } else {
259                quote! { #msg_enum_name::#variant_ident { #(#param_names,)* respond: tx } }
260            };
261
262            quote! {
263                async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
264                    let (tx, rx) = ::archy::tokio::sync::oneshot::channel();
265                    self.sender.send(#msg_construction).await
266                        .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
267                    rx.await.map_err(|_| ::archy::ServiceError::ServiceDropped)
268                }
269            }
270        }
271    });
272
273    let expanded = quote! {
274        // Original impl block (preserved)
275        #input
276
277        // Generated message enum
278        pub enum #msg_enum_name {
279            #(#msg_variants),*
280        }
281
282        // Generated Service implementation
283        impl ::archy::Service for #service_name {
284            type Message = #msg_enum_name;
285
286            fn create(app: &::archy::App) -> Self {
287                <Self as ::archy::ServiceFactory>::create(app)
288            }
289
290            fn handle(self: ::std::sync::Arc<Self>, msg: Self::Message) -> impl ::std::future::Future<Output = ()> + Send {
291                async move {
292                    match msg {
293                        #(#match_arms)*
294                    }
295                }
296            }
297        }
298
299        // Generated client trait
300        #[allow(async_fn_in_trait)]
301        pub trait #client_trait_name {
302            #(#client_trait_methods)*
303        }
304
305        // Generated client impl
306        impl #client_trait_name for ::archy::Client<#service_name> {
307            #(#client_impl_methods)*
308        }
309    };
310
311    TokenStream::from(expanded)
312}
313
314fn to_pascal_case(s: &str) -> String {
315    let mut result = String::new();
316    let mut capitalize_next = true;
317    for c in s.chars() {
318        if c == '_' {
319            capitalize_next = true;
320        } else if capitalize_next {
321            result.push(c.to_ascii_uppercase());
322            capitalize_next = false;
323        } else {
324            result.push(c);
325        }
326    }
327    result
328}
329
330/// Check if a return type is unit () - used for fire-and-forget optimization
331fn is_unit_return(output: &ReturnType) -> bool {
332    match output {
333        ReturnType::Default => true,
334        ReturnType::Type(_, ty) => {
335            if let syn::Type::Tuple(tuple) = &**ty {
336                tuple.elems.is_empty()
337            } else {
338                false
339            }
340        }
341    }
342}