ab_code_gen/
module.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
use proc_macro2::TokenStream;
use syn::{Error, Ident, Item, ItemMod, Result, Type};

use crate::{ActorProxy, MessageEnum, MessageHandlerMethod};

pub struct ActorModule<'a> {
    actor_ident: &'a Ident,
    handler_methods: Vec<MessageHandlerMethod<'a>>,
    events: Option<&'a Ident>,
    message_enum: MessageEnum<'a>,
    proxy: ActorProxy<'a>,
}

impl ActorModule<'_> {
    pub fn new<'a>(module: &'a ItemMod) -> Result<ActorModule<'a>> {
        let module_items = &module
            .content
            .as_ref()
            .ok_or_else(|| Error::new_spanned(module, "Expected module to have content"))?
            .1;
        // find the actor
        let actors = module_items
            .iter()
            .filter(|item| is_actor(item))
            .collect::<Vec<_>>();

        if actors.len() != 1 {
            return Err(Error::new_spanned(
                module,
                "Expected exactly one actor that is one struct or enum with #[actor] attribute",
            ));
        }
        let actor_id = match actors[0] {
            Item::Struct(it) => &it.ident,
            Item::Enum(it) => &it.ident,
            _ => unreachable!(),
        };

        // find handler methods
        let actor_implementations = module_items
            .iter()
            .filter_map(|item| is_impl_of(item, actor_id))
            .collect::<Vec<_>>();

        let impl_items = actor_implementations
            .iter()
            .flat_map(|item| item.items.iter());
        let mut methods: Vec<MessageHandlerMethod<'a>> = Vec::new();
        for item in impl_items {
            if let syn::ImplItem::Fn(m) = item {
                if m.attrs.iter().any(|a| test_attribute(a, "message_handler")) {
                    methods.push(MessageHandlerMethod::new(m)?);
                }
            }
        }

        // check if there are any events
        let events = extract_events_enum(module_items)?;
        if events.len() > 1 {
            return Err(Error::new_spanned(
                events.last().unwrap(),
                "Expected at most one events enum",
            ));
        }
        let events = events.into_iter().next();
        // build generators
        let msg_generator = MessageEnum::new(quote::format_ident!("{actor_id}Message"), &methods)?;
        let proxy = ActorProxy::new(
            quote::format_ident!("{actor_id}Proxy"),
            msg_generator.name.clone(),
            events,
            &methods,
        );

        Ok(ActorModule {
            actor_ident: actor_id,
            handler_methods: methods.clone(),
            events: events.into_iter().next(),
            message_enum: msg_generator,
            proxy,
        })
    }

    pub fn generate(&self) -> Result<TokenStream> {
        let proxy = &self.proxy.name;
        let message_dispatcher_method = self.generate_dispatcher_method();
        let messages_enum_name = &self.message_enum.name;
        let struct_name = self.actor_ident;
        let (events1, events2, events3) = match self.events.as_ref() {
            Some(events) => (
                quote::quote! { let (event_sender, event_receiver) = tokio::sync::broadcast::channel::<#events>(20);    },
                quote::quote! { events: event_receiver,                                                                 },
                quote::quote! { ,event_sender                                                                           },
            ),
            None => (TokenStream::new(), TokenStream::new(), TokenStream::new()),
        };
        let actor_impl = quote::quote! {
            impl #struct_name {
                pub fn run(self) -> #proxy {
                    let (msg_sender, mut msg_receiver) = tokio::sync::mpsc::channel(20);
                    #events1
                    let (stop_sender, stop_receiver) = tokio::sync::oneshot::channel::<()>();
                    let (task_sender, mut task_receiver) = tokio::sync::mpsc::channel::<Task<#struct_name>>(20);
                    tokio::spawn(async move {
                        let mut actor = self;
                        actor.start(task_sender  #events3 ).await;
                        tokio::select! {
                            _ = actor.select_receivers(&mut msg_receiver, &mut task_receiver) => { log::debug!("(abcgen) all proxies dropped"); }  // all proxies were dropped => shutdown
                            _ = stop_receiver => { log::debug!("(abcgen) stop signal received"); } // stop signal received => shutdown
                        }
                        // we get here when the actor is done
                        actor.shutdown().await;
                    });

                    // build the proxy
                    let proxy = #proxy {
                        message_sender: msg_sender,
                        stop_signal: Some(stop_sender),
                        #events2
                    };

                    proxy
                }

                async fn select_receivers(
                    &mut self,
                    msg_receiver: &mut tokio::sync::mpsc::Receiver<#messages_enum_name>,
                    task_receiver: &mut tokio::sync::mpsc::Receiver<Task<#struct_name>>,
                ) {
                    loop {
                        tokio::select! {
                            msg = msg_receiver.recv() => {
                                match msg {
                                    Some(msg) => { self.dispatch(msg).await; }
                                    None => { break; } // channel closed => shutdown
                                }
                            },
                            task = task_receiver.recv() => {
                                if let Some(task) = task {
                                    task(self).await;
                                }
                            }
                        }
                    }
                }

                #message_dispatcher_method

                /// Helper function to send a task to be invoked in the actor loop
                fn invoke(sender: &tokio::sync::mpsc::Sender<Task<#struct_name>>, task: Task<#struct_name>) -> Result<(), AbcgenError> {
                    sender.try_send(task)
                          .map_err(|e| AbcgenError::ChannelError(Box::new(e)))
                }
                //fn invoke_fn(sender: &tokio::sync::mpsc::Sender<Task<#struct_name>>, f: fn(&mut #struct_name) -> PinnedFuture<()> + Send>) -> Result<(), AbcgenError> {
                //    Self::invoke_task(sender, Box::new(move |actor| f(actor)))
                //}

            }

        };

        let message_enum_code = self.message_enum.generate()?;
        let proxy_code = self.proxy.generate();
        let code = quote::quote! {
            #message_enum_code
            #proxy_code
            #actor_impl
        };

        Ok(code)
    }

    pub fn generate_dispatcher_method(&self) -> TokenStream {
        let message_dispatcher_method = self
            .handler_methods
            .iter()
            .map(|m| self.generate_message_handler_case(m));
        let enum_name = &self.message_enum.name;
        quote::quote! {
            async fn dispatch(&mut self, message: #enum_name) {
                match message {
                    #(#message_dispatcher_method)*
                }
            }
        }
    }
    pub fn generate_message_handler_case(&self, method: &MessageHandlerMethod) -> TokenStream {
        let method_name = method.get_name_snake_case();
        let variant_name = method.get_name_camel_case();
        let enum_name = &self.message_enum.name;

        let method_params_names: Vec<_> = method.get_parameter_names();
        if method.has_return_type() {
            quote::quote! {
                #enum_name::#variant_name { #(#method_params_names),* , respond_to } => {
                    let result = self.#method_name(#(#method_params_names),*).await;
                    respond_to.send(result).unwrap();
                }
            }
        } else {
            quote::quote! {
                #enum_name::#variant_name { #(#method_params_names),* } => {
                    self.#method_name(#(#method_params_names),*).await;
                }
            }
        }
    }
}

fn test_attribute(attr: &syn::Attribute, expected: &str) -> bool {
    attr.path().segments.last().unwrap().ident == expected
}

fn is_actor(item: &Item) -> bool {
    let attribututes = match item {
        Item::Struct(it) => &it.attrs,
        Item::Enum(it) => &it.attrs,
        _ => return false,
    };
    attribututes
        .iter()
        .any(|attr| test_attribute(attr, "actor"))
}

fn is_impl_of<'a>(item: &'a Item, id: &'a Ident) -> Option<&'a syn::ItemImpl> {
    if let Item::Impl(item_impl) = item {
        if let Type::Path(tp) = item_impl.self_ty.as_ref() {
            if tp.path.segments.last().unwrap().ident == *id {
                return Some(item_impl);
            }
        }
    }
    None
}

fn extract_events_enum(items: &[syn::Item]) -> Result<Vec<&Ident>> {
    let events: Vec<_> = items
        .iter()
        .filter_map(|i| {
            if let Item::Enum(e) = i {
                if e.attrs.iter().any(|a| test_attribute(a, "events")) {
                    return Some(&e.ident);
                }
            }
            None
        })
        .collect();
    Ok(events)
}