Skip to main content

amqp_client_rust/api/
connection_manager.rs

1use super::callback::MyConnectionCallback;
2use crate::domain::config::Config;
3use crate::{
4    api::{
5        callback::MyChannelCallback,
6        channel::AsyncChannel,
7        utils::{
8            Confirmations, ContentEncoding, DeliveryMode, Handler, ChannelCmd,
9            QueueOptions, RPCHandler
10        },
11    },
12    errors::{AppError, AppErrorType},
13};
14#[cfg(feature = "tls")]
15use amqprs::tls::TlsAdaptor;
16use amqprs::{
17    channel::{ConfirmSelectArguments},
18    connection::{Connection, OpenConnectionArguments},
19};
20use dashmap::DashMap;
21use std::collections::HashMap;
22use std::error::Error as StdError;
23use std::{
24    collections::{BTreeMap, VecDeque},
25    sync::{
26        Arc,
27        atomic::Ordering,
28    },
29};
30use tokio::{
31    sync::{Mutex, mpsc, oneshot},
32    time::{Duration, sleep},
33};
34use tracing::error;
35
36// Command Enum for Actor Communication
37pub enum ConnectionCommand {
38    Publish {
39        exchange_name: String,
40        routing_key: String,
41        body: Vec<u8>,
42        content_type: String,
43        content_encoding: ContentEncoding,
44        delivery_mode: DeliveryMode,
45        expiration: Option<u32>,
46        response: oneshot::Sender<Result<(), AppError>>,
47        confirm: Option<oneshot::Sender<Result<(), AppError>>>,
48    },
49    Subscribe {
50        handler: Handler,
51        routing_key: String,
52        exchange_name: String,
53        exchange_type: String,
54        queue_name: String,
55        response: oneshot::Sender<Result<(), AppError>>,
56        process_timeout: Option<Duration>,
57        queue_options: QueueOptions,
58    },
59    RpcServer {
60        handler: RPCHandler,
61        routing_key: String,
62        exchange_name: String,
63        exchange_type: String,
64        queue_name: String,
65        response: oneshot::Sender<Result<(), AppError>>,
66        response_timeout: Option<Duration>,
67        queue_options: QueueOptions,
68    },
69    RpcClient {
70        exchange_name: String,
71        routing_key: String,
72        body: Vec<u8>,
73        content_type: String,
74        content_encoding: ContentEncoding,
75        response_timeout_millis: u32,
76        delivery_mode: DeliveryMode,
77        expiration: Option<u32>,
78        response: oneshot::Sender<Result<Vec<u8>, AppError>>,
79        confirm: Option<oneshot::Sender<Result<(), AppError>>>,
80    },
81    Close {
82        response: oneshot::Sender<()>,
83    },
84    CheckConnection {},
85    UpdateSecret {
86        new_secret: String,
87        reason: String,
88        response: oneshot::Sender<Result<(), AppError>>,
89    },
90}
91
92// Data structures for backup/restore on reconnection
93struct SubscribeBackup {
94    exchange_type: String,
95    handler: Handler,
96    process_timeout: Option<Duration>,
97    queue_options: QueueOptions,
98}
99
100struct RPCSubscribeBackup {
101    exchange_type: String,
102    handler: RPCHandler,
103    response_timeout: Option<Duration>,
104    queue_options: QueueOptions,
105}
106
107
108pub struct ConnectionManager {
109    config: Arc<Config>,
110    tx: mpsc::UnboundedSender<ConnectionCommand>,
111    rx: mpsc::UnboundedReceiver<ConnectionCommand>,
112    connection: Option<Connection>,
113    channel: Option<AsyncChannel>,
114    pending_commands: VecDeque<ConnectionCommand>,
115    subscribe_backup: HashMap<(String, String, String), SubscribeBackup>,
116    rpc_subscribe_backup: HashMap<(String, String, String), RPCSubscribeBackup>,
117    publisher_confirms: Confirmations,
118    pending_confirmations: BTreeMap<u64, oneshot::Sender<Result<(), AppError>>>,
119    channel_rx: mpsc::UnboundedReceiver<ChannelCmd>,
120    channel_tx: mpsc::UnboundedSender<ChannelCmd>,
121    message_number: u64,
122    auto_ack: bool,
123    prefetch_count: Option<u16>,
124    current_reconnect_delay: u16,
125    queues: HashMap<String, (AsyncChannel, QueueOptions)>,
126}
127
128impl ConnectionManager {
129    pub fn new(
130        config: Arc<Config>,
131        tx: mpsc::UnboundedSender<ConnectionCommand>,
132        rx: mpsc::UnboundedReceiver<ConnectionCommand>,
133        publisher_confirms: Confirmations,
134        auto_ack: bool,
135        prefetch_count: Option<u16>,
136    ) -> Self {
137        let (channel_tx, channel_rx) = mpsc::unbounded_channel();
138        Self {
139            config,
140            tx,
141            rx,
142            connection: None,
143            channel: None,
144            pending_commands: VecDeque::new(),
145            subscribe_backup: HashMap::new(),
146            rpc_subscribe_backup: HashMap::new(),
147            publisher_confirms,
148            pending_confirmations: BTreeMap::new(),
149            channel_rx,
150            channel_tx,
151            message_number: 0,
152            auto_ack,
153            prefetch_count,
154            current_reconnect_delay: 1,
155            queues: HashMap::new(),
156        }
157    }
158
159    pub async fn run(mut self) {
160        self.connect().await;
161
162        let mut health_check_interval = tokio::time::interval(Duration::from_secs(1));
163        let mut intentional_close = false;
164        loop {
165            tokio::select! {
166                Some(cmd) = self.channel_rx.recv() => {
167                    match cmd {
168                        ChannelCmd::PublishAck((tag, multiple)) => {
169                            if multiple {
170                                while let Some(entry) = self.pending_confirmations.first_entry() {
171                                    if entry.key() > &tag {
172                                        break;
173                                    }
174                                    let confirm = entry.remove();
175                                    let _ = confirm.send(Ok(()));
176                                }
177                            } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
178                                let _ = confirm.send(Ok(()));
179                            }
180                        },
181                        ChannelCmd::PublishNack((tag, multiple)) => {
182                            if multiple {
183                                while let Some(entry) = self.pending_confirmations.first_entry() {
184                                    if entry.key() > &tag {
185                                        break;
186                                    }
187                                    let confirm = entry.remove();
188                                    let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
189                                }
190                            } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
191                                let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
192                            }
193                        },
194                        ChannelCmd::ReOpen(channel_id) => {
195                            for channel in self.queues.values_mut() {
196                                let _ = channel.0.reopen(channel_id).await;
197                            }
198                        }
199                    }
200                }
201                Some(cmd) = self.rx.recv() => {
202                    match cmd {
203                        ConnectionCommand::Close{ response } => {
204                            intentional_close = true;
205                            let mut dispose_futures = Vec::new();
206    
207                            if let Some(channel) = &self.channel {
208                                dispose_futures.push(channel.dispose());
209                            }
210                            
211                            for (channel, _) in self.queues.values() {
212                                dispose_futures.push(channel.dispose());
213                            }
214                            
215                            futures::future::join_all(dispose_futures).await;
216
217                            if let Some(conn) = &self.connection {
218                                let _ = conn.clone().close().await;
219                            }
220
221                            let _ = response.send(());
222                            continue;
223                        },
224                        ConnectionCommand::CheckConnection{} => {
225                            continue;
226                        },
227                        _ => {
228                            if self.is_connected() {
229                                self.process_command(cmd).await;
230                            } else {
231                                self.pending_commands.push_back(cmd);
232                            }
233                        }
234                    }
235                },
236                _ = health_check_interval.tick() => {
237                    if !self.is_connected() && !intentional_close {
238                        sleep(Duration::from_secs(self.current_reconnect_delay as u64 -1)).await;
239                        self.connect().await;
240                        self.current_reconnect_delay = std::cmp::min(self.current_reconnect_delay * 2, 30);
241                    }
242                }
243            }
244        }
245    }
246
247    fn is_connected(&self) -> bool {
248        self.connection.as_ref().is_some_and(|c| c.is_open())
249            && self.channel.as_ref().is_some_and(|c| c.channel.is_open())
250    }
251
252    async fn connect(&mut self) {
253        #[cfg(feature = "default")]
254        let mut options = OpenConnectionArguments::new(
255            &self.config.host,
256            self.config.port,
257            &self.config.username,
258            &self.config.password,
259        );
260        options.virtual_host(&self.config.virtual_host);
261        #[cfg(feature = "tls")]
262        if let Some(tls_adaptor) = &self.config.tls_adaptor {
263            options = options.tls_adaptor(tls_adaptor.clone()).finish();
264        }
265        match Connection::open(&options).await {
266            Ok(conn) => {
267                if let Err(e) = conn
268                    .register_callback(MyConnectionCallback {
269                        sender: self.tx.clone(),
270                    })
271                    .await
272                {
273                    error!("Failed to register connection callback: {}", e);
274                }
275                self.current_reconnect_delay = 1;
276
277                self.connection = Some(conn.clone());
278                let conn_mutex = Arc::new(Mutex::new(conn.clone()));
279
280                if let Ok(ch) = self.open_channel(&conn, conn_mutex.clone(), self.channel.as_ref()).await {
281                    self.channel = Some(ch);
282                }
283                let old_queues = std::mem::take(&mut self.queues);
284
285                for (queue_name, (latest_channel, options)) in old_queues {
286                    if let Ok(ch) = self.open_channel(&conn, conn_mutex.clone(), Some(&latest_channel)).await {
287                        self.queues.insert(queue_name, (ch, options));
288                    } else {
289                        error!("Failed to open channel for queue {} during reconnection", queue_name);
290                    }
291                }
292                
293                self.message_number = 0;
294                self.restore_subscriptions().await;
295
296                while let Some(cmd) = self.pending_commands.pop_front() {
297                    self.process_command(cmd).await;
298                }
299            }
300            Err(e) => {
301                error!("Failed to connect: {}", e);
302            }
303        }
304    }
305
306    async fn open_channel(
307        &self,
308        conn: &Connection,
309        conn_mutex: Arc<Mutex<Connection>>,
310        latest_channel: Option<&AsyncChannel>,
311    ) -> Result<AsyncChannel, AppError> {
312        if let Ok(ch) = conn.open_channel(None).await {
313            if let Err(e) = ch
314                .register_callback(MyChannelCallback {
315                    channel_tx: self.channel_tx.clone(),
316                })
317                .await
318            {
319                error!("Failed to register channel callback: {}", e);
320            }
321
322            if self.publisher_confirms == Confirmations::PublisherConfirms
323                || self.publisher_confirms == Confirmations::RPCClientPublisherConfirms
324            {
325                let args = ConfirmSelectArguments::default();
326                let _ = ch.confirm_select(args).await;
327            }
328            if let Some(latest_channel) = latest_channel
329                && latest_channel.rpc_consumer_started.load(Ordering::SeqCst)
330            {
331                let mut async_ch = AsyncChannel::new(
332                    ch,
333                    conn_mutex,
334                    self.channel_tx.clone(),
335                    latest_channel.rpc_futures.clone(),
336                    self.publisher_confirms,
337                    self.auto_ack,
338                    self.prefetch_count,
339                    Some(latest_channel.aux_queue_name.clone()),
340                );
341                let _ = async_ch.start_rpc_consumer().await;
342                Ok(async_ch)
343            } else {
344                Ok(AsyncChannel::new(
345                    ch,
346                    conn_mutex,
347                    self.channel_tx.clone(),
348                    Arc::new(DashMap::new()),
349                    self.publisher_confirms,
350                    self.auto_ack,
351                    self.prefetch_count,
352                    None
353                ))
354            }
355        } else {
356            Err(AppError::new(
357                Some("No connection available".to_string()),
358                None,
359                AppErrorType::InternalError,
360            ))
361        }
362    }
363
364    async fn restore_subscriptions(&mut self) {
365        for (keys, values) in &self.subscribe_backup {
366            if let Some((isolated_ch, _)) = self.queues.get(&keys.0) {
367                let _ = isolated_ch.subscribe(
368                    values.handler.clone(),
369                    &keys.1,
370                    &keys.2,
371                    &values.exchange_type,
372                    &keys.0,
373                    values.process_timeout,
374                    &values.queue_options,
375                )
376                .await;
377            }
378        }
379        for (keys, values) in &self.rpc_subscribe_backup {
380            if let Some((isolated_ch, queue_options)) = self.queues.get_mut(&keys.0) {
381                let _ = isolated_ch
382                    .rpc_server(
383                        values.handler.clone(),
384                        &keys.1,
385                        &keys.2,
386                        &values.exchange_type,
387                        &keys.0,
388                        values.response_timeout,
389                        queue_options
390                    )
391                    .await;
392                }
393        }
394    }
395
396    async fn process_command(&mut self, cmd: ConnectionCommand) {
397        let channel = match &mut self.channel {
398            Some(c) => c,
399            None => {
400                self.pending_commands.push_front(cmd);
401                return;
402            }
403        };
404
405        match cmd {
406            ConnectionCommand::Publish {
407                exchange_name,
408                routing_key,
409                body,
410                content_type,
411                content_encoding,
412                delivery_mode,
413                expiration,
414                response,
415                confirm,
416            } => {
417                if let Some(confirm) = confirm {
418                    self.message_number += 1;
419                    self.pending_confirmations
420                        .insert(self.message_number, confirm);
421                }
422                let res = channel
423                    .publish(
424                        &exchange_name,
425                        &routing_key,
426                        body,
427                        &content_type,
428                        content_encoding,
429                        delivery_mode,
430                        expiration,
431                    )
432                    .await;
433                let _ = response.send(res);
434            }
435            ConnectionCommand::Subscribe {
436                handler,
437                routing_key,
438                exchange_name,
439                exchange_type,
440                queue_name,
441                response,
442                process_timeout,
443                queue_options,
444            } => {
445                let conn = self.connection.clone().unwrap();
446
447                let existing_queue = self.queues.get(&queue_name).cloned();
448
449                let channel_result = match existing_queue {
450                    Some((ch, args)) if args != queue_options => {
451                        self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
452                        Ok(ch)
453                    }
454                    Some((ch, _)) => Ok(ch),
455                    None => {
456                        match self.open_channel(&conn, Arc::new(Mutex::new(conn.clone())), None).await {
457                            Ok(ch) => {
458                                self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
459                                Ok(ch)
460                            }
461                            Err(e) => Err(AppError::new(
462                                Some(format!("Channel not Openned: Error {}", e)),
463                                None,
464                                AppErrorType::InternalError,
465                            )),
466                        }
467                    }
468                };
469
470                match channel_result {
471                    Ok(ch) => {
472                        let res = ch.subscribe(
473                            handler.clone(),
474                            &routing_key,
475                            &exchange_name,
476                            &exchange_type,
477                            &queue_name,
478                            process_timeout,
479                            &queue_options,
480                        ).await;
481                        if res.is_ok() {
482                            let key = (queue_name.clone(), routing_key.clone(), exchange_name.clone());
483                            self.subscribe_backup.entry(key).or_insert(SubscribeBackup {
484                                exchange_type: exchange_type.clone(),
485                                handler: handler,
486                                process_timeout,
487                                queue_options: queue_options.clone(),
488                            });
489                        }
490                        let _ = response.send(res);
491                    }
492                    Err(err) => {
493                        let _ = response.send(Err(err));
494                    }
495                }
496            }
497            ConnectionCommand::RpcServer {
498                handler,
499                routing_key,
500                exchange_name,
501                exchange_type,
502                queue_name,
503                response,
504                response_timeout,
505                queue_options,
506            } => {
507
508                let conn = self.connection.clone().unwrap();
509
510                let existing_queue = self.queues.get_mut(&queue_name).cloned();
511
512                let channel_result = match existing_queue {
513                    Some((ch, args)) if args != queue_options => {
514                        self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
515                        Ok(ch)
516                    }
517                    Some((ch, _)) => Ok(ch),
518                    None => {
519                        match self.open_channel(&conn, Arc::new(Mutex::new(conn.clone())), None).await {
520                            Ok(ch) => {
521                                self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
522                                Ok(ch)
523                            }
524                            Err(e) => Err(AppError::new(
525                                Some(format!("Channel not Openned: Error {}", e)),
526                                None,
527                                AppErrorType::InternalError,
528                            )),
529                        }
530                    }
531                };
532
533                match channel_result {
534                    Ok(mut ch) => {
535                        let res = ch.rpc_server(
536                            handler.clone(),
537                            &routing_key,
538                            &exchange_name,
539                            &exchange_type,
540                            &queue_name,
541                            response_timeout,
542                            &queue_options,
543                        )
544                        .await;
545                        if res.is_ok() {
546                            let key = (queue_name.clone(), routing_key.clone(), exchange_name.clone());
547                            self.rpc_subscribe_backup.entry(key).or_insert(RPCSubscribeBackup {
548                                exchange_type: exchange_type.clone(),
549                                handler: handler,
550                                response_timeout,
551                                queue_options,
552                            });
553                        }
554                        let _ = response.send(res);
555                    }
556                    Err(err) => {
557                        let _ = response.send(Err(err));
558                    }
559                }
560            }
561            ConnectionCommand::RpcClient {
562                exchange_name,
563                routing_key,
564                body,
565                content_type,
566                content_encoding,
567                response_timeout_millis,
568                delivery_mode,
569                expiration,
570                response,
571                confirm,
572            } => {
573                if let Some(confirm) = confirm {
574                    self.message_number += 1;
575                    self.pending_confirmations
576                        .insert(self.message_number, confirm);
577                    let _ = channel
578                        .rpc_client(
579                            &exchange_name,
580                            &routing_key,
581                            body,
582                            &content_type,
583                            content_encoding,
584                            response_timeout_millis,
585                            delivery_mode,
586                            expiration,
587                            response,
588                            self.channel_tx.clone(),
589                            Some(self.message_number),
590                        )
591                        .await;
592                } else {
593                    let _ = channel
594                        .rpc_client(
595                            &exchange_name,
596                            &routing_key,
597                            body,
598                            &content_type,
599                            content_encoding,
600                            response_timeout_millis,
601                            delivery_mode,
602                            expiration,
603                            response,
604                            self.channel_tx.clone(),
605                            None,
606                        )
607                        .await;
608                }
609            }
610            ConnectionCommand::UpdateSecret {
611                new_secret,
612                reason,
613                response,
614            } => {
615                if let Some(connection) = &mut self.connection {
616                    let _ = response.send(
617                        connection
618                            .update_secret(new_secret.as_str(), reason.as_str())
619                            .await
620                            .map_err(AppError::from),
621                    );
622                } else {
623                    let _ = response.send(Err(AppError::new(
624                        Some("connection is to openned".to_owned()),
625                        None,
626                        AppErrorType::UnexpectedResultError,
627                    )));
628                }
629            }
630            _ => {}
631        }
632    }
633}