1use crate::{
2 api::{callback::MyChannelCallback, consumers::{BroadRPCClientHandler, BroadRPCHandler, BroadSubscribeHandler, InternalRPCHandler, InternalSubscribeHandler}, utils::{ChannelCmd, ContentEncoding, DeliveryMode, Handler, QueueOptions, RPCHandler, TopicTrie}},
3 errors::{AppError, AppErrorType},
4};
5use amqprs::{
6 BasicProperties, FieldTable, channel::{
7 BasicCancelArguments, BasicConsumeArguments, BasicPublishArguments, BasicQosArguments, Channel, ConfirmSelectArguments, ExchangeDeclareArguments, QueueBindArguments, QueueDeclareArguments
8 }, connection::Connection
9};
10use arc_swap::ArcSwap;
11use dashmap::DashMap;
12use tracing::error;
13use std::{collections::HashMap, sync::atomic::{AtomicBool, AtomicUsize, Ordering}};
14use std::sync::Arc;
15use tokio::{sync::{Mutex, Notify, RwLock, mpsc::{self, UnboundedSender}, oneshot}, time::Duration};
16use uuid::Uuid;
17use crate::api::utils::Confirmations;
18
19
20
21#[derive(Clone)]
22pub struct AsyncChannel {
23 pub channel: Channel,
24 pub connection: Arc<Mutex<Connection>>,
25 pub aux_channel: Option<Channel>,
26 pub aux_queue_name: String,
27 pub rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>,
28 pub rpc_consumer_started: Arc<AtomicBool>,
29 consumers: Arc<DashMap<String, bool>>,
30 channel_tx: mpsc::UnboundedSender<ChannelCmd>,
31 subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>>>>,
32 rpc_subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>>>>,
33 publisher_confirms: Confirmations,
35 auto_ack: bool,
36 pre_fetch_count: Option<u16>,
37 consumer_tags: Arc<RwLock<Vec<String>>>,
38 in_flight: Arc<AtomicUsize>,
39 pub shutdown_notify: Arc<Notify>,
40}
41
42impl AsyncChannel {
43 pub fn new(channel: Channel, connection: Arc<Mutex<Connection>>, channel_tx: mpsc::UnboundedSender<ChannelCmd>, rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>, aux_queue_name: Option<String>) -> Self {
44 Self {
45 channel,
46 connection,
47 aux_channel: None,
48 aux_queue_name: aux_queue_name.unwrap_or_else(|| format!("amqp.{}", Uuid::new_v4())),
49 channel_tx,
50 rpc_futures,
51 rpc_consumer_started: Arc::new(AtomicBool::new(false)),
52 consumers: Arc::new(DashMap::new()),
53 subscribes: Arc::new(RwLock::new(HashMap::new())),
54 rpc_subscribes: Arc::new(RwLock::new(HashMap::new())),
55 publisher_confirms,
57 auto_ack,
58 pre_fetch_count,
59 consumer_tags: Arc::new(RwLock::new(Vec::new())),
60 in_flight: Arc::new(AtomicUsize::new(0)),
61 shutdown_notify: Arc::new(Notify::new()),
62 }
63 }
64
65 fn generate_consumer_tag(&self) -> String {
66 format!("ctag{}", Uuid::new_v4())
67 }
68
69 pub async fn reopen(&mut self, channel_id: u16) -> Result<(), AppError> {
70 if channel_id == self.channel.channel_id() {
71 let new_channel = self.connection.lock().await.open_channel(None).await?;
72 if self.publisher_confirms == Confirmations::PublisherConfirms || self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
73 let args = ConfirmSelectArguments::default();
74 let _ = new_channel.confirm_select(args).await;
75 }
76 self.channel.clone().close().await.ok();
77 self.channel = new_channel;
78 if !self.auto_ack {
79 if let Some(pre_fetch_count) = self.pre_fetch_count {
80 let args = BasicQosArguments::new(0, pre_fetch_count, false);
81 let _ = self.channel.basic_qos(args).await;
82 }
83 }
84 if let Err(e) = self.channel
85 .register_callback(MyChannelCallback {
86 channel_tx: self.channel_tx.clone(),
87 })
88 .await
89 {
90 error!("Failed to register channel callback: {}", e);
91 }
92
93 } else if self.aux_channel.is_some() && channel_id == self.aux_channel.as_ref().unwrap().channel_id() {
94 let new_channel = self.connection.lock().await.open_channel(None).await?;
95 if self.publisher_confirms == Confirmations::RPCServerPublisherConfirms {
96 let args = ConfirmSelectArguments::default();
97 let _ = new_channel.confirm_select(args).await;
98 }
99 let _ = self.aux_channel.as_ref().unwrap().clone().close().await;
100 self.aux_channel = Some(new_channel);
101 if !self.auto_ack {
102 if let Some(pre_fetch_count) = self.pre_fetch_count {
103 let args = BasicQosArguments::new(0, pre_fetch_count, false);
104 let _ = self.aux_channel.as_ref().unwrap().basic_qos(args).await;
105 }
106 }
107 if let Err(e) = self.aux_channel.as_ref().unwrap()
108 .register_callback(MyChannelCallback {
109 channel_tx: self.channel_tx.clone(),
110 })
111 .await
112 {
113 error!("Failed to register channel callback: {}", e);
114 }
115
116
117 } else {
118 error!("Received reopen for unknown channel id: {}", channel_id);
119 }
120 Ok(())
121 }
122
123 pub async fn add_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalSubscribeHandler) {
124 let queue_handlers = {
125 let mut handlers = self.subscribes.write().await;
126 handlers
127 .entry(queue_name.to_owned())
128 .or_insert_with(|| Arc::new(ArcSwap::from_pointee(TopicTrie::new())))
129 .clone()
130 };
131 queue_handlers.rcu(|current_map| {
132 let mut new_map = (**current_map).clone();
133 new_map.insert(routing_key, handler.clone());
134 Arc::new(new_map)
135 });
136
137 }
138
139 pub async fn add_rpc_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalRPCHandler) {
140 let queue_handlers = {
141 let mut rpc_handlers = self.rpc_subscribes.write().await;
142 rpc_handlers
143 .entry(queue_name.to_owned())
144 .or_insert_with(|| Arc::new(ArcSwap::from_pointee(HashMap::new())))
145 .clone()
146 };
147
148 queue_handlers.rcu(|current_map| {
149 let mut new_map = (**current_map).clone();
150 new_map.insert(routing_key.to_owned(), handler.clone());
151 Arc::new(new_map)
152 });
153 }
154
155 pub async fn queue_bind(&self, queue_name: &str, exchange_name: &str, routing_key: &str) -> Result<(), AppError> {
156 self.channel
157 .queue_bind(QueueBindArguments::new(
158 queue_name,
159 exchange_name,
160 routing_key,
161 ))
162 .await?;
163 Ok(())
164 }
165
166 pub async fn set_qos(&self, prefetch_count: u16) -> Result<(), AppError> {
167 let args = BasicQosArguments::new(0, prefetch_count, false);
168 self.channel.basic_qos(args).await?;
169 Ok(())
170 }
171
172 pub async fn setup_exchange(&self, exchange_name: &str, exchange_type: &str, durable: bool) -> Result<(), AppError> {
173 let arguments = ExchangeDeclareArguments{
174 exchange: exchange_name.to_string(),
175 exchange_type: exchange_type.to_string(),
176 durable,
177 ..Default::default()
178 };
179 Ok(self.channel.exchange_declare(arguments).await?)
180 }
181
182 pub async fn publish(
183 &self,
184 exchange_name: &str,
185 routing_key: &str,
186 body: impl Into<Vec<u8>>,
187 content_type: &str,
188 content_encoding: ContentEncoding,
189 delivery_mode: DeliveryMode,
190 expiration: Option<u32>,
191 ) -> Result<(), AppError>{
192 let args = BasicPublishArguments{
193 exchange: exchange_name.to_owned(),
194 routing_key: routing_key.to_owned(),
195 mandatory: true,
196 immediate: false
197 };
198 let mut properties = BasicProperties::default();
199 properties.with_content_type(content_type);
200 if content_encoding != ContentEncoding::None {
201 properties.with_content_encoding(content_encoding.as_str());
202 }
203 if let Some(exp) = expiration {
204 properties.with_expiration(&format!("{}", exp));
205 }
206 properties.with_delivery_mode(delivery_mode as u8);
207 Ok(self.channel.basic_publish(properties, body.into(), args).await?)
208 }
209
210 pub async fn queue_declare(&self, queue_name: &str, queue_options: &QueueOptions) -> Result<(), AppError> {
211 let queue_args = QueueDeclareArguments::new(queue_name)
212 .auto_delete(queue_options.auto_delete)
213 .durable(queue_options.durable)
214 .exclusive(queue_options.exclusive)
215 .passive(queue_options.no_create)
216 .arguments(queue_options.clone().into())
217 .finish();
218 self.channel.queue_declare(queue_args).await?;
219 Ok(())
220 }
221
222 pub async fn close(&self) -> Result<(), AppError> {
223 self.channel.clone().close().await?;
224 if let Some(aux_channel) = &self.aux_channel {
225 aux_channel.clone().close().await?;
226 }
227 Ok(())
228 }
229
230 pub async fn subscribe(
231 &self,
232 handler: Handler,
233 routing_key: &str,
234 exchange_name: &str,
235 exchange_type: &str,
236 queue_name: &str,
237 process_timeout: Option<Duration>,
238 queue_options: &QueueOptions
239 ) -> Result<(), AppError>
240 {
241 self.setup_exchange(exchange_name, exchange_type, queue_options.durable)
242 .await?;
243 self.queue_declare(queue_name, queue_options).await?;
244 self.channel
256 .queue_bind(QueueBindArguments::new(
257 &queue_name,
258 exchange_name,
259 routing_key,
260 ))
261 .await?;
262
263 self.add_subscribe(&queue_name, routing_key, InternalSubscribeHandler::new(
264 handler,
265 process_timeout,
266 )).await;
267
268 if !self.consumers.contains_key(queue_name) {
269 let queue_handler = self.subscribes.read().await;
270 let handler = queue_handler.get(queue_name).unwrap();
271 if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
272 let args = BasicQosArguments::new(0, pre_fetch_count, false);
273 let _ = self.channel.basic_qos(args).await;
274 }
275 self.consumers.insert(queue_name.to_string(), true);
276 let mut args = BasicConsumeArguments::new(&queue_name, &self.generate_consumer_tag());
277 args.manual_ack(!self.auto_ack);
278 let sub_handler = BroadSubscribeHandler::new(Arc::clone(handler), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
279 let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
280 self.consumer_tags.write().await.push(consumer_tag);
281 }
282 Ok(())
283 }
284 pub async fn unsubscribe(&self, consumer_tag: &str) -> Result<(), AppError> {
285 let args = BasicCancelArguments::new(consumer_tag);
286 self.channel.basic_cancel(args).await?;
287 Ok(())
288 }
289
290 pub async fn rpc_server(
291 &mut self,
292 handler: RPCHandler,
293 routing_key: &str,
294 exchange_name: &str,
295 exchange_type: &str,
296 queue_name: &str,
297 response_timeout: Option<Duration>,
298 queue_options: &QueueOptions
299 ) -> Result<(), AppError>
300 {
301 if self.aux_channel.is_none() {
302 let ch = self.connection.lock().await.open_channel(None).await?;
303
304 if self.publisher_confirms == Confirmations::RPCServerPublisherConfirms {
305 let args = ConfirmSelectArguments::default();
306 let _ = ch.confirm_select(args).await;
307 }
308 self.aux_channel = Some(ch);
309 }
310 self.add_rpc_subscribe(queue_name, routing_key, InternalRPCHandler::new(
311 handler,
312 response_timeout,
313 )).await;
314
315 self.setup_exchange(exchange_name, exchange_type, queue_options.durable)
316 .await?;
317 self.queue_declare(queue_name, queue_options).await?;
318 self.channel
329 .queue_bind(QueueBindArguments::new(
330 &queue_name,
331 exchange_name,
332 routing_key,
333 ))
334 .await?;
335 if !self.consumers.contains_key(queue_name) {
336 let queue_handler = self.rpc_subscribes.read().await;
337 let handler = queue_handler.get(queue_name).unwrap();
338 let mut args = BasicConsumeArguments::new(queue_name, &self.generate_consumer_tag());
339 args.manual_ack(!self.auto_ack);
340 self.consumers.insert(queue_name.to_string(), true);
341 let sub_handler = BroadRPCHandler::new(
342 Arc::new(self.aux_channel.as_ref().unwrap().clone()),
343 Arc::clone(handler),
344 self.auto_ack,
345 self.in_flight.clone(),
346 self.shutdown_notify.clone(),
347 );
348 drop(queue_handler);
349 if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
350 let args = BasicQosArguments::new(0, pre_fetch_count, false);
351 let _ = self.channel.basic_qos(args).await;
352 }
353 let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
354 self.consumer_tags.write().await.push(consumer_tag);
355 }
356 Ok(())
357 }
358
359 pub async fn start_rpc_consumer(&mut self) -> Result<(), AppError> {
360 if !self.rpc_consumer_started.load(std::sync::atomic::Ordering::SeqCst) {
361 {
362 self.aux_channel = Some(async {
363 let ch = self.connection.lock().await.open_channel(None).await?;
364 if let Err(e) = ch
365 .register_callback(MyChannelCallback {
366 channel_tx: self.channel_tx.clone(),
367 })
368 .await
369 {
370 error!("Failed to register channel callback: {}", e);
371 }
372 if self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
373 let args = ConfirmSelectArguments::default();
374 let _ = ch.confirm_select(args).await;
375 }
376 if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
377 let args = BasicQosArguments::new(0, pre_fetch_count, false);
378 let _ = ch.basic_qos(args).await;
379 }
380 Ok::<Channel, AppError>(ch)
381 }.await?);
382 }
383 if let Some(channel) = &self.aux_channel {
384 let mut queue_declare = QueueDeclareArguments::new(&self.aux_queue_name);
385 let mut field_table = FieldTable::new();
386 field_table.insert("x-expires".try_into().unwrap(), amqprs::FieldValue::l(60000));
387 queue_declare.auto_delete(false);
388 queue_declare.exclusive(false);
389 queue_declare.arguments(field_table);
390 let (_, _, _) = channel.queue_declare(queue_declare)
391 .await?
392 .ok_or_else(|| AppError::new(Some("Queue declare returned None".to_string()), None, AppErrorType::InternalError))?;
393 let rpc_handler = BroadRPCClientHandler::new(Arc::clone(&self.rpc_futures), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
394 let mut args = BasicConsumeArguments::new(&self.aux_queue_name, &self.generate_consumer_tag());
395 args.manual_ack(!self.auto_ack);
396 let consumer_tag = channel.basic_consume(rpc_handler, args).await?;
397 self.consumer_tags.write().await.push(consumer_tag);
398 self.rpc_consumer_started.store(true, std::sync::atomic::Ordering::SeqCst);
399 }
400 }
401 Ok(())
402 }
403
404 pub async fn rpc_client(
405 &mut self,
406 exchange_name: &str,
407 routing_key: &str,
408 body: impl Into<Vec<u8>>,
409 content_type: &str,
410 content_encoding: ContentEncoding,
411 timeout_millis: u32,
412 delivery_mode: DeliveryMode,
413 expiration: Option<u32>,
414 response: oneshot::Sender<Result<Vec<u8>, AppError>>,
415 clean_message: UnboundedSender<ChannelCmd>,
416 message_id: Option<u64>,
417 ) -> Result<(), AppError>
418 {
419 self.start_rpc_consumer().await?;
420 let (tx, rx) = oneshot::channel();
421
422 let correlated_id = Uuid::new_v4().to_string();
423 self.rpc_futures.insert(correlated_id.to_owned(), tx);
424 let mut args = BasicPublishArguments::new(exchange_name, routing_key);
425 args.mandatory(true);
426 let mut properties = BasicProperties::default();
427 properties.with_content_type(content_type);
428 if content_encoding != ContentEncoding::None {
429 properties.with_content_encoding(content_encoding.as_str());
430 }
431 properties.with_correlation_id(&correlated_id);
432 properties.with_reply_to(&self.aux_queue_name);
433 properties.with_delivery_mode(delivery_mode as u8);
434 let cn = self.channel.clone();
435 if let Some(exp) = expiration {
436 properties.with_expiration(&format!("{}", exp));
437 }
438 let body = body.into();
439 tokio::spawn(async move {
440 let _ = cn.basic_publish(properties, body, args).await;
441 let message = match tokio::time::timeout(std::time::Duration::from_millis(timeout_millis as u64), rx).await {
442 Ok(Ok(result)) => Ok(result),
443 Ok(Err(_)) => Err(AppError::new(Some("Receiver was dropped".to_string()), None, AppErrorType::InternalError)),
444 Err(_) => Err(AppError::new(Some("Timeout exceeded".to_string()), None, AppErrorType::TimeoutError)),
445 };
446 if let Err(_) = response.send(message) && let Some(id) = message_id {
447 let _ = clean_message.send(ChannelCmd::PublishNack((id, false)));
448 }
449 });
450 Ok(())
451 }
452 pub async fn dispose(&self) {
453 let cn = self.channel.clone();
454 for tag in self.consumer_tags.read().await.iter() {
455 let args = BasicCancelArguments::new(tag);
456 if let Err(e) = cn.basic_cancel(args).await {
457 error!("Failed to cancel consumer {}: {}", tag, e);
458 }
459 }
460 while self.in_flight.load(Ordering::Acquire) > 0 {
461 self.shutdown_notify.notified().await;
462 }
463 if let Err(e) = self.channel.clone().close().await {
464 error!("Failed to close main channel: {}", e);
465 }
466 if let Some(channel) = &self.aux_channel {
467 if let Err(e) = channel.clone().close().await {
468 error!("Failed to close aux channel: {}", e);
469 }
470 }
471 }
472}