Skip to main content

amqp_client_rust/api/
channel.rs

1use crate::{
2    api::{consumers::{BroadRPCClientHandler, BroadRPCHandler, BroadSubscribeHandler, InternalRPCHandler, InternalSubscribeHandler}, utils::{ContentEncoding, DeliveryMode, Handler, Message, PendingCmd, RPCHandler, TopicTrie}},
3    errors::{AppError, AppErrorType},
4};
5use amqprs::{
6    BasicProperties, channel::{
7        BasicCancelArguments, BasicConsumeArguments, BasicPublishArguments, BasicQosArguments, Channel, ConfirmSelectArguments, ExchangeDeclareArguments, QueueBindArguments, QueueDeclareArguments
8    }, connection::Connection
9};
10use arc_swap::ArcSwap;
11use dashmap::DashMap;
12use tracing::error;
13use std::{collections::HashMap, sync::atomic::{AtomicBool, AtomicUsize, Ordering}};
14use std::error::Error as StdError;
15use std::future::Future;
16use std::sync::Arc;
17use tokio::{sync::{Mutex, Notify, OnceCell, RwLock, mpsc::UnboundedSender, oneshot}, time::Duration};
18use uuid::Uuid;
19use crate::api::utils::Confirmations;
20
21
22#[derive(Clone)]
23pub struct AsyncChannel {
24    pub channel: Channel,
25    connection: Arc<Mutex<Connection>>,
26    aux_channel: Arc<OnceCell<Channel>>,
27    aux_queue_name: String,
28    pub rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>,
29    pub rpc_consumer_started: Arc<AtomicBool>,
30    consumers: Arc<DashMap<String, bool>>,
31    subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>>>>,
32    rpc_subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>>>>,
33    //declared_exchanges: Arc<ArcSwap<HashMap<String, ExchangeType>>>,
34    publisher_confirms: Confirmations,
35    auto_ack: bool,
36    pre_fetch_count: Option<u16>,
37    consumer_tags: Arc<RwLock<Vec<String>>>,
38    in_flight: Arc<AtomicUsize>,
39    pub shutdown_notify: Arc<Notify>,
40}
41
42impl AsyncChannel {
43    pub fn new(channel: Channel, connection: Arc<Mutex<Connection>>, rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>) -> Self {
44        Self {
45            channel,
46            connection,
47            aux_channel: Arc::new(OnceCell::new()),
48            aux_queue_name: format!("amqp.{}", Uuid::new_v4()),
49            rpc_futures,
50            rpc_consumer_started: Arc::new(AtomicBool::new(false)),
51            consumers: Arc::new(DashMap::new()),
52            subscribes: Arc::new(RwLock::new(HashMap::new())),
53            rpc_subscribes: Arc::new(RwLock::new(HashMap::new())),
54            //declared_exchanges: Arc::new(ArcSwap::from_pointee(HashMap::new())),
55            publisher_confirms,
56            auto_ack,
57            pre_fetch_count,
58            consumer_tags:  Arc::new(RwLock::new(Vec::new())),
59            in_flight: Arc::new(AtomicUsize::new(0)),
60            shutdown_notify: Arc::new(Notify::new()),
61        }
62    }
63
64    fn generate_consumer_tag(&self) -> String {
65        format!("ctag{}", Uuid::new_v4())
66    }
67
68    pub async fn add_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalSubscribeHandler) {
69        let queue_handlers = {
70            let mut handlers = self.subscribes.write().await;
71            handlers
72                .entry(queue_name.to_owned())
73                    .or_insert_with(|| Arc::new(ArcSwap::from_pointee(TopicTrie::new())))
74                    .clone()
75        };
76        queue_handlers.rcu(|current_map| {
77            let mut new_map = (**current_map).clone();
78            new_map.insert(routing_key, handler.clone());
79            Arc::new(new_map)
80        });
81                
82    }
83
84    pub async fn add_rpc_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalRPCHandler) {
85        let queue_handlers = {
86            let mut rpc_handlers = self.rpc_subscribes.write().await;
87            rpc_handlers
88                .entry(queue_name.to_owned())
89                .or_insert_with(|| Arc::new(ArcSwap::from_pointee(HashMap::new())))
90                .clone()
91        };
92
93        queue_handlers.rcu(|current_map| {
94            let mut new_map = (**current_map).clone();
95            new_map.insert(routing_key.to_owned(), handler.clone());
96            Arc::new(new_map)
97        });
98    }
99
100    pub async fn setup_exchange(&self, exchange_name: &str, exchange_type: &str, durable: bool) -> Result<(), AppError> {
101        let arguments = ExchangeDeclareArguments{
102            exchange: exchange_name.to_string(),
103            exchange_type: exchange_type.to_string(),
104            durable,
105            ..Default::default()
106        };
107        Ok(self.channel.exchange_declare(arguments).await?)
108    }
109
110    pub async fn publish(
111        &self,
112        exchange_name: &str,
113        routing_key: &str,
114        body: impl Into<Vec<u8>>,
115        content_type: &str,
116        content_encoding: ContentEncoding,
117        delivery_mode: DeliveryMode,
118        expiration: Option<u32>,
119    ) -> Result<(), AppError>{
120        let args = BasicPublishArguments{
121            exchange: exchange_name.to_owned(),
122            routing_key: routing_key.to_owned(),
123            mandatory: true,
124            immediate: false
125        };
126        let mut properties = BasicProperties::default();
127        properties.with_content_type(content_type);
128        if content_encoding != ContentEncoding::None {
129            properties.with_content_encoding(content_encoding.as_str());
130        }
131        if let Some(exp) = expiration {
132            properties.with_expiration(&format!("{}", exp));
133        }
134        properties.with_delivery_mode(delivery_mode as u8);
135        Ok(self.channel.basic_publish(properties, body.into(), args).await?)
136    }
137}
138impl AsyncChannel {
139    pub async fn subscribe(
140        &self,
141        handler: Handler,
142        routing_key: &str,
143        exchange_name: &str,
144        exchange_type: &str,
145        queue_name: &str,
146        process_timeout: Option<Duration>,
147    ) -> Result<(), AppError>
148    {
149        self.setup_exchange(exchange_name, exchange_type, true)
150            .await?;
151        /*self.declared_exchanges.rcu(|current_map| {
152            let mut new_map = (**current_map).clone();
153            new_map.insert(exchange_name.to_owned(), match exchange_type {
154                "direct" => ExchangeType::Direct,
155                "fanout" => ExchangeType::Fanout,
156                "topic" => ExchangeType::Topic,
157                _ => return Arc::new(new_map),
158            });
159            Arc::new(new_map)
160        });*/
161        let (queue_name, _, _) = self
162            .channel
163            .queue_declare(QueueDeclareArguments::durable_client_named(queue_name))
164            .await?
165            .ok_or_else(|| AppError::new(Some("Queue declare returned None".to_string()), None, AppErrorType::InternalError))?;
166        self.channel
167            .queue_bind(QueueBindArguments::new(
168                &queue_name,
169                exchange_name,
170                routing_key,
171            ))
172            .await?;
173        
174        self.add_subscribe(&queue_name, routing_key, InternalSubscribeHandler::new(
175            handler,
176            process_timeout,
177        )).await;
178
179        if !self.consumers.contains_key(&queue_name) {
180            let queue_handler = self.subscribes.read().await;
181            let handler = queue_handler.get(&queue_name).unwrap();
182            if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
183                let args = BasicQosArguments::new(0, pre_fetch_count, false);
184                let _ = self.channel.basic_qos(args).await;
185            }
186            self.consumers.insert(queue_name.to_string(), true);
187            let mut args = BasicConsumeArguments::new(&queue_name, &self.generate_consumer_tag());
188            args.manual_ack(!self.auto_ack);
189            let sub_handler = BroadSubscribeHandler::new(Arc::clone(handler), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
190            let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
191            self.consumer_tags.write().await.push(consumer_tag);
192        }
193        Ok(())
194    }
195}
196impl AsyncChannel{
197    pub async fn rpc_server(
198        &self,
199        handler: RPCHandler,
200        routing_key: &str,
201        exchange_name: &str,
202        exchange_type: &str,
203        queue_name: &str,
204        response_timeout: Option<Duration>,
205    ) -> Result<(), AppError>
206    {
207        self.aux_channel.get_or_try_init(|| async {
208            let ch = self.connection.lock().await.open_channel(None).await?;
209            
210            if self.publisher_confirms == Confirmations::RPCServerPublisherConfirms {
211                let args = ConfirmSelectArguments::default();
212                let _ = ch.confirm_select(args).await;
213            }
214            Ok::<Channel, AppError>(ch)
215        }).await?;
216        self.add_rpc_subscribe(queue_name, routing_key, InternalRPCHandler::new(
217            handler,
218            response_timeout,
219        )).await;
220
221        self.setup_exchange(exchange_name, exchange_type, true)
222            .await?;
223        /*self.declared_exchanges.rcu(|current_map| {
224            let mut new_map = (**current_map).clone();
225            new_map.insert(exchange_name.to_owned(), match exchange_type {
226                "direct" => ExchangeType::Direct,
227                "fanout" => ExchangeType::Fanout,
228                "topic" => ExchangeType::Topic,
229                _ => return Arc::new(new_map), // Invalid exchange type, skip updating
230            });
231            Arc::new(new_map)
232        });*/
233        if let Some((queue_name,_,_)) = self.channel.queue_declare(QueueDeclareArguments::durable_client_named(queue_name)).await? {
234            self.channel
235                .queue_bind(QueueBindArguments::new(
236                    &queue_name,
237                    exchange_name,
238                    routing_key,
239                ))
240                .await?;
241            if !self.consumers.contains_key(&queue_name) {
242                let queue_handler = self.rpc_subscribes.read().await;
243                let handler = queue_handler.get(&queue_name).unwrap();
244                let mut args = BasicConsumeArguments::new(&queue_name, &self.generate_consumer_tag());
245                args.manual_ack(!self.auto_ack);
246                self.consumers.insert(queue_name.to_string(), true);
247                let sub_handler = BroadRPCHandler::new(
248                    Arc::clone(&self.aux_channel),
249                    Arc::clone(handler),
250                    self.auto_ack,
251                    self.in_flight.clone(),
252                    self.shutdown_notify.clone(),
253                );
254                drop(queue_handler);
255                if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
256                    let args = BasicQosArguments::new(0, pre_fetch_count, false);
257                    let _ = self.channel.basic_qos(args).await;
258                }
259                let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
260                self.consumer_tags.write().await.push(consumer_tag);
261            }
262        }
263        Ok(())
264    }
265    
266    pub async fn start_rpc_consumer(&self) -> Result<(), AppError> {
267        if !self.rpc_consumer_started.load(std::sync::atomic::Ordering::SeqCst) {
268            {
269                self.aux_channel.get_or_try_init(|| async {
270                let ch = self.connection.lock().await.open_channel(None).await?;
271                if self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
272                    let args = ConfirmSelectArguments::default();
273                    let _ = ch.confirm_select(args).await;
274                }
275                if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
276                    let args = BasicQosArguments::new(0, pre_fetch_count, false);
277                    let _ = ch.basic_qos(args).await;
278                }
279                Ok::<Channel, AppError>(ch)
280                }).await?;
281            }
282            if let Some(channel) = self.aux_channel.get() {
283                let mut queue_declare = QueueDeclareArguments::new(&self.aux_queue_name);
284                queue_declare.auto_delete(true);
285                let (_, _, _) = channel.queue_declare(queue_declare)
286                    .await?
287                    .ok_or_else(|| AppError::new(Some("Queue declare returned None".to_string()), None, AppErrorType::InternalError))?;
288                let rpc_handler = BroadRPCClientHandler::new(Arc::clone(&self.rpc_futures), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
289                let mut args = BasicConsumeArguments::new(&self.aux_queue_name, &self.generate_consumer_tag());
290                args.manual_ack(!self.auto_ack);
291                let consumer_tag = channel.basic_consume(rpc_handler, args).await?;
292                self.consumer_tags.write().await.push(consumer_tag);
293                self.rpc_consumer_started.store(true, std::sync::atomic::Ordering::SeqCst);
294            }
295        }
296        Ok(())
297    }
298
299    pub async fn rpc_client(
300        &self,
301        exchange_name: &str,
302        routing_key: &str,
303        body: impl Into<Vec<u8>>,
304        content_type: &str,
305        content_encoding: ContentEncoding,
306        timeout_millis: u32,
307        delivery_mode: DeliveryMode,
308        expiration: Option<u32>,
309        response: oneshot::Sender<Result<Vec<u8>, AppError>>,
310        clean_message: UnboundedSender<PendingCmd>,
311        message_id: Option<u64>,
312    ) -> Result<(), AppError> 
313    {
314        self.start_rpc_consumer().await?;
315        let (tx, rx) = oneshot::channel();
316        
317        let correlated_id = Uuid::new_v4().to_string();
318        self.rpc_futures.insert(correlated_id.to_owned(), tx);
319        let mut args = BasicPublishArguments::new(exchange_name, routing_key);
320        args.mandatory(true);
321        let mut properties = BasicProperties::default();
322        properties.with_content_type(content_type);
323        if content_encoding != ContentEncoding::None {
324            properties.with_content_encoding(content_encoding.as_str());
325        }
326        properties.with_correlation_id(&correlated_id);
327        properties.with_reply_to(&self.aux_queue_name);
328        properties.with_delivery_mode(delivery_mode as u8);
329        let cn = self.channel.clone();
330        if let Some(exp) = expiration {
331            properties.with_expiration(&format!("{}", exp));
332        }
333        let body = body.into();
334        tokio::spawn(async move {
335            let _ = cn.basic_publish(properties, body, args).await;
336            let message = match tokio::time::timeout(std::time::Duration::from_millis(timeout_millis as u64), rx).await {
337                Ok(Ok(result)) => Ok(result),
338                Ok(Err(_)) => Err(AppError::new(Some("Receiver was dropped".to_string()), None, AppErrorType::InternalError)),
339                Err(_) => Err(AppError::new(Some("Timeout exceeded".to_string()), None, AppErrorType::TimeoutError)),
340            };
341            if let Err(_) = response.send(message) && let Some(id) = message_id {
342                let _ = clean_message.send(PendingCmd::Nack((id, false)));
343            }
344        });
345        Ok(())
346    }
347    pub async fn dispose(&self) {
348        let cn = self.channel.clone();
349        for tag in self.consumer_tags.read().await.iter() {
350            let args = BasicCancelArguments::new(tag);
351            if let Err(e) = cn.basic_cancel(args).await {
352                error!("Failed to cancel consumer {}: {}", tag, e);
353            }
354        }
355        while self.in_flight.load(Ordering::Acquire) > 0 {
356            self.shutdown_notify.notified().await;
357        }
358        if let Err(e) = self.channel.clone().close().await {
359            error!("Failed to close main channel: {}", e);
360        }
361        if let Some(channel) = self.aux_channel.get() {
362            if let Err(e) = channel.clone().close().await {
363                error!("Failed to close aux channel: {}", e);
364            }
365        }
366    }
367}