archy_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, DeriveInput, Fields, ItemImpl, ImplItem, ImplItemFn, FnArg, ReturnType, Pat};
4
5/// Marker attribute to opt a method into span propagation.
6/// Use on methods within a `#[service]` impl block.
7///
8/// ```ignore
9/// #[service]
10/// impl MyService {
11///     #[traced]
12///     pub async fn important_operation(&self) -> String {
13///         // This method will propagate the caller's tracing span
14///     }
15/// }
16/// ```
17#[proc_macro_attribute]
18pub fn traced(_attr: TokenStream, item: TokenStream) -> TokenStream {
19    // This is just a marker - the actual processing happens in #[service]
20    item
21}
22
23/// Marker attribute to opt a method out of span propagation.
24/// Use on methods within a `#[service(traced)]` impl block.
25///
26/// ```ignore
27/// #[service(traced)]
28/// impl MyService {
29///     #[untraced]
30///     pub async fn cache_refresh(&self) {
31///         // This method will NOT propagate spans (no overhead)
32///     }
33/// }
34/// ```
35#[proc_macro_attribute]
36pub fn untraced(_attr: TokenStream, item: TokenStream) -> TokenStream {
37    // This is just a marker - the actual processing happens in #[service]
38    item
39}
40
41/// Marker attribute for service startup hook.
42/// Use on a single method within a `#[service]` impl block.
43/// The method runs after service creation, before workers start receiving messages.
44///
45/// **Note:** Cannot call other services - workers aren't running yet.
46///
47/// ```ignore
48/// #[service]
49/// impl CacheService {
50///     #[startup]
51///     async fn load_cache(&self) {
52///         // Runs before any messages are processed
53///         let data = self.db.load_all().await;
54///         *self.cache.write() = data;
55///     }
56/// }
57/// ```
58#[proc_macro_attribute]
59pub fn startup(_attr: TokenStream, item: TokenStream) -> TokenStream {
60    // This is just a marker - the actual processing happens in #[service]
61    item
62}
63
64/// Marker attribute for service shutdown hook.
65/// Use on a single method within a `#[service]` impl block.
66/// The method runs on shutdown, before channels close.
67///
68/// **Note:** Can call other services - workers are still running.
69///
70/// ```ignore
71/// #[service]
72/// impl CacheService {
73///     #[shutdown]
74///     async fn flush_cache(&self) {
75///         // Runs on shutdown, can still call other services
76///         self.db.save_all(&self.cache.read()).await;
77///     }
78/// }
79/// ```
80#[proc_macro_attribute]
81pub fn shutdown(_attr: TokenStream, item: TokenStream) -> TokenStream {
82    // This is just a marker - the actual processing happens in #[service]
83    item
84}
85
86/// Derive macro for Service structs - generates ServiceFactory implementation
87///
88/// Automatically recognizes archy types (`Res`, `Client`, `Emit`, `Sub`, `Shutdown`)
89/// and injects them via `app.extract()`. All other fields use `Default::default()`.
90///
91/// For custom initialization of state fields, use `#[startup]` in your `#[service]` impl.
92///
93/// ```ignore
94/// #[derive(Service)]
95/// struct JobService {
96///     config: Res<Config>,              // → app.extract()
97///     metrics: Client<MetricsService>,  // → app.extract()
98///     cache: RwLock<HashMap<K, V>>,     // → Default::default()
99///     counter: AtomicU64,               // → Default::default()
100/// }
101///
102/// #[service]
103/// impl JobService {
104///     #[startup]
105///     async fn init(&self) {
106///         // Custom initialization if Default isn't enough
107///         self.counter.store(1, Ordering::SeqCst);
108///     }
109/// }
110/// ```
111///
112/// Generates:
113/// ```ignore
114/// impl ::archy::ServiceFactory for JobService {
115///     fn create(app: &::archy::App) -> Self {
116///         JobService {
117///             config: app.extract(),
118///             metrics: app.extract(),
119///             cache: Default::default(),
120///             counter: Default::default(),
121///         }
122///     }
123/// }
124/// ```
125#[proc_macro_derive(Service)]
126pub fn derive_service(input: TokenStream) -> TokenStream {
127    let input = parse_macro_input!(input as DeriveInput);
128    let name = &input.ident;
129
130    let fields = match &input.data {
131        syn::Data::Struct(data) => match &data.fields {
132            Fields::Named(fields) => &fields.named,
133            _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs with named fields")
134                .to_compile_error()
135                .into(),
136        },
137        _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs")
138            .to_compile_error()
139            .into(),
140    };
141
142    let field_inits = fields.iter().map(|f| {
143        let field_name = f.ident.as_ref().unwrap();
144        let field_ty = &f.ty;
145
146        if is_archy_injectable(field_ty) {
147            quote! { #field_name: app.extract() }
148        } else {
149            quote! { #field_name: ::std::default::Default::default() }
150        }
151    });
152
153    let expanded = quote! {
154        impl ::archy::ServiceFactory for #name {
155            fn create(app: &::archy::App) -> Self {
156                #name {
157                    #(#field_inits),*
158                }
159            }
160        }
161    };
162
163    TokenStream::from(expanded)
164}
165
166/// Check if a type is an archy injectable type (Res, Client, Emit, Sub, Shutdown)
167fn is_archy_injectable(ty: &syn::Type) -> bool {
168    if let syn::Type::Path(type_path) = ty {
169        if let Some(segment) = type_path.path.segments.last() {
170            let ident = segment.ident.to_string();
171            return matches!(ident.as_str(), "Res" | "Client" | "Emit" | "Sub" | "Shutdown");
172        }
173    }
174    false
175}
176
177/// Attribute macro for Service impl blocks - generates message enum, Service impl, and Client methods
178///
179/// # Basic usage
180/// ```ignore
181/// #[service]
182/// impl PaymentService {
183///     pub async fn process(&self, amount: u32) -> String {
184///         // ...
185///     }
186/// }
187/// ```
188///
189/// # Tracing support
190/// Use `#[service(traced)]` to propagate tracing spans across service calls:
191/// ```ignore
192/// #[service(traced)]
193/// impl PaymentService {
194///     pub async fn process(&self, amount: u32) -> String {
195///         tracing::info!("Processing"); // inherits caller's span context
196///     }
197///
198///     #[untraced]  // opt-out for this method
199///     pub async fn cache_refresh(&self) { ... }
200/// }
201/// ```
202///
203/// For non-traced services, individual methods can opt-in:
204/// ```ignore
205/// #[service]
206/// impl CleanupWorker {
207///     #[traced]  // opt-in just this method
208///     pub async fn handle_request(&self, id: u64) -> String { ... }
209/// }
210/// ```
211///
212/// Generates:
213/// - Message enum with variants for each public async method
214/// - impl Service for PaymentService
215/// - Client methods struct with async methods
216#[proc_macro_attribute]
217pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
218    // Parse the traced option from #[service] or #[service(traced)]
219    let service_traced = if attr.is_empty() {
220        false
221    } else {
222        match syn::parse::<syn::Ident>(attr.clone()) {
223            Ok(ident) if ident == "traced" => true,
224            Ok(ident) => return syn::Error::new(
225                ident.span(),
226                format!("expected `traced`, found `{}`", ident)
227            ).to_compile_error().into(),
228            Err(e) => return e.to_compile_error().into(),
229        }
230    };
231
232    let input = parse_macro_input!(item as ItemImpl);
233    let service_name = match &*input.self_ty {
234        syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
235        _ => return syn::Error::new_spanned(&input.self_ty, "#[service] must be applied to an impl block for a named type")
236            .to_compile_error()
237            .into(),
238    };
239
240    let msg_enum_name = format_ident!("{}Msg", service_name);
241    let methods_struct_name = format_ident!("{}Methods", service_name);
242
243    // Collect public async methods with their tracing status
244    // Also detect #[startup] and #[shutdown] lifecycle hooks
245    let mut methods: Vec<(&ImplItemFn, bool)> = Vec::new();
246    let mut startup_method: Option<&syn::Ident> = None;
247    let mut shutdown_method: Option<&syn::Ident> = None;
248
249    for item in &input.items {
250        if let ImplItem::Fn(method) = item {
251            let is_async = method.sig.asyncness.is_some();
252            let has_self = method.sig.inputs.first().map_or(false, |arg| matches!(arg, FnArg::Receiver(_)));
253
254            // Check for lifecycle attributes
255            let has_startup = has_attribute(&method.attrs, "startup");
256            let has_shutdown = has_attribute(&method.attrs, "shutdown");
257
258            // Validate lifecycle methods
259            if has_startup {
260                if !is_async || !has_self {
261                    return syn::Error::new_spanned(
262                        &method.sig.ident,
263                        "#[startup] method must be async fn(&self)"
264                    ).to_compile_error().into();
265                }
266                if method.sig.inputs.len() > 1 {
267                    return syn::Error::new_spanned(
268                        &method.sig.ident,
269                        "#[startup] method cannot have parameters other than &self"
270                    ).to_compile_error().into();
271                }
272                if startup_method.is_some() {
273                    return syn::Error::new_spanned(
274                        &method.sig.ident,
275                        "only one #[startup] method allowed per service"
276                    ).to_compile_error().into();
277                }
278                startup_method = Some(&method.sig.ident);
279                continue; // Don't add to regular methods
280            }
281
282            if has_shutdown {
283                if !is_async || !has_self {
284                    return syn::Error::new_spanned(
285                        &method.sig.ident,
286                        "#[shutdown] method must be async fn(&self)"
287                    ).to_compile_error().into();
288                }
289                if method.sig.inputs.len() > 1 {
290                    return syn::Error::new_spanned(
291                        &method.sig.ident,
292                        "#[shutdown] method cannot have parameters other than &self"
293                    ).to_compile_error().into();
294                }
295                if shutdown_method.is_some() {
296                    return syn::Error::new_spanned(
297                        &method.sig.ident,
298                        "only one #[shutdown] method allowed per service"
299                    ).to_compile_error().into();
300                }
301                shutdown_method = Some(&method.sig.ident);
302                continue; // Don't add to regular methods
303            }
304
305            let is_pub = matches!(method.vis, syn::Visibility::Public(_));
306
307            if is_pub && is_async && has_self {
308                // Check for #[traced] and #[untraced] attributes on the method
309                let has_traced = has_attribute(&method.attrs, "traced");
310                let has_untraced = has_attribute(&method.attrs, "untraced");
311
312                // Error if both attributes are present
313                if has_traced && has_untraced {
314                    return syn::Error::new_spanned(
315                        &method.sig.ident,
316                        "method cannot have both #[traced] and #[untraced] attributes"
317                    ).to_compile_error().into();
318                }
319
320                // Determine final traced status:
321                // - #[untraced] → not traced
322                // - #[traced] → traced
323                // - neither → use service default
324                let method_traced = if has_untraced {
325                    false
326                } else if has_traced {
327                    true
328                } else {
329                    service_traced
330                };
331
332                methods.push((method, method_traced));
333            }
334        }
335    }
336
337    // Generate message enum variants
338    let msg_variants = methods.iter().map(|(method, traced)| {
339        let method_name = &method.sig.ident;
340        let variant_name = to_pascal_case(&method_name.to_string());
341        let variant_ident = format_ident!("{}", variant_name);
342
343        // Get parameters (skip &self)
344        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
345            if let FnArg::Typed(pat_type) = arg {
346                if let Pat::Ident(pat_ident) = &*pat_type.pat {
347                    let name = &pat_ident.ident;
348                    let ty = &pat_type.ty;
349                    return Some(quote! { #name: #ty });
350                }
351            }
352            None
353        }).collect();
354
355        // Add span field for traced methods
356        let span_field = if *traced {
357            quote! { span: ::archy::tracing::Span, }
358        } else {
359            quote! {}
360        };
361
362        // Fire-and-forget optimization: skip respond field for unit returns
363        if is_unit_return(&method.sig.output) {
364            if params.is_empty() && !*traced {
365                quote! { #variant_ident }
366            } else {
367                quote! { #variant_ident { #(#params,)* #span_field } }
368            }
369        } else {
370            let return_type = match &method.sig.output {
371                ReturnType::Default => quote! { () },
372                ReturnType::Type(_, ty) => quote! { #ty },
373            };
374            quote! {
375                #variant_ident { #(#params,)* #span_field respond: ::archy::tokio::sync::oneshot::Sender<#return_type> }
376            }
377        }
378    });
379
380    // Generate handle match arms
381    let match_arms = methods.iter().map(|(method, traced)| {
382        let method_name = &method.sig.ident;
383        let variant_name = to_pascal_case(&method_name.to_string());
384        let variant_ident = format_ident!("{}", variant_name);
385
386        // Get parameter names (skip &self)
387        let param_names: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
388            if let FnArg::Typed(pat_type) = arg {
389                if let Pat::Ident(pat_ident) = &*pat_type.pat {
390                    return Some(&pat_ident.ident);
391                }
392            }
393            None
394        }).collect();
395
396        let method_call = if param_names.is_empty() {
397            quote! { self.#method_name().await }
398        } else {
399            quote! { self.#method_name(#(#param_names),*).await }
400        };
401
402        // Fire-and-forget optimization: no respond for unit returns
403        if is_unit_return(&method.sig.output) {
404            let span_pattern = if *traced { quote! { span, } } else { quote! {} };
405            let param_pattern = if param_names.is_empty() {
406                quote! { #span_pattern }
407            } else {
408                quote! { #(#param_names,)* #span_pattern }
409            };
410
411            if *traced {
412                quote! {
413                    #msg_enum_name::#variant_ident { #param_pattern } => {
414                        ::archy::tracing::Instrument::instrument(async {
415                            #method_call;
416                        }, span).await
417                    }
418                }
419            } else {
420                quote! {
421                    #msg_enum_name::#variant_ident { #param_pattern } => {
422                        #method_call;
423                    }
424                }
425            }
426        } else {
427            let span_pattern = if *traced { quote! { span, } } else { quote! {} };
428            let param_pattern = if param_names.is_empty() {
429                quote! { #span_pattern respond }
430            } else {
431                quote! { #(#param_names,)* #span_pattern respond }
432            };
433
434            if *traced {
435                quote! {
436                    #msg_enum_name::#variant_ident { #param_pattern } => {
437                        ::archy::tracing::Instrument::instrument(async {
438                            let result = #method_call;
439                            let _ = respond.send(result);
440                        }, span).await
441                    }
442                }
443            } else {
444                quote! {
445                    #msg_enum_name::#variant_ident { #param_pattern } => {
446                        let result = #method_call;
447                        let _ = respond.send(result);
448                    }
449                }
450            }
451        }
452    });
453
454    // Generate client inherent methods (no trait needed!)
455    let client_inherent_methods = methods.iter().map(|(method, traced)| {
456        let method_name = &method.sig.ident;
457        let variant_name = to_pascal_case(&method_name.to_string());
458        let variant_ident = format_ident!("{}", variant_name);
459
460        // Get parameters with types (skip &self)
461        let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
462            if let FnArg::Typed(pat_type) = arg {
463                if let Pat::Ident(pat_ident) = &*pat_type.pat {
464                    let name = &pat_ident.ident;
465                    let ty = &pat_type.ty;
466                    return Some((name.clone(), quote! { #ty }));
467                }
468            }
469            None
470        }).collect();
471
472        let param_decls: Vec<_> = params.iter().map(|(name, ty)| quote! { #name: #ty }).collect();
473        let param_names: Vec<_> = params.iter().map(|(name, _)| name).collect();
474
475        // Get return type
476        let return_type = match &method.sig.output {
477            ReturnType::Default => quote! { () },
478            ReturnType::Type(_, ty) => quote! { #ty },
479        };
480
481        // Capture span for traced methods
482        let span_capture = if *traced {
483            quote! { let span = ::archy::tracing::Span::current(); }
484        } else {
485            quote! {}
486        };
487        let span_field = if *traced {
488            quote! { span, }
489        } else {
490            quote! {}
491        };
492
493        // Fire-and-forget optimization for unit returns
494        if is_unit_return(&method.sig.output) {
495            let msg_construction = if param_names.is_empty() && !*traced {
496                quote! { #msg_enum_name::#variant_ident }
497            } else {
498                quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field } }
499            };
500
501            quote! {
502                pub async fn #method_name(&self, #(#param_decls),*) {
503                    #span_capture
504                    let _ = self.sender.send(#msg_construction).await;
505                }
506            }
507        } else {
508            let msg_construction = if param_names.is_empty() {
509                quote! { #msg_enum_name::#variant_ident { #span_field respond: tx } }
510            } else {
511                quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field respond: tx } }
512            };
513
514            quote! {
515                pub async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
516                    #span_capture
517                    let (tx, rx) = ::archy::tokio::sync::oneshot::channel();
518                    self.sender.send(#msg_construction).await
519                        .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
520                    rx.await.map_err(|_| ::archy::ServiceError::ServiceDropped)
521                }
522            }
523        }
524    });
525
526    // Generate startup impl if a #[startup] method was found
527    let startup_impl = startup_method.map(|method_name| {
528        quote! {
529            fn startup(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
530                async move { self.#method_name().await }
531            }
532        }
533    });
534
535    // Generate shutdown impl if a #[shutdown] method was found
536    let shutdown_impl = shutdown_method.map(|method_name| {
537        quote! {
538            fn shutdown(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
539                async move { self.#method_name().await }
540            }
541        }
542    });
543
544    let expanded = quote! {
545        // Original impl block (preserved)
546        #input
547
548        // Generated message enum
549        pub enum #msg_enum_name {
550            #(#msg_variants),*
551        }
552
553        // Generated client methods struct
554        #[derive(Clone)]
555        pub struct #methods_struct_name {
556            sender: ::archy::async_channel::Sender<#msg_enum_name>,
557        }
558
559        // Implement ClientMethods trait for dependency injection
560        impl ::archy::ClientMethods<#service_name> for #methods_struct_name {
561            fn from_sender(sender: ::archy::async_channel::Sender<#msg_enum_name>) -> Self {
562                Self { sender }
563            }
564        }
565
566        // Inherent methods
567        impl #methods_struct_name {
568            #(#client_inherent_methods)*
569        }
570
571        // Generated Service implementation
572        impl ::archy::Service for #service_name {
573            type Message = #msg_enum_name;
574            type ClientMethods = #methods_struct_name;
575
576            fn create(app: &::archy::App) -> Self {
577                <Self as ::archy::ServiceFactory>::create(app)
578            }
579
580            #startup_impl
581
582            fn handle(self: ::std::sync::Arc<Self>, msg: Self::Message) -> impl ::std::future::Future<Output = ()> + Send {
583                async move {
584                    match msg {
585                        #(#match_arms)*
586                    }
587                }
588            }
589
590            #shutdown_impl
591        }
592    };
593
594    TokenStream::from(expanded)
595}
596
597fn to_pascal_case(s: &str) -> String {
598    let mut result = String::new();
599    let mut capitalize_next = true;
600    for c in s.chars() {
601        if c == '_' {
602            capitalize_next = true;
603        } else if capitalize_next {
604            result.push(c.to_ascii_uppercase());
605            capitalize_next = false;
606        } else {
607            result.push(c);
608        }
609    }
610    result
611}
612
613/// Check if an attribute list contains an attribute with the given name
614fn has_attribute(attrs: &[syn::Attribute], name: &str) -> bool {
615    attrs.iter().any(|attr| attr.path().is_ident(name))
616}
617
618/// Check if a return type is unit () - used for fire-and-forget optimization
619fn is_unit_return(output: &ReturnType) -> bool {
620    match output {
621        ReturnType::Default => true,
622        ReturnType::Type(_, ty) => {
623            if let syn::Type::Tuple(tuple) = &**ty {
624                tuple.elems.is_empty()
625            } else {
626                false
627            }
628        }
629    }
630}