1use amqprs::{
2 channel::{BasicAckArguments, BasicNackArguments, BasicPublishArguments, Channel},
3 consumer::AsyncConsumer,
4 BasicProperties, Deliver,
5};
6use arc_swap::ArcSwap;
7use async_trait::async_trait;
8use tracing::error;
9use std::{collections::HashMap, sync::atomic::{AtomicUsize, Ordering}};
10use std::error::Error as StdError;
11use std::future::Future;
12use std::sync::Arc;
13use tokio::{sync::{Notify, OnceCell, oneshot::Sender}, time::{Duration, timeout}};
14use dashmap::DashMap;
15
16use crate::{api::utils::{ContentEncoding, Handler, Message, RPCHandler, TopicTrie, compress, decompress}, errors::{AppError, AppErrorType}};
17
18#[derive(Clone)]
19pub struct InternalSubscribeHandler {
20 handler: Handler,
21 process_timeout: Option<Duration>,
22}
23impl InternalSubscribeHandler {
24 pub fn new<F, Fut>(handler: Arc<F>, process_timeout: Option<Duration>) -> Self
25 where
26 F: Fn(Message) -> Fut + Send + Sync + 'static + ?Sized,
27 Fut: Future<Output = Result<(), Box<dyn StdError + Send + Sync>>> + Send + 'static,
28 {
29 Self {
30 handler: Arc::new(move |body| Box::pin(handler(body))),
31 process_timeout,
32 }
33 }
34}
35
36#[derive(Clone)]
37pub struct InternalRPCHandler {
38 handler: RPCHandler,
39 process_timeout: Option<Duration>,
40}
41impl InternalRPCHandler {
42 pub fn new(handler: RPCHandler, process_timeout: Option<Duration>) -> Self
44 {
45 Self {
46 handler: Arc::new(move |body| Box::pin(handler(body))),
47 process_timeout,
48 }
49 }
50}
51
52
53
54pub struct BroadSubscribeHandler {
55 handlers: Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>,
56 auto_ack: bool,
57 in_flight: Arc<AtomicUsize>,
58 shutdown_notify: Arc<Notify>,
59 }
61
62pub struct BroadRPCHandler {
63 channel: Arc<Channel>,
64 handlers: Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>,
65 auto_ack: bool,
66 in_flight: Arc<AtomicUsize>,
67 shutdown_notify: Arc<Notify>,
68 }
70pub struct BroadRPCClientHandler {
71 handlers: Arc<DashMap<String, Sender<Vec<u8>>>>,
72 auto_ack: bool,
73 in_flight: Arc<AtomicUsize>,
74 shutdown_notify: Arc<Notify>,
75 }
77
78impl BroadSubscribeHandler {
79 pub fn new(
80 handlers: Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>,
81 auto_ack: bool,
82 in_flight: Arc<AtomicUsize>,
83 shutdown_notify: Arc<Notify>,
84 ) -> Self {
85 Self {
86 handlers,
87 auto_ack,
88 in_flight,
89 shutdown_notify,
90 }
91 }
92}
93impl BroadRPCHandler {
94 pub fn new(
95 channel: Arc<Channel>,
96 handlers: Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>,
97 auto_ack: bool,
98 in_flight: Arc<AtomicUsize>,
99 shutdown_notify: Arc<Notify>,
100 ) -> Self {
101 Self {
102 channel,
103 handlers,
104 auto_ack,
105 in_flight,
106 shutdown_notify,
107 }
108 }
109}
110
111impl BroadRPCClientHandler {
112 pub fn new(handlers: Arc<DashMap<String, Sender<Vec<u8>>>>, auto_ack: bool, in_flight: Arc<AtomicUsize>, shutdown_notify: Arc<Notify>) -> Self {
113 Self { handlers, auto_ack, in_flight, shutdown_notify }
114 }
115}
116
117#[async_trait]
118impl AsyncConsumer for BroadRPCClientHandler {
119 async fn consume(
120 &mut self,
121 channel: &Channel,
122 deliver: Deliver,
123 basic_properties: BasicProperties,
124 content: Vec<u8>,
125 ) {
126 self.in_flight.fetch_add(1, Ordering::AcqRel);
127 if let Some(correlated_id) = basic_properties.correlation_id() {
128 if let Some(sender) = self.handlers.remove(correlated_id) {
129 if let Err(err) = sender.1.send(content) {
130 error!("The receiver dropped {:?}", err);
131 }
132 }
133 if !self.auto_ack {
134 let delivery_tag = deliver.delivery_tag();
135 let args = BasicAckArguments::new(delivery_tag, false);
136 if let Err(e) = channel.basic_ack(args).await {
137 error!("Failed to send ack: {}", e);
138 }
139 }
140 } else if !self.auto_ack {
141 let delivery_tag = deliver.delivery_tag();
142 let args = BasicNackArguments::new(delivery_tag, false, false);
143 let _ = channel.basic_nack(args).await;
144 }
145 let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
146 if previous_count == 1 {
147 self.shutdown_notify.notify_one();
148 }
149 }
150}
151
152#[async_trait]
153impl AsyncConsumer for BroadSubscribeHandler {
154 async fn consume(
155 &mut self,
156 channel: &Channel,
157 deliver: Deliver,
158 basic_properties: BasicProperties,
159 content: Vec<u8>,
160 ) {
161 self.in_flight.fetch_add(1, Ordering::AcqRel);
162
163 let routing_key = deliver.routing_key().to_string(); let handlers_guard = self.handlers.load().clone();
165 let handlers = handlers_guard.search(&routing_key);
166
167 if handlers.is_empty() {
168 error!("No handler found for routing key {}", routing_key);
169 if !self.auto_ack {
170 let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
171 let _ = channel.basic_nack(args).await;
172 }
173 let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
174 if previous_count == 1 {
175 self.shutdown_notify.notify_one();
176 }
177 return;
178 }
179
180 let channel = channel.clone();
181 let auto_ack = self.auto_ack;
182 let in_flight = Arc::clone(&self.in_flight);
183 let shutdown_notify = Arc::clone(&self.shutdown_notify);
184
185 tokio::spawn(async move {
186 let success = async {
187
188 if handlers.is_empty() {
189 error!("No handler found for routing key {}", routing_key);
190 return false;
191 }
192
193 let decompressed_content = match decompress(content, basic_properties.content_encoding().map(|e| e.as_str())) {
194 Ok(c) => c,
195 Err(e) => {
196 error!("Failed to decompress content: {}", e);
197 return false;
198 }
199 };
200
201 let futures = handlers.iter().map(|i| {
202 let content_clone = &decompressed_content;
203 let message = Message {
204 body: Arc::from(&content_clone[..]),
205 content_type: basic_properties.content_type().map(|s| s.to_string()),
206 };
207
208 async move {
209 let res = match i.process_timeout {
210 Some(dur) => match timeout(dur, (i.handler)(message)).await {
211 Ok(res) => res,
212 Err(_) => Err(AppError::new(Some("Response timeout exceed".to_string()), None, AppErrorType::TimeoutError).into()),
213 },
214 None => (i.handler)(message).await
215 };
216
217 if let Err(ref e) = res {
218 error!("Handler execution error: {}", e);
219 }
220 res
221 }
222 });
223
224 let results = futures::future::join_all(futures).await;
225
226 results.into_iter().all(|res| res.is_ok())
227 }.await;
228
229 if !auto_ack {
230 if success {
231 let args = BasicAckArguments::new(deliver.delivery_tag(), false);
232 if let Err(e) = channel.basic_ack(args).await {
233 error!("Failed to send ack: {}", e);
234 }
235 } else {
236 let args = BasicNackArguments::new(deliver.delivery_tag(), false, false);
237 if let Err(err) = channel.basic_nack(args).await {
238 error!("Failed to send nack: {}", err);
239 }
240 }
241 }
242
243 let previous_count = in_flight.fetch_sub(1, Ordering::AcqRel);
244 if previous_count == 1 {
245 shutdown_notify.notify_one();
246 }
247 });
248 }
249}
250
251#[async_trait]
252impl AsyncConsumer for BroadRPCHandler {
253 async fn consume(
254 &mut self,
255 channel: &Channel,
256 deliver: Deliver,
257 basic_properties: BasicProperties,
258 content: Vec<u8>,
259 ) {
260 self.in_flight.fetch_add(1, Ordering::AcqRel);
261
262 let routing_key = deliver.routing_key().as_str();
263
264 let handlers_guard = self.handlers.load();
265 if let Some(internal_handler) = handlers_guard.get(routing_key) {
266 let (handler, process_timeout) = (Arc::clone(&internal_handler.handler), internal_handler.process_timeout);
267 drop(handlers_guard);
268 let channel = channel.clone();
269 let aux_channel = Arc::clone(&self.channel);
270 let auto_ack = self.auto_ack;
271 let in_flight = Arc::clone(&self.in_flight);
272 let shutdown_notify = Arc::clone(&self.shutdown_notify);
273 tokio::spawn(async move {
274 match decompress(content, basic_properties.content_encoding().map(|e| e.as_str())) {
275 Ok(decompressed_content) => {
276 let message = Message {
277 body: Arc::from(&decompressed_content[..]),
278 content_type: basic_properties.content_type().map(|s| s.to_string()),
279 };
280 let result = async move {
281 match process_timeout {
282 Some(dur) => match timeout(dur, (handler)(message)).await {
283 Ok(res) => res,
284 Err(_) => Err(AppError::new(Some("Response timeout exceed".to_string()), None, AppErrorType::TimeoutError).into()),
285 },
286 None => (handler)(message).await
287 }
288 }
289 .await;
290 match result {
291 Ok(result) => {
292 if !auto_ack {
293 let args = BasicAckArguments::new(deliver.delivery_tag(), false);
294 if let Err(e) = channel.basic_ack(args).await {
295 error!("Failed to send ack: {}", e);
296 }
297 }
298 if let Some(reply_to) = basic_properties.reply_to() {
299 let mut content = result.body;
300 let mut props = BasicProperties::default();
301 if let Some(correlation_id) = basic_properties.correlation_id() {
302 props.with_correlation_id(correlation_id);
303 }
304 if let Some(content_type) = basic_properties.content_type() {
305 if let Some(encoding) = ContentEncoding::from_str(content_type) {
306 if let Ok(compressed_body) = compress(content.as_ref(), encoding) {
307 props.with_content_type(content_type);
308 content = compressed_body.into();
309 }
310 }
311 }
312 props.with_message_type("normal");
313 let args = BasicPublishArguments::new("", reply_to.as_str());
314 if let Err(e) = aux_channel
315 .basic_publish(props, content.to_vec(), args)
316 .await
317 {
318 error!("Failed to publish response: {}", e);
319 }
320 } else {
321 error!("No reply to");
322 }
323 }
324 Err(err) => {
325 if !auto_ack {
326 let args = BasicNackArguments::new(deliver.delivery_tag(), false, false);
327 if let Err(err) = channel.basic_nack(args).await {
328 error!("Failed to send nack: {}", err);
329 }
330 }
331 if let Some(reply_to) = basic_properties.reply_to() {
332 let mut props = BasicProperties::default();
333 if let Some(correlation_id) = basic_properties.correlation_id() {
334 props.with_correlation_id(correlation_id);
335 }
336 if let Some(content_type) = basic_properties.content_type() {
337 props.with_content_type(content_type);
338 }
339 props.with_message_type("error");
340 let args = BasicPublishArguments::new("", reply_to.as_str());
341 if let Err(e) = aux_channel
342 .basic_publish(props, err.to_string().as_bytes().to_vec(), args)
343 .await
344 {
345 error!("Failed to publish response: {}", e);
346 }
347 }
348 }
349 }
350 },
351 Err(e) => {
352 error!("Failed to decompress content: {}", e);
353 if !auto_ack {
354 let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
355 if let Err(err) = channel.basic_nack(args).await {
356 error!("Failed to send nack: {}", err);
357 }
358 }
359 }
360 }
361 let previous_count = in_flight.fetch_sub(1, Ordering::AcqRel);
362 if previous_count == 1 {
363 shutdown_notify.notify_one();
364 }
365 });
366 } else {
367 error!("No handler found for routing key {}", routing_key);
368 if !self.auto_ack {
369 let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
370 if let Err(err) = channel.basic_nack(args).await {
371 error!("Failed to send nack: {}", err);
372 }
373 }
374 let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
375 if previous_count == 1 {
376 self.shutdown_notify.notify_one();
377 }
378 }
379 }
380}