1use amqprs::{
2 channel::{BasicAckArguments, BasicNackArguments, BasicPublishArguments, Channel},
3 consumer::AsyncConsumer,
4 BasicProperties, Deliver,
5};
6use arc_swap::ArcSwap;
7use async_trait::async_trait;
8use tracing::error;
9use std::{collections::HashMap, sync::atomic::{AtomicUsize, Ordering}};
10use std::error::Error as StdError;
11use std::future::Future;
12use std::sync::Arc;
13use tokio::{sync::{Notify, OnceCell, oneshot::Sender}, time::{Duration, timeout}};
14use dashmap::DashMap;
15
16use crate::{api::utils::{ContentEncoding, Handler, Message, RPCHandler, TopicTrie, compress, decompress}, errors::{AppError, AppErrorType}};
17
18#[derive(Clone)]
19pub struct InternalSubscribeHandler {
20 handler: Handler,
21 process_timeout: Option<Duration>,
22}
23impl InternalSubscribeHandler {
24 pub fn new<F, Fut>(handler: Arc<F>, process_timeout: Option<Duration>) -> Self
25 where
26 F: Fn(Message) -> Fut + Send + Sync + 'static + ?Sized,
27 Fut: Future<Output = Result<(), Box<dyn StdError + Send + Sync>>> + Send + 'static,
28 {
29 Self {
30 handler: Arc::new(move |body| Box::pin(handler(body))),
31 process_timeout,
32 }
33 }
34}
35
36#[derive(Clone)]
37pub struct InternalRPCHandler {
38 handler: RPCHandler,
39 process_timeout: Option<Duration>,
40}
41impl InternalRPCHandler {
42 pub fn new(handler: RPCHandler, process_timeout: Option<Duration>) -> Self
44 {
45 Self {
46 handler: Arc::new(move |body| Box::pin(handler(body))),
47 process_timeout,
48 }
49 }
50}
51
52
53
54pub struct BroadSubscribeHandler {
55 handlers: Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>,
56 auto_ack: bool,
57 in_flight: Arc<AtomicUsize>,
58 shutdown_notify: Arc<Notify>,
59 }
61
62pub struct BroadRPCHandler {
63 channel: Arc<OnceCell<Channel>>,
64 handlers: Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>,
65 auto_ack: bool,
66 in_flight: Arc<AtomicUsize>,
67 shutdown_notify: Arc<Notify>,
68 }
70pub struct BroadRPCClientHandler {
71 handlers: Arc<DashMap<String, Sender<Vec<u8>>>>,
72 auto_ack: bool,
73 in_flight: Arc<AtomicUsize>,
74 shutdown_notify: Arc<Notify>,
75 }
77
78impl BroadSubscribeHandler {
79 pub fn new(
80 handlers: Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>,
81 auto_ack: bool,
82 in_flight: Arc<AtomicUsize>,
83 shutdown_notify: Arc<Notify>,
84 ) -> Self {
85 Self {
86 handlers,
87 auto_ack,
88 in_flight,
89 shutdown_notify,
90 }
91 }
92}
93impl BroadRPCHandler {
94 pub fn new(
95 channel: Arc<OnceCell<Channel>>,
96 handlers: Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>,
97 auto_ack: bool,
98 in_flight: Arc<AtomicUsize>,
99 shutdown_notify: Arc<Notify>,
100 ) -> Self {
101 Self {
102 channel,
103 handlers,
104 auto_ack,
105 in_flight,
106 shutdown_notify,
107 }
108 }
109}
110
111impl BroadRPCClientHandler {
112 pub fn new(handlers: Arc<DashMap<String, Sender<Vec<u8>>>>, auto_ack: bool, in_flight: Arc<AtomicUsize>, shutdown_notify: Arc<Notify>) -> Self {
113 Self { handlers, auto_ack, in_flight, shutdown_notify }
114 }
115}
116
117#[async_trait]
118impl AsyncConsumer for BroadRPCClientHandler {
119 async fn consume(
120 &mut self,
121 channel: &Channel,
122 deliver: Deliver,
123 basic_properties: BasicProperties,
124 content: Vec<u8>,
125 ) {
126 self.in_flight.fetch_add(1, Ordering::AcqRel);
127 if let Some(correlated_id) = basic_properties.correlation_id() {
128 if let Some(sender) = self.handlers.remove(correlated_id) {
129 if let Err(err) = sender.1.send(content) {
130 error!("The receiver dropped {:?}", err);
131 }
132 }
133 if !self.auto_ack {
134 let delivery_tag = deliver.delivery_tag();
135 let args = BasicAckArguments::new(delivery_tag, false);
136 if let Err(e) = channel.basic_ack(args).await {
137 error!("Failed to send ack: {}", e);
138 }
139 }
140 } else if !self.auto_ack {
141 let delivery_tag = deliver.delivery_tag();
142 let args = BasicNackArguments::new(delivery_tag, false, false);
143 let _ = channel.basic_nack(args).await;
144 }
145 let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
146 if previous_count == 1 {
147 self.shutdown_notify.notify_one();
148 }
149 }
150}
151
152#[async_trait]
153impl AsyncConsumer for BroadSubscribeHandler {
154 async fn consume(
155 &mut self,
156 channel: &Channel,
157 deliver: Deliver,
158 basic_properties: BasicProperties,
159 content: Vec<u8>,
160 ) {
161 self.in_flight.fetch_add(1, Ordering::AcqRel);
162
163 let routing_key = deliver.routing_key().to_string(); let handlers_guard = self.handlers.load().clone();
165 let handlers = handlers_guard.search(&routing_key);
166
167 if handlers.is_empty() {
168 error!("No handler found for routing key {}", routing_key);
169 if !self.auto_ack {
170 let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
171 let _ = channel.basic_nack(args).await;
172 }
173 let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
174 if previous_count == 1 {
175 self.shutdown_notify.notify_one();
176 }
177 return;
178 }
179
180 let channel = channel.clone();
181 let auto_ack = self.auto_ack;
182 let in_flight = Arc::clone(&self.in_flight);
183 let shutdown_notify = Arc::clone(&self.shutdown_notify);
184
185 tokio::spawn(async move {
186 let success = async {
187
188 if handlers.is_empty() {
189 error!("No handler found for routing key {}", routing_key);
190 return false;
191 }
192
193 let decompressed_content = match decompress(content, basic_properties.content_encoding().map(|e| e.as_str())) {
194 Ok(c) => c,
195 Err(e) => {
196 error!("Failed to decompress content: {}", e);
197 return false;
198 }
199 };
200
201 let futures = handlers.iter().map(|i| {
202 let content_clone = &decompressed_content;
203 let message = Message {
204 body: Arc::from(&content_clone[..]),
205 content_type: basic_properties.content_type().map(|s| s.to_string()),
206 };
207
208 async move {
209 let res = match i.process_timeout {
210 Some(dur) => match timeout(dur, (i.handler)(message)).await {
211 Ok(res) => res,
212 Err(_) => Err(AppError::new(Some("Response timeout exceed".to_string()), None, AppErrorType::TimeoutError).into()),
213 },
214 None => (i.handler)(message).await
215 };
216
217 if let Err(ref e) = res {
218 error!("Handler execution error: {}", e);
219 }
220 res
221 }
222 });
223
224 let results = futures::future::join_all(futures).await;
225
226 results.into_iter().all(|res| res.is_ok())
227 }.await;
228
229 if !auto_ack {
230 if success {
231 let args = BasicAckArguments::new(deliver.delivery_tag(), false);
232 if let Err(e) = channel.basic_ack(args).await {
233 error!("Failed to send ack: {}", e);
234 }
235 } else {
236 let args = BasicNackArguments::new(deliver.delivery_tag(), false, false);
237 if let Err(err) = channel.basic_nack(args).await {
238 error!("Failed to send nack: {}", err);
239 }
240 }
241 }
242
243 let previous_count = in_flight.fetch_sub(1, Ordering::AcqRel);
244 if previous_count == 1 {
245 shutdown_notify.notify_one();
246 }
247 });
248 }
249}
250
251#[async_trait]
252impl AsyncConsumer for BroadRPCHandler {
253 async fn consume(
254 &mut self,
255 channel: &Channel,
256 deliver: Deliver,
257 basic_properties: BasicProperties,
258 content: Vec<u8>,
259 ) {
260 self.in_flight.fetch_add(1, Ordering::AcqRel);
261
262 let routing_key = deliver.routing_key().as_str();
263
264 let handlers_guard = self.handlers.load();
265 if let Some(internal_handler) = handlers_guard.get(routing_key) {
266 let (handler, process_timeout) = (Arc::clone(&internal_handler.handler), internal_handler.process_timeout);
267 drop(handlers_guard);
268 let channel = channel.clone();
269 let aux_channel = Arc::clone(&self.channel);
270 let auto_ack = self.auto_ack;
271 let in_flight = Arc::clone(&self.in_flight);
272 let shutdown_notify = Arc::clone(&self.shutdown_notify);
273 tokio::spawn(async move {
274 match decompress(content, basic_properties.content_encoding().map(|e| e.as_str())) {
275 Ok(decompressed_content) => {
276 let message = Message {
277 body: Arc::from(&decompressed_content[..]),
278 content_type: basic_properties.content_type().map(|s| s.to_string()),
279 };
280 let result = async move {
281 match process_timeout {
282 Some(dur) => match timeout(dur, (handler)(message)).await {
283 Ok(res) => res,
284 Err(_) => Err(AppError::new(Some("Response timeout exceed".to_string()), None, AppErrorType::TimeoutError).into()),
285 },
286 None => (handler)(message).await
287 }
288 }
289 .await;
290 match result {
291 Ok(result) => {
292 if !auto_ack {
293 let args = BasicAckArguments::new(deliver.delivery_tag(), false);
294 if let Err(e) = channel.basic_ack(args).await {
295 error!("Failed to send ack: {}", e);
296 }
297 }
298 if let Some(reply_to) = basic_properties.reply_to() {
299 if let Some(aux_channel) = aux_channel.get() {
300 let mut content = result.body;
301 let mut props = BasicProperties::default();
302 if let Some(correlation_id) = basic_properties.correlation_id() {
303 props.with_correlation_id(correlation_id);
304 }
305 if let Some(content_type) = basic_properties.content_type() {
306 if let Some(encoding) = ContentEncoding::from_str(content_type) {
307 if let Ok(compressed_body) = compress(content.as_ref(), encoding) {
308 props.with_content_type(content_type);
309 content = compressed_body.into();
310 }
311 }
312 }
313 props.with_message_type("normal");
314 let args = BasicPublishArguments::new("", reply_to.as_str());
315 if let Err(e) = aux_channel
316 .basic_publish(props, content.to_vec(), args)
317 .await
318 {
319 error!("Failed to publish response: {}", e);
320 }
321 }
322 } else {
323 error!("No reply to");
324 }
325 }
326 Err(err) => {
327 if !auto_ack {
328 let args = BasicNackArguments::new(deliver.delivery_tag(), false, false);
329 if let Err(err) = channel.basic_nack(args).await {
330 error!("Failed to send nack: {}", err);
331 }
332 }
333 if let Some(reply_to) = basic_properties.reply_to() {
334 let mut props = BasicProperties::default();
335 if let Some(correlation_id) = basic_properties.correlation_id() {
336 props.with_correlation_id(correlation_id);
337 }
338 if let Some(content_type) = basic_properties.content_type() {
339 props.with_content_type(content_type);
340 }
341 props.with_message_type("error");
342 if let Some(aux_channel) = aux_channel.get() {
343 let args = BasicPublishArguments::new("", reply_to.as_str());
344 if let Err(e) = aux_channel
345 .basic_publish(props, err.to_string().as_bytes().to_vec(), args)
346 .await
347 {
348 error!("Failed to publish response: {}", e);
349 }
350 }
351 }
352 }
353 }
354 },
355 Err(e) => {
356 error!("Failed to decompress content: {}", e);
357 if !auto_ack {
358 let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
359 if let Err(err) = channel.basic_nack(args).await {
360 error!("Failed to send nack: {}", err);
361 }
362 }
363 }
364 }
365 let previous_count = in_flight.fetch_sub(1, Ordering::AcqRel);
366 if previous_count == 1 {
367 shutdown_notify.notify_one();
368 }
369 });
370 } else {
371 error!("No handler found for routing key {}", routing_key);
372 if !self.auto_ack {
373 let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
374 if let Err(err) = channel.basic_nack(args).await {
375 error!("Failed to send nack: {}", err);
376 }
377 }
378 let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
379 if previous_count == 1 {
380 self.shutdown_notify.notify_one();
381 }
382 }
383 }
384}