ab_code_gen/
actor_proxy.rs

1use proc_macro2::TokenStream;
2use syn::{Ident, ReturnType, Type, TypePath};
3
4use crate::{utils::type_path_from_generic_argument, MessageHandlerMethod};
5
6pub struct ActorProxy<'a> {
7    pub name: Ident,
8    methods: Vec<MessageHandlerMethod<'a>>,
9    message_enum: Ident,
10    events_enum: Option<&'a Ident>,
11    convertible_errors: Vec<&'a TypePath>,
12}
13
14impl ActorProxy<'_> {
15    pub fn new<'a>(
16        name: Ident,
17        message_enum: Ident,
18        events_enum: Option<&'a Ident>,
19        methods: &[MessageHandlerMethod<'a>],
20        convertible_errors: Vec<&'a TypePath>,
21    ) -> ActorProxy<'a> {
22        ActorProxy {
23            name,
24            methods: methods.to_vec(),
25            message_enum,
26            events_enum,
27            convertible_errors,
28        }
29    }
30
31    pub fn generate(&self) -> TokenStream {
32        let struct_name = &self.name;
33        let message_enum_name = &self.message_enum;
34        let methods = self
35            .methods
36            .iter()
37            .map(|m| self.generate_method(m))
38            .collect::<Vec<_>>();
39        let events_enum: Vec<_> = self.events_enum.iter().collect();
40
41        // -- def --
42        let struct_def: TokenStream = quote::quote! {
43            pub struct #struct_name {
44                message_sender: tokio::sync::mpsc::Sender<#message_enum_name>,
45                #(events: tokio::sync::broadcast::Sender<#events_enum>,)*
46                stop_signal: std::option::Option<tokio::sync::oneshot::Sender<()>>,
47            }
48        };
49
50        // -- impl --
51        let struct_impl = quote::quote! {
52            impl #struct_name {
53                pub fn is_running(&self) -> bool {
54                    match self.stop_signal.as_ref() {
55                        Some(s) => !s.is_closed(),
56                        None => false,
57                    }
58                }
59
60                pub fn stop(&mut self) -> Result<(), AbcgenError> {
61                    match self.stop_signal.take() {
62                        Some(tx) => tx.send(()).map_err(|_e: ()| AbcgenError::ActorShutDown),
63                        None => Err(AbcgenError::ActorShutDown),
64                    }
65                }
66
67                pub async fn stop_and_wait(&mut self) -> Result<(), AbcgenError> {
68                    self.stop()?;
69                    while self.is_running() {
70                        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
71                    }
72                    Ok(())
73                }
74
75                #(pub fn get_events(&self) -> tokio::sync::broadcast::Receiver<#events_enum> {
76                    self.events.subscribe()
77                })*
78
79                //---- message sender methods ----
80                #(#methods)*
81
82            }
83        };
84
85        quote::quote! {
86            #struct_def
87            #struct_impl
88        }
89    }
90
91    fn generate_method(&self, handler: &MessageHandlerMethod) -> TokenStream {
92        let fn_name = handler.get_name_snake_case();
93        let msg_name = handler.get_name_camel_case();
94        let message_enum_name = &self.message_enum;
95        let parameters = handler
96            .parameters
97            .iter()
98            .map(|(name, ty)| {
99                quote::quote! {
100                    #name: #ty
101                }
102            })
103            .collect::<Vec<_>>();
104
105        let parameters_names = handler
106            .parameters
107            .iter()
108            .map(|(name, _)| name)
109            .collect::<Vec<_>>();
110
111        if let ReturnType::Type(_, ref return_type) = handler.original.sig.output {
112            // if the return type is Result<U,V> and V implements From<AbcgenError> then instead of
113            // returning Result<Result<U,V>,AbcgenError> we return Result<U,V> directly
114            let mut can_be_converted = false;
115            if let Type::Path(ref path) = **return_type {
116                can_be_converted = self.check_if_can_be_converted(path);
117            }
118            if can_be_converted {
119                quote::quote! {
120                    pub async fn #fn_name(&self, #(#parameters),*) -> #return_type {
121                        let (tx, rx) = tokio::sync::oneshot::channel();
122                        let msg = #message_enum_name::#msg_name { #(#parameters_names,)* respond_to: tx };
123                        let send_res = self.message_sender.send(msg).await;
124                        match send_res {
125                            Ok(_) => rx.await.unwrap_or_else(|e| Err(AbcgenError::ActorShutDown.into())),
126                            Err(e) => Err(AbcgenError::ActorShutDown.into()),
127                        }
128                    }
129                }
130            } else {
131                quote::quote! {
132                    pub async fn #fn_name(&self, #(#parameters),*) -> Result<#return_type, AbcgenError> {
133                        let (tx, rx) = tokio::sync::oneshot::channel();
134                        let msg = #message_enum_name::#msg_name { #(#parameters_names,)* respond_to: tx };
135                        let send_res = self.message_sender.send(msg).await;
136                        match send_res {
137                            Ok(_) => rx.await.map_err(|e| AbcgenError::ActorShutDown),
138                            Err(e) => Err(AbcgenError::ActorShutDown),
139                        }
140                    }
141                }
142            }
143        } else {
144            quote::quote! {
145                pub async fn #fn_name(&self, #(#parameters),*) -> Result<(), AbcgenError> {
146                    let msg = #message_enum_name::#msg_name { #(#parameters_names),* };
147                    let send_res = self.message_sender.send(msg).await.map_err(|e| AbcgenError::ActorShutDown );
148                    send_res
149                }
150            }
151        }
152    }
153
154    fn check_if_can_be_converted(&self, path: &TypePath) -> bool {
155        // path her is something like Result<U,V>
156        let mut can_be_converted = false;
157        if let Some(segment) = path.path.segments.last() {
158            if segment.ident == "Result" {
159                match &segment.arguments {
160                    syn::PathArguments::None => {}
161                    syn::PathArguments::AngleBracketed(args) => {
162                        let err_type = if args.args.len() == 2 {
163                            type_path_from_generic_argument(&args.args[1])
164                        } else if args.args.len() == 1 {
165                            type_path_from_generic_argument(&args.args[0])
166                        } else {
167                            None
168                        };
169                        if let Some(err_type) = err_type {
170                            for ty in self.convertible_errors.iter() {
171                                if crate::utils::compare_type_path(&err_type, ty) {
172                                    can_be_converted = true;
173                                    break;
174                                }
175                            }
176                        }
177                    }
178                    syn::PathArguments::Parenthesized(_) => {}
179                }
180            }
181        }
182        can_be_converted
183    }
184}