amqp_api_server/api/input/
amqp_request_dispatch.rs1use std::sync::atomic::AtomicU16;
2use std::sync::Arc;
3use amqp_api_shared::request_result::RequestResult;
4
5use crate::api::input::amqp_request_replier;
6use crate::api::input::authorizer::Authorizer;
7use async_channel::Sender;
8use state_tracker::state::State;
9use state_tracker::state_tracker_client::StateTrackerClient;
10use futures_util::TryStreamExt;
11use lapin::message::Delivery;
12use lapin::{Channel, Consumer};
13use serde_json::{Map, Value};
14use uuid::Uuid;
15
16use crate::api::input::input_element::InputElement;
17use crate::api::input::request::Request;
18use crate::api::input::sanitizer::sanitize;
19use crate::error::{Error, ErrorKind};
20
21use super::amqp_request_replier::AmqpRequestReplier;
22
23pub struct AmqpRequestDispatch<LogicRequestType> {
24 channel: Arc<Channel>,
25 element: InputElement<LogicRequestType>,
26 authorizer: Arc<Authorizer>,
27 logic_request_sender: Sender<LogicRequestType>,
28 current_concurrent_requests: Arc<AtomicU16>,
29 state_tracker_client: StateTrackerClient,
30}
31
32impl<LogicRequestType: Send + 'static> AmqpRequestDispatch<LogicRequestType> {
33 pub fn new(
34 channel: Arc<Channel>,
35 element: InputElement<LogicRequestType>,
36 authorizer: Arc<Authorizer>,
37 logic_request_sender: Sender<LogicRequestType>,
38 mut state_tracker_client: StateTrackerClient,
39 ) -> AmqpRequestDispatch<LogicRequestType> {
40 state_tracker_client.set_id(element.name().to_string());
41
42 AmqpRequestDispatch {
43 channel,
44 element,
45 authorizer,
46 logic_request_sender,
47 current_concurrent_requests: Arc::new(AtomicU16::new(0)),
48 state_tracker_client
49 }
50 }
51
52 pub async fn run(self) -> Result<(), Error> {
56 let queue = match self
57 .channel
58 .queue_declare(
59 self.element.name(),
60 *self.element.config().queue_consumer().queue().declare().options(),
61 self.element
62 .config()
63 .queue_consumer()
64 .queue()
65 .declare()
66 .arguments()
67 .clone(),
68 )
69 .await
70 {
71 Ok(queue) => queue,
72 Err(error) => {
73 return Err(Error::new(
74 ErrorKind::AmqpFailure,
75 format!("failed to declare queue: {}", error),
76 ));
77 }
78 };
79
80 match self
81 .channel
82 .basic_qos(
83 self.element.config().queue_consumer().qos().prefetch_count(),
84 *self.element.config().queue_consumer().qos().options(),
85 )
86 .await
87 {
88 Ok(()) => (),
89 Err(error) => {
90 return Err(Error::new(
91 ErrorKind::AmqpFailure,
92 format!("failure basic qos: {}", error),
93 ));
94 }
95 }
96
97 let mut consumer = self.try_get_consumer(queue.name().as_str()).await?;
98
99 let reject_options = *self.element.config().queue_consumer().reject();
100 let acknowledge_options = *self.element.config().queue_consumer().acknowledge();
101 let max_concurrent_requests = self.element.config().max_concurrent_requests();
102
103 loop {
104 if self
105 .current_concurrent_requests
106 .load(std::sync::atomic::Ordering::Relaxed)
107 >= max_concurrent_requests
108 {
109 continue;
110 }
111
112 let state_tracker_client = self.state_tracker_client.clone();
113
114 let delivery = match consumer.try_next().await {
115 Ok(optional_delivery) => match optional_delivery {
116 Some(delivery) => delivery,
117 None => {
118 log::info!("consumer got an empty delivery");
119 continue;
120 }
121 },
122 Err(error) => {
123 let error_message = format!("consumer got an error: {}", error);
124
125 match state_tracker_client.send_state(State::Error(error_message.clone())).await {
126 Ok(_) => (),
127 Err(error) => log::error!("failed to send error state: {}", error)
128 }
129
130 log::warn!("{}", error_message);
131 continue;
132 }
133 };
134
135 let channel = self.channel.clone();
136
137 let request_replier: Option<AmqpRequestReplier> =
138 amqp_request_replier::try_generate_replier(&channel, &delivery);
139
140 let request = match self.prepare_request(&delivery).await {
141 Ok(request) => request,
142 Err(error) => {
143 if let Some(request_replier) = request_replier {
144 match request_replier
145 .reply(RequestResult::Err(error.clone().into()))
146 .await
147 {
148 Ok(_) => (),
149 Err(error) => {
150 log::warn!("failed to reply: {}", error);
151 }
152 }
153 }
154
155 log::info!("failed to prepare request: {}", error);
156 continue;
157 }
158 };
159
160 let request_handler = self.element.request_handler();
161
162 let logic_request_sender = self.logic_request_sender.clone();
163
164 self.current_concurrent_requests
165 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
166
167 let current_concurrent_requests = self.current_concurrent_requests.clone();
168
169 tokio::spawn(async move {
170 let result = request_handler(request, logic_request_sender).await;
171 let mut state = State::Valid;
172
173 match &result {
174 RequestResult::Ok(_) => {
175 if let Err(error) = delivery.ack(acknowledge_options).await {
176 let error_message = format!("failed to acknowledge delivery: {}", error);
177 log::error!("{}", error_message);
178
179 state = State::Error(error_message)
180 }
181 }
182 RequestResult::Err(error) => {
183 log::info!("failed to handle request: {}", error);
184
185 match delivery.reject(reject_options).await {
186 Ok(_) => (),
187 Err(error) => {
188 let error_message = format!("failed to reject delivery: {}", error);
189 log::error!("{}", error_message);
190
191 state = State::Error(error_message)
192 }
193 }
194 }
195 }
196
197 match state_tracker_client.send_state(state).await {
198 Ok(_) => (),
199 Err(error) => log::warn!("failed to send state: {}", error)
200 }
201
202 if let Some(amqp_request_replier) =
203 amqp_request_replier::try_generate_replier(&channel, &delivery)
204 {
205 match amqp_request_replier.reply(result).await {
206 Ok(_) => (),
207 Err(error) => {
208 log::info!("failed to reply: {}", error);
209 }
210 }
211 }
212
213 current_concurrent_requests.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
214 });
215 }
216 }
217
218 async fn try_get_consumer(&self, queue_name: &str) -> Result<Consumer, Error> {
219 let consumer_tag = format!("{}#{}", queue_name, Uuid::new_v4());
220 let consumer = match self
221 .channel
222 .basic_consume(
223 queue_name,
224 consumer_tag.as_str(),
225 *self.element.config().queue_consumer().consume().options(),
226 self.element
227 .config()
228 .queue_consumer()
229 .consume()
230 .arguments()
231 .clone(),
232 )
233 .await
234 {
235 Ok(consumer) => consumer,
236 Err(error) => {
237 return Err(Error::new(
238 ErrorKind::AmqpFailure,
239 format!("failure basic consume: {}", error),
240 ));
241 }
242 };
243
244 Ok(consumer)
245 }
246
247 async fn prepare_request(&self, delivery: &Delivery) -> Result<Request, Error> {
248 let reject_options = *self.element.config().queue_consumer().reject();
249
250 let request_data = match std::str::from_utf8(delivery.data.as_slice()) {
251 Ok(request_data) => request_data,
252 Err(error) => {
253 return match delivery.reject(reject_options).await {
254 Ok(_) => Err(Error::new(
255 ErrorKind::MalformedRequest,
256 format!("delivery is not an utf8 string: {}", error),
257 )),
258 Err(error) => Err(Error::new(
259 ErrorKind::AmqpFailure,
260 format!("failed to reject delivery: {}", error),
261 )),
262 };
263 }
264 };
265
266 let raw_request = match serde_json::from_str::<Map<String, Value>>(request_data) {
267 Ok(raw_request) => raw_request,
268 Err(error) => {
269 return match delivery.reject(reject_options).await {
270 Ok(()) => Err(Error::new(
271 ErrorKind::MalformedRequest,
272 format!("delivery is not a json object: {}", error),
273 )),
274 Err(error) => Err(Error::new(
275 ErrorKind::AmqpFailure,
276 format!("failed to reject delivery: {}", error),
277 )),
278 };
279 }
280 };
281
282 let mut request = match sanitize(raw_request, self.element.actions()) {
283 Ok(request) => request,
284 Err(error) => {
285 return match delivery.reject(reject_options).await {
286 Ok(()) => Err(Error::new(
287 ErrorKind::MalformedRequest,
288 format!("request sanitization failure: {}", error),
289 )),
290 Err(error) => Err(Error::new(
291 ErrorKind::AmqpFailure,
292 format!("failed to reject delivery: {}", error),
293 )),
294 };
295 }
296 };
297
298 request = match self.authorizer.authorize(request) {
299 Ok(request) => request,
300 Err(error) => {
301 return match delivery.reject(reject_options).await {
302 Ok(()) => Err(Error::new(
303 ErrorKind::MalformedRequest,
304 format!("request sanitization failure: {}", error),
305 )),
306 Err(error) => Err(Error::new(
307 ErrorKind::AmqpFailure,
308 format!("failed to reject delivery: {}", error),
309 )),
310 };
311 }
312 };
313
314 Ok(request)
315 }
316}