amqp_api_server/api/input/
amqp_request_dispatch.rs

1use std::sync::atomic::AtomicU16;
2use std::sync::Arc;
3use amqp_api_shared::request_result::RequestResult;
4
5use crate::api::input::amqp_request_replier;
6use crate::api::input::authorizer::Authorizer;
7use async_channel::Sender;
8use state_tracker::state::State;
9use state_tracker::state_tracker_client::StateTrackerClient;
10use futures_util::TryStreamExt;
11use lapin::message::Delivery;
12use lapin::{Channel, Consumer};
13use serde_json::{Map, Value};
14use uuid::Uuid;
15
16use crate::api::input::input_element::InputElement;
17use crate::api::input::request::Request;
18use crate::api::input::sanitizer::sanitize;
19use crate::error::{Error, ErrorKind};
20
21use super::amqp_request_replier::AmqpRequestReplier;
22
23pub struct AmqpRequestDispatch<LogicRequestType> {
24    channel: Arc<Channel>,
25    element: InputElement<LogicRequestType>,
26    authorizer: Arc<Authorizer>,
27    logic_request_sender: Sender<LogicRequestType>,
28    current_concurrent_requests: Arc<AtomicU16>,
29    state_tracker_client: StateTrackerClient,
30}
31
32impl<LogicRequestType: Send + 'static> AmqpRequestDispatch<LogicRequestType> {
33    pub fn new(
34        channel: Arc<Channel>,
35        element: InputElement<LogicRequestType>,
36        authorizer: Arc<Authorizer>,
37        logic_request_sender: Sender<LogicRequestType>,
38        mut state_tracker_client: StateTrackerClient,
39    ) -> AmqpRequestDispatch<LogicRequestType> {
40        state_tracker_client.set_id(element.name().to_string());
41
42        AmqpRequestDispatch {
43            channel,
44            element,
45            authorizer,
46            logic_request_sender,
47            current_concurrent_requests: Arc::new(AtomicU16::new(0)),
48            state_tracker_client
49        }
50    }
51
52    /// Blocks thread as long as the program is running.
53    /// Deliveries are received, sanitized and authorized before being moved into a
54    /// new task where the request will be handled.
55    pub async fn run(self) -> Result<(), Error> {
56        let queue = match self
57            .channel
58            .queue_declare(
59                self.element.name(),
60                *self.element.config().queue_consumer().queue().declare().options(),
61                self.element
62                    .config()
63                    .queue_consumer()
64                    .queue()
65                    .declare()
66                    .arguments()
67                    .clone(),
68            )
69            .await
70        {
71            Ok(queue) => queue,
72            Err(error) => {
73                return Err(Error::new(
74                    ErrorKind::AmqpFailure,
75                    format!("failed to declare queue: {}", error),
76                ));
77            }
78        };
79
80        match self
81            .channel
82            .basic_qos(
83                self.element.config().queue_consumer().qos().prefetch_count(),
84                *self.element.config().queue_consumer().qos().options(),
85            )
86            .await
87        {
88            Ok(()) => (),
89            Err(error) => {
90                return Err(Error::new(
91                    ErrorKind::AmqpFailure,
92                    format!("failure basic qos: {}", error),
93                ));
94            }
95        }
96
97        let mut consumer = self.try_get_consumer(queue.name().as_str()).await?;
98
99        let reject_options = *self.element.config().queue_consumer().reject();
100        let acknowledge_options = *self.element.config().queue_consumer().acknowledge();
101        let max_concurrent_requests = self.element.config().max_concurrent_requests();
102
103        loop {
104            if self
105                .current_concurrent_requests
106                .load(std::sync::atomic::Ordering::Relaxed)
107                >= max_concurrent_requests
108            {
109                continue;
110            }
111
112            let state_tracker_client = self.state_tracker_client.clone();
113
114            let delivery = match consumer.try_next().await {
115                Ok(optional_delivery) => match optional_delivery {
116                    Some(delivery) => delivery,
117                    None => {
118                        log::info!("consumer got an empty delivery");
119                        continue;
120                    }
121                },
122                Err(error) => {
123                    let error_message = format!("consumer got an error: {}", error);
124
125                    match state_tracker_client.send_state(State::Error(error_message.clone())).await {
126                        Ok(_) => (),
127                        Err(error) => log::error!("failed to send error state: {}", error)
128                    }
129
130                    log::warn!("{}", error_message);
131                    continue;
132                }
133            };
134
135            let channel = self.channel.clone();
136
137            let request_replier: Option<AmqpRequestReplier> =
138                amqp_request_replier::try_generate_replier(&channel, &delivery);
139
140            let request = match self.prepare_request(&delivery).await {
141                Ok(request) => request,
142                Err(error) => {
143                    if let Some(request_replier) = request_replier {
144                        match request_replier
145                            .reply(RequestResult::Err(error.clone().into()))
146                            .await
147                        {
148                            Ok(_) => (),
149                            Err(error) => {
150                                log::warn!("failed to reply: {}", error);
151                            }
152                        }
153                    }
154
155                    log::info!("failed to prepare request: {}", error);
156                    continue;
157                }
158            };
159
160            let request_handler = self.element.request_handler();
161
162            let logic_request_sender = self.logic_request_sender.clone();
163
164            self.current_concurrent_requests
165                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
166
167            let current_concurrent_requests = self.current_concurrent_requests.clone();
168
169            tokio::spawn(async move {
170                let result = request_handler(request, logic_request_sender).await;
171                let mut state = State::Valid;
172
173                match &result {
174                    RequestResult::Ok(_) => {
175                        if let Err(error) = delivery.ack(acknowledge_options).await {
176                            let error_message = format!("failed to acknowledge delivery: {}", error);
177                            log::error!("{}", error_message);
178
179                            state = State::Error(error_message)
180                        }
181                    }
182                    RequestResult::Err(error) => {
183                        log::info!("failed to handle request: {}", error);
184
185                        match delivery.reject(reject_options).await {
186                            Ok(_) => (),
187                            Err(error) => {
188                                let error_message = format!("failed to reject delivery: {}", error);
189                                log::error!("{}", error_message);
190
191                                state = State::Error(error_message)
192                            }
193                        }
194                    }
195                }
196
197                match state_tracker_client.send_state(state).await {
198                    Ok(_) => (),
199                    Err(error) => log::warn!("failed to send state: {}", error)
200                }
201
202                if let Some(amqp_request_replier) =
203                    amqp_request_replier::try_generate_replier(&channel, &delivery)
204                {
205                    match amqp_request_replier.reply(result).await {
206                        Ok(_) => (),
207                        Err(error) => {
208                            log::info!("failed to reply: {}", error);
209                        }
210                    }
211                }
212
213                current_concurrent_requests.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
214            });
215        }
216    }
217
218    async fn try_get_consumer(&self, queue_name: &str) -> Result<Consumer, Error> {
219        let consumer_tag = format!("{}#{}", queue_name, Uuid::new_v4());
220        let consumer = match self
221            .channel
222            .basic_consume(
223                queue_name,
224                consumer_tag.as_str(),
225                *self.element.config().queue_consumer().consume().options(),
226                self.element
227                    .config()
228                    .queue_consumer()
229                    .consume()
230                    .arguments()
231                    .clone(),
232            )
233            .await
234        {
235            Ok(consumer) => consumer,
236            Err(error) => {
237                return Err(Error::new(
238                    ErrorKind::AmqpFailure,
239                    format!("failure basic consume: {}", error),
240                ));
241            }
242        };
243
244        Ok(consumer)
245    }
246
247    async fn prepare_request(&self, delivery: &Delivery) -> Result<Request, Error> {
248        let reject_options = *self.element.config().queue_consumer().reject();
249
250        let request_data = match std::str::from_utf8(delivery.data.as_slice()) {
251            Ok(request_data) => request_data,
252            Err(error) => {
253                return match delivery.reject(reject_options).await {
254                    Ok(_) => Err(Error::new(
255                        ErrorKind::MalformedRequest,
256                        format!("delivery is not an utf8 string: {}", error),
257                    )),
258                    Err(error) => Err(Error::new(
259                        ErrorKind::AmqpFailure,
260                        format!("failed to reject delivery: {}", error),
261                    )),
262                };
263            }
264        };
265
266        let raw_request = match serde_json::from_str::<Map<String, Value>>(request_data) {
267            Ok(raw_request) => raw_request,
268            Err(error) => {
269                return match delivery.reject(reject_options).await {
270                    Ok(()) => Err(Error::new(
271                        ErrorKind::MalformedRequest,
272                        format!("delivery is not a json object: {}", error),
273                    )),
274                    Err(error) => Err(Error::new(
275                        ErrorKind::AmqpFailure,
276                        format!("failed to reject delivery: {}", error),
277                    )),
278                };
279            }
280        };
281
282        let mut request = match sanitize(raw_request, self.element.actions()) {
283            Ok(request) => request,
284            Err(error) => {
285                return match delivery.reject(reject_options).await {
286                    Ok(()) => Err(Error::new(
287                        ErrorKind::MalformedRequest,
288                        format!("request sanitization failure: {}", error),
289                    )),
290                    Err(error) => Err(Error::new(
291                        ErrorKind::AmqpFailure,
292                        format!("failed to reject delivery: {}", error),
293                    )),
294                };
295            }
296        };
297
298        request = match self.authorizer.authorize(request) {
299            Ok(request) => request,
300            Err(error) => {
301                return match delivery.reject(reject_options).await {
302                    Ok(()) => Err(Error::new(
303                        ErrorKind::MalformedRequest,
304                        format!("request sanitization failure: {}", error),
305                    )),
306                    Err(error) => Err(Error::new(
307                        ErrorKind::AmqpFailure,
308                        format!("failed to reject delivery: {}", error),
309                    )),
310                };
311            }
312        };
313
314        Ok(request)
315    }
316}