1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, DeriveInput, Fields, ItemImpl, ImplItem, ImplItemFn, FnArg, ReturnType, Pat};
4
5#[proc_macro_attribute]
18pub fn traced(_attr: TokenStream, item: TokenStream) -> TokenStream {
19 item
21}
22
23#[proc_macro_attribute]
36pub fn untraced(_attr: TokenStream, item: TokenStream) -> TokenStream {
37 item
39}
40
41#[proc_macro_attribute]
59pub fn startup(_attr: TokenStream, item: TokenStream) -> TokenStream {
60 item
62}
63
64#[proc_macro_attribute]
81pub fn shutdown(_attr: TokenStream, item: TokenStream) -> TokenStream {
82 item
84}
85
86#[proc_macro_derive(Service)]
126pub fn derive_service(input: TokenStream) -> TokenStream {
127 let input = parse_macro_input!(input as DeriveInput);
128 let name = &input.ident;
129
130 let fields = match &input.data {
131 syn::Data::Struct(data) => match &data.fields {
132 Fields::Named(fields) => &fields.named,
133 _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs with named fields")
134 .to_compile_error()
135 .into(),
136 },
137 _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs")
138 .to_compile_error()
139 .into(),
140 };
141
142 let field_inits = fields.iter().map(|f| {
143 let field_name = f.ident.as_ref().unwrap();
144 let field_ty = &f.ty;
145
146 if is_archy_injectable(field_ty) {
147 quote! { #field_name: app.extract() }
148 } else {
149 quote! { #field_name: ::std::default::Default::default() }
150 }
151 });
152
153 let expanded = quote! {
154 impl ::archy::ServiceFactory for #name {
155 fn create(app: &::archy::App) -> Self {
156 #name {
157 #(#field_inits),*
158 }
159 }
160 }
161 };
162
163 TokenStream::from(expanded)
164}
165
166fn is_archy_injectable(ty: &syn::Type) -> bool {
168 if let syn::Type::Path(type_path) = ty {
169 if let Some(segment) = type_path.path.segments.last() {
170 let ident = segment.ident.to_string();
171 return matches!(ident.as_str(), "Res" | "Client" | "Emit" | "Sub" | "Shutdown");
172 }
173 }
174 false
175}
176
177#[proc_macro_attribute]
217pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
218 let service_traced = if attr.is_empty() {
220 false
221 } else {
222 match syn::parse::<syn::Ident>(attr.clone()) {
223 Ok(ident) if ident == "traced" => true,
224 Ok(ident) => return syn::Error::new(
225 ident.span(),
226 format!("expected `traced`, found `{}`", ident)
227 ).to_compile_error().into(),
228 Err(e) => return e.to_compile_error().into(),
229 }
230 };
231
232 let input = parse_macro_input!(item as ItemImpl);
233 let service_name = match &*input.self_ty {
234 syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
235 _ => return syn::Error::new_spanned(&input.self_ty, "#[service] must be applied to an impl block for a named type")
236 .to_compile_error()
237 .into(),
238 };
239
240 let msg_enum_name = format_ident!("{}Msg", service_name);
241 let methods_struct_name = format_ident!("{}Methods", service_name);
242
243 let mut methods: Vec<(&ImplItemFn, bool)> = Vec::new();
246 let mut startup_method: Option<&syn::Ident> = None;
247 let mut shutdown_method: Option<&syn::Ident> = None;
248
249 for item in &input.items {
250 if let ImplItem::Fn(method) = item {
251 let is_async = method.sig.asyncness.is_some();
252 let has_self = method.sig.inputs.first().map_or(false, |arg| matches!(arg, FnArg::Receiver(_)));
253
254 let has_startup = has_attribute(&method.attrs, "startup");
256 let has_shutdown = has_attribute(&method.attrs, "shutdown");
257
258 if has_startup {
260 if !is_async || !has_self {
261 return syn::Error::new_spanned(
262 &method.sig.ident,
263 "#[startup] method must be async fn(&self)"
264 ).to_compile_error().into();
265 }
266 if method.sig.inputs.len() > 1 {
267 return syn::Error::new_spanned(
268 &method.sig.ident,
269 "#[startup] method cannot have parameters other than &self"
270 ).to_compile_error().into();
271 }
272 if startup_method.is_some() {
273 return syn::Error::new_spanned(
274 &method.sig.ident,
275 "only one #[startup] method allowed per service"
276 ).to_compile_error().into();
277 }
278 startup_method = Some(&method.sig.ident);
279 continue; }
281
282 if has_shutdown {
283 if !is_async || !has_self {
284 return syn::Error::new_spanned(
285 &method.sig.ident,
286 "#[shutdown] method must be async fn(&self)"
287 ).to_compile_error().into();
288 }
289 if method.sig.inputs.len() > 1 {
290 return syn::Error::new_spanned(
291 &method.sig.ident,
292 "#[shutdown] method cannot have parameters other than &self"
293 ).to_compile_error().into();
294 }
295 if shutdown_method.is_some() {
296 return syn::Error::new_spanned(
297 &method.sig.ident,
298 "only one #[shutdown] method allowed per service"
299 ).to_compile_error().into();
300 }
301 shutdown_method = Some(&method.sig.ident);
302 continue; }
304
305 let is_pub = matches!(method.vis, syn::Visibility::Public(_));
306
307 if is_pub && is_async && has_self {
308 let has_traced = has_attribute(&method.attrs, "traced");
310 let has_untraced = has_attribute(&method.attrs, "untraced");
311
312 if has_traced && has_untraced {
314 return syn::Error::new_spanned(
315 &method.sig.ident,
316 "method cannot have both #[traced] and #[untraced] attributes"
317 ).to_compile_error().into();
318 }
319
320 let method_traced = if has_untraced {
325 false
326 } else if has_traced {
327 true
328 } else {
329 service_traced
330 };
331
332 methods.push((method, method_traced));
333 }
334 }
335 }
336
337 let msg_variants = methods.iter().map(|(method, traced)| {
339 let method_name = &method.sig.ident;
340 let variant_name = to_pascal_case(&method_name.to_string());
341 let variant_ident = format_ident!("{}", variant_name);
342
343 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
345 if let FnArg::Typed(pat_type) = arg {
346 if let Pat::Ident(pat_ident) = &*pat_type.pat {
347 let name = &pat_ident.ident;
348 let ty = &pat_type.ty;
349 return Some(quote! { #name: #ty });
350 }
351 }
352 None
353 }).collect();
354
355 let span_field = if *traced {
357 quote! { span: ::archy::tracing::Span, }
358 } else {
359 quote! {}
360 };
361
362 if is_unit_return(&method.sig.output) {
364 if params.is_empty() && !*traced {
365 quote! { #variant_ident }
366 } else {
367 quote! { #variant_ident { #(#params,)* #span_field } }
368 }
369 } else {
370 let return_type = match &method.sig.output {
371 ReturnType::Default => quote! { () },
372 ReturnType::Type(_, ty) => quote! { #ty },
373 };
374 quote! {
375 #variant_ident { #(#params,)* #span_field respond: ::archy::tokio::sync::oneshot::Sender<#return_type> }
376 }
377 }
378 });
379
380 let match_arms = methods.iter().map(|(method, traced)| {
382 let method_name = &method.sig.ident;
383 let variant_name = to_pascal_case(&method_name.to_string());
384 let variant_ident = format_ident!("{}", variant_name);
385
386 let param_names: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
388 if let FnArg::Typed(pat_type) = arg {
389 if let Pat::Ident(pat_ident) = &*pat_type.pat {
390 return Some(&pat_ident.ident);
391 }
392 }
393 None
394 }).collect();
395
396 let method_call = if param_names.is_empty() {
397 quote! { self.#method_name().await }
398 } else {
399 quote! { self.#method_name(#(#param_names),*).await }
400 };
401
402 if is_unit_return(&method.sig.output) {
404 let span_pattern = if *traced { quote! { span, } } else { quote! {} };
405 let param_pattern = if param_names.is_empty() {
406 quote! { #span_pattern }
407 } else {
408 quote! { #(#param_names,)* #span_pattern }
409 };
410
411 if *traced {
412 quote! {
413 #msg_enum_name::#variant_ident { #param_pattern } => {
414 ::archy::tracing::Instrument::instrument(async {
415 #method_call;
416 }, span).await
417 }
418 }
419 } else {
420 quote! {
421 #msg_enum_name::#variant_ident { #param_pattern } => {
422 #method_call;
423 }
424 }
425 }
426 } else {
427 let span_pattern = if *traced { quote! { span, } } else { quote! {} };
428 let param_pattern = if param_names.is_empty() {
429 quote! { #span_pattern respond }
430 } else {
431 quote! { #(#param_names,)* #span_pattern respond }
432 };
433
434 if *traced {
435 quote! {
436 #msg_enum_name::#variant_ident { #param_pattern } => {
437 ::archy::tracing::Instrument::instrument(async {
438 let result = #method_call;
439 let _ = respond.send(result);
440 }, span).await
441 }
442 }
443 } else {
444 quote! {
445 #msg_enum_name::#variant_ident { #param_pattern } => {
446 let result = #method_call;
447 let _ = respond.send(result);
448 }
449 }
450 }
451 }
452 });
453
454 let client_inherent_methods = methods.iter().map(|(method, traced)| {
456 let method_name = &method.sig.ident;
457 let variant_name = to_pascal_case(&method_name.to_string());
458 let variant_ident = format_ident!("{}", variant_name);
459
460 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
462 if let FnArg::Typed(pat_type) = arg {
463 if let Pat::Ident(pat_ident) = &*pat_type.pat {
464 let name = &pat_ident.ident;
465 let ty = &pat_type.ty;
466 return Some((name.clone(), quote! { #ty }));
467 }
468 }
469 None
470 }).collect();
471
472 let param_decls: Vec<_> = params.iter().map(|(name, ty)| quote! { #name: #ty }).collect();
473 let param_names: Vec<_> = params.iter().map(|(name, _)| name).collect();
474
475 let return_type = match &method.sig.output {
477 ReturnType::Default => quote! { () },
478 ReturnType::Type(_, ty) => quote! { #ty },
479 };
480
481 let span_capture = if *traced {
483 quote! { let span = ::archy::tracing::Span::current(); }
484 } else {
485 quote! {}
486 };
487 let span_field = if *traced {
488 quote! { span, }
489 } else {
490 quote! {}
491 };
492
493 if is_unit_return(&method.sig.output) {
495 let msg_construction = if param_names.is_empty() && !*traced {
496 quote! { #msg_enum_name::#variant_ident }
497 } else {
498 quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field } }
499 };
500
501 quote! {
502 pub async fn #method_name(&self, #(#param_decls),*) {
503 #span_capture
504 let _ = self.sender.send(#msg_construction).await;
505 }
506 }
507 } else {
508 let msg_construction = if param_names.is_empty() {
509 quote! { #msg_enum_name::#variant_ident { #span_field respond: tx } }
510 } else {
511 quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field respond: tx } }
512 };
513
514 quote! {
515 pub async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
516 #span_capture
517 let (tx, rx) = ::archy::tokio::sync::oneshot::channel();
518 self.sender.send(#msg_construction).await
519 .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
520 rx.await.map_err(|_| ::archy::ServiceError::ServiceDropped)
521 }
522 }
523 }
524 });
525
526 let startup_impl = startup_method.map(|method_name| {
528 quote! {
529 fn startup(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
530 async move { self.#method_name().await }
531 }
532 }
533 });
534
535 let shutdown_impl = shutdown_method.map(|method_name| {
537 quote! {
538 fn shutdown(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
539 async move { self.#method_name().await }
540 }
541 }
542 });
543
544 let expanded = quote! {
545 #input
547
548 pub enum #msg_enum_name {
550 #(#msg_variants),*
551 }
552
553 #[derive(Clone)]
555 pub struct #methods_struct_name {
556 sender: ::archy::async_channel::Sender<#msg_enum_name>,
557 }
558
559 impl ::archy::ClientMethods<#service_name> for #methods_struct_name {
561 fn from_sender(sender: ::archy::async_channel::Sender<#msg_enum_name>) -> Self {
562 Self { sender }
563 }
564 }
565
566 impl #methods_struct_name {
568 #(#client_inherent_methods)*
569 }
570
571 impl ::archy::Service for #service_name {
573 type Message = #msg_enum_name;
574 type ClientMethods = #methods_struct_name;
575
576 fn create(app: &::archy::App) -> Self {
577 <Self as ::archy::ServiceFactory>::create(app)
578 }
579
580 #startup_impl
581
582 fn handle(self: ::std::sync::Arc<Self>, msg: Self::Message) -> impl ::std::future::Future<Output = ()> + Send {
583 async move {
584 match msg {
585 #(#match_arms)*
586 }
587 }
588 }
589
590 #shutdown_impl
591 }
592 };
593
594 TokenStream::from(expanded)
595}
596
597fn to_pascal_case(s: &str) -> String {
598 let mut result = String::new();
599 let mut capitalize_next = true;
600 for c in s.chars() {
601 if c == '_' {
602 capitalize_next = true;
603 } else if capitalize_next {
604 result.push(c.to_ascii_uppercase());
605 capitalize_next = false;
606 } else {
607 result.push(c);
608 }
609 }
610 result
611}
612
613fn has_attribute(attrs: &[syn::Attribute], name: &str) -> bool {
615 attrs.iter().any(|attr| attr.path().is_ident(name))
616}
617
618fn is_unit_return(output: &ReturnType) -> bool {
620 match output {
621 ReturnType::Default => true,
622 ReturnType::Type(_, ty) => {
623 if let syn::Type::Tuple(tuple) = &**ty {
624 tuple.elems.is_empty()
625 } else {
626 false
627 }
628 }
629 }
630}