1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use syn::{spanned::Spanned, Error, Ident, ImplItem, Item, ItemMod, Result, Type};
4
5use crate::{
6 utils::type_path_from_type, Actor, ActorProxy, Config, MessageEnum, MessageHandlerMethod,
7};
8
9pub struct ActorModule<'a> {
10 pub(crate) actor: Actor<'a>,
11 pub(crate) events: Option<&'a Ident>,
12 pub(crate) message_enum: MessageEnum<'a>,
13 pub(crate) proxy: ActorProxy<'a>,
14}
15
16impl ActorModule<'_> {
17 pub fn new<'a>(module: &'a ItemMod, config: &'a Config) -> Result<ActorModule<'a>> {
18 let module_items = &module
19 .content
20 .as_ref()
21 .ok_or_else(|| Error::new_spanned(module, "Expected module to have content"))?
22 .1;
23 let actors = module_items
25 .iter()
26 .filter(|item| is_actor(item))
27 .collect::<Vec<_>>();
28
29 if actors.len() != 1 {
30 return Err(Error::new_spanned(
31 module,
32 "Expected exactly one actor, that is one struct or enum with #[actor] attribute.",
33 ));
34 }
35 let actor = actors[0];
36 let actor_id = match actors[0] {
37 Item::Struct(it) => &it.ident,
38 Item::Enum(it) => &it.ident,
39 _ => unreachable!(),
40 };
41
42 let actor_implementations_items = module_items
43 .iter()
44 .filter_map(|item| is_impl_of(item, actor_id))
45 .flat_map(|item| item.items.iter())
46 .collect::<Vec<_>>();
47
48 let events = extract_events_enum(module_items)?;
50 if events.len() > 1 {
51 return Err(Error::new_spanned(
52 events.last().unwrap(),
53 "Expected at most one events enum",
54 ));
55 }
56 let events = events.into_iter().next();
57
58 validate_start_method(&actor_implementations_items, &events, actor.span())?;
60 validate_shutdown_method(&actor_implementations_items, actor.span())?;
62
63 let mut methods: Vec<MessageHandlerMethod<'a>> = Vec::new();
65 for item in actor_implementations_items {
66 if let syn::ImplItem::Fn(m) = item {
67 if m.attrs.iter().any(|a| test_attribute(a, "message_handler")) {
68 methods.push(MessageHandlerMethod::new(m)?);
69 }
70 }
71 }
72
73 let convertible_errors = find_convertible_error_types(module_items);
75 let convertible_errors = convertible_errors
76 .iter()
77 .map(|t| type_path_from_type(t).unwrap())
78 .collect::<Vec<_>>();
79
80 let msg_generator = MessageEnum::new(quote::format_ident!("{actor_id}Message"), &methods)?;
82 let proxy = ActorProxy::new(
83 quote::format_ident!("{actor_id}Proxy"),
84 msg_generator.name.clone(),
85 events,
86 &methods,
87 convertible_errors,
88 );
89 let actor = Actor {
91 ident: actor_id,
92 generic_params: None,
93 handler_methods: methods.clone(),
94 msg_chan_size: config.channels_size,
95 task_chan_size: config.channels_size,
96 events_chan_size: config.events_chan_size,
97 };
98 Ok(ActorModule {
99 actor,
100 events: events.into_iter().next(),
101 message_enum: msg_generator,
102 proxy,
103 })
104 }
105
106 pub fn generate(&self) -> Result<TokenStream> {
107 let struct_name = self.actor.ident;
108 let event_sender_alias = match self.events.as_ref() {
109 Some(events) => {
110 quote::quote! { type EventSender = tokio::sync::broadcast::Sender<#events>; }
111 }
112
113 None => TokenStream::new(),
114 };
115
116 let message_enum_code = self.message_enum.generate()?;
117 let proxy_code = self.proxy.generate();
118 let actor = self.actor.generate(self);
119
120 let code = quote::quote! {
121 #event_sender_alias
122 pub type TaskSender = tokio::sync::mpsc::Sender<Task<#struct_name>>;
123 #message_enum_code
124 #actor
125 #proxy_code
126 };
128
129 Ok(code)
130 }
131}
132
133fn find_convertible_error_types(module_items: &Vec<Item>) -> Vec<&Type> {
134 let possible_types = [
135 quote::quote! { From<::abcgen::AbcgenError> }.to_string(),
136 quote::quote! { From<abcgen::AbcgenError> }.to_string(),
137 quote::quote! { From<AbcgenError> }.to_string(),
138 ];
139 let mut res = Vec::new();
140 for item in module_items {
141 if let Item::Impl(impl_item) = item {
142 if let Some((_, t_trait, _)) = impl_item.trait_.as_ref() {
143 if possible_types
144 .iter()
145 .any(|t| *t == t_trait.to_token_stream().to_string())
146 {
147 res.push(impl_item.self_ty.as_ref());
148 }
149 }
150 }
151 }
152 res
153}
154
155fn validate_start_method(
156 actor_implementations_items: &Vec<&ImplItem>,
157 events: &Option<&Ident>,
158 span: Span,
159) -> Result<()> {
160 let start_method = actor_implementations_items
161 .iter()
162 .filter_map(|item| {
163 if let syn::ImplItem::Fn(m) = item {
164 if m.sig.ident == "start" {
165 return Some(m);
166 }
167 }
168 None
169 })
170 .collect::<Vec<_>>();
171
172 let the_error_msg = if events.is_some() {
173 "Expected a start method to be implemented for the actor: `fn start(&mut self, task_sender: TaskSender, event_sender: EventSender)`"
174 } else {
175 "Expected a start method to be implemented for the actor: `fn start(&mut self, task_sender: TaskSender)`"
176 };
177
178 if start_method.len() != 1 {
179 return Err(Error::new(span.clone(), the_error_msg));
180 }
181 let start_method = start_method[0];
183 let the_error = Error::new_spanned(start_method, the_error_msg);
184 let receiver_ok = start_method
186 .sig
187 .receiver()
188 .is_some_and(|r| r.reference.is_some());
189 if !receiver_ok {
190 return Err(the_error);
191 }
192 let other_inputs = start_method.sig.inputs.iter().skip(1).collect::<Vec<_>>();
194 if events.is_some() {
195 if other_inputs.len() != 2 {
196 return Err(the_error);
197 }
198 if !check_argument_type(&other_inputs[1], "EventSender")
199 || !check_argument_type(&other_inputs[0], "TaskSender")
200 {
201 return Err(the_error);
202 }
203 } else {
204 if other_inputs.len() != 1 {
205 return Err(the_error);
206 }
207 if !check_argument_type(&other_inputs[0], "TaskSender") {
208 return Err(the_error);
209 }
210 }
211 Ok(())
212}
213
214fn validate_shutdown_method(actor_implementations: &Vec<&ImplItem>, span: Span) -> Result<()> {
215 let the_method = actor_implementations
216 .iter()
217 .filter_map(|item| {
218 if let syn::ImplItem::Fn(m) = item {
219 if m.sig.ident == "shutdown" {
220 return Some(m);
221 }
222 }
223 None
224 })
225 .collect::<Vec<_>>();
226
227 let the_error_msg =
228 "Expected a shutdown method to be implemented for the actor: `fn shutdown(&mut self)`";
229 if the_method.len() != 1 {
230 return Err(Error::new(span, the_error_msg));
231 }
232 let the_method = the_method[0];
234 let the_error = Error::new_spanned(the_method, the_error_msg);
235 let receiver_ok = the_method
237 .sig
238 .receiver()
239 .is_some_and(|r| r.reference.is_some());
240 if !receiver_ok {
241 return Err(the_error);
242 }
243 let other_inputs = the_method.sig.inputs.iter().skip(1).collect::<Vec<_>>();
245 if !other_inputs.is_empty() {
246 return Err(the_error);
247 }
248 Ok(())
249}
250
251fn check_argument_type(other_inputs: &syn::FnArg, expected_type: &str) -> bool {
252 match other_inputs {
253 syn::FnArg::Typed(t) => {
254 if let Type::Path(tp) = t.ty.as_ref() {
255 if tp.path.segments.last().unwrap().ident != expected_type {
256 false
257 } else {
258 true
259 }
260 } else {
261 false
262 }
263 }
264 _ => false,
265 }
266}
267
268fn test_attribute(attr: &syn::Attribute, expected: &str) -> bool {
269 attr.path().segments.last().unwrap().ident == expected
270}
271
272fn is_actor(item: &Item) -> bool {
273 let attribututes = match item {
274 Item::Struct(it) => &it.attrs,
275 Item::Enum(it) => &it.attrs,
276 _ => return false,
277 };
278 attribututes
279 .iter()
280 .any(|attr| test_attribute(attr, "actor"))
281}
282
283fn is_impl_of<'a>(item: &'a Item, id: &'a Ident) -> Option<&'a syn::ItemImpl> {
284 if let Item::Impl(item_impl) = item {
285 if let Type::Path(tp) = item_impl.self_ty.as_ref() {
286 if tp.path.segments.last().unwrap().ident == *id {
287 return Some(item_impl);
288 }
289 }
290 }
291 None
292}
293
294fn extract_events_enum(items: &[syn::Item]) -> Result<Vec<&Ident>> {
295 let events: Vec<_> = items
296 .iter()
297 .filter_map(|i| {
298 match i {
299 Item::Enum(e) => {
300 if e.attrs.iter().any(|a| test_attribute(a, "events")) {
301 return Some(&e.ident);
302 }
303 }
304 Item::Type(t) => {
305 if t.attrs.iter().any(|a| test_attribute(a, "events")) {
306 return Some(&t.ident);
307 }
308 }
309 Item::Struct(s) => {
310 if s.attrs.iter().any(|a| test_attribute(a, "events")) {
311 return Some(&s.ident);
312 }
313 }
314 _ => {}
315 };
316 None
317 })
318 .collect();
319 Ok(events)
320}