Skip to main content

amqp_client_rust/api/
channel.rs

1use crate::{
2    api::{callback::MyChannelCallback, consumers::{BroadRPCClientHandler, BroadRPCHandler, BroadSubscribeHandler, InternalRPCHandler, InternalSubscribeHandler}, utils::{ChannelCmd, ContentEncoding, DeliveryMode, Handler, QueueOptions, RPCHandler, TopicTrie}},
3    errors::{AppError, AppErrorType},
4};
5use amqprs::{
6    BasicProperties, FieldTable, channel::{
7        BasicCancelArguments, BasicConsumeArguments, BasicPublishArguments, BasicQosArguments, Channel, ConfirmSelectArguments, ExchangeDeclareArguments, QueueBindArguments, QueueDeclareArguments
8    }, connection::Connection
9};
10use arc_swap::ArcSwap;
11use dashmap::DashMap;
12use tracing::error;
13use std::{collections::HashMap, sync::atomic::{AtomicBool, AtomicUsize, Ordering}};
14use std::sync::Arc;
15use tokio::{sync::{Mutex, Notify, RwLock, mpsc::{self, UnboundedSender}, oneshot}, time::Duration};
16use uuid::Uuid;
17use crate::api::utils::Confirmations;
18
19
20
21#[derive(Clone)]
22pub struct AsyncChannel {
23    pub channel: Channel,
24    pub connection: Arc<Mutex<Connection>>,
25    pub aux_channel: Option<Channel>,
26    pub aux_queue_name: String,
27    pub rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>,
28    pub rpc_consumer_started: Arc<AtomicBool>,
29    consumers: Arc<DashMap<String, bool>>,
30    channel_tx: mpsc::UnboundedSender<ChannelCmd>,
31    subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>>>>,
32    rpc_subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>>>>,
33    //declared_exchanges: Arc<ArcSwap<HashMap<String, ExchangeType>>>,
34    publisher_confirms: Confirmations,
35    auto_ack: bool,
36    pre_fetch_count: Option<u16>,
37    consumer_tags: Arc<RwLock<Vec<String>>>,
38    in_flight: Arc<AtomicUsize>,
39    pub shutdown_notify: Arc<Notify>,
40}
41
42impl AsyncChannel {
43    pub fn new(channel: Channel, connection: Arc<Mutex<Connection>>, channel_tx: mpsc::UnboundedSender<ChannelCmd>, rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>, aux_queue_name: Option<String>) -> Self {
44        Self {
45            channel,
46            connection,
47            aux_channel: None,
48            aux_queue_name: aux_queue_name.unwrap_or_else(|| format!("amqp.{}", Uuid::new_v4())),
49            channel_tx,
50            rpc_futures,
51            rpc_consumer_started: Arc::new(AtomicBool::new(false)),
52            consumers: Arc::new(DashMap::new()),
53            subscribes: Arc::new(RwLock::new(HashMap::new())),
54            rpc_subscribes: Arc::new(RwLock::new(HashMap::new())),
55            //declared_exchanges: Arc::new(ArcSwap::from_pointee(HashMap::new())),
56            publisher_confirms,
57            auto_ack,
58            pre_fetch_count,
59            consumer_tags:  Arc::new(RwLock::new(Vec::new())),
60            in_flight: Arc::new(AtomicUsize::new(0)),
61            shutdown_notify: Arc::new(Notify::new()),
62        }
63    }
64
65    fn generate_consumer_tag(&self) -> String {
66        format!("ctag{}", Uuid::new_v4())
67    }
68
69    pub async fn reopen(&mut self, channel_id: u16) -> Result<(), AppError> {
70        if channel_id == self.channel.channel_id() {
71            let new_channel = self.connection.lock().await.open_channel(None).await?;
72            if self.publisher_confirms == Confirmations::PublisherConfirms || self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
73                let args = ConfirmSelectArguments::default();
74                let _ = new_channel.confirm_select(args).await;
75            }
76            self.channel.clone().close().await.ok();
77            self.channel = new_channel;
78            if !self.auto_ack {
79                if let Some(pre_fetch_count) = self.pre_fetch_count {
80                    let args = BasicQosArguments::new(0, pre_fetch_count, false);
81                    let _ = self.channel.basic_qos(args).await;
82                }   
83            }
84            if let Err(e) = self.channel
85                .register_callback(MyChannelCallback {
86                    channel_tx: self.channel_tx.clone(),
87                })
88                .await
89            {
90                error!("Failed to register channel callback: {}", e);
91            }
92            
93        } else if self.aux_channel.is_some() && channel_id == self.aux_channel.as_ref().unwrap().channel_id() {
94            let new_channel = self.connection.lock().await.open_channel(None).await?;
95            if self.publisher_confirms == Confirmations::RPCServerPublisherConfirms {
96                let args = ConfirmSelectArguments::default();
97                let _ = new_channel.confirm_select(args).await;
98            }
99            let _ = self.aux_channel.as_ref().unwrap().clone().close().await;
100            self.aux_channel = Some(new_channel);
101            if !self.auto_ack {
102                if let Some(pre_fetch_count) = self.pre_fetch_count {
103                    let args = BasicQosArguments::new(0, pre_fetch_count, false);
104                    let _ = self.aux_channel.as_ref().unwrap().basic_qos(args).await;
105                }   
106            }
107            if let Err(e) = self.aux_channel.as_ref().unwrap()
108                .register_callback(MyChannelCallback {
109                    channel_tx: self.channel_tx.clone(),
110                })
111                .await
112            {
113                error!("Failed to register channel callback: {}", e);
114            }
115
116            
117        } else {
118            error!("Received reopen for unknown channel id: {}", channel_id);
119        }
120        Ok(())
121    }
122
123    pub async fn add_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalSubscribeHandler) {
124        let queue_handlers = {
125            let mut handlers = self.subscribes.write().await;
126            handlers
127                .entry(queue_name.to_owned())
128                    .or_insert_with(|| Arc::new(ArcSwap::from_pointee(TopicTrie::new())))
129                    .clone()
130        };
131        queue_handlers.rcu(|current_map| {
132            let mut new_map = (**current_map).clone();
133            new_map.insert(routing_key, handler.clone());
134            Arc::new(new_map)
135        });
136                
137    }
138
139    pub async fn add_rpc_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalRPCHandler) {
140        let queue_handlers = {
141            let mut rpc_handlers = self.rpc_subscribes.write().await;
142            rpc_handlers
143                .entry(queue_name.to_owned())
144                .or_insert_with(|| Arc::new(ArcSwap::from_pointee(HashMap::new())))
145                .clone()
146        };
147
148        queue_handlers.rcu(|current_map| {
149            let mut new_map = (**current_map).clone();
150            new_map.insert(routing_key.to_owned(), handler.clone());
151            Arc::new(new_map)
152        });
153    }
154
155    pub async fn queue_bind(&self, queue_name: &str, exchange_name: &str, routing_key: &str) -> Result<(), AppError> {
156        self.channel
157            .queue_bind(QueueBindArguments::new(
158                queue_name,
159                exchange_name,
160                routing_key,
161            ))
162            .await?;
163        Ok(())
164    }
165
166    pub async fn set_qos(&self, prefetch_count: u16) -> Result<(), AppError> {
167        let args = BasicQosArguments::new(0, prefetch_count, false);
168        self.channel.basic_qos(args).await?;
169        Ok(())
170    }
171
172    pub async fn setup_exchange(&self, exchange_name: &str, exchange_type: &str, durable: bool) -> Result<(), AppError> {
173        let arguments = ExchangeDeclareArguments{
174            exchange: exchange_name.to_string(),
175            exchange_type: exchange_type.to_string(),
176            durable,
177            ..Default::default()
178        };
179        Ok(self.channel.exchange_declare(arguments).await?)
180    }
181
182    pub async fn publish(
183        &self,
184        exchange_name: &str,
185        routing_key: &str,
186        body: impl Into<Vec<u8>>,
187        content_type: &str,
188        content_encoding: ContentEncoding,
189        delivery_mode: DeliveryMode,
190        expiration: Option<u32>,
191    ) -> Result<(), AppError>{
192        let args = BasicPublishArguments{
193            exchange: exchange_name.to_owned(),
194            routing_key: routing_key.to_owned(),
195            mandatory: true,
196            immediate: false
197        };
198        let mut properties = BasicProperties::default();
199        properties.with_content_type(content_type);
200        if content_encoding != ContentEncoding::None {
201            properties.with_content_encoding(content_encoding.as_str());
202        }
203        if let Some(exp) = expiration {
204            properties.with_expiration(&format!("{}", exp));
205        }
206        properties.with_delivery_mode(delivery_mode as u8);
207        Ok(self.channel.basic_publish(properties, body.into(), args).await?)
208    }
209
210    pub async fn queue_declare(&self, queue_name: &str, queue_options: &QueueOptions) -> Result<(), AppError> {
211        let queue_args = QueueDeclareArguments::new(queue_name)
212            .auto_delete(queue_options.auto_delete)
213            .durable(queue_options.durable)
214            .exclusive(queue_options.exclusive)
215            .passive(queue_options.no_create)
216            .arguments(queue_options.clone().into())
217            .finish();
218        self.channel.queue_declare(queue_args).await?;
219        Ok(())
220    }
221    
222    pub async fn close(&self) -> Result<(), AppError> {
223        self.channel.clone().close().await?;
224        if let Some(aux_channel) = &self.aux_channel {
225            aux_channel.clone().close().await?;
226        }
227        Ok(())
228    }
229
230    pub async fn subscribe(
231        &self,
232        handler: Handler,
233        routing_key: &str,
234        exchange_name: &str,
235        exchange_type: &str,
236        queue_name: &str,
237        process_timeout: Option<Duration>,
238        queue_options: &QueueOptions
239    ) -> Result<(), AppError>
240    {
241        self.setup_exchange(exchange_name, exchange_type, queue_options.durable)
242            .await?;
243        self.queue_declare(queue_name, queue_options).await?;
244        /*self.declared_exchanges.rcu(|current_map| {
245            let mut new_map = (**current_map).clone();
246            new_map.insert(exchange_name.to_owned(), match exchange_type {
247                "direct" => ExchangeType::Direct,
248                "fanout" => ExchangeType::Fanout,
249                "topic" => ExchangeType::Topic,
250                _ => return Arc::new(new_map),
251            });
252            Arc::new(new_map)
253        });*/
254        
255        self.channel
256        .queue_bind(QueueBindArguments::new(
257            &queue_name,
258            exchange_name,
259            routing_key,
260        ))
261        .await?;
262        
263        self.add_subscribe(&queue_name, routing_key, InternalSubscribeHandler::new(
264            handler,
265            process_timeout,
266        )).await;
267
268        if !self.consumers.contains_key(queue_name) {
269            let queue_handler = self.subscribes.read().await;
270            let handler = queue_handler.get(queue_name).unwrap();
271            if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
272                let args = BasicQosArguments::new(0, pre_fetch_count, false);
273                let _ = self.channel.basic_qos(args).await;
274            }
275            self.consumers.insert(queue_name.to_string(), true);
276            let mut args = BasicConsumeArguments::new(&queue_name, &self.generate_consumer_tag());
277            args.manual_ack(!self.auto_ack);
278            let sub_handler = BroadSubscribeHandler::new(Arc::clone(handler), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
279            let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
280            self.consumer_tags.write().await.push(consumer_tag);
281        }
282        Ok(())
283    }
284    pub async fn unsubscribe(&self, consumer_tag: &str) -> Result<(), AppError> {
285        let args = BasicCancelArguments::new(consumer_tag);   
286        self.channel.basic_cancel(args).await?;
287        Ok(())
288    }
289
290    pub async fn rpc_server(
291        &mut self,
292        handler: RPCHandler,
293        routing_key: &str,
294        exchange_name: &str,
295        exchange_type: &str,
296        queue_name: &str,
297        response_timeout: Option<Duration>,
298        queue_options: &QueueOptions
299    ) -> Result<(), AppError>
300    {
301        if self.aux_channel.is_none() {
302            let ch = self.connection.lock().await.open_channel(None).await?;
303            
304            if self.publisher_confirms == Confirmations::RPCServerPublisherConfirms {
305                let args = ConfirmSelectArguments::default();
306                let _ = ch.confirm_select(args).await;
307            }
308            self.aux_channel = Some(ch);
309        }
310        self.add_rpc_subscribe(queue_name, routing_key, InternalRPCHandler::new(
311            handler,
312            response_timeout,
313        )).await;
314
315        self.setup_exchange(exchange_name, exchange_type, queue_options.durable)
316            .await?;
317        self.queue_declare(queue_name, queue_options).await?;
318        /*self.declared_exchanges.rcu(|current_map| {
319            let mut new_map = (**current_map).clone();
320            new_map.insert(exchange_name.to_owned(), match exchange_type {
321                "direct" => ExchangeType::Direct,
322                "fanout" => ExchangeType::Fanout,
323                "topic" => ExchangeType::Topic,
324                _ => return Arc::new(new_map), // Invalid exchange type, skip updating
325            });
326            Arc::new(new_map)
327        });*/
328        self.channel
329            .queue_bind(QueueBindArguments::new(
330                &queue_name,
331                exchange_name,
332                routing_key,
333            ))
334            .await?;
335        if !self.consumers.contains_key(queue_name) {
336            let queue_handler = self.rpc_subscribes.read().await;
337            let handler = queue_handler.get(queue_name).unwrap();
338            let mut args = BasicConsumeArguments::new(queue_name, &self.generate_consumer_tag());
339            args.manual_ack(!self.auto_ack);
340            self.consumers.insert(queue_name.to_string(), true);
341            let sub_handler = BroadRPCHandler::new(
342                Arc::new(self.aux_channel.as_ref().unwrap().clone()),
343                Arc::clone(handler),
344                self.auto_ack,
345                self.in_flight.clone(),
346                self.shutdown_notify.clone(),
347            );
348            drop(queue_handler);
349            if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
350                let args = BasicQosArguments::new(0, pre_fetch_count, false);
351                let _ = self.channel.basic_qos(args).await;
352            }
353            let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
354            self.consumer_tags.write().await.push(consumer_tag);
355        }
356        Ok(())
357    }
358    
359    pub async fn start_rpc_consumer(&mut self) -> Result<(), AppError> {
360        if !self.rpc_consumer_started.load(std::sync::atomic::Ordering::SeqCst) {
361            {
362                self.aux_channel = Some(async {
363                    let ch = self.connection.lock().await.open_channel(None).await?;
364                    if let Err(e) = ch
365                        .register_callback(MyChannelCallback {
366                            channel_tx: self.channel_tx.clone(),
367                        })
368                        .await
369                    {
370                        error!("Failed to register channel callback: {}", e);
371                    }
372                    if self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
373                        let args = ConfirmSelectArguments::default();
374                        let _ = ch.confirm_select(args).await;
375                    }
376                    if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
377                        let args = BasicQosArguments::new(0, pre_fetch_count, false);
378                        let _ = ch.basic_qos(args).await;
379                    }
380                    Ok::<Channel, AppError>(ch)
381                }.await?);
382            }
383            if let Some(channel) = &self.aux_channel {
384                let mut queue_declare = QueueDeclareArguments::new(&self.aux_queue_name);
385                let mut field_table = FieldTable::new();
386                field_table.insert("x-expires".try_into().unwrap(), amqprs::FieldValue::l(60000));
387                queue_declare.auto_delete(false);
388                queue_declare.exclusive(false);
389                queue_declare.arguments(field_table);
390                let (_, _, _) = channel.queue_declare(queue_declare)
391                    .await?
392                    .ok_or_else(|| AppError::new(Some("Queue declare returned None".to_string()), None, AppErrorType::InternalError))?;
393                let rpc_handler = BroadRPCClientHandler::new(Arc::clone(&self.rpc_futures), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
394                let mut args = BasicConsumeArguments::new(&self.aux_queue_name, &self.generate_consumer_tag());
395                args.manual_ack(!self.auto_ack);
396                let consumer_tag = channel.basic_consume(rpc_handler, args).await?;
397                self.consumer_tags.write().await.push(consumer_tag);
398                self.rpc_consumer_started.store(true, std::sync::atomic::Ordering::SeqCst);
399            }
400        }
401        Ok(())
402    }
403
404    pub async fn rpc_client(
405        &mut self,
406        exchange_name: &str,
407        routing_key: &str,
408        body: impl Into<Vec<u8>>,
409        content_type: &str,
410        content_encoding: ContentEncoding,
411        timeout_millis: u32,
412        delivery_mode: DeliveryMode,
413        expiration: Option<u32>,
414        response: oneshot::Sender<Result<Vec<u8>, AppError>>,
415        clean_message: UnboundedSender<ChannelCmd>,
416        message_id: Option<u64>,
417    ) -> Result<(), AppError> 
418    {
419        self.start_rpc_consumer().await?;
420        let (tx, rx) = oneshot::channel();
421        
422        let correlated_id = Uuid::new_v4().to_string();
423        self.rpc_futures.insert(correlated_id.to_owned(), tx);
424        let mut args = BasicPublishArguments::new(exchange_name, routing_key);
425        args.mandatory(true);
426        let mut properties = BasicProperties::default();
427        properties.with_content_type(content_type);
428        if content_encoding != ContentEncoding::None {
429            properties.with_content_encoding(content_encoding.as_str());
430        }
431        properties.with_correlation_id(&correlated_id);
432        properties.with_reply_to(&self.aux_queue_name);
433        properties.with_delivery_mode(delivery_mode as u8);
434        let cn = self.channel.clone();
435        if let Some(exp) = expiration {
436            properties.with_expiration(&format!("{}", exp));
437        }
438        let body = body.into();
439        tokio::spawn(async move {
440            let _ = cn.basic_publish(properties, body, args).await;
441            let message = match tokio::time::timeout(std::time::Duration::from_millis(timeout_millis as u64), rx).await {
442                Ok(Ok(result)) => Ok(result),
443                Ok(Err(_)) => Err(AppError::new(Some("Receiver was dropped".to_string()), None, AppErrorType::InternalError)),
444                Err(_) => Err(AppError::new(Some("Timeout exceeded".to_string()), None, AppErrorType::TimeoutError)),
445            };
446            if let Err(_) = response.send(message) && let Some(id) = message_id {
447                let _ = clean_message.send(ChannelCmd::PublishNack((id, false)));
448            }
449        });
450        Ok(())
451    }
452    pub async fn dispose(&self) {
453        let cn = self.channel.clone();
454        for tag in self.consumer_tags.read().await.iter() {
455            let args = BasicCancelArguments::new(tag);
456            if let Err(e) = cn.basic_cancel(args).await {
457                error!("Failed to cancel consumer {}: {}", tag, e);
458            }
459        }
460        while self.in_flight.load(Ordering::Acquire) > 0 {
461            self.shutdown_notify.notified().await;
462        }
463        if let Err(e) = self.channel.clone().close().await {
464            error!("Failed to close main channel: {}", e);
465        }
466        if let Some(channel) = &self.aux_channel {
467            if let Err(e) = channel.clone().close().await {
468                error!("Failed to close aux channel: {}", e);
469            }
470        }
471    }
472}