cp_microservice/api/server/
dispatch.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::api::server::input::action::Action;
6use crate::api::server::input::executor::Executor;
7use crate::api::server::input::input::Input;
8use crate::api::server::input::input_data::InputData;
9use crate::api::server::input::input_plugin::InputPlugin;
10use crate::api::server::input::replier::Replier;
11use crate::api::shared::request::Request;
12use crate::api::shared::request_header::RequestHeader;
13use async_channel::Sender;
14use async_trait::async_trait;
15use log::{info, warn};
16use serde_json::{json, Value};
17use tokio::sync::RwLock;
18use tokio::task::JoinHandle;
19use tokio::time::{sleep, timeout};
20use tokio_util::sync::CancellationToken;
21
22use crate::core::error::Error;
23
24pub struct Dispatch<InputImpl: 'static + Input + Send, LogicRequestType: 'static + Send> {
25    inputs: Vec<InputImpl>,
26    actions: Arc<HashMap<String, Action<LogicRequestType>>>,
27    sender: Sender<LogicRequestType>,
28    plugins: Arc<Vec<Arc<dyn InputPlugin + Send + Sync>>>,
29}
30
31impl<InputImpl: 'static + Input + Send, LogicRequestType: 'static + Send>
32    Dispatch<InputImpl, LogicRequestType>
33{
34    pub fn new(
35        inputs: Vec<InputImpl>,
36        actions: HashMap<String, Action<LogicRequestType>>,
37        sender: Sender<LogicRequestType>,
38        plugins: Vec<Arc<dyn InputPlugin + Send + Sync>>,
39    ) -> Dispatch<InputImpl, LogicRequestType> {
40        Dispatch {
41            inputs,
42            actions: Arc::new(actions),
43            sender,
44            plugins: Arc::new(plugins),
45        }
46    }
47
48    pub async fn run(self, cancellation_token: CancellationToken) -> Vec<JoinHandle<()>> {
49        let mut api_handles = Vec::new();
50
51        for input in self.inputs {
52            let actions_pointer: Arc<HashMap<String, Action<LogicRequestType>>> =
53                self.actions.clone();
54            let logic_request_sender = self.sender.clone();
55            let plugins_pointer = self.plugins.clone();
56
57            api_handles.push(tokio::spawn(run_dispatch_input(
58                input,
59                actions_pointer,
60                logic_request_sender,
61                plugins_pointer,
62                cancellation_token.clone(),
63            )));
64        }
65
66        api_handles
67    }
68}
69
70fn get_filtered_out_plugins_for_action<LogicRequestType>(
71    action: &str,
72    actions: &Arc<HashMap<String, Action<LogicRequestType>>>,
73) -> Vec<String> {
74    match actions.get(action) {
75        Some(action) => action.filter_out_plugins(),
76        None => Vec::new(),
77    }
78}
79
80async fn handle_input_data<LogicRequestType: 'static + Send>(
81    input_data: InputData,
82    actions: &Arc<HashMap<String, Action<LogicRequestType>>>,
83    sender: Sender<LogicRequestType>,
84) {
85    let action = input_data.request.header().action();
86
87    match actions.get(action) {
88        Some(action) => {
89            let executor = action.executor();
90            let action_result = executor(input_data.request, sender).await;
91
92            let replier: Replier = input_data.replier;
93            if let Err(error) = replier(json!(action_result)).await {
94                warn!("failed to reply with action_result: {}", error);
95            }
96        }
97        None => {
98            info!("unknown action received: {}", action);
99        }
100    }
101}
102
103async fn run_dispatch_input<InputImpl: 'static + Input + Send, LogicRequestType: 'static + Send>(
104    mut input: InputImpl,
105    actions_pointer: Arc<HashMap<String, Action<LogicRequestType>>>,
106    logic_request_sender: Sender<LogicRequestType>,
107    plugins_pointer: Arc<Vec<Arc<dyn InputPlugin + Send + Sync>>>,
108    cancellation_token: CancellationToken,
109) {
110    loop {
111        if cancellation_token.is_cancelled() {
112            info!("cancellation token is cancelled, api dispatch is stopping");
113
114            break;
115        }
116
117        let result = input.receive().await;
118
119        match result {
120            Ok(mut input_data) => {
121                if plugins_pointer.len() == 0 {
122                    handle_input_data::<LogicRequestType>(
123                        input_data,
124                        &actions_pointer,
125                        logic_request_sender.clone(),
126                    )
127                    .await;
128                } else {
129                    let filtered_out_plugins =
130                        get_filtered_out_plugins_for_action::<LogicRequestType>(
131                            input_data.request.header().action(),
132                            &actions_pointer,
133                        );
134
135                    for (index, plugin) in plugins_pointer.as_slice().iter().enumerate() {
136                        if !filtered_out_plugins.contains(&plugin.id().to_string()) {
137                            input_data = match plugin.handle_input_data(input_data).await {
138                                Ok(input_data) => input_data,
139                                Err((input_data, error)) => {
140                                    let replier = input_data.replier;
141
142                                    let error_value = match serde_json::to_value(error.clone()) {
143                                        Ok(error_value) => error_value,
144                                        Err(error) => {
145                                            json!(format!("failed to process request: {}", error))
146                                        }
147                                    };
148
149                                    match replier(error_value).await {
150                                        Ok(_) => (),
151                                        Err(error) => {
152                                            warn!("failed to reply when plugin failed: {}", error)
153                                        }
154                                    }
155
156                                    warn!("plugin failed to handle input data: {}", error);
157                                    break;
158                                }
159                            };
160                        }
161
162                        if index == plugins_pointer.len() - 1 {
163                            handle_input_data::<LogicRequestType>(
164                                input_data,
165                                &actions_pointer,
166                                logic_request_sender.clone(),
167                            )
168                            .await;
169
170                            break;
171                        }
172                    }
173                }
174            }
175            Err(error) => {
176                warn!("failed to receive input: {}", error);
177            }
178        }
179    }
180}
181
182#[cfg(test)]
183pub struct LogicRequest {}
184
185pub struct InputTimedImpl {
186    sleep_duration: Duration,
187    sender: tokio::sync::mpsc::Sender<()>,
188}
189
190impl InputTimedImpl {
191    pub fn new(sleep_duration: Duration, sender: tokio::sync::mpsc::Sender<()>) -> InputTimedImpl {
192        InputTimedImpl {
193            sleep_duration,
194            sender,
195        }
196    }
197}
198
199#[async_trait]
200impl Input for InputTimedImpl {
201    async fn receive(&mut self) -> Result<InputData, Error> {
202        sleep(self.sleep_duration).await;
203        self.sender
204            .send(())
205            .await
206            .expect("failed to send empty message");
207
208        Ok(InputData {
209            request: Request::new(
210                RequestHeader::new("".to_string(), "".to_string()),
211                Value::Null,
212            ),
213            replier: Arc::new(move |value: Value| Box::pin(async { Ok(()) })),
214        })
215    }
216}
217
218#[tokio::test]
219pub async fn handle_multiple_inputs_concurrently() {
220    let sleep_duration: Duration = Duration::from_millis(500u64);
221    let max_execution_duration: Duration = Duration::from_millis(900u64);
222    let expected_inputs: u8 = 2;
223
224    let (sender, mut receiver) = tokio::sync::mpsc::channel::<()>(1024usize);
225    let (logic_request_sender, _) = async_channel::unbounded::<LogicRequest>();
226    let inputs: Vec<InputTimedImpl> = vec![
227        InputTimedImpl::new(sleep_duration, sender.clone()),
228        InputTimedImpl::new(sleep_duration, sender.clone()),
229    ];
230    let dispatch: Dispatch<InputTimedImpl, LogicRequest> =
231        Dispatch::new(inputs, HashMap::new(), logic_request_sender, vec![]);
232
233    tokio::spawn(dispatch.run(CancellationToken::new()));
234
235    timeout(max_execution_duration, async move {
236        let mut count: u8 = 0;
237
238        for _ in 0..expected_inputs {
239            if (receiver.recv().await).is_some() {
240                count += 1;
241            }
242        }
243
244        assert_eq!(expected_inputs, count);
245    })
246    .await
247    .expect("inputs are not being received concurrently");
248}
249
250pub struct InputDummyImpl {
251    has_message_been_sent: RwLock<bool>,
252}
253
254impl Default for InputDummyImpl {
255    fn default() -> Self {
256        InputDummyImpl {
257            has_message_been_sent: RwLock::new(false),
258        }
259    }
260}
261
262#[async_trait]
263impl Input for InputDummyImpl {
264    async fn receive(&mut self) -> Result<InputData, Error> {
265        if *self.has_message_been_sent.try_read().unwrap() {
266            loop {
267                sleep(Duration::MAX).await;
268            }
269        }
270
271        let request = Request::new(
272            RequestHeader::new("".to_string(), "".to_string()),
273            Value::Null,
274        );
275        let replier: Replier = Arc::new(move |value| Box::pin(async { Ok(()) }));
276
277        if !(*self.has_message_been_sent.try_read().unwrap()) {
278            *self.has_message_been_sent.try_write().unwrap() = true;
279        }
280
281        Ok(InputData { request, replier })
282    }
283}
284
285pub struct DummyPlugin {
286    send_value: u8,
287    sender: tokio::sync::mpsc::Sender<u8>,
288}
289
290impl DummyPlugin {
291    pub fn new(send_value: u8, sender: tokio::sync::mpsc::Sender<u8>) -> DummyPlugin {
292        DummyPlugin { send_value, sender }
293    }
294}
295
296#[async_trait]
297impl InputPlugin for DummyPlugin {
298    fn id(&self) -> &str {
299        "dummy"
300    }
301
302    async fn handle_input_data(
303        &self,
304        input_data: InputData,
305    ) -> Result<InputData, (InputData, Error)> {
306        self.sender.send(self.send_value).await.unwrap();
307
308        Ok(input_data)
309    }
310}
311
312#[tokio::test]
313pub async fn execute_specified_plugins_for_each_input() {
314    const EXPECTED_SUM: u8 = 24u8;
315
316    let inputs: Vec<InputDummyImpl> = vec![InputDummyImpl::default(), InputDummyImpl::default()];
317
318    let (sender, _) = async_channel::unbounded();
319
320    let (plugin_sender, mut plugin_receiver) = tokio::sync::mpsc::channel::<u8>(1024usize);
321
322    let plugins: Vec<Arc<dyn InputPlugin + Send + Sync>> = vec![
323        Arc::new(DummyPlugin::new(13u8, plugin_sender.clone())),
324        Arc::new(DummyPlugin::new(11u8, plugin_sender)),
325    ];
326
327    let dispatch: Dispatch<InputDummyImpl, LogicRequest> =
328        Dispatch::new(inputs, HashMap::new(), sender, plugins);
329
330    tokio::spawn(dispatch.run(CancellationToken::new()));
331
332    let mut sum: u8 = 0;
333
334    for _ in 0..2 {
335        sum += timeout(Duration::from_millis(200u64), plugin_receiver.recv())
336            .await
337            .expect("timed out waiting for plugin to send byte")
338            .expect("failed to receive byte from plugin");
339    }
340
341    assert_eq!(EXPECTED_SUM, sum);
342}