ab_code_gen/
module.rs

1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use syn::{spanned::Spanned, Error, Ident, ImplItem, Item, ItemMod, Result, Type};
4
5use crate::{
6    utils::type_path_from_type, Actor, ActorProxy, Config, MessageEnum, MessageHandlerMethod,
7};
8
9pub struct ActorModule<'a> {
10    pub(crate) actor: Actor<'a>,
11    pub(crate) events: Option<&'a Ident>,
12    pub(crate) message_enum: MessageEnum<'a>,
13    pub(crate) proxy: ActorProxy<'a>,
14}
15
16impl ActorModule<'_> {
17    pub fn new<'a>(module: &'a ItemMod, config: &'a Config) -> Result<ActorModule<'a>> {
18        let module_items = &module
19            .content
20            .as_ref()
21            .ok_or_else(|| Error::new_spanned(module, "Expected module to have content"))?
22            .1;
23        // find the actor
24        let actors = module_items
25            .iter()
26            .filter(|item| is_actor(item))
27            .collect::<Vec<_>>();
28
29        if actors.len() != 1 {
30            return Err(Error::new_spanned(
31                module,
32                "Expected exactly one actor, that is one struct or enum with #[actor] attribute.",
33            ));
34        }
35        let actor = actors[0];
36        let actor_id = match actors[0] {
37            Item::Struct(it) => &it.ident,
38            Item::Enum(it) => &it.ident,
39            _ => unreachable!(),
40        };
41
42        let actor_implementations_items = module_items
43            .iter()
44            .filter_map(|item| is_impl_of(item, actor_id))
45            .flat_map(|item| item.items.iter())
46            .collect::<Vec<_>>();
47
48        // check if there are any events
49        let events = extract_events_enum(module_items)?;
50        if events.len() > 1 {
51            return Err(Error::new_spanned(
52                events.last().unwrap(),
53                "Expected at most one events enum",
54            ));
55        }
56        let events = events.into_iter().next();
57
58        // validate the start method
59        validate_start_method(&actor_implementations_items, &events, actor.span())?;
60        // validate the shutdown method
61        validate_shutdown_method(&actor_implementations_items, actor.span())?;
62
63        // find handler methods
64        let mut methods: Vec<MessageHandlerMethod<'a>> = Vec::new();
65        for item in actor_implementations_items {
66            if let syn::ImplItem::Fn(m) = item {
67                if m.attrs.iter().any(|a| test_attribute(a, "message_handler")) {
68                    methods.push(MessageHandlerMethod::new(m)?);
69                }
70            }
71        }
72
73        // todo find all of the implementations of trait From<abcgen::AbcgenError>
74        let convertible_errors = find_convertible_error_types(module_items);
75        let convertible_errors = convertible_errors
76            .iter()
77            .map(|t| type_path_from_type(t).unwrap())
78            .collect::<Vec<_>>();
79
80        // build generators
81        let msg_generator = MessageEnum::new(quote::format_ident!("{actor_id}Message"), &methods)?;
82        let proxy = ActorProxy::new(
83            quote::format_ident!("{actor_id}Proxy"),
84            msg_generator.name.clone(),
85            events,
86            &methods,
87            convertible_errors,
88        );
89        // build actor
90        let actor = Actor {
91            ident: actor_id,
92            generic_params: None,
93            handler_methods: methods.clone(),
94            msg_chan_size: config.channels_size,
95            task_chan_size: config.channels_size,
96            events_chan_size: config.events_chan_size,
97        };
98        Ok(ActorModule {
99            actor,
100            events: events.into_iter().next(),
101            message_enum: msg_generator,
102            proxy,
103        })
104    }
105
106    pub fn generate(&self) -> Result<TokenStream> {
107        let struct_name = self.actor.ident;
108        let event_sender_alias = match self.events.as_ref() {
109            Some(events) => {
110                quote::quote! { type EventSender = tokio::sync::broadcast::Sender<#events>;               }
111            }
112
113            None => TokenStream::new(),
114        };
115
116        let message_enum_code = self.message_enum.generate()?;
117        let proxy_code = self.proxy.generate();
118        let actor = self.actor.generate(self);
119
120        let code = quote::quote! {
121            #event_sender_alias
122            pub type TaskSender = tokio::sync::mpsc::Sender<Task<#struct_name>>;
123            #message_enum_code
124            #actor
125            #proxy_code
126            //#actor_impl
127        };
128
129        Ok(code)
130    }
131}
132
133fn find_convertible_error_types(module_items: &Vec<Item>) -> Vec<&Type> {
134    let possible_types = [
135        quote::quote! { From<::abcgen::AbcgenError> }.to_string(),
136        quote::quote! { From<abcgen::AbcgenError> }.to_string(),
137        quote::quote! { From<AbcgenError> }.to_string(),
138    ];
139    let mut res = Vec::new();
140    for item in module_items {
141        if let Item::Impl(impl_item) = item {
142            if let Some((_, t_trait, _)) = impl_item.trait_.as_ref() {
143                if possible_types
144                    .iter()
145                    .any(|t| *t == t_trait.to_token_stream().to_string())
146                {
147                    res.push(impl_item.self_ty.as_ref());
148                }
149            }
150        }
151    }
152    res
153}
154
155fn validate_start_method(
156    actor_implementations_items: &Vec<&ImplItem>,
157    events: &Option<&Ident>,
158    span: Span,
159) -> Result<()> {
160    let start_method = actor_implementations_items
161        .iter()
162        .filter_map(|item| {
163            if let syn::ImplItem::Fn(m) = item {
164                if m.sig.ident == "start" {
165                    return Some(m);
166                }
167            }
168            None
169        })
170        .collect::<Vec<_>>();
171
172    let the_error_msg = if events.is_some() {
173        "Expected a start method to be implemented for the actor: `fn start(&mut self, task_sender: TaskSender, event_sender: EventSender)`"
174    } else {
175        "Expected a start method to be implemented for the actor: `fn start(&mut self, task_sender: TaskSender)`"
176    };
177
178    if start_method.len() != 1 {
179        return Err(Error::new(span.clone(), the_error_msg));
180    }
181    // start method found
182    let start_method = start_method[0];
183    let the_error = Error::new_spanned(start_method, the_error_msg);
184    // check the receiver input
185    let receiver_ok = start_method
186        .sig
187        .receiver()
188        .is_some_and(|r| r.reference.is_some());
189    if !receiver_ok {
190        return Err(the_error);
191    }
192    // check the number of parameters and their types
193    let other_inputs = start_method.sig.inputs.iter().skip(1).collect::<Vec<_>>();
194    if events.is_some() {
195        if other_inputs.len() != 2 {
196            return Err(the_error);
197        }
198        if !check_argument_type(&other_inputs[1], "EventSender")
199            || !check_argument_type(&other_inputs[0], "TaskSender")
200        {
201            return Err(the_error);
202        }
203    } else {
204        if other_inputs.len() != 1 {
205            return Err(the_error);
206        }
207        if !check_argument_type(&other_inputs[0], "TaskSender") {
208            return Err(the_error);
209        }
210    }
211    Ok(())
212}
213
214fn validate_shutdown_method(actor_implementations: &Vec<&ImplItem>, span: Span) -> Result<()> {
215    let the_method = actor_implementations
216        .iter()
217        .filter_map(|item| {
218            if let syn::ImplItem::Fn(m) = item {
219                if m.sig.ident == "shutdown" {
220                    return Some(m);
221                }
222            }
223            None
224        })
225        .collect::<Vec<_>>();
226
227    let the_error_msg =
228        "Expected a shutdown method to be implemented for the actor: `fn shutdown(&mut self)`";
229    if the_method.len() != 1 {
230        return Err(Error::new(span, the_error_msg));
231    }
232    // start method found
233    let the_method = the_method[0];
234    let the_error = Error::new_spanned(the_method, the_error_msg);
235    // check the receiver input
236    let receiver_ok = the_method
237        .sig
238        .receiver()
239        .is_some_and(|r| r.reference.is_some());
240    if !receiver_ok {
241        return Err(the_error);
242    }
243    // check the number of parameters and their types
244    let other_inputs = the_method.sig.inputs.iter().skip(1).collect::<Vec<_>>();
245    if !other_inputs.is_empty() {
246        return Err(the_error);
247    }
248    Ok(())
249}
250
251fn check_argument_type(other_inputs: &syn::FnArg, expected_type: &str) -> bool {
252    match other_inputs {
253        syn::FnArg::Typed(t) => {
254            if let Type::Path(tp) = t.ty.as_ref() {
255                if tp.path.segments.last().unwrap().ident != expected_type {
256                    false
257                } else {
258                    true
259                }
260            } else {
261                false
262            }
263        }
264        _ => false,
265    }
266}
267
268fn test_attribute(attr: &syn::Attribute, expected: &str) -> bool {
269    attr.path().segments.last().unwrap().ident == expected
270}
271
272fn is_actor(item: &Item) -> bool {
273    let attribututes = match item {
274        Item::Struct(it) => &it.attrs,
275        Item::Enum(it) => &it.attrs,
276        _ => return false,
277    };
278    attribututes
279        .iter()
280        .any(|attr| test_attribute(attr, "actor"))
281}
282
283fn is_impl_of<'a>(item: &'a Item, id: &'a Ident) -> Option<&'a syn::ItemImpl> {
284    if let Item::Impl(item_impl) = item {
285        if let Type::Path(tp) = item_impl.self_ty.as_ref() {
286            if tp.path.segments.last().unwrap().ident == *id {
287                return Some(item_impl);
288            }
289        }
290    }
291    None
292}
293
294fn extract_events_enum(items: &[syn::Item]) -> Result<Vec<&Ident>> {
295    let events: Vec<_> = items
296        .iter()
297        .filter_map(|i| {
298            match i {
299                Item::Enum(e) => {
300                    if e.attrs.iter().any(|a| test_attribute(a, "events")) {
301                        return Some(&e.ident);
302                    }
303                }
304                Item::Type(t) => {
305                    if t.attrs.iter().any(|a| test_attribute(a, "events")) {
306                        return Some(&t.ident);
307                    }
308                }
309                Item::Struct(s) => {
310                    if s.attrs.iter().any(|a| test_attribute(a, "events")) {
311                        return Some(&s.ident);
312                    }
313                }
314                _ => {}
315            };
316            None
317        })
318        .collect();
319    Ok(events)
320}