cp_microservice/api/server/
dispatch.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use crate::api::server::input::action::Action;
6use crate::api::server::input::executor::Executor;
7use crate::api::server::input::input::Input;
8use crate::api::server::input::input_data::InputData;
9use crate::api::server::input::input_plugin::InputPlugin;
10use crate::api::server::input::replier::Replier;
11use crate::api::shared::request::Request;
12use crate::api::shared::request_header::RequestHeader;
13use async_channel::Sender;
14use async_trait::async_trait;
15use log::{info, warn};
16use serde_json::{json, Value};
17use tokio::sync::RwLock;
18use tokio::task::JoinHandle;
19use tokio::time::{sleep, timeout};
20use tokio_util::sync::CancellationToken;
21
22use crate::core::error::Error;
23
24pub struct Dispatch<InputImpl: 'static + Input + Send, LogicRequestType: 'static + Send> {
25 inputs: Vec<InputImpl>,
26 actions: Arc<HashMap<String, Action<LogicRequestType>>>,
27 sender: Sender<LogicRequestType>,
28 plugins: Arc<Vec<Arc<dyn InputPlugin + Send + Sync>>>,
29}
30
31impl<InputImpl: 'static + Input + Send, LogicRequestType: 'static + Send>
32 Dispatch<InputImpl, LogicRequestType>
33{
34 pub fn new(
35 inputs: Vec<InputImpl>,
36 actions: HashMap<String, Action<LogicRequestType>>,
37 sender: Sender<LogicRequestType>,
38 plugins: Vec<Arc<dyn InputPlugin + Send + Sync>>,
39 ) -> Dispatch<InputImpl, LogicRequestType> {
40 Dispatch {
41 inputs,
42 actions: Arc::new(actions),
43 sender,
44 plugins: Arc::new(plugins),
45 }
46 }
47
48 pub async fn run(self, cancellation_token: CancellationToken) -> Vec<JoinHandle<()>> {
49 let mut api_handles = Vec::new();
50
51 for input in self.inputs {
52 let actions_pointer: Arc<HashMap<String, Action<LogicRequestType>>> =
53 self.actions.clone();
54 let logic_request_sender = self.sender.clone();
55 let plugins_pointer = self.plugins.clone();
56
57 api_handles.push(tokio::spawn(run_dispatch_input(
58 input,
59 actions_pointer,
60 logic_request_sender,
61 plugins_pointer,
62 cancellation_token.clone(),
63 )));
64 }
65
66 api_handles
67 }
68}
69
70fn get_filtered_out_plugins_for_action<LogicRequestType>(
71 action: &str,
72 actions: &Arc<HashMap<String, Action<LogicRequestType>>>,
73) -> Vec<String> {
74 match actions.get(action) {
75 Some(action) => action.filter_out_plugins(),
76 None => Vec::new(),
77 }
78}
79
80async fn handle_input_data<LogicRequestType: 'static + Send>(
81 input_data: InputData,
82 actions: &Arc<HashMap<String, Action<LogicRequestType>>>,
83 sender: Sender<LogicRequestType>,
84) {
85 let action = input_data.request.header().action();
86
87 match actions.get(action) {
88 Some(action) => {
89 let executor = action.executor();
90 let action_result = executor(input_data.request, sender).await;
91
92 let replier: Replier = input_data.replier;
93 if let Err(error) = replier(json!(action_result)).await {
94 warn!("failed to reply with action_result: {}", error);
95 }
96 }
97 None => {
98 info!("unknown action received: {}", action);
99 }
100 }
101}
102
103async fn run_dispatch_input<InputImpl: 'static + Input + Send, LogicRequestType: 'static + Send>(
104 mut input: InputImpl,
105 actions_pointer: Arc<HashMap<String, Action<LogicRequestType>>>,
106 logic_request_sender: Sender<LogicRequestType>,
107 plugins_pointer: Arc<Vec<Arc<dyn InputPlugin + Send + Sync>>>,
108 cancellation_token: CancellationToken,
109) {
110 loop {
111 if cancellation_token.is_cancelled() {
112 info!("cancellation token is cancelled, api dispatch is stopping");
113
114 break;
115 }
116
117 let result = input.receive().await;
118
119 match result {
120 Ok(mut input_data) => {
121 if plugins_pointer.len() == 0 {
122 handle_input_data::<LogicRequestType>(
123 input_data,
124 &actions_pointer,
125 logic_request_sender.clone(),
126 )
127 .await;
128 } else {
129 let filtered_out_plugins =
130 get_filtered_out_plugins_for_action::<LogicRequestType>(
131 input_data.request.header().action(),
132 &actions_pointer,
133 );
134
135 for (index, plugin) in plugins_pointer.as_slice().iter().enumerate() {
136 if !filtered_out_plugins.contains(&plugin.id().to_string()) {
137 input_data = match plugin.handle_input_data(input_data).await {
138 Ok(input_data) => input_data,
139 Err((input_data, error)) => {
140 let replier = input_data.replier;
141
142 let error_value = match serde_json::to_value(error.clone()) {
143 Ok(error_value) => error_value,
144 Err(error) => {
145 json!(format!("failed to process request: {}", error))
146 }
147 };
148
149 match replier(error_value).await {
150 Ok(_) => (),
151 Err(error) => {
152 warn!("failed to reply when plugin failed: {}", error)
153 }
154 }
155
156 warn!("plugin failed to handle input data: {}", error);
157 break;
158 }
159 };
160 }
161
162 if index == plugins_pointer.len() - 1 {
163 handle_input_data::<LogicRequestType>(
164 input_data,
165 &actions_pointer,
166 logic_request_sender.clone(),
167 )
168 .await;
169
170 break;
171 }
172 }
173 }
174 }
175 Err(error) => {
176 warn!("failed to receive input: {}", error);
177 }
178 }
179 }
180}
181
182#[cfg(test)]
183pub struct LogicRequest {}
184
185pub struct InputTimedImpl {
186 sleep_duration: Duration,
187 sender: tokio::sync::mpsc::Sender<()>,
188}
189
190impl InputTimedImpl {
191 pub fn new(sleep_duration: Duration, sender: tokio::sync::mpsc::Sender<()>) -> InputTimedImpl {
192 InputTimedImpl {
193 sleep_duration,
194 sender,
195 }
196 }
197}
198
199#[async_trait]
200impl Input for InputTimedImpl {
201 async fn receive(&mut self) -> Result<InputData, Error> {
202 sleep(self.sleep_duration).await;
203 self.sender
204 .send(())
205 .await
206 .expect("failed to send empty message");
207
208 Ok(InputData {
209 request: Request::new(
210 RequestHeader::new("".to_string(), "".to_string()),
211 Value::Null,
212 ),
213 replier: Arc::new(move |value: Value| Box::pin(async { Ok(()) })),
214 })
215 }
216}
217
218#[tokio::test]
219pub async fn handle_multiple_inputs_concurrently() {
220 let sleep_duration: Duration = Duration::from_millis(500u64);
221 let max_execution_duration: Duration = Duration::from_millis(900u64);
222 let expected_inputs: u8 = 2;
223
224 let (sender, mut receiver) = tokio::sync::mpsc::channel::<()>(1024usize);
225 let (logic_request_sender, _) = async_channel::unbounded::<LogicRequest>();
226 let inputs: Vec<InputTimedImpl> = vec![
227 InputTimedImpl::new(sleep_duration, sender.clone()),
228 InputTimedImpl::new(sleep_duration, sender.clone()),
229 ];
230 let dispatch: Dispatch<InputTimedImpl, LogicRequest> =
231 Dispatch::new(inputs, HashMap::new(), logic_request_sender, vec![]);
232
233 tokio::spawn(dispatch.run(CancellationToken::new()));
234
235 timeout(max_execution_duration, async move {
236 let mut count: u8 = 0;
237
238 for _ in 0..expected_inputs {
239 if (receiver.recv().await).is_some() {
240 count += 1;
241 }
242 }
243
244 assert_eq!(expected_inputs, count);
245 })
246 .await
247 .expect("inputs are not being received concurrently");
248}
249
250pub struct InputDummyImpl {
251 has_message_been_sent: RwLock<bool>,
252}
253
254impl Default for InputDummyImpl {
255 fn default() -> Self {
256 InputDummyImpl {
257 has_message_been_sent: RwLock::new(false),
258 }
259 }
260}
261
262#[async_trait]
263impl Input for InputDummyImpl {
264 async fn receive(&mut self) -> Result<InputData, Error> {
265 if *self.has_message_been_sent.try_read().unwrap() {
266 loop {
267 sleep(Duration::MAX).await;
268 }
269 }
270
271 let request = Request::new(
272 RequestHeader::new("".to_string(), "".to_string()),
273 Value::Null,
274 );
275 let replier: Replier = Arc::new(move |value| Box::pin(async { Ok(()) }));
276
277 if !(*self.has_message_been_sent.try_read().unwrap()) {
278 *self.has_message_been_sent.try_write().unwrap() = true;
279 }
280
281 Ok(InputData { request, replier })
282 }
283}
284
285pub struct DummyPlugin {
286 send_value: u8,
287 sender: tokio::sync::mpsc::Sender<u8>,
288}
289
290impl DummyPlugin {
291 pub fn new(send_value: u8, sender: tokio::sync::mpsc::Sender<u8>) -> DummyPlugin {
292 DummyPlugin { send_value, sender }
293 }
294}
295
296#[async_trait]
297impl InputPlugin for DummyPlugin {
298 fn id(&self) -> &str {
299 "dummy"
300 }
301
302 async fn handle_input_data(
303 &self,
304 input_data: InputData,
305 ) -> Result<InputData, (InputData, Error)> {
306 self.sender.send(self.send_value).await.unwrap();
307
308 Ok(input_data)
309 }
310}
311
312#[tokio::test]
313pub async fn execute_specified_plugins_for_each_input() {
314 const EXPECTED_SUM: u8 = 24u8;
315
316 let inputs: Vec<InputDummyImpl> = vec![InputDummyImpl::default(), InputDummyImpl::default()];
317
318 let (sender, _) = async_channel::unbounded();
319
320 let (plugin_sender, mut plugin_receiver) = tokio::sync::mpsc::channel::<u8>(1024usize);
321
322 let plugins: Vec<Arc<dyn InputPlugin + Send + Sync>> = vec![
323 Arc::new(DummyPlugin::new(13u8, plugin_sender.clone())),
324 Arc::new(DummyPlugin::new(11u8, plugin_sender)),
325 ];
326
327 let dispatch: Dispatch<InputDummyImpl, LogicRequest> =
328 Dispatch::new(inputs, HashMap::new(), sender, plugins);
329
330 tokio::spawn(dispatch.run(CancellationToken::new()));
331
332 let mut sum: u8 = 0;
333
334 for _ in 0..2 {
335 sum += timeout(Duration::from_millis(200u64), plugin_receiver.recv())
336 .await
337 .expect("timed out waiting for plugin to send byte")
338 .expect("failed to receive byte from plugin");
339 }
340
341 assert_eq!(EXPECTED_SUM, sum);
342}