Skip to main content

amqp_client_rust/api/
connection_manager.rs

1use super::callback::MyConnectionCallback;
2use crate::domain::config::Config;
3use crate::{
4    api::{
5        callback::MyChannelCallback,
6        channel::AsyncChannel,
7        utils::{
8            Confirmations, ContentEncoding, DeliveryMode, Handler, ChannelCmd,
9            QueueOptions, RPCHandler
10        },
11    },
12    errors::{AppError, AppErrorType},
13};
14#[cfg(feature = "tls")]
15use amqprs::tls::TlsAdaptor;
16use amqprs::{
17    channel::{ConfirmSelectArguments},
18    connection::{Connection, OpenConnectionArguments},
19};
20use dashmap::DashMap;
21use std::collections::HashMap;
22use std::error::Error as StdError;
23use std::{
24    collections::{BTreeMap, VecDeque},
25    sync::{
26        Arc,
27        atomic::Ordering,
28    },
29};
30use tokio::{
31    sync::{Mutex, mpsc, oneshot},
32    time::{Duration, sleep},
33};
34use tracing::error;
35
36// Command Enum for Actor Communication
37pub enum ConnectionCommand {
38    Publish {
39        exchange_name: String,
40        routing_key: String,
41        body: Vec<u8>,
42        content_type: String,
43        content_encoding: ContentEncoding,
44        delivery_mode: DeliveryMode,
45        expiration: Option<u32>,
46        response: oneshot::Sender<Result<(), AppError>>,
47        confirm: Option<oneshot::Sender<Result<(), AppError>>>,
48    },
49    Subscribe {
50        handler: Handler,
51        routing_key: String,
52        exchange_name: String,
53        exchange_type: String,
54        queue_name: String,
55        response: oneshot::Sender<Result<(), AppError>>,
56        process_timeout: Option<Duration>,
57        queue_options: QueueOptions,
58    },
59    RpcServer {
60        handler: RPCHandler,
61        routing_key: String,
62        exchange_name: String,
63        exchange_type: String,
64        queue_name: String,
65        response: oneshot::Sender<Result<(), AppError>>,
66        response_timeout: Option<Duration>,
67        queue_options: QueueOptions,
68    },
69    RpcClient {
70        exchange_name: String,
71        routing_key: String,
72        body: Vec<u8>,
73        content_type: String,
74        content_encoding: ContentEncoding,
75        response_timeout_millis: u32,
76        delivery_mode: DeliveryMode,
77        expiration: Option<u32>,
78        response: oneshot::Sender<Result<Vec<u8>, AppError>>,
79        confirm: Option<oneshot::Sender<Result<(), AppError>>>,
80    },
81    Close {
82        response: oneshot::Sender<()>,
83    },
84    CheckConnection {},
85    UpdateSecret {
86        new_secret: String,
87        reason: String,
88        response: oneshot::Sender<Result<(), AppError>>,
89    },
90}
91
92// Data structures for backup/restore on reconnection
93struct SubscribeBackup {
94    exchange_type: String,
95    handler: Handler,
96    process_timeout: Option<Duration>,
97    queue_options: QueueOptions,
98}
99
100struct RPCSubscribeBackup {
101    exchange_type: String,
102    handler: RPCHandler,
103    response_timeout: Option<Duration>,
104    queue_options: QueueOptions,
105}
106
107
108pub struct ConnectionManager {
109    config: Arc<Config>,
110    tx: mpsc::UnboundedSender<ConnectionCommand>,
111    rx: mpsc::UnboundedReceiver<ConnectionCommand>,
112    connection: Option<Connection>,
113    channel: Option<AsyncChannel>,
114    pending_commands: VecDeque<ConnectionCommand>,
115    subscribe_backup: HashMap<(String, String, String), SubscribeBackup>,
116    rpc_subscribe_backup: HashMap<(String, String, String), RPCSubscribeBackup>,
117    publisher_confirms: Confirmations,
118    pending_confirmations: BTreeMap<u64, oneshot::Sender<Result<(), AppError>>>,
119    channel_rx: mpsc::UnboundedReceiver<ChannelCmd>,
120    channel_tx: mpsc::UnboundedSender<ChannelCmd>,
121    message_number: u64,
122    auto_ack: bool,
123    prefetch_count: Option<u16>,
124    current_reconnect_delay: u16,
125    queues: HashMap<String, (AsyncChannel, QueueOptions)>,
126}
127
128impl ConnectionManager {
129    pub fn new(
130        config: Arc<Config>,
131        tx: mpsc::UnboundedSender<ConnectionCommand>,
132        rx: mpsc::UnboundedReceiver<ConnectionCommand>,
133        publisher_confirms: Confirmations,
134        auto_ack: bool,
135        prefetch_count: Option<u16>,
136    ) -> Self {
137        let (channel_tx, channel_rx) = mpsc::unbounded_channel();
138        Self {
139            config,
140            tx,
141            rx,
142            connection: None,
143            channel: None,
144            pending_commands: VecDeque::new(),
145            subscribe_backup: HashMap::new(),
146            rpc_subscribe_backup: HashMap::new(),
147            publisher_confirms,
148            pending_confirmations: BTreeMap::new(),
149            channel_rx,
150            channel_tx,
151            message_number: 0,
152            auto_ack,
153            prefetch_count,
154            current_reconnect_delay: 1,
155            queues: HashMap::new(),
156        }
157    }
158
159    pub async fn run(mut self) {
160        self.connect().await;
161
162        let mut health_check_interval = tokio::time::interval(Duration::from_secs(1));
163        let mut intentional_close = false;
164        loop {
165            tokio::select! {
166                Some(cmd) = self.channel_rx.recv() => {
167                    match cmd {
168                        ChannelCmd::PublishAck((tag, multiple)) => {
169                            if multiple {
170                                while let Some(entry) = self.pending_confirmations.first_entry() {
171                                    if entry.key() > &tag {
172                                        break;
173                                    }
174                                    let confirm = entry.remove();
175                                    let _ = confirm.send(Ok(()));
176                                }
177                            } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
178                                let _ = confirm.send(Ok(()));
179                            }
180                        },
181                        ChannelCmd::PublishNack((tag, multiple)) => {
182                            if multiple {
183                                while let Some(entry) = self.pending_confirmations.first_entry() {
184                                    if entry.key() > &tag {
185                                        break;
186                                    }
187                                    let confirm = entry.remove();
188                                    let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
189                                }
190                            } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
191                                let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
192                            }
193                        },
194                        ChannelCmd::ReOpen(channel_id) => {
195                            for channel in self.queues.values_mut() {
196                                let _ = channel.0.reopen(channel_id).await;
197                            }
198                        }
199                    }
200                }
201                Some(cmd) = self.rx.recv() => {
202                    match cmd {
203                        ConnectionCommand::Close{ response } => {
204                            intentional_close = true;
205                            if let Some(channel) = &self.channel {
206                                channel.dispose().await;
207                            }
208                            if let Some(conn) = &self.connection {
209                                let _ = conn.clone().close().await;
210                            }
211
212                            let _ = response.send(());
213                            continue;
214                        },
215                        ConnectionCommand::CheckConnection{} => {
216                            continue;
217                        },
218                        _ => {
219                            if self.is_connected() {
220                                self.process_command(cmd).await;
221                            } else {
222                                self.pending_commands.push_back(cmd);
223                            }
224                        }
225                    }
226                },
227                _ = health_check_interval.tick() => {
228                    if !self.is_connected() && !intentional_close {
229                        sleep(Duration::from_secs(self.current_reconnect_delay as u64 -1)).await;
230                        self.connect().await;
231                        self.current_reconnect_delay = std::cmp::min(self.current_reconnect_delay * 2, 30);
232                    }
233                }
234            }
235        }
236    }
237
238    fn is_connected(&self) -> bool {
239        self.connection.as_ref().is_some_and(|c| c.is_open())
240            && self.channel.as_ref().is_some_and(|c| c.channel.is_open())
241    }
242
243    async fn connect(&mut self) {
244        #[cfg(feature = "default")]
245        let mut options = OpenConnectionArguments::new(
246            &self.config.host,
247            self.config.port,
248            &self.config.username,
249            &self.config.password,
250        );
251        options.virtual_host(&self.config.virtual_host);
252        #[cfg(feature = "tls")]
253        if let Some(tls_adaptor) = &self.config.tls_adaptor {
254            options = options.tls_adaptor(tls_adaptor.clone()).finish();
255        }
256        match Connection::open(&options).await {
257            Ok(conn) => {
258                if let Err(e) = conn
259                    .register_callback(MyConnectionCallback {
260                        sender: self.tx.clone(),
261                    })
262                    .await
263                {
264                    error!("Failed to register connection callback: {}", e);
265                }
266                self.current_reconnect_delay = 1;
267
268                self.connection = Some(conn.clone());
269                let conn_mutex = Arc::new(Mutex::new(conn.clone()));
270
271                if let Ok(ch) = self.open_channel(&conn, conn_mutex.clone(), self.channel.as_ref()).await {
272                    self.channel = Some(ch);
273                }
274                let old_queues = std::mem::take(&mut self.queues);
275
276                for (queue_name, (latest_channel, options)) in old_queues {
277                    if let Ok(ch) = self.open_channel(&conn, conn_mutex.clone(), Some(&latest_channel)).await {
278                        self.queues.insert(queue_name, (ch, options));
279                    } else {
280                        error!("Failed to open channel for queue {} during reconnection", queue_name);
281                    }
282                }
283                
284                self.message_number = 0;
285                self.restore_subscriptions().await;
286
287                while let Some(cmd) = self.pending_commands.pop_front() {
288                    self.process_command(cmd).await;
289                }
290            }
291            Err(e) => {
292                error!("Failed to connect: {}", e);
293            }
294        }
295    }
296
297    async fn open_channel(
298        &self,
299        conn: &Connection,
300        conn_mutex: Arc<Mutex<Connection>>,
301        latest_channel: Option<&AsyncChannel>,
302    ) -> Result<AsyncChannel, AppError> {
303        if let Ok(ch) = conn.open_channel(None).await {
304            if let Err(e) = ch
305                .register_callback(MyChannelCallback {
306                    channel_tx: self.channel_tx.clone(),
307                })
308                .await
309            {
310                error!("Failed to register channel callback: {}", e);
311            }
312
313            if self.publisher_confirms == Confirmations::PublisherConfirms
314                || self.publisher_confirms == Confirmations::RPCClientPublisherConfirms
315            {
316                let args = ConfirmSelectArguments::default();
317                let _ = ch.confirm_select(args).await;
318            }
319            if let Some(latest_channel) = latest_channel
320                && latest_channel.rpc_consumer_started.load(Ordering::SeqCst)
321            {
322                let mut async_ch = AsyncChannel::new(
323                    ch,
324                    conn_mutex,
325                    self.channel_tx.clone(),
326                    latest_channel.rpc_futures.clone(),
327                    self.publisher_confirms,
328                    self.auto_ack,
329                    self.prefetch_count,
330                    Some(latest_channel.aux_queue_name.clone()),
331                );
332                let _ = async_ch.start_rpc_consumer().await;
333                Ok(async_ch)
334            } else {
335                Ok(AsyncChannel::new(
336                    ch,
337                    conn_mutex,
338                    self.channel_tx.clone(),
339                    Arc::new(DashMap::new()),
340                    self.publisher_confirms,
341                    self.auto_ack,
342                    self.prefetch_count,
343                    None
344                ))
345            }
346        } else {
347            Err(AppError::new(
348                Some("No connection available".to_string()),
349                None,
350                AppErrorType::InternalError,
351            ))
352        }
353    }
354
355    async fn restore_subscriptions(&mut self) {
356        for (keys, values) in &self.subscribe_backup {
357            if let Some((isolated_ch, _)) = self.queues.get(&keys.0) {
358                let _ = isolated_ch.subscribe(
359                    values.handler.clone(),
360                    &keys.1,
361                    &keys.2,
362                    &values.exchange_type,
363                    &keys.0,
364                    values.process_timeout,
365                    &values.queue_options,
366                )
367                .await;
368            }
369        }
370        for (keys, values) in &self.rpc_subscribe_backup {
371            if let Some((isolated_ch, queue_options)) = self.queues.get_mut(&keys.0) {
372                let _ = isolated_ch
373                    .rpc_server(
374                        values.handler.clone(),
375                        &keys.1,
376                        &keys.2,
377                        &values.exchange_type,
378                        &keys.0,
379                        values.response_timeout,
380                        queue_options
381                    )
382                    .await;
383                }
384        }
385    }
386
387    async fn process_command(&mut self, cmd: ConnectionCommand) {
388        let channel = match &mut self.channel {
389            Some(c) => c,
390            None => {
391                self.pending_commands.push_front(cmd);
392                return;
393            }
394        };
395
396        match cmd {
397            ConnectionCommand::Publish {
398                exchange_name,
399                routing_key,
400                body,
401                content_type,
402                content_encoding,
403                delivery_mode,
404                expiration,
405                response,
406                confirm,
407            } => {
408                if let Some(confirm) = confirm {
409                    self.message_number += 1;
410                    self.pending_confirmations
411                        .insert(self.message_number, confirm);
412                }
413                let res = channel
414                    .publish(
415                        &exchange_name,
416                        &routing_key,
417                        body,
418                        &content_type,
419                        content_encoding,
420                        delivery_mode,
421                        expiration,
422                    )
423                    .await;
424                let _ = response.send(res);
425            }
426            ConnectionCommand::Subscribe {
427                handler,
428                routing_key,
429                exchange_name,
430                exchange_type,
431                queue_name,
432                response,
433                process_timeout,
434                queue_options,
435            } => {
436                let conn = self.connection.clone().unwrap();
437
438                let existing_queue = self.queues.get(&queue_name).cloned();
439
440                let channel_result = match existing_queue {
441                    Some((ch, args)) if args != queue_options => {
442                        self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
443                        Ok(ch)
444                    }
445                    Some((ch, _)) => Ok(ch),
446                    None => {
447                        match self.open_channel(&conn, Arc::new(Mutex::new(conn.clone())), None).await {
448                            Ok(ch) => {
449                                self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
450                                Ok(ch)
451                            }
452                            Err(e) => Err(AppError::new(
453                                Some(format!("Channel not Openned: Error {}", e)),
454                                None,
455                                AppErrorType::InternalError,
456                            )),
457                        }
458                    }
459                };
460
461                match channel_result {
462                    Ok(ch) => {
463                        let res = ch.subscribe(
464                            handler.clone(),
465                            &routing_key,
466                            &exchange_name,
467                            &exchange_type,
468                            &queue_name,
469                            process_timeout,
470                            &queue_options,
471                        ).await;
472                        if res.is_ok() {
473                            let key = (queue_name.clone(), routing_key.clone(), exchange_name.clone());
474                            self.subscribe_backup.entry(key).or_insert(SubscribeBackup {
475                                exchange_type: exchange_type.clone(),
476                                handler: handler,
477                                process_timeout,
478                                queue_options: queue_options.clone(),
479                            });
480                        }
481                        let _ = response.send(res);
482                    }
483                    Err(err) => {
484                        let _ = response.send(Err(err));
485                    }
486                }
487            }
488            ConnectionCommand::RpcServer {
489                handler,
490                routing_key,
491                exchange_name,
492                exchange_type,
493                queue_name,
494                response,
495                response_timeout,
496                queue_options,
497            } => {
498
499                let conn = self.connection.clone().unwrap();
500
501                let existing_queue = self.queues.get_mut(&queue_name).cloned();
502
503                let channel_result = match existing_queue {
504                    Some((ch, args)) if args != queue_options => {
505                        self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
506                        Ok(ch)
507                    }
508                    Some((ch, _)) => Ok(ch),
509                    None => {
510                        match self.open_channel(&conn, Arc::new(Mutex::new(conn.clone())), None).await {
511                            Ok(ch) => {
512                                self.queues.insert(queue_name.clone(), (ch.clone(), queue_options.clone()));
513                                Ok(ch)
514                            }
515                            Err(e) => Err(AppError::new(
516                                Some(format!("Channel not Openned: Error {}", e)),
517                                None,
518                                AppErrorType::InternalError,
519                            )),
520                        }
521                    }
522                };
523
524                match channel_result {
525                    Ok(mut ch) => {
526                        let res = ch.rpc_server(
527                            handler.clone(),
528                            &routing_key,
529                            &exchange_name,
530                            &exchange_type,
531                            &queue_name,
532                            response_timeout,
533                            &queue_options,
534                        )
535                        .await;
536                        if res.is_ok() {
537                            let key = (queue_name.clone(), routing_key.clone(), exchange_name.clone());
538                            self.rpc_subscribe_backup.entry(key).or_insert(RPCSubscribeBackup {
539                                exchange_type: exchange_type.clone(),
540                                handler: handler,
541                                response_timeout,
542                                queue_options,
543                            });
544                        }
545                        let _ = response.send(res);
546                    }
547                    Err(err) => {
548                        let _ = response.send(Err(err));
549                    }
550                }
551            }
552            ConnectionCommand::RpcClient {
553                exchange_name,
554                routing_key,
555                body,
556                content_type,
557                content_encoding,
558                response_timeout_millis,
559                delivery_mode,
560                expiration,
561                response,
562                confirm,
563            } => {
564                if let Some(confirm) = confirm {
565                    self.message_number += 1;
566                    self.pending_confirmations
567                        .insert(self.message_number, confirm);
568                    let _ = channel
569                        .rpc_client(
570                            &exchange_name,
571                            &routing_key,
572                            body,
573                            &content_type,
574                            content_encoding,
575                            response_timeout_millis,
576                            delivery_mode,
577                            expiration,
578                            response,
579                            self.channel_tx.clone(),
580                            Some(self.message_number),
581                        )
582                        .await;
583                } else {
584                    let _ = channel
585                        .rpc_client(
586                            &exchange_name,
587                            &routing_key,
588                            body,
589                            &content_type,
590                            content_encoding,
591                            response_timeout_millis,
592                            delivery_mode,
593                            expiration,
594                            response,
595                            self.channel_tx.clone(),
596                            None,
597                        )
598                        .await;
599                }
600            }
601            ConnectionCommand::UpdateSecret {
602                new_secret,
603                reason,
604                response,
605            } => {
606                if let Some(connection) = &mut self.connection {
607                    let _ = response.send(
608                        connection
609                            .update_secret(new_secret.as_str(), reason.as_str())
610                            .await
611                            .map_err(AppError::from),
612                    );
613                } else {
614                    let _ = response.send(Err(AppError::new(
615                        Some("connection is to openned".to_owned()),
616                        None,
617                        AppErrorType::UnexpectedResultError,
618                    )));
619                }
620            }
621            _ => {}
622        }
623    }
624}