1use std::collections::{HashMap, VecDeque};
2
3use tokio::{
4 sync::{mpsc, oneshot},
5 task::yield_now,
6 time,
7};
8
9use crate::{
10 api::{callbacks::ChannelCallback, channel::ReturnMessage},
11 channel::GetOkMessage,
12 frame::{CancelOk, CloseChannelOk, ContentBody, FlowOk, Frame, MethodHeader},
13 net::IncomingMessage,
14 BasicProperties, Return,
15};
16#[cfg(feature = "traces")]
17use tracing::{debug, error, info, trace};
18
19use super::{Channel, ConsumerMessage, DispatcherManagementCommand};
20
21const CONSUMER_PURGE_INTERVAL: time::Duration = time::Duration::from_secs(10);
26
27const CONSUMER_EXPIRY_PERIOD: time::Duration = time::Duration::from_secs(5);
32
33struct ConsumerResource {
35 fifo: VecDeque<ConsumerMessage>,
37 tx: Option<mpsc::UnboundedSender<ConsumerMessage>>,
40 expiration: Option<time::Instant>,
42}
43
44impl ConsumerResource {
45 fn new() -> Self {
46 Self {
47 fifo: VecDeque::new(),
48 tx: None,
49 expiration: Some(time::Instant::now() + CONSUMER_EXPIRY_PERIOD),
50 }
51 }
52
53 fn register_tx(
54 &mut self,
55 tx: mpsc::UnboundedSender<ConsumerMessage>,
56 ) -> Option<mpsc::UnboundedSender<ConsumerMessage>> {
57 self.expiration.take();
59 self.tx.replace(tx)
60 }
61
62 fn get_tx(&self) -> Option<&mpsc::UnboundedSender<ConsumerMessage>> {
63 self.tx.as_ref()
64 }
65
66 fn get_expiration(&self) -> Option<&time::Instant> {
67 self.expiration.as_ref()
68 }
69
70 fn push_message(&mut self, message: ConsumerMessage) {
71 self.fifo.push_back(message);
72 }
73
74 fn pop_message(&mut self) -> Option<ConsumerMessage> {
75 self.fifo.pop_front()
76 }
77}
78
79enum State {
80 Initial,
81 Deliver,
82 GetOk,
83 GetEmpty,
84 Return,
85}
86
87pub(crate) struct ChannelDispatcher {
93 channel: Channel,
94 dispatcher_rx: mpsc::UnboundedReceiver<IncomingMessage>,
95 dispatcher_mgmt_rx: mpsc::UnboundedReceiver<DispatcherManagementCommand>,
96 consumer_resources: HashMap<String, ConsumerResource>,
97 get_content_responder: Option<mpsc::UnboundedSender<IncomingMessage>>,
98 responders: HashMap<&'static MethodHeader, oneshot::Sender<IncomingMessage>>,
99 callback: Option<Box<dyn ChannelCallback + Send + 'static>>,
100 state: State,
101}
102impl ChannelDispatcher {
104 pub(crate) fn new(
105 channel: Channel,
106 dispatcher_rx: mpsc::UnboundedReceiver<IncomingMessage>,
107 dispatcher_mgmt_rx: mpsc::UnboundedReceiver<DispatcherManagementCommand>,
108 ) -> Self {
109 Self {
110 channel,
111 dispatcher_rx,
112 dispatcher_mgmt_rx,
113 consumer_resources: HashMap::new(),
114 get_content_responder: None,
115 responders: HashMap::new(),
116 callback: None,
117 state: State::Initial,
118 }
119 }
120
121 fn get_or_new_consumer_resource(&mut self, consumer_tag: &String) -> &mut ConsumerResource {
123 if !self.consumer_resources.contains_key(consumer_tag) {
124 let resource = ConsumerResource::new();
125 self.consumer_resources
126 .insert(consumer_tag.clone(), resource);
127 }
128 self.consumer_resources.get_mut(consumer_tag).unwrap()
129 }
130
131 fn purge_consumer_resource(&mut self) {
133 let purge_keys: Vec<String> = self
135 .consumer_resources
136 .iter()
137 .filter_map(|(k, v)| {
138 if let Some(expiration) = v.get_expiration() {
139 if expiration < &time::Instant::now() {
140 return Some(k.clone());
141 }
142 }
143 None
144 })
145 .collect();
146
147 for key in purge_keys {
149 self.consumer_resources.remove(&key);
150 #[cfg(feature = "traces")]
151 debug!(
152 "purge stale consumer resource {} on channel {}",
153 key, self.channel
154 );
155 }
156 }
157 fn remove_consumer_resource(&mut self, consumer_tag: &String) -> Option<ConsumerResource> {
161 self.consumer_resources.remove(consumer_tag)
162 }
163
164 async fn forward_deliver(&mut self, consumer_message: ConsumerMessage) {
165 let consumer_tag = consumer_message
166 .deliver
167 .as_ref()
168 .unwrap()
169 .consumer_tag()
170 .clone();
171 let consumer = self.get_or_new_consumer_resource(&consumer_tag);
172 match consumer.get_tx() {
173 Some(consumer_tx) => {
174 if (consumer_tx.send(consumer_message)).is_err() {
175 #[cfg(feature = "traces")]
176 error!(
177 "failed to dispatch message to consumer {} on channel {}",
178 consumer_tag, self.channel
179 );
180 }
181 }
182 None => {
183 #[cfg(feature = "traces")]
184 debug!("can't find consumer {}, message is buffered", consumer_tag);
185 consumer.push_message(consumer_message);
186 yield_now().await;
189 }
190 };
191 }
192
193 async fn handle_return(
194 &mut self,
195 ret: Return,
196 basic_properties: BasicProperties,
197 content: Vec<u8>,
198 ) {
199 if let Some(ref mut cb) = self.callback {
200 cb.publish_return(&self.channel, ret, basic_properties, content)
201 .await;
202 } else {
203 #[cfg(feature = "traces")]
204 error!("callback not registered on channel {}", self.channel);
205 }
206 }
207 pub(in crate::api) async fn spawn(mut self) {
209 tokio::spawn(async move {
210 let mut message_buffer = ConsumerMessage {
212 deliver: None,
213 basic_properties: None,
214 content: None,
215 remaining: 0,
216 };
217 let mut return_buffer = ReturnMessage {
219 ret: None,
220 basic_properties: None,
221 content: None,
222 remaining: 0,
223 };
224 let mut getok_content_buffer = GetOkMessage {
226 content: None,
227 remaining: 0,
228 };
229
230 #[cfg(feature = "traces")]
231 trace!("starts up dispatcher task of channel {}", self.channel);
232
233 let mut purge_timer = time::interval(CONSUMER_PURGE_INTERVAL);
234 purge_timer.tick().await;
235 loop {
237 tokio::select! {
238 biased;
239
240 command = self.dispatcher_mgmt_rx.recv() => {
243 let cmd = match command {
245 None => {
246 unreachable!("dispatcher command channel closed, {}", self.channel);
247 },
248 Some(v) => v,
249 };
250 match cmd {
252 DispatcherManagementCommand::RegisterContentConsumer(cmd) => {
253 #[cfg(feature="traces")]
254 info!("register consumer {}", cmd.consumer_tag);
255 let consumer = self.get_or_new_consumer_resource(&cmd.consumer_tag);
256 consumer.register_tx(cmd.consumer_tx);
257 while !consumer.fifo.is_empty() {
259 #[cfg(feature="traces")]
260 trace!("consumer {} total buffered messages: {}", cmd.consumer_tag, consumer.fifo.len());
261 let msg = consumer.pop_message().unwrap();
262 if let Err(_err) = consumer.get_tx().unwrap().send(msg) {
263 #[cfg(feature="traces")]
264 error!("failed to forward message to consumer {}", cmd.consumer_tag);
265 }
266 }
267 },
268 DispatcherManagementCommand::DeregisterContentConsumer(cmd) => {
269 if let Some(consumer) = self.remove_consumer_resource(&cmd.consumer_tag) {
270 #[cfg(feature="traces")]
271 info!("deregister consumer {}, total buffered messages: {}",
272 cmd.consumer_tag, consumer.fifo.len()
273 );
274 }
275 },
276 DispatcherManagementCommand::RegisterGetContentResponder(cmd) => {
277 self.get_content_responder.replace(cmd.tx);
278 }
279 DispatcherManagementCommand::RegisterOneshotResponder(cmd) => {
280 self.responders.insert(cmd.method_header, cmd.responder);
281 cmd.acker.send(()).unwrap();
282 }
283 DispatcherManagementCommand::RegisterChannelCallback(cmd) => {
284 self.callback.replace(cmd.callback);
285 #[cfg(feature="traces")]
286 debug!("callback registered on channel {}", self.channel);
287 }
288 }
289 }
290 message = self.dispatcher_rx.recv() => {
293 let frame = match message {
295 None => {
296 #[cfg(feature="traces")]
298 debug!("dispatcher mpsc channel closed, channel {}", self.channel);
299 break;
300 },
301 Some(v) => v,
302 };
303 match frame {
305 Frame::CloseChannelOk(method_header, close_channel_ok) => {
309 self.channel.set_is_open(false);
310
311 match self.responders.remove(method_header) {
312 Some(responder) => responder.send(close_channel_ok.into_frame()).unwrap(),
313 None => unreachable!("responder must be registered for {} on channel {}",
314 close_channel_ok.into_frame(), self.channel),
315 }
316 break;
318 }
319 Frame::CloseChannel(_, close_channel) => {
321 if let Some(ref mut cb) = self.callback {
323 if let Err(err) = cb.close(&self.channel, close_channel).await {
324 #[cfg(feature="traces")]
325 error!("close callback returns error on channel {}, cause: {}", self.channel, err);
326 break;
328 };
329 } else {
330 #[cfg(feature="traces")]
331 error!("callback not registered on channel {}", self.channel);
332 }
333 self.channel.set_is_open(false);
334
335 self.channel.shared.outgoing_tx
337 .send((self.channel.channel_id(), CloseChannelOk.into_frame()))
338 .await.unwrap();
339 break;
341 }
342 Frame::GetEmpty(_, get_empty) => {
345 self.state = State::GetEmpty;
346
347 self.get_content_responder.take()
348 .expect("get responder must be registered")
349 .send(get_empty.into_frame()).unwrap();
350 }
351 Frame::GetOk(_, get_ok) => {
352 self.state = State::GetOk;
353
354 self.get_content_responder.as_ref()
355 .expect("get responder must be registered")
356 .send(get_ok.into_frame()).unwrap();
357 }
358 Frame::Return(_, ret) => {
359 self.state = State::Return;
360 return_buffer.ret = Some(ret);
361 }
362 Frame::Deliver(_, deliver) => {
363 self.state = State::Deliver;
364 message_buffer.deliver = Some(deliver);
365 }
366 Frame::ContentHeader(header) => {
367 match self.state {
368 State::Deliver => {
369 message_buffer.remaining = header.common.body_size.try_into().unwrap();
370 if message_buffer.remaining == 0 {
372 let consumer_message = ConsumerMessage {
373 deliver: message_buffer.deliver.take(),
374 basic_properties: Some(header.basic_properties),
375 content: Some(Vec::new()),
376 remaining: 0,
377 };
378 self.forward_deliver(consumer_message).await;
379 } else {
380 message_buffer.basic_properties = Some(header.basic_properties);
381 message_buffer.content = Some(Vec::new());
382 }
383 },
384 State::GetOk => {
385 getok_content_buffer.remaining = header.common.body_size.try_into().unwrap();
386
387 let responder = self.get_content_responder.as_ref().expect("get responder must be registered");
388 responder.send(header.into_frame()).unwrap();
389 if getok_content_buffer.remaining == 0 {
391 responder.send(ContentBody::new(Vec::new()).into_frame()).unwrap();
392 } else {
393 getok_content_buffer.content = Some(Vec::new());
394 }
395 },
396 State::Return => {
397 return_buffer.remaining = header.common.body_size.try_into().unwrap();
398
399 if return_buffer.remaining == 0 {
400 self.handle_return(return_buffer.ret.take().unwrap(), header.basic_properties, Vec::new()).await;
402 } else {
403 return_buffer.basic_properties = Some(header.basic_properties);
404 return_buffer.content = Some(Vec::new());
405 }
406 },
407 _ => unreachable!("invalid dispatcher state"),
408 }
409 }
410 Frame::ContentBody(body) => {
411 match self.state {
412 State::Deliver => {
413 let mut content_buffer = message_buffer.content.take().unwrap();
414 content_buffer.extend_from_slice(&body.inner);
415 message_buffer.content.replace(content_buffer);
416 message_buffer.remaining = message_buffer.remaining.checked_sub(body.inner.len()).expect("should never overflow");
418
419 if message_buffer.remaining == 0 {
420 let consumer_message = ConsumerMessage {
421 deliver: message_buffer.deliver.take(),
422 basic_properties: message_buffer.basic_properties.take(),
423 content: message_buffer.content.take(),
424 remaining: message_buffer.remaining,
425 };
426 self.forward_deliver(consumer_message).await;
427 }
428 }
429 State::GetOk => {
430 let mut content_buffer = getok_content_buffer.content.take().unwrap();
431 content_buffer.extend_from_slice(&body.inner);
432 getok_content_buffer.content.replace(content_buffer);
433 getok_content_buffer.remaining = getok_content_buffer.remaining.checked_sub(body.inner.len()).expect("should never overflow");
434 if getok_content_buffer.remaining == 0 {
435 let content = getok_content_buffer.content.take().unwrap();
436 self.get_content_responder.take()
437 .expect("get responder must be registered")
438 .send(ContentBody::new(content).into_frame()).unwrap();
439 }
440 },
441 State::Return => {
442 let mut content_buffer = return_buffer.content.take().unwrap();
443 content_buffer.extend_from_slice(&body.inner);
444 return_buffer.content.replace(content_buffer);
445 return_buffer.remaining = return_buffer.remaining.checked_sub(body.inner.len()).expect("should never overflow");
446
447 if return_buffer.remaining == 0 {
448 self.handle_return(
449 return_buffer.ret.take().unwrap(),
450 return_buffer.basic_properties.take().unwrap(),
451 return_buffer.content.take().unwrap()).await;
452 }
453 },
454 State::Initial | State::GetEmpty => unreachable!("invalid dispatcher state on channel {}", self.channel),
455 }
456 }
457 Frame::FlowOk(method_header, _)
460 | Frame::DeclareOk(method_header, _)
462 | Frame::DeleteOk(method_header, _)
463 | Frame::BindOk(method_header, _)
464 | Frame::UnbindOk(method_header, _)
465 | Frame::DeclareQueueOk(method_header, _)
466 | Frame::BindQueueOk(method_header, _)
467 | Frame::PurgeQueueOk(method_header, _)
468 | Frame::DeleteQueueOk(method_header, _)
469 | Frame::UnbindQueueOk(method_header, _)
470 | Frame::QosOk(method_header, _)
471 | Frame::ConsumeOk(method_header, _)
472 | Frame::CancelOk(method_header, _)
473 | Frame::RecoverOk(method_header, _)
474 | Frame::SelectOk(method_header, _)
475 | Frame::TxSelectOk(method_header, _)
476 | Frame::TxCommitOk(method_header, _)
477 | Frame::TxRollbackOk(method_header, _) => {
478 match self.responders.remove(method_header)
480 {
481 Some(responder) => {
482 if let Err(response) = responder.send(frame) {
483 #[cfg(feature="traces")]
484 error!(
485 "failed to dispatch {} to channel {}",
486 response, self.channel
487 );
488 }
489 }
490 None => unreachable!(
491 "responder must be registered for {} on channel {}",
492 frame, self.channel
493 ),
494 }
495 }
496 Frame::Flow(_, flow) => {
499 if let Some(ref mut cb) = self.callback {
501 match cb.flow(&self.channel, flow.active).await {
502 Err(err) => {
503 #[cfg(feature="traces")]
504 error!("flow callback error on channel {}, cause: '{}'.", self.channel, err);
505 }
506 Ok(active) => {
507 self.channel.shared.outgoing_tx
509 .send((self.channel.channel_id(), FlowOk::new(active).into_frame()))
510 .await.unwrap();
511 }
512 };
513 } else {
514 #[cfg(feature="traces")]
515 error!("callback not registered on channel {}", self.channel);
516 }
517 }
518 Frame::Cancel(_, cancel) => {
519 if let Some(ref mut cb) = self.callback {
521 let consumer_tag = cancel.consumer_tag().clone();
522 let no_wait = cancel.no_wait();
523 match cb.cancel(&self.channel, cancel).await {
524 Err(err) => {
525 #[cfg(feature="traces")]
526 error!("cancel callback error on channel {}, cause: '{}'.", self.channel, err);
527 }
528 Ok(_) => {
529 self.remove_consumer_resource(&consumer_tag);
530
531 if !no_wait {
533 self.channel.shared.outgoing_tx
534 .send((self.channel.channel_id(), CancelOk::new(consumer_tag.try_into().unwrap()).into_frame()))
535 .await.unwrap();
536 }
537 }
538 };
539 } else {
540 #[cfg(feature="traces")]
541 error!("callback not registered on channel {}", self.channel);
542 }
543 }
544 Frame::Ack(_, ack) => {
546 if let Some(ref mut cb) = self.callback {
547 cb.publish_ack(&self.channel, ack).await;
548 } else {
549 #[cfg(feature="traces")]
550 error!("callback not registered on channel {}", self.channel);
551 }
552 }
553 Frame::Nack(_, nack) => {
554 if let Some(ref mut cb) = self.callback {
555 cb.publish_nack(&self.channel, nack).await;
556 } else {
557 #[cfg(feature="traces")]
558 error!("callback not registered on channel {}", self.channel);
559 } }
560 _ => unreachable!("dispatcher of channel {} receive unexpected frame {}", self.channel, frame),
561 }
562 }
563 _ = purge_timer.tick() => {
565 self.purge_consumer_resource();
566 }
567 else => {
568 break;
569 }
570 }
571 }
572 self.channel.set_is_open(false);
573
574 #[cfg(feature = "traces")]
575 info!("exit dispatcher of channel {}", self.channel);
576 });
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use tokio::time;
583
584 use crate::{
585 channel::{
586 BasicCancelArguments, BasicConsumeArguments, BasicPublishArguments, QueueBindArguments,
587 QueueDeclareArguments,
588 },
589 connection::{Connection, OpenConnectionArguments},
590 consumer::DefaultConsumer,
591 test_utils::setup_logging,
592 BasicProperties,
593 };
594
595 use super::{CONSUMER_EXPIRY_PERIOD, CONSUMER_PURGE_INTERVAL};
596
597 #[tokio::test]
598 async fn test_purge_consumer_resource() {
599 setup_logging();
600
601 let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami");
602 let connection = Connection::open(&args).await.unwrap();
603
604 let exchange_name = "amq.topic";
605 let routing_key = "test.purge.consumer";
606
607 let consumer_channel = connection.open_channel(None).await.unwrap();
608 let (queue_name, _, _) = consumer_channel
609 .queue_declare(QueueDeclareArguments::default())
610 .await
611 .unwrap()
612 .unwrap();
613 consumer_channel
614 .queue_bind(QueueBindArguments::new(
615 &queue_name,
616 exchange_name,
617 routing_key,
618 ))
619 .await
620 .unwrap();
621
622 let pub_channel = connection.open_channel(None).await.unwrap();
625
626 for _ in 0..100 {
627 pub_channel
628 .basic_publish(
629 BasicProperties::default(),
630 String::from("stale message").into_bytes(),
631 BasicPublishArguments::new(exchange_name, routing_key),
632 )
633 .await
634 .unwrap();
635 }
636 time::sleep(time::Duration::from_secs(1)).await;
638
639 let consumer_tag = consumer_channel
641 .basic_consume(
642 DefaultConsumer::new(false),
643 BasicConsumeArguments::new(&queue_name, "purge-tester")
644 .no_wait(true)
645 .finish(),
646 )
647 .await
648 .unwrap();
649
650 consumer_channel
652 .basic_cancel(
653 BasicCancelArguments::new(&consumer_tag)
654 .no_wait(true)
655 .finish(),
656 )
657 .await
658 .unwrap();
659
660 time::sleep(CONSUMER_PURGE_INTERVAL + CONSUMER_EXPIRY_PERIOD).await;
662 }
663}