1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 DeriveInput, FnArg, ItemFn, Pat, PatType, PathArguments, ReturnType, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
8pub fn ask_handler(_args: TokenStream, item: TokenStream) -> TokenStream {
9 let input = parse_macro_input!(item as ItemFn);
10
11 let sig = &input.sig;
12 let block = &input.block;
13
14 let fn_name = &sig.ident;
15 let inputs = &sig.inputs;
16 let output = &sig.output;
17
18 let mut actor_ty = None;
19 let mut msg_ty = None;
20
21 for arg in inputs {
22 match arg {
23 FnArg::Receiver(receiver) => actor_ty = Some(receiver.ty.clone()),
24 FnArg::Typed(PatType { pat, ty, .. }) => {
25 if let Pat::Ident(pat_ident) = pat.as_ref() {
26 let ident = pat_ident.ident.to_string();
27 if ident.as_str() == "msg" {
28 msg_ty = Some(ty.clone())
29 }
30 }
31 }
32 }
33 }
34
35 let actor_ty = actor_ty.expect("Missing self: &Actor argument");
36 let msg_ty = msg_ty.expect("Missing msg argument");
37
38 let clean_actor_ty = strip_reference(&actor_ty);
39 let clean_msg_ty = strip_reference(&msg_ty);
40
41 let (resp_ty, err_ty) = extract_result_types(output);
42
43 let expanded = quote! {
44 #[async_trait::async_trait]
45 impl ascolt::handler::AskHandlerTrait<#clean_msg_ty, #resp_ty, #err_ty> for #clean_actor_ty {
46 async fn #fn_name(
47 self: #actor_ty,
48 msg: #msg_ty,
49 ) -> Result<#resp_ty, #err_ty> {
50 #block
51 }
52 }
53 };
54
55 TokenStream::from(expanded)
56}
57
58#[proc_macro_attribute]
59pub fn tell_handler(_args: TokenStream, item: TokenStream) -> TokenStream {
60 let input = parse_macro_input!(item as ItemFn);
61
62 let sig = &input.sig;
63 let block = &input.block;
64
65 let fn_name = &sig.ident;
66 let inputs = &sig.inputs;
67 let output = &sig.output;
68
69 let mut actor_ty = None;
70 let mut msg_ty = None;
71
72 for arg in inputs {
73 match arg {
74 FnArg::Receiver(receiver) => actor_ty = Some(receiver.ty.clone()),
75 FnArg::Typed(PatType { pat, ty, .. }) => {
76 if let Pat::Ident(pat_ident) = pat.as_ref() {
77 let ident = pat_ident.ident.to_string();
78 if ident.as_str() == "msg" {
79 msg_ty = Some(ty.clone())
80 }
81 }
82 }
83 }
84 }
85
86 let actor_ty = actor_ty.expect("Missing self: &Actor argument");
87 let msg_ty = msg_ty.expect("Missing msg argument");
88
89 let clean_actor_ty = strip_reference(&actor_ty);
90 let clean_msg_ty = strip_reference(&msg_ty);
91
92 let (_, err_ty) = extract_result_types(output);
93
94 let expanded = quote! {
95 #[async_trait::async_trait]
96 impl ascolt::handler::TellHandlerTrait<#clean_msg_ty, #err_ty> for #clean_actor_ty {
97 async fn #fn_name(
98 self: #actor_ty,
99 msg: #msg_ty,
100 ) -> Result<(), #err_ty> {
101 #block
102 }
103 }
104 };
105
106 TokenStream::from(expanded)
107}
108
109fn strip_reference(ty: &syn::Type) -> &syn::Type {
110 match ty {
111 syn::Type::Reference(r) => strip_reference(&r.elem),
112 _ => ty,
113 }
114}
115
116fn extract_result_types(
117 output: &ReturnType,
118) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
119 match output {
120 ReturnType::Type(_, ty) => {
121 let type_path = match ty.as_ref() {
122 Type::Path(tp) => tp,
123 _ => panic!("Expected a path type (e.g. Result<T, E>)"),
124 };
125
126 let seg = type_path
127 .path
128 .segments
129 .first()
130 .expect("Expected a Result return type");
131
132 if seg.ident != "Result" {
133 panic!("Return type must be Result<T, E>");
134 }
135
136 let args = match &seg.arguments {
137 PathArguments::AngleBracketed(args) => args,
138 _ => panic!("Expected Result<T, E> with angle-bracketed args"),
139 };
140
141 let mut args_iter = args.args.iter();
142 let resp = args_iter
143 .next()
144 .expect("Missing success type in Result<T, E>");
145 let err = args_iter
146 .next()
147 .expect("Missing error type in Result<T, E>");
148
149 (quote!(#resp), quote!(#err))
150 }
151 _ => panic!("Expected function to have a return type"),
152 }
153}
154
155#[proc_macro_derive(Actor, attributes(actor))]
156pub fn derive_actor(input: TokenStream) -> TokenStream {
157 let input = parse_macro_input!(input as DeriveInput);
158 let name = input.ident;
159
160 let mut error_ty = None;
161 for attr in input.attrs.iter().filter(|a| a.path().is_ident("actor")) {
162 attr.parse_nested_meta(|meta| {
163 if meta.path.is_ident("error") {
164 let value: syn::Type = meta.value()?.parse()?;
165 error_ty = Some(value);
166 Ok(())
167 } else {
168 Err(meta.error("unsupported attribute"))
169 }
170 })
171 .unwrap();
172 }
173
174 let error_ty = error_ty.expect("missing #[actor(error = ...)]");
175
176 let expanded = quote! {
177 impl ascolt::ActorTrait<#error_ty> for #name {}
178 };
179
180 TokenStream::from(expanded)
181}