ascolt_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    DeriveInput, FnArg, ItemFn, Pat, PatType, PathArguments, ReturnType, Type, parse_macro_input,
5};
6
7#[proc_macro_attribute]
8pub fn ask_handler(_args: TokenStream, item: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(item as ItemFn);
10
11    let sig = &input.sig;
12    let block = &input.block;
13
14    let fn_name = &sig.ident;
15    let inputs = &sig.inputs;
16    let output = &sig.output;
17
18    let mut actor_ty = None;
19    let mut msg_ty = None;
20
21    for arg in inputs {
22        match arg {
23            FnArg::Receiver(receiver) => actor_ty = Some(receiver.ty.clone()),
24            FnArg::Typed(PatType { pat, ty, .. }) => {
25                if let Pat::Ident(pat_ident) = pat.as_ref() {
26                    let ident = pat_ident.ident.to_string();
27                    if ident.as_str() == "msg" {
28                        msg_ty = Some(ty.clone())
29                    }
30                }
31            }
32        }
33    }
34
35    let actor_ty = actor_ty.expect("Missing self: &Actor argument");
36    let msg_ty = msg_ty.expect("Missing msg argument");
37
38    let clean_actor_ty = strip_reference(&actor_ty);
39    let clean_msg_ty = strip_reference(&msg_ty);
40
41    let (resp_ty, err_ty) = extract_result_types(output);
42
43    let expanded = quote! {
44        #[async_trait::async_trait]
45        impl ascolt::handler::AskHandlerTrait<#clean_msg_ty, #resp_ty, #err_ty> for #clean_actor_ty {
46            async fn #fn_name(
47                self: #actor_ty,
48                msg: #msg_ty,
49            ) -> Result<#resp_ty, #err_ty> {
50                #block
51            }
52        }
53    };
54
55    TokenStream::from(expanded)
56}
57
58#[proc_macro_attribute]
59pub fn tell_handler(_args: TokenStream, item: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(item as ItemFn);
61
62    let sig = &input.sig;
63    let block = &input.block;
64
65    let fn_name = &sig.ident;
66    let inputs = &sig.inputs;
67    let output = &sig.output;
68
69    let mut actor_ty = None;
70    let mut msg_ty = None;
71
72    for arg in inputs {
73        match arg {
74            FnArg::Receiver(receiver) => actor_ty = Some(receiver.ty.clone()),
75            FnArg::Typed(PatType { pat, ty, .. }) => {
76                if let Pat::Ident(pat_ident) = pat.as_ref() {
77                    let ident = pat_ident.ident.to_string();
78                    if ident.as_str() == "msg" {
79                        msg_ty = Some(ty.clone())
80                    }
81                }
82            }
83        }
84    }
85
86    let actor_ty = actor_ty.expect("Missing self: &Actor argument");
87    let msg_ty = msg_ty.expect("Missing msg argument");
88
89    let clean_actor_ty = strip_reference(&actor_ty);
90    let clean_msg_ty = strip_reference(&msg_ty);
91
92    let (_, err_ty) = extract_result_types(output);
93
94    let expanded = quote! {
95        #[async_trait::async_trait]
96        impl ascolt::handler::TellHandlerTrait<#clean_msg_ty, #err_ty> for #clean_actor_ty {
97            async fn #fn_name(
98                self: #actor_ty,
99                msg: #msg_ty,
100            ) -> Result<(), #err_ty> {
101                #block
102            }
103        }
104    };
105
106    TokenStream::from(expanded)
107}
108
109fn strip_reference(ty: &syn::Type) -> &syn::Type {
110    match ty {
111        syn::Type::Reference(r) => strip_reference(&r.elem),
112        _ => ty,
113    }
114}
115
116fn extract_result_types(
117    output: &ReturnType,
118) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
119    match output {
120        ReturnType::Type(_, ty) => {
121            let type_path = match ty.as_ref() {
122                Type::Path(tp) => tp,
123                _ => panic!("Expected a path type (e.g. Result<T, E>)"),
124            };
125
126            let seg = type_path
127                .path
128                .segments
129                .first()
130                .expect("Expected a Result return type");
131
132            if seg.ident != "Result" {
133                panic!("Return type must be Result<T, E>");
134            }
135
136            let args = match &seg.arguments {
137                PathArguments::AngleBracketed(args) => args,
138                _ => panic!("Expected Result<T, E> with angle-bracketed args"),
139            };
140
141            let mut args_iter = args.args.iter();
142            let resp = args_iter
143                .next()
144                .expect("Missing success type in Result<T, E>");
145            let err = args_iter
146                .next()
147                .expect("Missing error type in Result<T, E>");
148
149            (quote!(#resp), quote!(#err))
150        }
151        _ => panic!("Expected function to have a return type"),
152    }
153}
154
155#[proc_macro_derive(Actor, attributes(actor))]
156pub fn derive_actor(input: TokenStream) -> TokenStream {
157    let input = parse_macro_input!(input as DeriveInput);
158    let name = input.ident;
159
160    let mut error_ty = None;
161    for attr in input.attrs.iter().filter(|a| a.path().is_ident("actor")) {
162        attr.parse_nested_meta(|meta| {
163            if meta.path.is_ident("error") {
164                let value: syn::Type = meta.value()?.parse()?;
165                error_ty = Some(value);
166                Ok(())
167            } else {
168                Err(meta.error("unsupported attribute"))
169            }
170        })
171        .unwrap();
172    }
173
174    let error_ty = error_ty.expect("missing #[actor(error = ...)]");
175
176    let expanded = quote! {
177        impl ascolt::ActorTrait<#error_ty> for #name {}
178    };
179
180    TokenStream::from(expanded)
181}