1use crate::{
2 api::{consumers::{BroadRPCClientHandler, BroadRPCHandler, BroadSubscribeHandler, InternalRPCHandler, InternalSubscribeHandler}, utils::{ContentEncoding, DeliveryMode, Handler, Message, PendingCmd, RPCHandler, TopicTrie}},
3 errors::{AppError, AppErrorType},
4};
5use amqprs::{
6 BasicProperties, channel::{
7 BasicCancelArguments, BasicConsumeArguments, BasicPublishArguments, BasicQosArguments, Channel, ConfirmSelectArguments, ExchangeDeclareArguments, QueueBindArguments, QueueDeclareArguments
8 }, connection::Connection
9};
10use arc_swap::ArcSwap;
11use dashmap::DashMap;
12use tracing::error;
13use std::{collections::HashMap, sync::atomic::{AtomicBool, AtomicUsize, Ordering}};
14use std::error::Error as StdError;
15use std::future::Future;
16use std::sync::Arc;
17use tokio::{sync::{Mutex, Notify, OnceCell, RwLock, mpsc::UnboundedSender, oneshot}, time::Duration};
18use uuid::Uuid;
19use crate::api::utils::Confirmations;
20
21
22#[derive(Clone)]
23pub struct AsyncChannel {
24 pub channel: Channel,
25 connection: Arc<Mutex<Connection>>,
26 aux_channel: Arc<OnceCell<Channel>>,
27 aux_queue_name: String,
28 pub rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>,
29 pub rpc_consumer_started: Arc<AtomicBool>,
30 consumers: Arc<DashMap<String, bool>>,
31 subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<TopicTrie<InternalSubscribeHandler>>>>>>,
32 rpc_subscribes: Arc<RwLock<HashMap<String, Arc<ArcSwap<HashMap<String, InternalRPCHandler>>>>>>,
33 publisher_confirms: Confirmations,
35 auto_ack: bool,
36 pre_fetch_count: Option<u16>,
37 consumer_tags: Arc<RwLock<Vec<String>>>,
38 in_flight: Arc<AtomicUsize>,
39 pub shutdown_notify: Arc<Notify>,
40}
41
42impl AsyncChannel {
43 pub fn new(channel: Channel, connection: Arc<Mutex<Connection>>, rpc_futures: Arc<DashMap<String, oneshot::Sender<Vec<u8>>>>, publisher_confirms: Confirmations, auto_ack: bool, pre_fetch_count: Option<u16>) -> Self {
44 Self {
45 channel,
46 connection,
47 aux_channel: Arc::new(OnceCell::new()),
48 aux_queue_name: format!("amqp.{}", Uuid::new_v4()),
49 rpc_futures,
50 rpc_consumer_started: Arc::new(AtomicBool::new(false)),
51 consumers: Arc::new(DashMap::new()),
52 subscribes: Arc::new(RwLock::new(HashMap::new())),
53 rpc_subscribes: Arc::new(RwLock::new(HashMap::new())),
54 publisher_confirms,
56 auto_ack,
57 pre_fetch_count,
58 consumer_tags: Arc::new(RwLock::new(Vec::new())),
59 in_flight: Arc::new(AtomicUsize::new(0)),
60 shutdown_notify: Arc::new(Notify::new()),
61 }
62 }
63
64 fn generate_consumer_tag(&self) -> String {
65 format!("ctag{}", Uuid::new_v4())
66 }
67
68 pub async fn add_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalSubscribeHandler) {
69 let queue_handlers = {
70 let mut handlers = self.subscribes.write().await;
71 handlers
72 .entry(queue_name.to_owned())
73 .or_insert_with(|| Arc::new(ArcSwap::from_pointee(TopicTrie::new())))
74 .clone()
75 };
76 queue_handlers.rcu(|current_map| {
77 let mut new_map = (**current_map).clone();
78 new_map.insert(routing_key, handler.clone());
79 Arc::new(new_map)
80 });
81
82 }
83
84 pub async fn add_rpc_subscribe(&self, queue_name: &str, routing_key: &str, handler: InternalRPCHandler) {
85 let queue_handlers = {
86 let mut rpc_handlers = self.rpc_subscribes.write().await;
87 rpc_handlers
88 .entry(queue_name.to_owned())
89 .or_insert_with(|| Arc::new(ArcSwap::from_pointee(HashMap::new())))
90 .clone()
91 };
92
93 queue_handlers.rcu(|current_map| {
94 let mut new_map = (**current_map).clone();
95 new_map.insert(routing_key.to_owned(), handler.clone());
96 Arc::new(new_map)
97 });
98 }
99
100 pub async fn setup_exchange(&self, exchange_name: &str, exchange_type: &str, durable: bool) -> Result<(), AppError> {
101 let arguments = ExchangeDeclareArguments{
102 exchange: exchange_name.to_string(),
103 exchange_type: exchange_type.to_string(),
104 durable,
105 ..Default::default()
106 };
107 Ok(self.channel.exchange_declare(arguments).await?)
108 }
109
110 pub async fn publish(
111 &self,
112 exchange_name: &str,
113 routing_key: &str,
114 body: impl Into<Vec<u8>>,
115 content_type: &str,
116 content_encoding: ContentEncoding,
117 delivery_mode: DeliveryMode,
118 expiration: Option<u32>,
119 ) -> Result<(), AppError>{
120 let args = BasicPublishArguments{
121 exchange: exchange_name.to_owned(),
122 routing_key: routing_key.to_owned(),
123 mandatory: true,
124 immediate: false
125 };
126 let mut properties = BasicProperties::default();
127 properties.with_content_type(content_type);
128 if content_encoding != ContentEncoding::None {
129 properties.with_content_encoding(content_encoding.as_str());
130 }
131 if let Some(exp) = expiration {
132 properties.with_expiration(&format!("{}", exp));
133 }
134 properties.with_delivery_mode(delivery_mode as u8);
135 Ok(self.channel.basic_publish(properties, body.into(), args).await?)
136 }
137}
138impl AsyncChannel {
139 pub async fn subscribe(
140 &self,
141 handler: Handler,
142 routing_key: &str,
143 exchange_name: &str,
144 exchange_type: &str,
145 queue_name: &str,
146 process_timeout: Option<Duration>,
147 ) -> Result<(), AppError>
148 {
149 self.setup_exchange(exchange_name, exchange_type, true)
150 .await?;
151 let (queue_name, _, _) = self
162 .channel
163 .queue_declare(QueueDeclareArguments::durable_client_named(queue_name))
164 .await?
165 .ok_or_else(|| AppError::new(Some("Queue declare returned None".to_string()), None, AppErrorType::InternalError))?;
166 self.channel
167 .queue_bind(QueueBindArguments::new(
168 &queue_name,
169 exchange_name,
170 routing_key,
171 ))
172 .await?;
173
174 self.add_subscribe(&queue_name, routing_key, InternalSubscribeHandler::new(
175 handler,
176 process_timeout,
177 )).await;
178
179 if !self.consumers.contains_key(&queue_name) {
180 let queue_handler = self.subscribes.read().await;
181 let handler = queue_handler.get(&queue_name).unwrap();
182 if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
183 let args = BasicQosArguments::new(0, pre_fetch_count, false);
184 let _ = self.channel.basic_qos(args).await;
185 }
186 self.consumers.insert(queue_name.to_string(), true);
187 let mut args = BasicConsumeArguments::new(&queue_name, &self.generate_consumer_tag());
188 args.manual_ack(!self.auto_ack);
189 let sub_handler = BroadSubscribeHandler::new(Arc::clone(handler), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
190 let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
191 self.consumer_tags.write().await.push(consumer_tag);
192 }
193 Ok(())
194 }
195}
196impl AsyncChannel{
197 pub async fn rpc_server(
198 &self,
199 handler: RPCHandler,
200 routing_key: &str,
201 exchange_name: &str,
202 exchange_type: &str,
203 queue_name: &str,
204 response_timeout: Option<Duration>,
205 ) -> Result<(), AppError>
206 {
207 self.aux_channel.get_or_try_init(|| async {
208 let ch = self.connection.lock().await.open_channel(None).await?;
209
210 if self.publisher_confirms == Confirmations::RPCServerPublisherConfirms {
211 let args = ConfirmSelectArguments::default();
212 let _ = ch.confirm_select(args).await;
213 }
214 Ok::<Channel, AppError>(ch)
215 }).await?;
216 self.add_rpc_subscribe(queue_name, routing_key, InternalRPCHandler::new(
217 handler,
218 response_timeout,
219 )).await;
220
221 self.setup_exchange(exchange_name, exchange_type, true)
222 .await?;
223 if let Some((queue_name,_,_)) = self.channel.queue_declare(QueueDeclareArguments::durable_client_named(queue_name)).await? {
234 self.channel
235 .queue_bind(QueueBindArguments::new(
236 &queue_name,
237 exchange_name,
238 routing_key,
239 ))
240 .await?;
241 if !self.consumers.contains_key(&queue_name) {
242 let queue_handler = self.rpc_subscribes.read().await;
243 let handler = queue_handler.get(&queue_name).unwrap();
244 let mut args = BasicConsumeArguments::new(&queue_name, &self.generate_consumer_tag());
245 args.manual_ack(!self.auto_ack);
246 self.consumers.insert(queue_name.to_string(), true);
247 let sub_handler = BroadRPCHandler::new(
248 Arc::clone(&self.aux_channel),
249 Arc::clone(handler),
250 self.auto_ack,
251 self.in_flight.clone(),
252 self.shutdown_notify.clone(),
253 );
254 drop(queue_handler);
255 if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
256 let args = BasicQosArguments::new(0, pre_fetch_count, false);
257 let _ = self.channel.basic_qos(args).await;
258 }
259 let consumer_tag = self.channel.basic_consume(sub_handler, args).await?;
260 self.consumer_tags.write().await.push(consumer_tag);
261 }
262 }
263 Ok(())
264 }
265
266 pub async fn start_rpc_consumer(&self) -> Result<(), AppError> {
267 if !self.rpc_consumer_started.load(std::sync::atomic::Ordering::SeqCst) {
268 {
269 self.aux_channel.get_or_try_init(|| async {
270 let ch = self.connection.lock().await.open_channel(None).await?;
271 if self.publisher_confirms == Confirmations::RPCClientPublisherConfirms {
272 let args = ConfirmSelectArguments::default();
273 let _ = ch.confirm_select(args).await;
274 }
275 if !self.auto_ack && let Some(pre_fetch_count) = self.pre_fetch_count {
276 let args = BasicQosArguments::new(0, pre_fetch_count, false);
277 let _ = ch.basic_qos(args).await;
278 }
279 Ok::<Channel, AppError>(ch)
280 }).await?;
281 }
282 if let Some(channel) = self.aux_channel.get() {
283 let mut queue_declare = QueueDeclareArguments::new(&self.aux_queue_name);
284 queue_declare.auto_delete(true);
285 let (_, _, _) = channel.queue_declare(queue_declare)
286 .await?
287 .ok_or_else(|| AppError::new(Some("Queue declare returned None".to_string()), None, AppErrorType::InternalError))?;
288 let rpc_handler = BroadRPCClientHandler::new(Arc::clone(&self.rpc_futures), self.auto_ack, self.in_flight.clone(), self.shutdown_notify.clone());
289 let mut args = BasicConsumeArguments::new(&self.aux_queue_name, &self.generate_consumer_tag());
290 args.manual_ack(!self.auto_ack);
291 let consumer_tag = channel.basic_consume(rpc_handler, args).await?;
292 self.consumer_tags.write().await.push(consumer_tag);
293 self.rpc_consumer_started.store(true, std::sync::atomic::Ordering::SeqCst);
294 }
295 }
296 Ok(())
297 }
298
299 pub async fn rpc_client(
300 &self,
301 exchange_name: &str,
302 routing_key: &str,
303 body: impl Into<Vec<u8>>,
304 content_type: &str,
305 content_encoding: ContentEncoding,
306 timeout_millis: u32,
307 delivery_mode: DeliveryMode,
308 expiration: Option<u32>,
309 response: oneshot::Sender<Result<Vec<u8>, AppError>>,
310 clean_message: UnboundedSender<PendingCmd>,
311 message_id: Option<u64>,
312 ) -> Result<(), AppError>
313 {
314 self.start_rpc_consumer().await?;
315 let (tx, rx) = oneshot::channel();
316
317 let correlated_id = Uuid::new_v4().to_string();
318 self.rpc_futures.insert(correlated_id.to_owned(), tx);
319 let mut args = BasicPublishArguments::new(exchange_name, routing_key);
320 args.mandatory(true);
321 let mut properties = BasicProperties::default();
322 properties.with_content_type(content_type);
323 if content_encoding != ContentEncoding::None {
324 properties.with_content_encoding(content_encoding.as_str());
325 }
326 properties.with_correlation_id(&correlated_id);
327 properties.with_reply_to(&self.aux_queue_name);
328 properties.with_delivery_mode(delivery_mode as u8);
329 let cn = self.channel.clone();
330 if let Some(exp) = expiration {
331 properties.with_expiration(&format!("{}", exp));
332 }
333 let body = body.into();
334 tokio::spawn(async move {
335 let _ = cn.basic_publish(properties, body, args).await;
336 let message = match tokio::time::timeout(std::time::Duration::from_millis(timeout_millis as u64), rx).await {
337 Ok(Ok(result)) => Ok(result),
338 Ok(Err(_)) => Err(AppError::new(Some("Receiver was dropped".to_string()), None, AppErrorType::InternalError)),
339 Err(_) => Err(AppError::new(Some("Timeout exceeded".to_string()), None, AppErrorType::TimeoutError)),
340 };
341 if let Err(_) = response.send(message) && let Some(id) = message_id {
342 let _ = clean_message.send(PendingCmd::Nack((id, false)));
343 }
344 });
345 Ok(())
346 }
347 pub async fn dispose(&self) {
348 let cn = self.channel.clone();
349 for tag in self.consumer_tags.read().await.iter() {
350 let args = BasicCancelArguments::new(tag);
351 if let Err(e) = cn.basic_cancel(args).await {
352 error!("Failed to cancel consumer {}: {}", tag, e);
353 }
354 }
355 while self.in_flight.load(Ordering::Acquire) > 0 {
356 self.shutdown_notify.notified().await;
357 }
358 if let Err(e) = self.channel.clone().close().await {
359 error!("Failed to close main channel: {}", e);
360 }
361 if let Some(channel) = self.aux_channel.get() {
362 if let Err(e) = channel.clone().close().await {
363 error!("Failed to close aux channel: {}", e);
364 }
365 }
366 }
367}