1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, DeriveInput, Fields, ItemImpl, ImplItem, ImplItemFn, FnArg, ReturnType, Pat};
4
5#[proc_macro_attribute]
18pub fn traced(_attr: TokenStream, item: TokenStream) -> TokenStream {
19 item
21}
22
23#[proc_macro_attribute]
36pub fn untraced(_attr: TokenStream, item: TokenStream) -> TokenStream {
37 item
39}
40
41#[proc_macro_attribute]
59pub fn startup(_attr: TokenStream, item: TokenStream) -> TokenStream {
60 item
62}
63
64#[proc_macro_attribute]
81pub fn shutdown(_attr: TokenStream, item: TokenStream) -> TokenStream {
82 item
84}
85
86#[proc_macro_derive(Service)]
108pub fn derive_service(input: TokenStream) -> TokenStream {
109 let input = parse_macro_input!(input as DeriveInput);
110 let name = &input.ident;
111
112 let fields = match &input.data {
113 syn::Data::Struct(data) => match &data.fields {
114 Fields::Named(fields) => &fields.named,
115 _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs with named fields")
116 .to_compile_error()
117 .into(),
118 },
119 _ => return syn::Error::new_spanned(&input, "#[derive(Service)] only supports structs")
120 .to_compile_error()
121 .into(),
122 };
123
124 let field_inits = fields.iter().map(|f| {
125 let field_name = f.ident.as_ref().unwrap();
126 quote! { #field_name: app.extract() }
127 });
128
129 let expanded = quote! {
130 impl ::archy::ServiceFactory for #name {
131 fn create(app: &::archy::App) -> Self {
132 #name {
133 #(#field_inits),*
134 }
135 }
136 }
137 };
138
139 TokenStream::from(expanded)
140}
141
142#[proc_macro_attribute]
182pub fn service(attr: TokenStream, item: TokenStream) -> TokenStream {
183 let service_traced = if attr.is_empty() {
185 false
186 } else {
187 match syn::parse::<syn::Ident>(attr.clone()) {
188 Ok(ident) if ident == "traced" => true,
189 Ok(ident) => return syn::Error::new(
190 ident.span(),
191 format!("expected `traced`, found `{}`", ident)
192 ).to_compile_error().into(),
193 Err(e) => return e.to_compile_error().into(),
194 }
195 };
196
197 let input = parse_macro_input!(item as ItemImpl);
198 let service_name = match &*input.self_ty {
199 syn::Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
200 _ => return syn::Error::new_spanned(&input.self_ty, "#[service] must be applied to an impl block for a named type")
201 .to_compile_error()
202 .into(),
203 };
204
205 let msg_enum_name = format_ident!("{}Msg", service_name);
206 let methods_struct_name = format_ident!("{}Methods", service_name);
207
208 let mut methods: Vec<(&ImplItemFn, bool)> = Vec::new();
211 let mut startup_method: Option<&syn::Ident> = None;
212 let mut shutdown_method: Option<&syn::Ident> = None;
213
214 for item in &input.items {
215 if let ImplItem::Fn(method) = item {
216 let is_async = method.sig.asyncness.is_some();
217 let has_self = method.sig.inputs.first().map_or(false, |arg| matches!(arg, FnArg::Receiver(_)));
218
219 let has_startup = has_attribute(&method.attrs, "startup");
221 let has_shutdown = has_attribute(&method.attrs, "shutdown");
222
223 if has_startup {
225 if !is_async || !has_self {
226 return syn::Error::new_spanned(
227 &method.sig.ident,
228 "#[startup] method must be async fn(&self)"
229 ).to_compile_error().into();
230 }
231 if method.sig.inputs.len() > 1 {
232 return syn::Error::new_spanned(
233 &method.sig.ident,
234 "#[startup] method cannot have parameters other than &self"
235 ).to_compile_error().into();
236 }
237 if startup_method.is_some() {
238 return syn::Error::new_spanned(
239 &method.sig.ident,
240 "only one #[startup] method allowed per service"
241 ).to_compile_error().into();
242 }
243 startup_method = Some(&method.sig.ident);
244 continue; }
246
247 if has_shutdown {
248 if !is_async || !has_self {
249 return syn::Error::new_spanned(
250 &method.sig.ident,
251 "#[shutdown] method must be async fn(&self)"
252 ).to_compile_error().into();
253 }
254 if method.sig.inputs.len() > 1 {
255 return syn::Error::new_spanned(
256 &method.sig.ident,
257 "#[shutdown] method cannot have parameters other than &self"
258 ).to_compile_error().into();
259 }
260 if shutdown_method.is_some() {
261 return syn::Error::new_spanned(
262 &method.sig.ident,
263 "only one #[shutdown] method allowed per service"
264 ).to_compile_error().into();
265 }
266 shutdown_method = Some(&method.sig.ident);
267 continue; }
269
270 let is_pub = matches!(method.vis, syn::Visibility::Public(_));
271
272 if is_pub && is_async && has_self {
273 let has_traced = has_attribute(&method.attrs, "traced");
275 let has_untraced = has_attribute(&method.attrs, "untraced");
276
277 if has_traced && has_untraced {
279 return syn::Error::new_spanned(
280 &method.sig.ident,
281 "method cannot have both #[traced] and #[untraced] attributes"
282 ).to_compile_error().into();
283 }
284
285 let method_traced = if has_untraced {
290 false
291 } else if has_traced {
292 true
293 } else {
294 service_traced
295 };
296
297 methods.push((method, method_traced));
298 }
299 }
300 }
301
302 let msg_variants = methods.iter().map(|(method, traced)| {
304 let method_name = &method.sig.ident;
305 let variant_name = to_pascal_case(&method_name.to_string());
306 let variant_ident = format_ident!("{}", variant_name);
307
308 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
310 if let FnArg::Typed(pat_type) = arg {
311 if let Pat::Ident(pat_ident) = &*pat_type.pat {
312 let name = &pat_ident.ident;
313 let ty = &pat_type.ty;
314 return Some(quote! { #name: #ty });
315 }
316 }
317 None
318 }).collect();
319
320 let span_field = if *traced {
322 quote! { span: ::archy::tracing::Span, }
323 } else {
324 quote! {}
325 };
326
327 if is_unit_return(&method.sig.output) {
329 if params.is_empty() && !*traced {
330 quote! { #variant_ident }
331 } else {
332 quote! { #variant_ident { #(#params,)* #span_field } }
333 }
334 } else {
335 let return_type = match &method.sig.output {
336 ReturnType::Default => quote! { () },
337 ReturnType::Type(_, ty) => quote! { #ty },
338 };
339 quote! {
340 #variant_ident { #(#params,)* #span_field respond: ::archy::tokio::sync::oneshot::Sender<#return_type> }
341 }
342 }
343 });
344
345 let match_arms = methods.iter().map(|(method, traced)| {
347 let method_name = &method.sig.ident;
348 let variant_name = to_pascal_case(&method_name.to_string());
349 let variant_ident = format_ident!("{}", variant_name);
350
351 let param_names: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
353 if let FnArg::Typed(pat_type) = arg {
354 if let Pat::Ident(pat_ident) = &*pat_type.pat {
355 return Some(&pat_ident.ident);
356 }
357 }
358 None
359 }).collect();
360
361 let method_call = if param_names.is_empty() {
362 quote! { self.#method_name().await }
363 } else {
364 quote! { self.#method_name(#(#param_names),*).await }
365 };
366
367 if is_unit_return(&method.sig.output) {
369 let span_pattern = if *traced { quote! { span, } } else { quote! {} };
370 let param_pattern = if param_names.is_empty() {
371 quote! { #span_pattern }
372 } else {
373 quote! { #(#param_names,)* #span_pattern }
374 };
375
376 if *traced {
377 quote! {
378 #msg_enum_name::#variant_ident { #param_pattern } => {
379 ::archy::tracing::Instrument::instrument(async {
380 #method_call;
381 }, span).await
382 }
383 }
384 } else {
385 quote! {
386 #msg_enum_name::#variant_ident { #param_pattern } => {
387 #method_call;
388 }
389 }
390 }
391 } else {
392 let span_pattern = if *traced { quote! { span, } } else { quote! {} };
393 let param_pattern = if param_names.is_empty() {
394 quote! { #span_pattern respond }
395 } else {
396 quote! { #(#param_names,)* #span_pattern respond }
397 };
398
399 if *traced {
400 quote! {
401 #msg_enum_name::#variant_ident { #param_pattern } => {
402 ::archy::tracing::Instrument::instrument(async {
403 let result = #method_call;
404 let _ = respond.send(result);
405 }, span).await
406 }
407 }
408 } else {
409 quote! {
410 #msg_enum_name::#variant_ident { #param_pattern } => {
411 let result = #method_call;
412 let _ = respond.send(result);
413 }
414 }
415 }
416 }
417 });
418
419 let client_inherent_methods = methods.iter().map(|(method, traced)| {
421 let method_name = &method.sig.ident;
422 let variant_name = to_pascal_case(&method_name.to_string());
423 let variant_ident = format_ident!("{}", variant_name);
424
425 let params: Vec<_> = method.sig.inputs.iter().skip(1).filter_map(|arg| {
427 if let FnArg::Typed(pat_type) = arg {
428 if let Pat::Ident(pat_ident) = &*pat_type.pat {
429 let name = &pat_ident.ident;
430 let ty = &pat_type.ty;
431 return Some((name.clone(), quote! { #ty }));
432 }
433 }
434 None
435 }).collect();
436
437 let param_decls: Vec<_> = params.iter().map(|(name, ty)| quote! { #name: #ty }).collect();
438 let param_names: Vec<_> = params.iter().map(|(name, _)| name).collect();
439
440 let return_type = match &method.sig.output {
442 ReturnType::Default => quote! { () },
443 ReturnType::Type(_, ty) => quote! { #ty },
444 };
445
446 let span_capture = if *traced {
448 quote! { let span = ::archy::tracing::Span::current(); }
449 } else {
450 quote! {}
451 };
452 let span_field = if *traced {
453 quote! { span, }
454 } else {
455 quote! {}
456 };
457
458 if is_unit_return(&method.sig.output) {
460 let msg_construction = if param_names.is_empty() && !*traced {
461 quote! { #msg_enum_name::#variant_ident }
462 } else {
463 quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field } }
464 };
465
466 quote! {
467 pub async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
468 #span_capture
469 self.sender.send(#msg_construction).await
470 .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
471 Ok(())
472 }
473 }
474 } else {
475 let msg_construction = if param_names.is_empty() {
476 quote! { #msg_enum_name::#variant_ident { #span_field respond: tx } }
477 } else {
478 quote! { #msg_enum_name::#variant_ident { #(#param_names,)* #span_field respond: tx } }
479 };
480
481 quote! {
482 pub async fn #method_name(&self, #(#param_decls),*) -> ::std::result::Result<#return_type, ::archy::ServiceError> {
483 #span_capture
484 let (tx, rx) = ::archy::tokio::sync::oneshot::channel();
485 self.sender.send(#msg_construction).await
486 .map_err(|_| ::archy::ServiceError::ChannelClosed)?;
487 rx.await.map_err(|_| ::archy::ServiceError::ServiceDropped)
488 }
489 }
490 }
491 });
492
493 let startup_impl = startup_method.map(|method_name| {
495 quote! {
496 fn startup(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
497 async move { self.#method_name().await }
498 }
499 }
500 });
501
502 let shutdown_impl = shutdown_method.map(|method_name| {
504 quote! {
505 fn shutdown(self: ::std::sync::Arc<Self>) -> impl ::std::future::Future<Output = ()> + Send {
506 async move { self.#method_name().await }
507 }
508 }
509 });
510
511 let expanded = quote! {
512 #input
514
515 pub enum #msg_enum_name {
517 #(#msg_variants),*
518 }
519
520 #[derive(Clone)]
522 pub struct #methods_struct_name {
523 sender: ::archy::async_channel::Sender<#msg_enum_name>,
524 }
525
526 impl ::archy::ClientMethods<#service_name> for #methods_struct_name {
528 fn from_sender(sender: ::archy::async_channel::Sender<#msg_enum_name>) -> Self {
529 Self { sender }
530 }
531 }
532
533 impl #methods_struct_name {
535 #(#client_inherent_methods)*
536 }
537
538 impl ::archy::Service for #service_name {
540 type Message = #msg_enum_name;
541 type ClientMethods = #methods_struct_name;
542
543 fn create(app: &::archy::App) -> Self {
544 <Self as ::archy::ServiceFactory>::create(app)
545 }
546
547 #startup_impl
548
549 fn handle(self: ::std::sync::Arc<Self>, msg: Self::Message) -> impl ::std::future::Future<Output = ()> + Send {
550 async move {
551 match msg {
552 #(#match_arms)*
553 }
554 }
555 }
556
557 #shutdown_impl
558 }
559 };
560
561 TokenStream::from(expanded)
562}
563
564fn to_pascal_case(s: &str) -> String {
565 let mut result = String::new();
566 let mut capitalize_next = true;
567 for c in s.chars() {
568 if c == '_' {
569 capitalize_next = true;
570 } else if capitalize_next {
571 result.push(c.to_ascii_uppercase());
572 capitalize_next = false;
573 } else {
574 result.push(c);
575 }
576 }
577 result
578}
579
580fn has_attribute(attrs: &[syn::Attribute], name: &str) -> bool {
582 attrs.iter().any(|attr| attr.path().is_ident(name))
583}
584
585fn is_unit_return(output: &ReturnType) -> bool {
587 match output {
588 ReturnType::Default => true,
589 ReturnType::Type(_, ty) => {
590 if let syn::Type::Tuple(tuple) = &**ty {
591 tuple.elems.is_empty()
592 } else {
593 false
594 }
595 }
596 }
597}