archy_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, DeriveInput, Fields, ItemImpl, ImplItem, ImplItemFn, FnArg, ReturnType, Pat};
4
5/// Marker attribute to opt a method into span propagation.
6/// Use on methods within a `#[service]` impl block.
7///
8/// ```ignore
9/// #[service]
10/// impl MyService {
11///     #[traced]
12///     pub async fn important_operation(&self) -> String {
13///         // This method will propagate the caller's tracing span
14///     }
15/// }
16/// ```
17#[proc_macro_attribute]
18pub fn traced(_attr: TokenStream, item: TokenStream) -> TokenStream {
19    // This is just a marker - the actual processing happens in #[service]
20    item
21}
22
23/// Marker attribute to opt a method out of span propagation.
24/// Use on methods within a `#[service(traced)]` impl block.
25///
26/// ```ignore
27/// #[service(traced)]
28/// impl MyService {
29///     #[untraced]
30///     pub async fn cache_refresh(&self) {
31///         // This method will NOT propagate spans (no overhead)
32///     }
33/// }
34/// ```
35#[proc_macro_attribute]
36pub fn untraced(_attr: TokenStream, item: TokenStream) -> TokenStream {
37    // This is just a marker - the actual processing happens in #[service]
38    item
39}
40
41/// Marker attribute for service startup hook.
42/// Use on a single method within a `#[service]` impl block.
43/// The method runs after service creation, before workers start receiving messages.
44///
45/// **Note:** Cannot call other services - workers aren't running yet.
46///
47/// ```ignore
48/// #[service]
49/// impl CacheService {
50///     #[startup]
51///     async fn load_cache(&self) {
52///         // Runs before any messages are processed
53///         let data = self.db.load_all().await;
54///         *self.cache.write() = data;
55///     }
56/// }
57/// ```
58#[proc_macro_attribute]
59pub fn startup(_attr: TokenStream, item: TokenStream) -> TokenStream {
60    // This is just a marker - the actual processing happens in #[service]
61    item
62}
63
64/// Marker attribute for service shutdown hook.
65/// Use on a single method within a `#[service]` impl block.
66/// The method runs on shutdown, before channels close.
67///
68/// **Note:** Can call other services - workers are still running.
69///
70/// ```ignore
71/// #[service]
72/// impl CacheService {
73///     #[shutdown]
74///     async fn flush_cache(&self) {
75///         // Runs on shutdown, can still call other services
76///         self.db.save_all(&self.cache.read()).await;
77///     }
78/// }
79/// ```
80#[proc_macro_attribute]
81pub fn shutdown(_attr: TokenStream, item: TokenStream) -> TokenStream {
82    // This is just a marker - the actual processing happens in #[service]
83    item
84}
85
86/// Derive macro for Service structs - generates ServiceFactory implementation
87///
88/// ```ignore
89/// #[derive(Service)]
90/// struct PaymentService {
91///     config: Res<Config>,
92///     orders: Client<OrderService>,
93/// }
94/// ```
95///
96/// Generates:
97/// ```ignore
98/// impl ::archy::ServiceFactory for PaymentService {
99///     fn create(app: &::archy::App) -> Self {
100///         PaymentService {
101///             config: app.extract(),
102///             orders: app.extract(),
103///         }
104///     }
105/// }
106/// ```
107#[proc_macro_derive(Service)]
108pub fn derive_service(input: TokenStream) -> TokenStream {
109    let input = parse_macro_input!(input as DeriveInput);
110    let name = &input.ident;
111
112    let fields = match &input.data {
113        syn::Data::Struct(data) => match &data.fields {
114            Fields::Named(fields) => &fields.named,
115            _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs with named fields")
116                .to_compile_error()
117                .into(),
118        },
119        _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs")
120            .to_compile_error()
121            .into(),
122    };
123
124    let field_inits = fields.iter().map(|f| {
125        let field_name = f.ident.as_ref().unwrap();
126        quote! { #field_name: app.extract() }
127    });
128
129    let expanded = quote! {
130        impl ::archy::ServiceFactory for #name {
131            fn create(app: &::archy::App) -> Self {
132                #name {
133                    #(#field_inits),*
134                }
135            }
136        }
137    };
138
139    TokenStream::from(expanded)
140}
141
142/// Attribute macro for Service impl blocks - generates message enum, Service impl, and Client methods
143///
144/// # Basic usage
145/// ```ignore
146/// #[service]
147/// impl PaymentService {
148///     pub async fn process(&self, amount: u32) -> String {
149///         // ...
150///     }
151/// }
152/// ```
153///
154/// # Tracing support
155/// Use `#[service(traced)]` to propagate tracing spans across service calls:
156/// ```ignore
157/// #[service(traced)]
158/// impl PaymentService {
159///     pub async fn process(&self, amount: u32) -> String {
160///         tracing::info!("Processing"); // inherits caller's span context
161///     }
162///
163///     #[untraced]  // opt-out for this method
164///     pub async fn cache_refresh(&self) { ... }
165/// }
166/// ```
167///
168/// For non-traced services, individual methods can opt-in:
169/// ```ignore
170/// #[service]
171/// impl CleanupWorker {
172///     #[traced]  // opt-in just this method
173///     pub async fn handle_request(&self, id: u64) -> String { ... }
174/// }
175/// ```
176///
177/// Generates:
178/// - Message enum with variants for each public async method
179/// - impl Service for PaymentService
180/// - Client methods struct with async methods
181#[proc_macro_attribute]
182pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
183    // Parse the traced option from #[service] or #[service(traced)]
184    let service_traced = if attr.is_empty() {
185        false
186    } else {
187        match syn::parse::<syn::Ident>(attr.clone()) {
188            Ok(ident) if ident == "traced" => true,
189            Ok(ident) => return syn::Error::new(
190                ident.span(),
191                format!("expected `traced`, found `{}`", ident)
192            ).to_compile_error().into(),
193            Err(e) => return e.to_compile_error().into(),
194        }
195    };
196
197    let input = parse_macro_input!(item as ItemImpl);
198    let service_name = match &*input.self_ty {
199        syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
200        _ => return syn::Error::new_spanned(&input.self_ty, "#[service] must be applied to an impl block for a named type")
201            .to_compile_error()
202            .into(),
203    };
204
205    let msg_enum_name = format_ident!("{}Msg", service_name);
206    let methods_struct_name = format_ident!("{}Methods", service_name);
207
208    // Collect public async methods with their tracing status
209    // Also detect #[startup] and #[shutdown] lifecycle hooks
210    let mut methods: Vec<(&ImplItemFn, bool)> = Vec::new();
211    let mut startup_method: Option<&syn::Ident> = None;
212    let mut shutdown_method: Option<&syn::Ident> = None;
213
214    for item in &input.items {
215        if let ImplItem::Fn(method) = item {
216            let is_async = method.sig.asyncness.is_some();
217            let has_self = method.sig.inputs.first().map_or(false, |arg| matches!(arg, FnArg::Receiver(_)));
218
219            // Check for lifecycle attributes
220            let has_startup = has_attribute(&method.attrs, "startup");
221            let has_shutdown = has_attribute(&method.attrs, "shutdown");
222
223            // Validate lifecycle methods
224            if has_startup {
225                if !is_async || !has_self {
226                    return syn::Error::new_spanned(
227                        &method.sig.ident,
228                        "#[startup] method must be async fn(&self)"
229                    ).to_compile_error().into();
230                }
231                if method.sig.inputs.len() > 1 {
232                    return syn::Error::new_spanned(
233                        &method.sig.ident,
234                        "#[startup] method cannot have parameters other than &self"
235                    ).to_compile_error().into();
236                }
237                if startup_method.is_some() {
238                    return syn::Error::new_spanned(
239                        &method.sig.ident,
240                        "only one #[startup] method allowed per service"
241                    ).to_compile_error().into();
242                }
243                startup_method = Some(&method.sig.ident);
244                continue; // Don't add to regular methods
245            }
246
247            if has_shutdown {
248                if !is_async || !has_self {
249                    return syn::Error::new_spanned(
250                        &method.sig.ident,
251                        "#[shutdown] method must be async fn(&self)"
252                    ).to_compile_error().into();
253                }
254                if method.sig.inputs.len() > 1 {
255                    return syn::Error::new_spanned(
256                        &method.sig.ident,
257                        "#[shutdown] method cannot have parameters other than &self"
258                    ).to_compile_error().into();
259                }
260                if shutdown_method.is_some() {
261                    return syn::Error::new_spanned(
262                        &method.sig.ident,
263                        "only one #[shutdown] method allowed per service"
264                    ).to_compile_error().into();
265                }
266                shutdown_method = Some(&method.sig.ident);
267                continue; // Don't add to regular methods
268            }
269
270            let is_pub = matches!(method.vis, syn::Visibility::Public(_));
271
272            if is_pub && is_async && has_self {
273                // Check for #[traced] and #[untraced] attributes on the method
274                let has_traced = has_attribute(&method.attrs, "traced");
275                let has_untraced = has_attribute(&method.attrs, "untraced");
276
277                // Error if both attributes are present
278                if has_traced && has_untraced {
279                    return syn::Error::new_spanned(
280                        &method.sig.ident,
281                        "method cannot have both #[traced] and #[untraced] attributes"
282                    ).to_compile_error().into();
283                }
284
285                // Determine final traced status:
286                // - #[untraced] → not traced
287                // - #[traced] → traced
288                // - neither → use service default
289                let method_traced = if has_untraced {
290                    false
291                } else if has_traced {
292                    true
293                } else {
294                    service_traced
295                };
296
297                methods.push((method, method_traced));
298            }
299        }
300    }
301
302    // Generate message enum variants
303    let msg_variants = methods.iter().map(|(method, traced)| {
304        let method_name = &method.sig.ident;
305        let variant_name = to_pascal_case(&method_name.to_string());
306        let variant_ident = format_ident!("{}", variant_name);
307
308        // Get parameters (skip &self)
309        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
310            if let FnArg::Typed(pat_type) = arg {
311                if let Pat::Ident(pat_ident) = &*pat_type.pat {
312                    let name = &pat_ident.ident;
313                    let ty = &pat_type.ty;
314                    return Some(quote! { #name: #ty });
315                }
316            }
317            None
318        }).collect();
319
320        // Add span field for traced methods
321        let span_field = if *traced {
322            quote! { span: ::archy::tracing::Span, }
323        } else {
324            quote! {}
325        };
326
327        // Fire-and-forget optimization: skip respond field for unit returns
328        if is_unit_return(&method.sig.output) {
329            if params.is_empty() && !*traced {
330                quote! { #variant_ident }
331            } else {
332                quote! { #variant_ident { #(#params,)* #span_field } }
333            }
334        } else {
335            let return_type = match &method.sig.output {
336                ReturnType::Default => quote! { () },
337                ReturnType::Type(_, ty) => quote! { #ty },
338            };
339            quote! {
340                #variant_ident { #(#params,)* #span_field respond: ::archy::tokio::sync::oneshot::Sender<#return_type> }
341            }
342        }
343    });
344
345    // Generate handle match arms
346    let match_arms = methods.iter().map(|(method, traced)| {
347        let method_name = &method.sig.ident;
348        let variant_name = to_pascal_case(&method_name.to_string());
349        let variant_ident = format_ident!("{}", variant_name);
350
351        // Get parameter names (skip &self)
352        let param_names: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
353            if let FnArg::Typed(pat_type) = arg {
354                if let Pat::Ident(pat_ident) = &*pat_type.pat {
355                    return Some(&pat_ident.ident);
356                }
357            }
358            None
359        }).collect();
360
361        let method_call = if param_names.is_empty() {
362            quote! { self.#method_name().await }
363        } else {
364            quote! { self.#method_name(#(#param_names),*).await }
365        };
366
367        // Fire-and-forget optimization: no respond for unit returns
368        if is_unit_return(&method.sig.output) {
369            let span_pattern = if *traced { quote! { span, } } else { quote! {} };
370            let param_pattern = if param_names.is_empty() {
371                quote! { #span_pattern }
372            } else {
373                quote! { #(#param_names,)* #span_pattern }
374            };
375
376            if *traced {
377                quote! {
378                    #msg_enum_name::#variant_ident { #param_pattern } => {
379                        ::archy::tracing::Instrument::instrument(async {
380                            #method_call;
381                        }, span).await
382                    }
383                }
384            } else {
385                quote! {
386                    #msg_enum_name::#variant_ident { #param_pattern } => {
387                        #method_call;
388                    }
389                }
390            }
391        } else {
392            let span_pattern = if *traced { quote! { span, } } else { quote! {} };
393            let param_pattern = if param_names.is_empty() {
394                quote! { #span_pattern respond }
395            } else {
396                quote! { #(#param_names,)* #span_pattern respond }
397            };
398
399            if *traced {
400                quote! {
401                    #msg_enum_name::#variant_ident { #param_pattern } => {
402                        ::archy::tracing::Instrument::instrument(async {
403                            let result = #method_call;
404                            let _ = respond.send(result);
405                        }, span).await
406                    }
407                }
408            } else {
409                quote! {
410                    #msg_enum_name::#variant_ident { #param_pattern } => {
411                        let result = #method_call;
412                        let _ = respond.send(result);
413                    }
414                }
415            }
416        }
417    });
418
419    // Generate client inherent methods (no trait needed!)
420    let client_inherent_methods = methods.iter().map(|(method, traced)| {
421        let method_name = &method.sig.ident;
422        let variant_name = to_pascal_case(&method_name.to_string());
423        let variant_ident = format_ident!("{}", variant_name);
424
425        // Get parameters with types (skip &self)
426        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
427            if let FnArg::Typed(pat_type) = arg {
428                if let Pat::Ident(pat_ident) = &*pat_type.pat {
429                    let name = &pat_ident.ident;
430                    let ty = &pat_type.ty;
431                    return Some((name.clone(), quote! { #ty }));
432                }
433            }
434            None
435        }).collect();
436
437        let param_decls: Vec<_> = params.iter().map(|(name, ty)| quote! { #name: #ty }).collect();
438        let param_names: Vec<_> = params.iter().map(|(name, _)| name).collect();
439
440        // Get return type
441        let return_type = match &method.sig.output {
442            ReturnType::Default => quote! { () },
443            ReturnType::Type(_, ty) => quote! { #ty },
444        };
445
446        // Capture span for traced methods
447        let span_capture = if *traced {
448            quote! { let span = ::archy::tracing::Span::current(); }
449        } else {
450            quote! {}
451        };
452        let span_field = if *traced {
453            quote! { span, }
454        } else {
455            quote! {}
456        };
457
458        // Fire-and-forget optimization for unit returns
459        if is_unit_return(&method.sig.output) {
460            let msg_construction = if param_names.is_empty() && !*traced {
461                quote! { #msg_enum_name::#variant_ident }
462            } else {
463                quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field } }
464            };
465
466            quote! {
467                pub async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
468                    #span_capture
469                    self.sender.send(#msg_construction).await
470                        .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
471                    Ok(())
472                }
473            }
474        } else {
475            let msg_construction = if param_names.is_empty() {
476                quote! { #msg_enum_name::#variant_ident { #span_field respond: tx } }
477            } else {
478                quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field respond: tx } }
479            };
480
481            quote! {
482                pub async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
483                    #span_capture
484                    let (tx, rx) = ::archy::tokio::sync::oneshot::channel();
485                    self.sender.send(#msg_construction).await
486                        .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
487                    rx.await.map_err(|_| ::archy::ServiceError::ServiceDropped)
488                }
489            }
490        }
491    });
492
493    // Generate startup impl if a #[startup] method was found
494    let startup_impl = startup_method.map(|method_name| {
495        quote! {
496            fn startup(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
497                async move { self.#method_name().await }
498            }
499        }
500    });
501
502    // Generate shutdown impl if a #[shutdown] method was found
503    let shutdown_impl = shutdown_method.map(|method_name| {
504        quote! {
505            fn shutdown(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
506                async move { self.#method_name().await }
507            }
508        }
509    });
510
511    let expanded = quote! {
512        // Original impl block (preserved)
513        #input
514
515        // Generated message enum
516        pub enum #msg_enum_name {
517            #(#msg_variants),*
518        }
519
520        // Generated client methods struct
521        #[derive(Clone)]
522        pub struct #methods_struct_name {
523            sender: ::archy::async_channel::Sender<#msg_enum_name>,
524        }
525
526        // Implement ClientMethods trait for dependency injection
527        impl ::archy::ClientMethods<#service_name> for #methods_struct_name {
528            fn from_sender(sender: ::archy::async_channel::Sender<#msg_enum_name>) -> Self {
529                Self { sender }
530            }
531        }
532
533        // Inherent methods
534        impl #methods_struct_name {
535            #(#client_inherent_methods)*
536        }
537
538        // Generated Service implementation
539        impl ::archy::Service for #service_name {
540            type Message = #msg_enum_name;
541            type ClientMethods = #methods_struct_name;
542
543            fn create(app: &::archy::App) -> Self {
544                <Self as ::archy::ServiceFactory>::create(app)
545            }
546
547            #startup_impl
548
549            fn handle(self: ::std::sync::Arc<Self>, msg: Self::Message) -> impl ::std::future::Future<Output = ()> + Send {
550                async move {
551                    match msg {
552                        #(#match_arms)*
553                    }
554                }
555            }
556
557            #shutdown_impl
558        }
559    };
560
561    TokenStream::from(expanded)
562}
563
564fn to_pascal_case(s: &str) -> String {
565    let mut result = String::new();
566    let mut capitalize_next = true;
567    for c in s.chars() {
568        if c == '_' {
569            capitalize_next = true;
570        } else if capitalize_next {
571            result.push(c.to_ascii_uppercase());
572            capitalize_next = false;
573        } else {
574            result.push(c);
575        }
576    }
577    result
578}
579
580/// Check if an attribute list contains an attribute with the given name
581fn has_attribute(attrs: &[syn::Attribute], name: &str) -> bool {
582    attrs.iter().any(|attr| attr.path().is_ident(name))
583}
584
585/// Check if a return type is unit () - used for fire-and-forget optimization
586fn is_unit_return(output: &ReturnType) -> bool {
587    match output {
588        ReturnType::Default => true,
589        ReturnType::Type(_, ty) => {
590            if let syn::Type::Tuple(tuple) = &**ty {
591                tuple.elems.is_empty()
592            } else {
593                false
594            }
595        }
596    }
597}