Skip to main content

amqp_client_rust/api/
consumers.rs

1use amqprs::{
2    channel::{BasicAckArguments, BasicNackArguments, BasicPublishArguments, Channel},
3    consumer::AsyncConsumer,
4    BasicProperties, Deliver,
5};
6use arc_swap::ArcSwap;
7use async_trait::async_trait;
8use tracing::error;
9use std::{collections::HashMap, sync::atomic::{AtomicUsize, Ordering}};
10use std::error::Error as StdError;
11use std::future::Future;
12use std::sync::Arc;
13use tokio::{sync::{Notify, OnceCell, oneshot::Sender}, time::{Duration, timeout}};
14use dashmap::DashMap;
15
16use crate::{api::utils::{ContentEncoding, Handler, Message, RPCHandler, TopicTrie, compress, decompress}, errors::{AppError, AppErrorType}};
17
18#[derive(Clone)]
19pub struct InternalSubscribeHandler {
20    handler: Handler,
21    process_timeout: Option<Duration>,
22}
23impl InternalSubscribeHandler {
24    pub fn new<F, Fut>(handler: Arc<F>, process_timeout: Option<Duration>) -> Self
25    where
26        F: Fn(Message) -> Fut + Send + Sync + 'static + ?Sized,
27        Fut: Future<Output = Result<(), Box<dyn StdError + Send + Sync>>> + Send + 'static,
28    {
29        Self {
30            handler: Arc::new(move |body| Box::pin(handler(body))),
31            process_timeout,
32        }
33    }
34}
35
36#[derive(Clone)]
37pub struct InternalRPCHandler {
38    handler: RPCHandler,
39    process_timeout: Option<Duration>,
40}
41impl InternalRPCHandler {
42    // Added ?Sized to F
43    pub fn new(handler: RPCHandler, process_timeout: Option<Duration>) -> Self
44    {
45        Self {
46            handler: Arc::new(move |body| Box::pin(handler(body))),
47            process_timeout,
48        }
49    }
50}
51
52
53
54pub struct BroadSubscribeHandler {
55    handlers: Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>,
56    auto_ack: bool,
57    in_flight: Arc<AtomicUsize>,
58    shutdown_notify: Arc<Notify>,
59    // response_timeout: i16
60}
61
62pub struct BroadRPCHandler {
63    channel: Arc<Channel>,
64    handlers: Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>,
65    auto_ack: bool,
66    in_flight: Arc<AtomicUsize>,
67    shutdown_notify: Arc<Notify>,
68    // response_timeout: i16
69}
70pub struct BroadRPCClientHandler {
71    handlers: Arc<DashMap<String, Sender<Vec<u8>>>>,
72    auto_ack: bool,
73    in_flight: Arc<AtomicUsize>,
74    shutdown_notify: Arc<Notify>,
75    // response_timeout: i16
76}
77
78impl BroadSubscribeHandler {
79    pub fn new(
80        handlers: Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>,
81        auto_ack: bool,
82        in_flight: Arc<AtomicUsize>,
83        shutdown_notify: Arc<Notify>,
84    ) -> Self {
85        Self {
86            handlers,
87            auto_ack,
88            in_flight,
89            shutdown_notify,
90        }
91    }
92}
93impl BroadRPCHandler {
94    pub fn new(
95        channel: Arc<Channel>,
96        handlers: Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>,
97        auto_ack: bool,
98        in_flight: Arc<AtomicUsize>,
99        shutdown_notify: Arc<Notify>,
100    ) -> Self {
101        Self {
102            channel,
103            handlers,
104            auto_ack,
105            in_flight,
106            shutdown_notify,
107        }
108    }
109}
110
111impl BroadRPCClientHandler {
112    pub fn new(handlers: Arc<DashMap<String, Sender<Vec<u8>>>>, auto_ack: bool, in_flight: Arc<AtomicUsize>, shutdown_notify: Arc<Notify>) -> Self {
113        Self { handlers, auto_ack, in_flight, shutdown_notify }
114    }
115}
116
117#[async_trait]
118impl AsyncConsumer for BroadRPCClientHandler {
119    async fn consume(
120        &mut self,
121        channel: &Channel,
122        deliver: Deliver,
123        basic_properties: BasicProperties,
124        content: Vec<u8>,
125    ) {
126        self.in_flight.fetch_add(1, Ordering::AcqRel);
127        if let Some(correlated_id) = basic_properties.correlation_id() {
128            if let Some(sender) = self.handlers.remove(correlated_id) {
129                if let Err(err) = sender.1.send(content) {
130                    error!("The receiver dropped {:?}", err);
131                }
132            }
133            if !self.auto_ack {
134                let delivery_tag = deliver.delivery_tag();
135                let args = BasicAckArguments::new(delivery_tag, false);
136                if let Err(e) = channel.basic_ack(args).await {
137                    error!("Failed to send ack: {}", e);
138                }
139            }
140        } else if !self.auto_ack {
141            let delivery_tag = deliver.delivery_tag();
142            let args = BasicNackArguments::new(delivery_tag, false, false);
143            let _ = channel.basic_nack(args).await;
144        }
145        let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
146        if previous_count == 1 {
147            self.shutdown_notify.notify_one();
148        }
149    }
150}
151
152#[async_trait]
153impl AsyncConsumer for BroadSubscribeHandler {
154    async fn consume(
155        &mut self,
156        channel: &Channel,
157        deliver: Deliver,
158        basic_properties: BasicProperties,
159        content: Vec<u8>,
160    ) {
161        self.in_flight.fetch_add(1, Ordering::AcqRel);
162
163        let routing_key = deliver.routing_key().to_string(); // Own the string
164        let handlers_guard = self.handlers.load().clone();
165        let handlers = handlers_guard.search(&routing_key);
166
167        if handlers.is_empty() {
168            error!("No handler found for routing key {}", routing_key);
169            if !self.auto_ack {
170                let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
171                let _ = channel.basic_nack(args).await;
172            }
173            let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
174            if previous_count == 1 {
175                self.shutdown_notify.notify_one();
176            }
177            return;
178        }
179
180        let channel = channel.clone();
181        let auto_ack = self.auto_ack;
182        let in_flight = Arc::clone(&self.in_flight);
183        let shutdown_notify = Arc::clone(&self.shutdown_notify);
184
185        tokio::spawn(async move {
186            let success = async {
187
188                if handlers.is_empty() {
189                    error!("No handler found for routing key {}", routing_key);
190                    return false;
191                }
192
193                let decompressed_content = match decompress(content, basic_properties.content_encoding().map(|e| e.as_str())) {
194                    Ok(c) => c,
195                    Err(e) => {
196                        error!("Failed to decompress content: {}", e);
197                        return false;
198                    }
199                };
200
201                let futures = handlers.iter().map(|i| {
202                    let content_clone = &decompressed_content; 
203                    let message = Message {
204                        body: Arc::from(&content_clone[..]),
205                        content_type: basic_properties.content_type().map(|s| s.to_string()),
206                    };
207                    
208                    async move {
209                        let res = match i.process_timeout {
210                            Some(dur) => match timeout(dur, (i.handler)(message)).await {
211                                Ok(res) => res,
212                                Err(_) => Err(AppError::new(Some("Response timeout exceed".to_string()), None, AppErrorType::TimeoutError).into()),
213                            },
214                            None => (i.handler)(message).await
215                        };
216
217                        if let Err(ref e) = res {
218                            error!("Handler execution error: {}", e);
219                        }
220                        res
221                    }
222                });
223
224                let results = futures::future::join_all(futures).await;
225
226                results.into_iter().all(|res| res.is_ok())
227            }.await;
228
229            if !auto_ack {
230                if success {
231                    let args = BasicAckArguments::new(deliver.delivery_tag(), false);
232                    if let Err(e) = channel.basic_ack(args).await {
233                        error!("Failed to send ack: {}", e);
234                    }
235                } else {
236                    let args = BasicNackArguments::new(deliver.delivery_tag(), false, false);
237                    if let Err(err) = channel.basic_nack(args).await {
238                        error!("Failed to send nack: {}", err);
239                    }
240                }
241            }
242
243            let previous_count = in_flight.fetch_sub(1, Ordering::AcqRel);
244            if previous_count == 1 {
245                shutdown_notify.notify_one();
246            }
247        });
248    }
249}
250
251#[async_trait]
252impl AsyncConsumer for BroadRPCHandler {
253    async fn consume(
254        &mut self,
255        channel: &Channel,
256        deliver: Deliver,
257        basic_properties: BasicProperties,
258        content: Vec<u8>,
259    ) {
260        self.in_flight.fetch_add(1, Ordering::AcqRel);
261
262        let routing_key = deliver.routing_key().as_str();
263
264        let handlers_guard = self.handlers.load();
265        if let Some(internal_handler) = handlers_guard.get(routing_key) {
266            let (handler, process_timeout) = (Arc::clone(&internal_handler.handler), internal_handler.process_timeout);
267            drop(handlers_guard);
268            let channel = channel.clone();
269            let aux_channel = Arc::clone(&self.channel);
270            let auto_ack = self.auto_ack;
271            let in_flight = Arc::clone(&self.in_flight);
272            let shutdown_notify = Arc::clone(&self.shutdown_notify);
273            tokio::spawn(async move {
274                match decompress(content, basic_properties.content_encoding().map(|e| e.as_str())) {
275                    Ok(decompressed_content) => {
276                        let message = Message {
277                            body: Arc::from(&decompressed_content[..]),
278                            content_type: basic_properties.content_type().map(|s| s.to_string()),
279                        };
280                        let result = async move {
281                            match process_timeout {
282                                Some(dur) => match timeout(dur, (handler)(message)).await {
283                                    Ok(res) => res,
284                                    Err(_) => Err(AppError::new(Some("Response timeout exceed".to_string()), None, AppErrorType::TimeoutError).into()),
285                                },
286                                None => (handler)(message).await
287                            }
288                        }
289                        .await;
290                        match result {
291                            Ok(result) => {
292                                if !auto_ack {
293                                    let args = BasicAckArguments::new(deliver.delivery_tag(), false);
294                                    if let Err(e) = channel.basic_ack(args).await {
295                                        error!("Failed to send ack: {}", e);
296                                    }
297                                }
298                                if let Some(reply_to) = basic_properties.reply_to() {
299                                    let mut content = result.body;
300                                    let mut props = BasicProperties::default();
301                                    if let Some(correlation_id) = basic_properties.correlation_id() {
302                                        props.with_correlation_id(correlation_id);
303                                    }
304                                    if let Some(content_type) = basic_properties.content_type() {
305                                        if let Some(encoding) = ContentEncoding::from_str(content_type) {
306                                            if let Ok(compressed_body) = compress(content.as_ref(), encoding) {
307                                                props.with_content_type(content_type);
308                                                content = compressed_body.into();
309                                            } 
310                                        }
311                                    }
312                                    props.with_message_type("normal");
313                                    let args = BasicPublishArguments::new("", reply_to.as_str());
314                                    if let Err(e) = aux_channel
315                                        .basic_publish(props, content.to_vec(), args)
316                                        .await
317                                    {
318                                        error!("Failed to publish response: {}", e);
319                                    }
320                                } else {
321                                    error!("No reply to");
322                                }
323                            }
324                            Err(err) => {
325                                if !auto_ack {
326                                    let args = BasicNackArguments::new(deliver.delivery_tag(), false, false);
327                                    if let Err(err) = channel.basic_nack(args).await {
328                                        error!("Failed to send nack: {}", err);
329                                    }
330                                }
331                                if let Some(reply_to) = basic_properties.reply_to() {
332                                    let mut props = BasicProperties::default();
333                                    if let Some(correlation_id) = basic_properties.correlation_id() {
334                                        props.with_correlation_id(correlation_id);
335                                    }
336                                    if let Some(content_type) = basic_properties.content_type() {
337                                        props.with_content_type(content_type);
338                                    }
339                                    props.with_message_type("error");
340                                    let args = BasicPublishArguments::new("", reply_to.as_str());
341                                    if let Err(e) = aux_channel
342                                        .basic_publish(props, err.to_string().as_bytes().to_vec(), args)
343                                        .await
344                                    {
345                                        error!("Failed to publish response: {}", e);
346                                    }
347                                }
348                            }
349                        }
350                    },
351                    Err(e) => {
352                        error!("Failed to decompress content: {}", e);
353                        if !auto_ack {
354                            let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
355                            if let Err(err) = channel.basic_nack(args).await {
356                                error!("Failed to send nack: {}", err);
357                            }
358                        }
359                    }
360                }
361                let previous_count = in_flight.fetch_sub(1, Ordering::AcqRel);
362                if previous_count == 1 {
363                    shutdown_notify.notify_one();
364                }
365            });
366        } else {
367            error!("No handler found for routing key {}", routing_key);
368            if !self.auto_ack {
369                let args = BasicNackArguments::new(deliver.delivery_tag(), false, true);
370                if let Err(err) = channel.basic_nack(args).await {
371                    error!("Failed to send nack: {}", err);
372                }
373            }
374            let previous_count = self.in_flight.fetch_sub(1, Ordering::AcqRel);
375            if previous_count == 1 {
376                self.shutdown_notify.notify_one();
377            }
378        }
379    }
380}