ab_code_gen/
actor.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3
4use crate::{ActorModule, MessageHandlerMethod};
5
6/// Actor code generator
7pub struct Actor<'a> {
8    pub(crate) ident: &'a syn::Ident,
9    pub(crate) generic_params: Option<&'a syn::Generics>,
10    pub(crate) handler_methods: Vec<MessageHandlerMethod<'a>>,
11    pub(crate) msg_chan_size: usize,
12    pub(crate) task_chan_size: usize,
13    pub(crate) events_chan_size: usize,
14}
15
16impl<'a> Actor<'a> {
17    pub fn generate(&self, module: &ActorModule) -> TokenStream {
18        let Self {
19            ident: struct_name,
20            msg_chan_size,
21            task_chan_size,
22            events_chan_size,
23            generic_params,
24            ..
25        } = self;
26        let proxy_ident = &module.proxy.name;
27        let events = module.events.as_ref();
28        let messages_enum_name = &module.message_enum.name.to_token_stream();
29        let message_dispatcher_method = self.generate_dispatcher_method(messages_enum_name);
30        let generic_params: Vec<_> = generic_params.iter().collect();
31        let proxy = quote! { #proxy_ident #(#generic_params)*};
32
33        let (events1, events2, events3) = match events {
34            Some(events) => (
35                quote::quote! { let (event_sender, _) = tokio::sync::broadcast::channel::<#events>(#events_chan_size);
36                let event_sender_clone = event_sender.clone();                            },
37                quote::quote! { ,event_sender_clone                                                       },
38                quote::quote! { events: event_sender,                                                     },
39            ),
40            None => (TokenStream::new(), TokenStream::new(), TokenStream::new()),
41        };
42
43        quote::quote! {
44            impl #(#generic_params)* #struct_name #(#generic_params)*{
45                pub fn run(self) -> #proxy {
46                    let (msg_sender, mut msg_receiver) = tokio::sync::mpsc::channel(#msg_chan_size);
47                    #events1
48                    let (stop_sender, stop_receiver) = tokio::sync::oneshot::channel::<()>();
49                    let (task_sender, mut task_receiver) = tokio::sync::mpsc::channel::<Task<#struct_name>>(#task_chan_size);
50                    tokio::spawn(async move {
51                        let mut actor = self;
52                        actor.start(task_sender  #events2 ).await;
53                        tokio::select! {
54                            _ = actor.select_receivers(&mut msg_receiver, &mut task_receiver) => { log::debug!("(abcgen) all proxies dropped"); }  // all proxies were dropped => shutdown
55                            _ = stop_receiver => { log::debug!("(abcgen) stop signal received"); } // stop signal received => shutdown
56                        }
57                        // we get here when the actor is done
58                        actor.shutdown().await;
59                    });
60
61                    // build the proxy
62                    let proxy = #proxy {
63                        message_sender: msg_sender,
64                        stop_signal: Some(stop_sender),
65                        #events3
66                    };
67
68                    proxy
69                }
70
71                async fn select_receivers(
72                    &mut self,
73                    msg_receiver: &mut tokio::sync::mpsc::Receiver<#messages_enum_name>,
74                    task_receiver: &mut tokio::sync::mpsc::Receiver<Task<#struct_name>>,
75                ) {
76                    loop {
77                        tokio::select! {
78                            msg = msg_receiver.recv() => {
79                                match msg {
80                                    Some(msg) => { self.dispatch(msg).await; }
81                                    None => { break; } // channel closed => shutdown
82                                }
83                            },
84                            task = task_receiver.recv() => {
85                                if let Some(task) = task {
86                                    task(self).await;
87                                }
88                            }
89                        }
90                    }
91                }
92
93                #message_dispatcher_method
94            }
95
96        }
97    }
98
99    fn generate_dispatcher_method(&self, messages_id: &TokenStream) -> TokenStream {
100        let patterns = self
101            .handler_methods
102            .iter()
103            .map(|m| self.generate_message_handler_case(m, messages_id));
104
105        quote::quote! {
106            async fn dispatch(&mut self, message: #messages_id) {
107                match message {
108                    #(#patterns)*
109                }
110            }
111        }
112    }
113
114    fn generate_message_handler_case(
115        &self,
116        method: &MessageHandlerMethod,
117        messages_id: &TokenStream,
118    ) -> TokenStream {
119        let method_name = method.get_name_snake_case();
120        let variant_name = method.get_name_camel_case();
121
122        let method_params_names: Vec<_> = method.get_parameter_names();
123        if method.has_return_type() {
124            quote::quote! {
125                #messages_id::#variant_name { #(#method_params_names,)* respond_to } => {
126                    let result = self.#method_name(#(#method_params_names),*).await;
127                    respond_to.send(result).unwrap();
128                }
129            }
130        } else {
131            quote::quote! {
132                #messages_id::#variant_name { #(#method_params_names),* } => {
133                    self.#method_name(#(#method_params_names),*).await;
134                }
135            }
136        }
137    }
138}