Skip to main content

amqp_client_rust/api/
connection.rs

1use std::{collections::{BTreeMap, VecDeque}, future::Future, pin::Pin, sync::{Arc,atomic::{AtomicBool, Ordering}}};
2use dashmap::DashMap;
3use tokio::{sync::{Mutex, mpsc, oneshot}, time::{Duration, sleep, timeout}};
4use tracing::error;
5use crate::{api::{
6    callback::MyChannelCallback, channel::AsyncChannel, utils::{Confirmations, ContentEncoding, DeliveryMode, Handler, Message, PendingCmd, RPCHandler, compress}
7}, errors::{AppError, AppErrorType}};
8use amqprs::{channel::{ConfirmSelectArguments}, connection::{Connection, OpenConnectionArguments}};
9use crate::domain::config::Config;
10use super::callback::MyConnectionCallback;
11#[cfg(feature = "tls")]
12use amqprs::tls::TlsAdaptor;
13use std::error::Error as StdError;
14
15// Command Enum for Actor Communication
16pub enum ConnectionCommand {
17    Publish {
18        exchange_name: String,
19        routing_key: String,
20        body: Vec<u8>,
21        content_type: String,
22        content_encoding: ContentEncoding,
23        delivery_mode: DeliveryMode,
24        expiration: Option<u32>,
25        response: oneshot::Sender<Result<(), AppError>>,
26        confirm: Option<oneshot::Sender<Result<(), AppError>>>,
27    },
28    Subscribe {
29        handler: Handler,
30        routing_key: String,
31        exchange_name: String,
32        exchange_type: String,
33        queue_name: String,
34        response: oneshot::Sender<Result<(), AppError>>,
35        process_timeout: Option<Duration>,
36    },
37    RpcServer {
38        handler: RPCHandler,
39        routing_key: String,
40        exchange_name: String,
41        exchange_type: String,
42        queue_name: String,
43        response: oneshot::Sender<Result<(), AppError>>,
44        response_timeout: Option<Duration>,
45    },
46    RpcClient {
47        exchange_name: String,
48        routing_key: String,
49        body: Vec<u8>,
50        content_type: String,
51        content_encoding: ContentEncoding,
52        response_timeout_millis: u32,
53        delivery_mode: DeliveryMode,
54        expiration: Option<u32>,
55        response: oneshot::Sender<Result<Vec<u8>, AppError>>,
56        confirm: Option<oneshot::Sender<Result<(), AppError>>>,
57    },
58    Close {
59        response: oneshot::Sender<()>,
60    },
61    CheckConnection {
62    },
63    UpdateSecret {
64        new_secret: String,
65        reason: String,
66        response: oneshot::Sender<Result<(), AppError>>,
67    }
68}
69
70// Data structures for backup/restore on reconnection
71struct SubscribeBackup {
72    queue: String,
73    exchange_name: String,
74    exchange_type: String,
75    handler: Handler,
76    routing_key: String,
77    process_timeout: Option<Duration>
78}
79
80struct RPCSubscribeBackup {
81    queue: String,
82    exchange_name: String,
83    exchange_type: String,
84    handler: RPCHandler,
85    routing_key: String,
86    response_timeout: Option<Duration>,
87}
88
89// The Handle exposed to the EventBus
90#[derive(Clone)]
91pub struct AsyncConnection {
92    sender: mpsc::UnboundedSender<ConnectionCommand>,
93    publisher_confirms: Confirmations,
94    is_closing: Arc<AtomicBool>,
95}
96
97impl AsyncConnection {
98    pub fn new(config: Arc<Config>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>) -> Self {
99        let (tx, rx) = mpsc::unbounded_channel();
100
101        let manager = ConnectionManager::new(config, tx.clone(), rx, publisher_confirms, auto_ack, pre_fetch_count);
102        tokio::spawn(async move {
103            manager.run().await;
104        });
105        Self { sender: tx, publisher_confirms, is_closing: Arc::new(AtomicBool::new(false)) }
106    }
107
108    pub async fn publish(
109        &self, exchange_name: &str, routing_key: &str, body: impl Into<Vec<u8>>,
110            content_type: &str, content_encoding: ContentEncoding, command_timeout: Option<Duration>,
111            delivery_mode: DeliveryMode, expiration: Option<u32>
112        ) -> Result<(), AppError> {
113        if self.is_closing.load(Ordering::Acquire) {
114            return Err(AppError::new(
115                Some("Connection is shutting down".to_owned()),
116                None,
117                AppErrorType::InternalError // Or a new ConnectionClosed type
118            ));
119        }
120        let (resp_tx, resp_rx) = oneshot::channel();
121        let body = compress(body, content_encoding)?;
122        if self.publisher_confirms == Confirmations::PublisherConfirms {
123            let confirmation = oneshot::channel();
124        
125            let cmd = ConnectionCommand::Publish {
126                exchange_name: exchange_name.to_string(),
127                routing_key: routing_key.to_string(),
128                body,
129                content_type: content_type.to_string(),
130                content_encoding,
131                delivery_mode,
132                expiration,
133                response: resp_tx,
134                confirm: Some(confirmation.0),
135            };
136            let (_, _) = tokio::try_join!(self.send_command(cmd, resp_rx, command_timeout), async {
137                match timeout(command_timeout.unwrap_or(Duration::from_secs(16)), confirmation.1).await {
138                    Ok(Ok(res)) => res,
139                    Ok(Err(_)) => Err(AppError::new(Some("Confirm channel closed".to_owned()), None, AppErrorType::InternalError)),
140                    Err(_) => Err(AppError::new(Some("Timeout waiting for confirmation".to_owned()), None, AppErrorType::TimeoutError)),
141                }
142            })?;
143            Ok(())
144        } else {
145            let cmd = ConnectionCommand::Publish {
146                exchange_name: exchange_name.to_string(),
147                routing_key: routing_key.to_string(),
148                body,
149                content_type: content_type.to_string(),
150                content_encoding,
151                delivery_mode,
152                expiration,
153                response: resp_tx,
154                confirm: None
155            };
156            self.send_command(cmd, resp_rx, command_timeout).await
157        }
158    }
159
160    pub async fn subscribe(
161        &self,
162        handler: Handler,
163        routing_key: &str,
164        exchange_name: &str,
165        exchange_type: &str,
166        queue_name: &str,
167        process_timeout: Option<Duration>,
168        timeout_duration: Option<Duration>
169    ) -> Result<(), AppError> {
170        if self.is_closing.load(Ordering::Acquire) {
171            return Err(AppError::new(
172                Some("Connection is shutting down".to_string()),
173                None,
174                AppErrorType::InternalError // Or a new ConnectionClosed type
175            ));
176        }
177        let (resp_tx, resp_rx) = oneshot::channel();
178        let cmd = ConnectionCommand::Subscribe {
179            handler,
180            routing_key: routing_key.to_string(),
181            exchange_name: exchange_name.to_string(),
182            exchange_type: exchange_type.to_string(),
183            queue_name: queue_name.to_string(),
184            response: resp_tx,
185            process_timeout,
186        };
187        self.send_command(cmd, resp_rx, timeout_duration).await
188    }
189
190    pub async fn rpc_server(
191        &self,
192        handler: RPCHandler,
193        routing_key: &str,
194        exchange_name: &str,
195        exchange_type: &str,
196        queue_name: &str,
197        response_timeout: Option<Duration>,
198        timeout_duration: Option<Duration>
199    ) -> Result<(), AppError> {
200        if self.is_closing.load(Ordering::Acquire) {
201            return Err(AppError::new(
202                Some("Connection is shutting down".to_string()),
203                None,
204                AppErrorType::InternalError
205            ));
206        }
207        let (resp_tx, resp_rx) = oneshot::channel();
208        let cmd = ConnectionCommand::RpcServer {
209            handler,
210            routing_key: routing_key.to_string(),
211            exchange_name: exchange_name.to_string(),
212            exchange_type: exchange_type.to_string(),
213            queue_name: queue_name.to_string(),
214            response: resp_tx,
215            response_timeout,
216        };
217        self.send_command(cmd, resp_rx, timeout_duration).await
218    }
219
220    pub async fn rpc_client(
221        &self,
222        exchange_name: &str,
223        routing_key: &str,
224        body: impl Into<Vec<u8>>,
225        content_type: &str,
226        content_encoding: ContentEncoding,
227        response_timeout_millis: u32,
228        command_timeout: Option<Duration>,
229        delivery_mode: DeliveryMode,
230        expiration: Option<u32>,
231    ) -> Result<Vec<u8>, AppError> {
232        if self.is_closing.load(Ordering::Acquire) {
233            return Err(AppError::new(
234                Some("Connection is shutting down".to_string()),
235                None,
236                AppErrorType::InternalError
237            ));
238        }
239        let (resp_tx, resp_rx) = oneshot::channel();
240        let body = compress(body.into(), content_encoding)?;
241        if self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
242            let confirmation = oneshot::channel();
243            let cmd = ConnectionCommand::RpcClient {
244                exchange_name: exchange_name.to_string(),
245                routing_key: routing_key.to_string(),
246                body,
247                content_type: content_type.to_string(),
248                content_encoding,
249                response_timeout_millis,
250                delivery_mode,
251                expiration,
252                response: resp_tx,
253                confirm: Some(confirmation.0),
254            };
255            let confirmation = async {
256                match timeout(command_timeout.unwrap_or(Duration::from_secs(16)), confirmation.1).await {
257                    Ok(Ok(res)) => res,
258                    Ok(Err(_)) => Err(AppError::new(Some("Confirm channel closed".to_owned()), None, AppErrorType::InternalError)),
259                    Err(_) => Err(AppError::new(Some("Timeout waiting for confirmation".to_owned()), None, AppErrorType::TimeoutError)),
260                }
261            };
262            let (response, _) = tokio::try_join!(self.send_command(cmd, resp_rx, command_timeout), confirmation)?;
263            Ok(response)
264        } else {
265            let cmd = ConnectionCommand::RpcClient {
266                exchange_name: exchange_name.to_string(),
267                routing_key: routing_key.to_string(),
268                body,
269                content_type: content_type.to_string(),
270                content_encoding,
271                response_timeout_millis,
272                delivery_mode,
273                expiration,
274                response: resp_tx,
275                confirm: None,
276            };
277            self.send_command(cmd, resp_rx, command_timeout).await
278        }
279    }
280
281    pub async fn update_secret(&self, new_secret: &str, reason: &str, command_timeout: Option<Duration>) -> Result<(), AppError> {
282        if self.is_closing.load(Ordering::Acquire) {
283            return Err(AppError::new(
284                Some("Connection is shutting down".to_string()),
285                None,
286                AppErrorType::InternalError
287            ));
288        }
289        let (resp_tx, resp_rx) = oneshot::channel();
290        let cmd = ConnectionCommand::UpdateSecret {
291            new_secret: new_secret.to_string(),
292            reason: reason.to_string(),
293            response: resp_tx,
294        };
295        self.send_command(cmd, resp_rx, command_timeout).await
296    }
297
298    async fn send_command<T>(&self, cmd: ConnectionCommand, rx: oneshot::Receiver<Result<T, AppError>>, command_timeout: Option<Duration>) -> Result<T, AppError> {
299        if self.sender.send(cmd).is_err() {
300            return Err(AppError::new(Some("Connection manager dropped".to_string()), None, AppErrorType::InternalError));
301        }
302        
303        match command_timeout {
304            Some(dur) => match timeout(dur, rx).await {
305                Ok(Ok(res)) => res,
306                Ok(Err(_)) => Err(AppError::new(Some("Response channel closed".to_owned()), None, AppErrorType::InternalError)),
307                Err(_) => Err(AppError::new(Some("Timeout waiting for connection".to_owned()), None, AppErrorType::TimeoutError)),
308            },
309            None => match rx.await {
310                Ok(res) => res,
311                Err(_) => Err(AppError::new(Some("Response channel closed".to_owned()), None, AppErrorType::InternalError)),
312            }
313        }
314    }
315
316    pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>>  {
317        self.is_closing.store(true, Ordering::Release);
318        let (tx, rx) = oneshot::channel();
319        self.sender.send(ConnectionCommand::Close { response: tx })?;
320        rx.await?;
321        Ok(())
322    }
323}
324
325
326struct ConnectionManager {
327    config: Arc<Config>,
328    tx: mpsc::UnboundedSender<ConnectionCommand>,
329    rx: mpsc::UnboundedReceiver<ConnectionCommand>,
330    connection: Option<Connection>,
331    channel: Option<AsyncChannel>,
332    pending_commands: VecDeque<ConnectionCommand>,
333    subscribe_backup: Vec<SubscribeBackup>,
334    rpc_subscribe_backup: Vec<RPCSubscribeBackup>,
335    publisher_confirms: Confirmations,
336    pending_confirmations: BTreeMap<u64, oneshot::Sender<Result<(), AppError>>>,
337    pending_rx: mpsc::UnboundedReceiver<PendingCmd>,
338    pending_tx: mpsc::UnboundedSender<PendingCmd>,
339    message_number: u64,
340    auto_ack: bool,
341    pre_fetch_count: Option<u16>,
342    current_reconnect_delay: u16,
343}
344
345impl ConnectionManager {
346    fn new(config: Arc<Config>, tx: mpsc::UnboundedSender<ConnectionCommand>, rx: mpsc::UnboundedReceiver<ConnectionCommand>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>) -> Self {
347        let (pending_tx, pending_rx) = mpsc::unbounded_channel();
348        Self {
349            config,
350            tx,
351            rx,
352            connection: None,
353            channel: None,
354            pending_commands: VecDeque::new(),
355            subscribe_backup: Vec::new(),
356            rpc_subscribe_backup: Vec::new(),
357            publisher_confirms,
358            pending_confirmations: BTreeMap::new(),
359            pending_rx,
360            pending_tx,
361            message_number: 0,
362            auto_ack,
363            pre_fetch_count,
364            current_reconnect_delay: 1,
365        }
366    }
367
368    async fn run(mut self) {
369        self.connect().await;
370
371        let mut health_check_interval = tokio::time::interval(Duration::from_secs(1));
372        let mut intentional_close = false;
373        loop {
374            tokio::select! {
375                Some(cmd) = self.pending_rx.recv() => {
376                    match cmd {
377                        PendingCmd::Ack((tag, multiple)) => {
378                            if multiple {
379                                while let Some(entry) = self.pending_confirmations.first_entry() {
380                                    if entry.key() > &tag {
381                                        break;
382                                    }
383                                    let confirm = entry.remove(); 
384                                    let _ = confirm.send(Ok(()));
385                                }
386                            } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
387                                let _ = confirm.send(Ok(()));
388                            }
389                        },
390                        PendingCmd::Nack((tag, multiple)) => {
391                            if multiple {
392                                while let Some(entry) = self.pending_confirmations.first_entry() {
393                                    if entry.key() > &tag {
394                                        break; // Stop if we go past the tag
395                                    }
396                                    let confirm = entry.remove(); 
397                                    let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
398                                }
399                            } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
400                                let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
401                            }
402                        },
403                    }
404                }
405                Some(cmd) = self.rx.recv() => {
406                    match cmd {
407                        ConnectionCommand::Close{ response } => {
408                            intentional_close = true;
409                            if let Some(channel) = &self.channel {
410                                channel.dispose().await;
411                            }
412                            if let Some(conn) = &self.connection {
413                                let _ = conn.clone().close().await;
414                            }
415                            
416                            let _ = response.send(());
417                            continue;
418                        },
419                        ConnectionCommand::CheckConnection{} => {
420                            continue;
421                        },
422                        _ => {
423                            if self.is_connected() {
424                                self.process_command(cmd).await;
425                            } else {
426                                self.pending_commands.push_back(cmd);
427                            }
428                        }
429                    }
430                },
431                _ = health_check_interval.tick() => {
432                    if !self.is_connected() && !intentional_close {
433                        sleep(Duration::from_secs(self.current_reconnect_delay as u64 -1)).await;
434                        self.connect().await;
435                        self.current_reconnect_delay = std::cmp::min(self.current_reconnect_delay * 2, 30);
436                    }
437                }
438            }
439        }
440    }
441
442    fn is_connected(&self) -> bool {
443        self.connection.as_ref().is_some_and(|c| c.is_open()) 
444            && self.channel.as_ref().is_some_and(|c| c.channel.is_open())
445    }
446
447    async fn connect(&mut self) {
448        #[cfg(feature = "default")]
449        let mut options = OpenConnectionArguments::new(
450            &self.config.host,
451            self.config.port,
452            &self.config.username,
453            &self.config.password,
454        );
455        options.virtual_host(&self.config.virtual_host);
456        #[cfg(feature = "tls")]
457        if let Some(tls_adaptor) = &self.config.tls_adaptor {
458            options = options.tls_adaptor(
459                tls_adaptor.clone()
460            ).finish();
461        }
462        match Connection::open(&options).await {
463            Ok(conn) => {
464                if let Err(e) = conn.register_callback(MyConnectionCallback{sender: self.tx.clone()}).await {
465                    error!("Failed to register connection callback: {}", e);
466                }
467                self.current_reconnect_delay = 1;
468
469                self.connection = Some(conn.clone());
470                let conn_mutex = Arc::new(Mutex::new(conn.clone()));
471                
472                if let Ok(ch) = conn.open_channel(None).await {
473                    if let Err(e) = ch.register_callback(MyChannelCallback{sender_pending: self.pending_tx.clone()}).await {
474                        error!("Failed to register channel callback: {}", e);
475                    }
476
477                    if self.publisher_confirms == Confirmations::PublisherConfirms || self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
478                        let args = ConfirmSelectArguments::default();
479                        let _ = ch.confirm_select(args).await;
480                    }
481                    self.message_number = 0;
482                    if let Some(latest_channel) = &self.channel && latest_channel.rpc_consumer_started.load(Ordering::SeqCst){
483                        let async_ch = AsyncChannel::new(ch, conn_mutex,latest_channel.rpc_futures.clone(), self.publisher_confirms, self.auto_ack, self.pre_fetch_count);
484                        let _ = async_ch.start_rpc_consumer().await;
485                        self.channel = Some(async_ch);
486                    } else {
487                        self.channel = Some(AsyncChannel::new(ch, conn_mutex, Arc::new(DashMap::new()), self.publisher_confirms, self.auto_ack, self.pre_fetch_count));
488                    }
489                    
490                    self.restore_subscriptions().await;
491                    
492                    while let Some(cmd) = self.pending_commands.pop_front() {
493                        self.process_command(cmd).await;
494                    }
495                }
496            }
497            Err(e) => {
498                error!("Failed to connect: {}", e);
499            }
500        }
501    }
502
503    async fn restore_subscriptions(&mut self) {
504        if let Some(channel) = &mut self.channel {
505            for sub in &self.subscribe_backup {
506                let _ = channel.subscribe(
507                    sub.handler.clone(),
508                    &sub.routing_key,
509                    &sub.exchange_name,
510                    &sub.exchange_type,
511                    &sub.queue,
512                    sub.process_timeout,
513                ).await;
514            }
515            for sub in &self.rpc_subscribe_backup {
516                 let _ = channel.rpc_server(
517                    sub.handler.clone(),
518                    &sub.routing_key,
519                    &sub.exchange_name,
520                    &sub.exchange_type,
521                    &sub.queue,
522                    sub.response_timeout,
523                ).await;
524            }
525        }
526    }
527
528    async fn process_command(&mut self, cmd: ConnectionCommand) {
529        let channel = match &mut self.channel {
530            Some(c) => c,
531            None => {
532                self.pending_commands.push_front(cmd);
533                return;
534            }
535        };
536
537        match cmd {
538            ConnectionCommand::Publish { exchange_name, routing_key, body, content_type, content_encoding, delivery_mode, expiration, response, confirm} => {
539                if let Some(confirm) = confirm {
540                    self.message_number += 1;
541                    self.pending_confirmations.insert(self.message_number, confirm);
542                }
543                let res = channel.publish(&exchange_name, &routing_key, body, &content_type, content_encoding, delivery_mode, expiration).await;
544                let _ = response.send(res);
545            },
546            ConnectionCommand::Subscribe { handler, routing_key, exchange_name, exchange_type, queue_name, response, process_timeout } => {
547                self.subscribe_backup.push(SubscribeBackup {
548                    queue: queue_name.clone(),
549                    exchange_name: exchange_name.clone(),
550                    exchange_type: exchange_type.clone(),
551                    handler: handler.clone(),
552                    routing_key: routing_key.clone(),
553                    process_timeout,
554                });
555                
556                let res = channel.subscribe(handler, &routing_key, &exchange_name, &exchange_type, &queue_name, process_timeout).await;
557                let _ = response.send(res);
558            },
559            ConnectionCommand::RpcServer { handler, routing_key, exchange_name, exchange_type, queue_name, response, response_timeout } => {
560                self.rpc_subscribe_backup.push(RPCSubscribeBackup {
561                    queue: queue_name.clone(),
562                    exchange_name: exchange_name.clone(),
563                    exchange_type: exchange_type.clone(),
564                    handler: handler.clone(),
565                    routing_key: routing_key.clone(),
566                    response_timeout,
567                });
568                let res = channel.rpc_server(handler, &routing_key, &exchange_name, &exchange_type, &queue_name, response_timeout).await;
569                let _ = response.send(res);
570            },
571            ConnectionCommand::RpcClient { exchange_name, routing_key, body,
572                content_type, content_encoding, response_timeout_millis, delivery_mode, expiration, response, confirm } => {
573                if let Some(confirm) = confirm {
574                    self.message_number += 1;
575                    self.pending_confirmations.insert(self.message_number, confirm);
576                    let _ = channel.rpc_client(&exchange_name, &routing_key, body,
577                    &content_type, content_encoding, response_timeout_millis, delivery_mode, expiration, response, self.pending_tx.clone(), Some(self.message_number)).await;
578                } else {
579                    let _ = channel.rpc_client(&exchange_name, &routing_key, body,
580                    &content_type, content_encoding, response_timeout_millis, delivery_mode, expiration, response, self.pending_tx.clone(), None).await;
581                }
582            },
583            ConnectionCommand::UpdateSecret { new_secret, reason, response } => {
584                if let Some(connection) = &mut self.connection {
585                    let _ = response.send(connection.update_secret(new_secret.as_str(), reason.as_str()).await.map_err(AppError::from));
586                } else {
587                    let _ = response.send(Err(AppError::new(Some("connection is to openned".to_owned()), None, AppErrorType::UnexpectedResultError)));
588                }
589            },
590            _ => {
591            }
592        }
593    }
594}