1use std::{collections::{BTreeMap, VecDeque}, future::Future, pin::Pin, sync::{Arc,atomic::{AtomicBool, Ordering}}};
2use dashmap::DashMap;
3use tokio::{sync::{Mutex, mpsc, oneshot}, time::{Duration, sleep, timeout}};
4use tracing::error;
5use crate::{api::{
6 callback::MyChannelCallback, channel::AsyncChannel, utils::{Confirmations, ContentEncoding, DeliveryMode, Handler, Message, PendingCmd, RPCHandler, compress}
7}, errors::{AppError, AppErrorType}};
8use amqprs::{channel::{ConfirmSelectArguments}, connection::{Connection, OpenConnectionArguments}};
9use crate::domain::config::Config;
10use super::callback::MyConnectionCallback;
11#[cfg(feature = "tls")]
12use amqprs::tls::TlsAdaptor;
13use std::error::Error as StdError;
14
15pub enum ConnectionCommand {
17 Publish {
18 exchange_name: String,
19 routing_key: String,
20 body: Vec<u8>,
21 content_type: String,
22 content_encoding: ContentEncoding,
23 delivery_mode: DeliveryMode,
24 expiration: Option<u32>,
25 response: oneshot::Sender<Result<(), AppError>>,
26 confirm: Option<oneshot::Sender<Result<(), AppError>>>,
27 },
28 Subscribe {
29 handler: Handler,
30 routing_key: String,
31 exchange_name: String,
32 exchange_type: String,
33 queue_name: String,
34 response: oneshot::Sender<Result<(), AppError>>,
35 process_timeout: Option<Duration>,
36 },
37 RpcServer {
38 handler: RPCHandler,
39 routing_key: String,
40 exchange_name: String,
41 exchange_type: String,
42 queue_name: String,
43 response: oneshot::Sender<Result<(), AppError>>,
44 response_timeout: Option<Duration>,
45 },
46 RpcClient {
47 exchange_name: String,
48 routing_key: String,
49 body: Vec<u8>,
50 content_type: String,
51 content_encoding: ContentEncoding,
52 response_timeout_millis: u32,
53 delivery_mode: DeliveryMode,
54 expiration: Option<u32>,
55 response: oneshot::Sender<Result<Vec<u8>, AppError>>,
56 confirm: Option<oneshot::Sender<Result<(), AppError>>>,
57 },
58 Close {
59 response: oneshot::Sender<()>,
60 },
61 CheckConnection {
62 },
63 UpdateSecret {
64 new_secret: String,
65 reason: String,
66 response: oneshot::Sender<Result<(), AppError>>,
67 }
68}
69
70struct SubscribeBackup {
72 queue: String,
73 exchange_name: String,
74 exchange_type: String,
75 handler: Handler,
76 routing_key: String,
77 process_timeout: Option<Duration>
78}
79
80struct RPCSubscribeBackup {
81 queue: String,
82 exchange_name: String,
83 exchange_type: String,
84 handler: RPCHandler,
85 routing_key: String,
86 response_timeout: Option<Duration>,
87}
88
89#[derive(Clone)]
91pub struct AsyncConnection {
92 sender: mpsc::UnboundedSender<ConnectionCommand>,
93 publisher_confirms: Confirmations,
94 is_closing: Arc<AtomicBool>,
95}
96
97impl AsyncConnection {
98 pub fn new(config: Arc<Config>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>) -> Self {
99 let (tx, rx) = mpsc::unbounded_channel();
100
101 let manager = ConnectionManager::new(config, tx.clone(), rx, publisher_confirms, auto_ack, pre_fetch_count);
102 tokio::spawn(async move {
103 manager.run().await;
104 });
105 Self { sender: tx, publisher_confirms, is_closing: Arc::new(AtomicBool::new(false)) }
106 }
107
108 pub async fn publish(
109 &self, exchange_name: &str, routing_key: &str, body: impl Into<Vec<u8>>,
110 content_type: &str, content_encoding: ContentEncoding, command_timeout: Option<Duration>,
111 delivery_mode: DeliveryMode, expiration: Option<u32>
112 ) -> Result<(), AppError> {
113 if self.is_closing.load(Ordering::Acquire) {
114 return Err(AppError::new(
115 Some("Connection is shutting down".to_owned()),
116 None,
117 AppErrorType::InternalError ));
119 }
120 let (resp_tx, resp_rx) = oneshot::channel();
121 let body = compress(body, content_encoding)?;
122 if self.publisher_confirms == Confirmations::PublisherConfirms {
123 let confirmation = oneshot::channel();
124
125 let cmd = ConnectionCommand::Publish {
126 exchange_name: exchange_name.to_string(),
127 routing_key: routing_key.to_string(),
128 body,
129 content_type: content_type.to_string(),
130 content_encoding,
131 delivery_mode,
132 expiration,
133 response: resp_tx,
134 confirm: Some(confirmation.0),
135 };
136 let (_, _) = tokio::try_join!(self.send_command(cmd, resp_rx, command_timeout), async {
137 match timeout(command_timeout.unwrap_or(Duration::from_secs(16)), confirmation.1).await {
138 Ok(Ok(res)) => res,
139 Ok(Err(_)) => Err(AppError::new(Some("Confirm channel closed".to_owned()), None, AppErrorType::InternalError)),
140 Err(_) => Err(AppError::new(Some("Timeout waiting for confirmation".to_owned()), None, AppErrorType::TimeoutError)),
141 }
142 })?;
143 Ok(())
144 } else {
145 let cmd = ConnectionCommand::Publish {
146 exchange_name: exchange_name.to_string(),
147 routing_key: routing_key.to_string(),
148 body,
149 content_type: content_type.to_string(),
150 content_encoding,
151 delivery_mode,
152 expiration,
153 response: resp_tx,
154 confirm: None
155 };
156 self.send_command(cmd, resp_rx, command_timeout).await
157 }
158 }
159
160 pub async fn subscribe(
161 &self,
162 handler: Handler,
163 routing_key: &str,
164 exchange_name: &str,
165 exchange_type: &str,
166 queue_name: &str,
167 process_timeout: Option<Duration>,
168 timeout_duration: Option<Duration>
169 ) -> Result<(), AppError> {
170 if self.is_closing.load(Ordering::Acquire) {
171 return Err(AppError::new(
172 Some("Connection is shutting down".to_string()),
173 None,
174 AppErrorType::InternalError ));
176 }
177 let (resp_tx, resp_rx) = oneshot::channel();
178 let cmd = ConnectionCommand::Subscribe {
179 handler,
180 routing_key: routing_key.to_string(),
181 exchange_name: exchange_name.to_string(),
182 exchange_type: exchange_type.to_string(),
183 queue_name: queue_name.to_string(),
184 response: resp_tx,
185 process_timeout,
186 };
187 self.send_command(cmd, resp_rx, timeout_duration).await
188 }
189
190 pub async fn rpc_server(
191 &self,
192 handler: RPCHandler,
193 routing_key: &str,
194 exchange_name: &str,
195 exchange_type: &str,
196 queue_name: &str,
197 response_timeout: Option<Duration>,
198 timeout_duration: Option<Duration>
199 ) -> Result<(), AppError> {
200 if self.is_closing.load(Ordering::Acquire) {
201 return Err(AppError::new(
202 Some("Connection is shutting down".to_string()),
203 None,
204 AppErrorType::InternalError
205 ));
206 }
207 let (resp_tx, resp_rx) = oneshot::channel();
208 let cmd = ConnectionCommand::RpcServer {
209 handler,
210 routing_key: routing_key.to_string(),
211 exchange_name: exchange_name.to_string(),
212 exchange_type: exchange_type.to_string(),
213 queue_name: queue_name.to_string(),
214 response: resp_tx,
215 response_timeout,
216 };
217 self.send_command(cmd, resp_rx, timeout_duration).await
218 }
219
220 pub async fn rpc_client(
221 &self,
222 exchange_name: &str,
223 routing_key: &str,
224 body: impl Into<Vec<u8>>,
225 content_type: &str,
226 content_encoding: ContentEncoding,
227 response_timeout_millis: u32,
228 command_timeout: Option<Duration>,
229 delivery_mode: DeliveryMode,
230 expiration: Option<u32>,
231 ) -> Result<Vec<u8>, AppError> {
232 if self.is_closing.load(Ordering::Acquire) {
233 return Err(AppError::new(
234 Some("Connection is shutting down".to_string()),
235 None,
236 AppErrorType::InternalError
237 ));
238 }
239 let (resp_tx, resp_rx) = oneshot::channel();
240 let body = compress(body.into(), content_encoding)?;
241 if self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
242 let confirmation = oneshot::channel();
243 let cmd = ConnectionCommand::RpcClient {
244 exchange_name: exchange_name.to_string(),
245 routing_key: routing_key.to_string(),
246 body,
247 content_type: content_type.to_string(),
248 content_encoding,
249 response_timeout_millis,
250 delivery_mode,
251 expiration,
252 response: resp_tx,
253 confirm: Some(confirmation.0),
254 };
255 let confirmation = async {
256 match timeout(command_timeout.unwrap_or(Duration::from_secs(16)), confirmation.1).await {
257 Ok(Ok(res)) => res,
258 Ok(Err(_)) => Err(AppError::new(Some("Confirm channel closed".to_owned()), None, AppErrorType::InternalError)),
259 Err(_) => Err(AppError::new(Some("Timeout waiting for confirmation".to_owned()), None, AppErrorType::TimeoutError)),
260 }
261 };
262 let (response, _) = tokio::try_join!(self.send_command(cmd, resp_rx, command_timeout), confirmation)?;
263 Ok(response)
264 } else {
265 let cmd = ConnectionCommand::RpcClient {
266 exchange_name: exchange_name.to_string(),
267 routing_key: routing_key.to_string(),
268 body,
269 content_type: content_type.to_string(),
270 content_encoding,
271 response_timeout_millis,
272 delivery_mode,
273 expiration,
274 response: resp_tx,
275 confirm: None,
276 };
277 self.send_command(cmd, resp_rx, command_timeout).await
278 }
279 }
280
281 pub async fn update_secret(&self, new_secret: &str, reason: &str, command_timeout: Option<Duration>) -> Result<(), AppError> {
282 if self.is_closing.load(Ordering::Acquire) {
283 return Err(AppError::new(
284 Some("Connection is shutting down".to_string()),
285 None,
286 AppErrorType::InternalError
287 ));
288 }
289 let (resp_tx, resp_rx) = oneshot::channel();
290 let cmd = ConnectionCommand::UpdateSecret {
291 new_secret: new_secret.to_string(),
292 reason: reason.to_string(),
293 response: resp_tx,
294 };
295 self.send_command(cmd, resp_rx, command_timeout).await
296 }
297
298 async fn send_command<T>(&self, cmd: ConnectionCommand, rx: oneshot::Receiver<Result<T, AppError>>, command_timeout: Option<Duration>) -> Result<T, AppError> {
299 if self.sender.send(cmd).is_err() {
300 return Err(AppError::new(Some("Connection manager dropped".to_string()), None, AppErrorType::InternalError));
301 }
302
303 match command_timeout {
304 Some(dur) => match timeout(dur, rx).await {
305 Ok(Ok(res)) => res,
306 Ok(Err(_)) => Err(AppError::new(Some("Response channel closed".to_owned()), None, AppErrorType::InternalError)),
307 Err(_) => Err(AppError::new(Some("Timeout waiting for connection".to_owned()), None, AppErrorType::TimeoutError)),
308 },
309 None => match rx.await {
310 Ok(res) => res,
311 Err(_) => Err(AppError::new(Some("Response channel closed".to_owned()), None, AppErrorType::InternalError)),
312 }
313 }
314 }
315
316 pub async fn close(&self) -> Result<(), Box<dyn std::error::Error>> {
317 self.is_closing.store(true, Ordering::Release);
318 let (tx, rx) = oneshot::channel();
319 self.sender.send(ConnectionCommand::Close { response: tx })?;
320 rx.await?;
321 Ok(())
322 }
323}
324
325
326struct ConnectionManager {
327 config: Arc<Config>,
328 tx: mpsc::UnboundedSender<ConnectionCommand>,
329 rx: mpsc::UnboundedReceiver<ConnectionCommand>,
330 connection: Option<Connection>,
331 channel: Option<AsyncChannel>,
332 pending_commands: VecDeque<ConnectionCommand>,
333 subscribe_backup: Vec<SubscribeBackup>,
334 rpc_subscribe_backup: Vec<RPCSubscribeBackup>,
335 publisher_confirms: Confirmations,
336 pending_confirmations: BTreeMap<u64, oneshot::Sender<Result<(), AppError>>>,
337 pending_rx: mpsc::UnboundedReceiver<PendingCmd>,
338 pending_tx: mpsc::UnboundedSender<PendingCmd>,
339 message_number: u64,
340 auto_ack: bool,
341 pre_fetch_count: Option<u16>,
342 current_reconnect_delay: u16,
343}
344
345impl ConnectionManager {
346 fn new(config: Arc<Config>, tx: mpsc::UnboundedSender<ConnectionCommand>, rx: mpsc::UnboundedReceiver<ConnectionCommand>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>) -> Self {
347 let (pending_tx, pending_rx) = mpsc::unbounded_channel();
348 Self {
349 config,
350 tx,
351 rx,
352 connection: None,
353 channel: None,
354 pending_commands: VecDeque::new(),
355 subscribe_backup: Vec::new(),
356 rpc_subscribe_backup: Vec::new(),
357 publisher_confirms,
358 pending_confirmations: BTreeMap::new(),
359 pending_rx,
360 pending_tx,
361 message_number: 0,
362 auto_ack,
363 pre_fetch_count,
364 current_reconnect_delay: 1,
365 }
366 }
367
368 async fn run(mut self) {
369 self.connect().await;
370
371 let mut health_check_interval = tokio::time::interval(Duration::from_secs(1));
372 let mut intentional_close = false;
373 loop {
374 tokio::select! {
375 Some(cmd) = self.pending_rx.recv() => {
376 match cmd {
377 PendingCmd::Ack((tag, multiple)) => {
378 if multiple {
379 while let Some(entry) = self.pending_confirmations.first_entry() {
380 if entry.key() > &tag {
381 break;
382 }
383 let confirm = entry.remove();
384 let _ = confirm.send(Ok(()));
385 }
386 } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
387 let _ = confirm.send(Ok(()));
388 }
389 },
390 PendingCmd::Nack((tag, multiple)) => {
391 if multiple {
392 while let Some(entry) = self.pending_confirmations.first_entry() {
393 if entry.key() > &tag {
394 break; }
396 let confirm = entry.remove();
397 let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
398 }
399 } else if let Some(confirm) = self.pending_confirmations.remove(&tag) {
400 let _ = confirm.send(Err(AppError { message: None, description: None, error_type: AppErrorType::NackError }));
401 }
402 },
403 }
404 }
405 Some(cmd) = self.rx.recv() => {
406 match cmd {
407 ConnectionCommand::Close{ response } => {
408 intentional_close = true;
409 if let Some(channel) = &self.channel {
410 channel.dispose().await;
411 }
412 if let Some(conn) = &self.connection {
413 let _ = conn.clone().close().await;
414 }
415
416 let _ = response.send(());
417 continue;
418 },
419 ConnectionCommand::CheckConnection{} => {
420 continue;
421 },
422 _ => {
423 if self.is_connected() {
424 self.process_command(cmd).await;
425 } else {
426 self.pending_commands.push_back(cmd);
427 }
428 }
429 }
430 },
431 _ = health_check_interval.tick() => {
432 if !self.is_connected() && !intentional_close {
433 sleep(Duration::from_secs(self.current_reconnect_delay as u64 -1)).await;
434 self.connect().await;
435 self.current_reconnect_delay = std::cmp::min(self.current_reconnect_delay * 2, 30);
436 }
437 }
438 }
439 }
440 }
441
442 fn is_connected(&self) -> bool {
443 self.connection.as_ref().is_some_and(|c| c.is_open())
444 && self.channel.as_ref().is_some_and(|c| c.channel.is_open())
445 }
446
447 async fn connect(&mut self) {
448 #[cfg(feature = "default")]
449 let mut options = OpenConnectionArguments::new(
450 &self.config.host,
451 self.config.port,
452 &self.config.username,
453 &self.config.password,
454 );
455 options.virtual_host(&self.config.virtual_host);
456 #[cfg(feature = "tls")]
457 if let Some(tls_adaptor) = &self.config.tls_adaptor {
458 options = options.tls_adaptor(
459 tls_adaptor.clone()
460 ).finish();
461 }
462 match Connection::open(&options).await {
463 Ok(conn) => {
464 if let Err(e) = conn.register_callback(MyConnectionCallback{sender: self.tx.clone()}).await {
465 error!("Failed to register connection callback: {}", e);
466 }
467 self.current_reconnect_delay = 1;
468
469 self.connection = Some(conn.clone());
470 let conn_mutex = Arc::new(Mutex::new(conn.clone()));
471
472 if let Ok(ch) = conn.open_channel(None).await {
473 if let Err(e) = ch.register_callback(MyChannelCallback{sender_pending: self.pending_tx.clone()}).await {
474 error!("Failed to register channel callback: {}", e);
475 }
476
477 if self.publisher_confirms == Confirmations::PublisherConfirms || self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
478 let args = ConfirmSelectArguments::default();
479 let _ = ch.confirm_select(args).await;
480 }
481 self.message_number = 0;
482 if let Some(latest_channel) = &self.channel && latest_channel.rpc_consumer_started.load(Ordering::SeqCst){
483 let async_ch = AsyncChannel::new(ch, conn_mutex,latest_channel.rpc_futures.clone(), self.publisher_confirms, self.auto_ack, self.pre_fetch_count);
484 let _ = async_ch.start_rpc_consumer().await;
485 self.channel = Some(async_ch);
486 } else {
487 self.channel = Some(AsyncChannel::new(ch, conn_mutex, Arc::new(DashMap::new()), self.publisher_confirms, self.auto_ack, self.pre_fetch_count));
488 }
489
490 self.restore_subscriptions().await;
491
492 while let Some(cmd) = self.pending_commands.pop_front() {
493 self.process_command(cmd).await;
494 }
495 }
496 }
497 Err(e) => {
498 error!("Failed to connect: {}", e);
499 }
500 }
501 }
502
503 async fn restore_subscriptions(&mut self) {
504 if let Some(channel) = &mut self.channel {
505 for sub in &self.subscribe_backup {
506 let _ = channel.subscribe(
507 sub.handler.clone(),
508 &sub.routing_key,
509 &sub.exchange_name,
510 &sub.exchange_type,
511 &sub.queue,
512 sub.process_timeout,
513 ).await;
514 }
515 for sub in &self.rpc_subscribe_backup {
516 let _ = channel.rpc_server(
517 sub.handler.clone(),
518 &sub.routing_key,
519 &sub.exchange_name,
520 &sub.exchange_type,
521 &sub.queue,
522 sub.response_timeout,
523 ).await;
524 }
525 }
526 }
527
528 async fn process_command(&mut self, cmd: ConnectionCommand) {
529 let channel = match &mut self.channel {
530 Some(c) => c,
531 None => {
532 self.pending_commands.push_front(cmd);
533 return;
534 }
535 };
536
537 match cmd {
538 ConnectionCommand::Publish { exchange_name, routing_key, body, content_type, content_encoding, delivery_mode, expiration, response, confirm} => {
539 if let Some(confirm) = confirm {
540 self.message_number += 1;
541 self.pending_confirmations.insert(self.message_number, confirm);
542 }
543 let res = channel.publish(&exchange_name, &routing_key, body, &content_type, content_encoding, delivery_mode, expiration).await;
544 let _ = response.send(res);
545 },
546 ConnectionCommand::Subscribe { handler, routing_key, exchange_name, exchange_type, queue_name, response, process_timeout } => {
547 self.subscribe_backup.push(SubscribeBackup {
548 queue: queue_name.clone(),
549 exchange_name: exchange_name.clone(),
550 exchange_type: exchange_type.clone(),
551 handler: handler.clone(),
552 routing_key: routing_key.clone(),
553 process_timeout,
554 });
555
556 let res = channel.subscribe(handler, &routing_key, &exchange_name, &exchange_type, &queue_name, process_timeout).await;
557 let _ = response.send(res);
558 },
559 ConnectionCommand::RpcServer { handler, routing_key, exchange_name, exchange_type, queue_name, response, response_timeout } => {
560 self.rpc_subscribe_backup.push(RPCSubscribeBackup {
561 queue: queue_name.clone(),
562 exchange_name: exchange_name.clone(),
563 exchange_type: exchange_type.clone(),
564 handler: handler.clone(),
565 routing_key: routing_key.clone(),
566 response_timeout,
567 });
568 let res = channel.rpc_server(handler, &routing_key, &exchange_name, &exchange_type, &queue_name, response_timeout).await;
569 let _ = response.send(res);
570 },
571 ConnectionCommand::RpcClient { exchange_name, routing_key, body,
572 content_type, content_encoding, response_timeout_millis, delivery_mode, expiration, response, confirm } => {
573 if let Some(confirm) = confirm {
574 self.message_number += 1;
575 self.pending_confirmations.insert(self.message_number, confirm);
576 let _ = channel.rpc_client(&exchange_name, &routing_key, body,
577 &content_type, content_encoding, response_timeout_millis, delivery_mode, expiration, response, self.pending_tx.clone(), Some(self.message_number)).await;
578 } else {
579 let _ = channel.rpc_client(&exchange_name, &routing_key, body,
580 &content_type, content_encoding, response_timeout_millis, delivery_mode, expiration, response, self.pending_tx.clone(), None).await;
581 }
582 },
583 ConnectionCommand::UpdateSecret { new_secret, reason, response } => {
584 if let Some(connection) = &mut self.connection {
585 let _ = response.send(connection.update_secret(new_secret.as_str(), reason.as_str()).await.map_err(AppError::from));
586 } else {
587 let _ = response.send(Err(AppError::new(Some("connection is to openned".to_owned()), None, AppErrorType::UnexpectedResultError)));
588 }
589 },
590 _ => {
591 }
592 }
593 }
594}