amqprs/api/channel/
dispatcher.rs

1use std::collections::{HashMap, VecDeque};
2
3use tokio::{
4    sync::{mpsc, oneshot},
5    task::yield_now,
6    time,
7};
8
9use crate::{
10    api::{callbacks::ChannelCallback, channel::ReturnMessage},
11    channel::GetOkMessage,
12    frame::{CancelOk, CloseChannelOk, ContentBody, FlowOk, Frame, MethodHeader},
13    net::IncomingMessage,
14    BasicProperties, Return,
15};
16#[cfg(feature = "traces")]
17use tracing::{debug, error, info, trace};
18
19use super::{Channel, ConsumerMessage, DispatcherManagementCommand};
20
21/// Assumption:
22/// Depends on total number of consumers per channel, a reasonable value
23/// should be selected. Assume most cases, searching expiry consumers
24/// every `10` seconds does not impact performance
25const CONSUMER_PURGE_INTERVAL: time::Duration = time::Duration::from_secs(10);
26
27/// Assumption:
28/// Consumer is expected to be registered right after `consume/consume-ok` handshake is done
29/// which can't take longer than `5` seconds.
30/// After consumer is canceled, all on-the-fly messages should be received within `5` seconds
31const CONSUMER_EXPIRY_PERIOD: time::Duration = time::Duration::from_secs(5);
32
33/// Resource for handling consumer messages.
34struct ConsumerResource {
35    /// FIFO buffer for a delivery = `deliver + content`.
36    fifo: VecDeque<ConsumerMessage>,
37    /// tx channel to forward a delivery to a consumer task.
38    /// dispatcher task holds the tx half, and the consumer task holds the rx half.
39    tx: Option<mpsc::UnboundedSender<ConsumerMessage>>,
40    /// expiry time of fifo buffer
41    expiration: Option<time::Instant>,
42}
43
44impl ConsumerResource {
45    fn new() -> Self {
46        Self {
47            fifo: VecDeque::new(),
48            tx: None,
49            expiration: Some(time::Instant::now() + CONSUMER_EXPIRY_PERIOD),
50        }
51    }
52
53    fn register_tx(
54        &mut self,
55        tx: mpsc::UnboundedSender<ConsumerMessage>,
56    ) -> Option<mpsc::UnboundedSender<ConsumerMessage>> {
57        // once consumer's tx half is registered, clear the expiry timer
58        self.expiration.take();
59        self.tx.replace(tx)
60    }
61
62    fn get_tx(&self) -> Option<&mpsc::UnboundedSender<ConsumerMessage>> {
63        self.tx.as_ref()
64    }
65
66    fn get_expiration(&self) -> Option<&time::Instant> {
67        self.expiration.as_ref()
68    }
69
70    fn push_message(&mut self, message: ConsumerMessage) {
71        self.fifo.push_back(message);
72    }
73
74    fn pop_message(&mut self) -> Option<ConsumerMessage> {
75        self.fifo.pop_front()
76    }
77}
78
79enum State {
80    Initial,
81    Deliver,
82    GetOk,
83    GetEmpty,
84    Return,
85}
86
87/// Dispatcher for a channel.
88///
89/// Each channel will spawn a dispatcher.
90/// It handles channel level callbacks, incoming messages and registration commands.
91/// It also dispatch messages to consumers.
92pub(crate) struct ChannelDispatcher {
93    channel: Channel,
94    dispatcher_rx: mpsc::UnboundedReceiver<IncomingMessage>,
95    dispatcher_mgmt_rx: mpsc::UnboundedReceiver<DispatcherManagementCommand>,
96    consumer_resources: HashMap<String, ConsumerResource>,
97    get_content_responder: Option<mpsc::UnboundedSender<IncomingMessage>>,
98    responders: HashMap<&'static MethodHeader, oneshot::Sender<IncomingMessage>>,
99    callback: Option<Box<dyn ChannelCallback + Send + 'static>>,
100    state: State,
101}
102/////////////////////////////////////////////////////////////////////////////
103impl ChannelDispatcher {
104    pub(crate) fn new(
105        channel: Channel,
106        dispatcher_rx: mpsc::UnboundedReceiver<IncomingMessage>,
107        dispatcher_mgmt_rx: mpsc::UnboundedReceiver<DispatcherManagementCommand>,
108    ) -> Self {
109        Self {
110            channel,
111            dispatcher_rx,
112            dispatcher_mgmt_rx,
113            consumer_resources: HashMap::new(),
114            get_content_responder: None,
115            responders: HashMap::new(),
116            callback: None,
117            state: State::Initial,
118        }
119    }
120
121    /// Return the consumer resource if it always exists, otherwise create new one.
122    fn get_or_new_consumer_resource(&mut self, consumer_tag: &String) -> &mut ConsumerResource {
123        if !self.consumer_resources.contains_key(consumer_tag) {
124            let resource = ConsumerResource::new();
125            self.consumer_resources
126                .insert(consumer_tag.clone(), resource);
127        }
128        self.consumer_resources.get_mut(consumer_tag).unwrap()
129    }
130
131    /// purge expired consumer resource
132    fn purge_consumer_resource(&mut self) {
133        // find all resources that are expired
134        let purge_keys: Vec<String> = self
135            .consumer_resources
136            .iter()
137            .filter_map(|(k, v)| {
138                if let Some(expiration) = v.get_expiration() {
139                    if expiration < &time::Instant::now() {
140                        return Some(k.clone());
141                    }
142                }
143                None
144            })
145            .collect();
146
147        // purge expired resources
148        for key in purge_keys {
149            self.consumer_resources.remove(&key);
150            #[cfg(feature = "traces")]
151            debug!(
152                "purge stale consumer resource {} on channel {}",
153                key, self.channel
154            );
155        }
156    }
157    /// Remove the consumer resource.
158    ///
159    /// Becuase the tx channel will drop, the consumer task will also exit.
160    fn remove_consumer_resource(&mut self, consumer_tag: &String) -> Option<ConsumerResource> {
161        self.consumer_resources.remove(consumer_tag)
162    }
163
164    async fn forward_deliver(&mut self, consumer_message: ConsumerMessage) {
165        let consumer_tag = consumer_message
166            .deliver
167            .as_ref()
168            .unwrap()
169            .consumer_tag()
170            .clone();
171        let consumer = self.get_or_new_consumer_resource(&consumer_tag);
172        match consumer.get_tx() {
173            Some(consumer_tx) => {
174                if (consumer_tx.send(consumer_message)).is_err() {
175                    #[cfg(feature = "traces")]
176                    error!(
177                        "failed to dispatch message to consumer {} on channel {}",
178                        consumer_tag, self.channel
179                    );
180                }
181            }
182            None => {
183                #[cfg(feature = "traces")]
184                debug!("can't find consumer {}, message is buffered", consumer_tag);
185                consumer.push_message(consumer_message);
186                // try to yield for expected consumer registration command,
187                // it might reduceas buffering
188                yield_now().await;
189            }
190        };
191    }
192
193    async fn handle_return(
194        &mut self,
195        ret: Return,
196        basic_properties: BasicProperties,
197        content: Vec<u8>,
198    ) {
199        if let Some(ref mut cb) = self.callback {
200            cb.publish_return(&self.channel, ret, basic_properties, content)
201                .await;
202        } else {
203            #[cfg(feature = "traces")]
204            error!("callback not registered on channel {}", self.channel);
205        }
206    }
207    /// Spawn dispatcher task.
208    pub(in crate::api) async fn spawn(mut self) {
209        tokio::spawn(async move {
210            // aggregation buffer for `deliver + content` messages to a consumer
211            let mut message_buffer = ConsumerMessage {
212                deliver: None,
213                basic_properties: None,
214                content: None,
215                remaining: 0,
216            };
217            // buffer for `return + content` messages due to publish failure.
218            let mut return_buffer = ReturnMessage {
219                ret: None,
220                basic_properties: None,
221                content: None,
222                remaining: 0,
223            };
224            // buffer for `getok + content` messages
225            let mut getok_content_buffer = GetOkMessage {
226                content: None,
227                remaining: 0,
228            };
229
230            #[cfg(feature = "traces")]
231            trace!("starts up dispatcher task of channel {}", self.channel);
232
233            let mut purge_timer = time::interval(CONSUMER_PURGE_INTERVAL);
234            purge_timer.tick().await;
235            // main loop of dispatcher
236            loop {
237                tokio::select! {
238                    biased;
239
240                    // the dispatcher also holds a `Channel` instance, so this
241                    // should never return `None`
242                    command = self.dispatcher_mgmt_rx.recv() => {
243                        // handle command channel error
244                        let cmd = match command {
245                            None => {
246                                unreachable!("dispatcher command channel closed, {}", self.channel);
247                            },
248                            Some(v) => v,
249                        };
250                        // handle command
251                        match cmd {
252                            DispatcherManagementCommand::RegisterContentConsumer(cmd) => {
253                                #[cfg(feature="traces")]
254                                info!("register consumer {}", cmd.consumer_tag);
255                                let consumer = self.get_or_new_consumer_resource(&cmd.consumer_tag);
256                                consumer.register_tx(cmd.consumer_tx);
257                                // forward buffered messages
258                                while !consumer.fifo.is_empty() {
259                                    #[cfg(feature="traces")]
260                                    trace!("consumer {} total buffered messages: {}", cmd.consumer_tag, consumer.fifo.len());
261                                    let msg = consumer.pop_message().unwrap();
262                                    if let Err(_err) = consumer.get_tx().unwrap().send(msg) {
263                                        #[cfg(feature="traces")]
264                                        error!("failed to forward message to consumer {}", cmd.consumer_tag);
265                                    }
266                                }
267                            },
268                            DispatcherManagementCommand::DeregisterContentConsumer(cmd) => {
269                                if let Some(consumer) = self.remove_consumer_resource(&cmd.consumer_tag) {
270                                    #[cfg(feature="traces")]
271                                    info!("deregister consumer {}, total buffered messages: {}",
272                                        cmd.consumer_tag, consumer.fifo.len()
273                                    );
274                                }
275                            },
276                            DispatcherManagementCommand::RegisterGetContentResponder(cmd) => {
277                                self.get_content_responder.replace(cmd.tx);
278                            }
279                            DispatcherManagementCommand::RegisterOneshotResponder(cmd) => {
280                                self.responders.insert(cmd.method_header, cmd.responder);
281                                cmd.acker.send(()).unwrap();
282                            }
283                            DispatcherManagementCommand::RegisterChannelCallback(cmd) => {
284                                self.callback.replace(cmd.callback);
285                                #[cfg(feature="traces")]
286                                debug!("callback registered on channel {}", self.channel);
287                            }
288                        }
289                    }
290                    // only one tx half held by connection handler, once the tx half dorp
291                    // it will return `None`, so exit the dispatcher
292                    message = self.dispatcher_rx.recv() => {
293                        // handle message channel error
294                        let frame = match message {
295                            None => {
296                                // exit
297                                #[cfg(feature="traces")]
298                                debug!("dispatcher mpsc channel closed, channel {}", self.channel);
299                                break;
300                            },
301                            Some(v) => v,
302                        };
303                        // handle frames
304                        match frame {
305                            ////////////////////////////////////////////////
306                            // frames for closing channel
307                            // channel.close-ok response from server
308                            Frame::CloseChannelOk(method_header, close_channel_ok) => {
309                                self.channel.set_is_open(false);
310
311                                match self.responders.remove(method_header) {
312                                    Some(responder) => responder.send(close_channel_ok.into_frame()).unwrap(),
313                                    None => unreachable!("responder must be registered for {} on channel {}",
314                                    close_channel_ok.into_frame(), self.channel),
315                                }
316                                // exit
317                                break;
318                            }
319                            // channel.close request from server
320                            Frame::CloseChannel(_, close_channel) => {
321                                // callback
322                                if let Some(ref mut cb) = self.callback {
323                                    if let Err(err) = cb.close(&self.channel, close_channel).await {
324                                      #[cfg(feature="traces")]
325                                      error!("close callback returns error on channel {}, cause: {}", self.channel, err);
326                                      // exit immediately, no response to server
327                                      break;
328                                    };
329                                } else {
330                                    #[cfg(feature="traces")]
331                                    error!("callback not registered on channel {}", self.channel);
332                                }
333                                self.channel.set_is_open(false);
334
335                                // implictly respond OK to server
336                                self.channel.shared.outgoing_tx
337                                .send((self.channel.channel_id(), CloseChannelOk.into_frame()))
338                                .await.unwrap();
339                                // exit
340                                break;
341                            }
342                            ////////////////////////////////////////////////
343                            // the method frames followed by content frames
344                            Frame::GetEmpty(_, get_empty) => {
345                                self.state = State::GetEmpty;
346
347                                self.get_content_responder.take()
348                                .expect("get responder must be registered")
349                                .send(get_empty.into_frame()).unwrap();
350                            }
351                            Frame::GetOk(_, get_ok) => {
352                                self.state = State::GetOk;
353
354                                self.get_content_responder.as_ref()
355                                .expect("get responder must be registered")
356                                .send(get_ok.into_frame()).unwrap();
357                            }
358                            Frame::Return(_, ret) => {
359                                self.state = State::Return;
360                                return_buffer.ret = Some(ret);
361                            }
362                            Frame::Deliver(_, deliver) => {
363                                self.state = State::Deliver;
364                                message_buffer.deliver = Some(deliver);
365                            }
366                            Frame::ContentHeader(header) => {
367                                match self.state {
368                                    State::Deliver => {
369                                        message_buffer.remaining = header.common.body_size.try_into().unwrap();
370                                        // do not wait for content body frame if content body size is zero
371                                        if message_buffer.remaining == 0 {
372                                            let consumer_message  = ConsumerMessage {
373                                                deliver: message_buffer.deliver.take(),
374                                                basic_properties: Some(header.basic_properties),
375                                                content: Some(Vec::new()),
376                                                remaining: 0,
377                                            };
378                                            self.forward_deliver(consumer_message).await;
379                                        } else {
380                                            message_buffer.basic_properties = Some(header.basic_properties);
381                                            message_buffer.content = Some(Vec::new());
382                                        }
383                                    },
384                                    State::GetOk => {
385                                        getok_content_buffer.remaining = header.common.body_size.try_into().unwrap();
386
387                                        let responder = self.get_content_responder.as_ref().expect("get responder must be registered");
388                                        responder.send(header.into_frame()).unwrap();
389                                        // do not wait for content body frame if content body size is zero
390                                        if getok_content_buffer.remaining  == 0 {
391                                            responder.send(ContentBody::new(Vec::new()).into_frame()).unwrap();
392                                        } else {
393                                            getok_content_buffer.content = Some(Vec::new());
394                                        }
395                                    },
396                                    State::Return => {
397                                        return_buffer.remaining = header.common.body_size.try_into().unwrap();
398
399                                        if return_buffer.remaining == 0 {
400                                            // do not wait for content body frame if content body size is zero
401                                            self.handle_return(return_buffer.ret.take().unwrap(), header.basic_properties, Vec::new()).await;
402                                        } else {
403                                            return_buffer.basic_properties = Some(header.basic_properties);
404                                            return_buffer.content = Some(Vec::new());
405                                        }
406                                    },
407                                    _  => unreachable!("invalid dispatcher state"),
408                                }
409                            }
410                            Frame::ContentBody(body) => {
411                                match self.state {
412                                    State::Deliver => {
413                                        let mut content_buffer = message_buffer.content.take().unwrap();
414                                        content_buffer.extend_from_slice(&body.inner);
415                                        message_buffer.content.replace(content_buffer);
416                                        // calculate remaining size of content body
417                                        message_buffer.remaining = message_buffer.remaining.checked_sub(body.inner.len()).expect("should never overflow");
418
419                                        if message_buffer.remaining == 0 {
420                                            let consumer_message  = ConsumerMessage {
421                                                deliver: message_buffer.deliver.take(),
422                                                basic_properties: message_buffer.basic_properties.take(),
423                                                content: message_buffer.content.take(),
424                                                remaining: message_buffer.remaining,
425                                            };
426                                            self.forward_deliver(consumer_message).await;
427                                        }
428                                    }
429                                    State::GetOk => {
430                                        let mut content_buffer = getok_content_buffer.content.take().unwrap();
431                                        content_buffer.extend_from_slice(&body.inner);
432                                        getok_content_buffer.content.replace(content_buffer);
433                                        getok_content_buffer.remaining = getok_content_buffer.remaining.checked_sub(body.inner.len()).expect("should never overflow");
434                                        if getok_content_buffer.remaining == 0 {
435                                            let content = getok_content_buffer.content.take().unwrap();
436                                            self.get_content_responder.take()
437                                            .expect("get responder must be registered")
438                                            .send(ContentBody::new(content).into_frame()).unwrap();
439                                        }
440                                    },
441                                    State::Return => {
442                                        let mut content_buffer = return_buffer.content.take().unwrap();
443                                        content_buffer.extend_from_slice(&body.inner);
444                                        return_buffer.content.replace(content_buffer);
445                                        return_buffer.remaining = return_buffer.remaining.checked_sub(body.inner.len()).expect("should never overflow");
446
447                                        if return_buffer.remaining == 0 {
448                                            self.handle_return(
449                                                return_buffer.ret.take().unwrap(),
450                                                return_buffer.basic_properties.take().unwrap(),
451                                                return_buffer.content.take().unwrap()).await;
452                                        }
453                                    },
454                                    State::Initial | State::GetEmpty  => unreachable!("invalid dispatcher state on channel {}", self.channel),
455                                }
456                            }
457                            ////////////////////////////////////////////////
458                            // synchronous response frames
459                            Frame::FlowOk(method_header, _)
460                            // | Frame::RequestOk(method_header, _) // Deprecated
461                            | Frame::DeclareOk(method_header, _)
462                            | Frame::DeleteOk(method_header, _)
463                            | Frame::BindOk(method_header, _)
464                            | Frame::UnbindOk(method_header, _)
465                            | Frame::DeclareQueueOk(method_header, _)
466                            | Frame::BindQueueOk(method_header, _)
467                            | Frame::PurgeQueueOk(method_header, _)
468                            | Frame::DeleteQueueOk(method_header, _)
469                            | Frame::UnbindQueueOk(method_header, _)
470                            | Frame::QosOk(method_header, _)
471                            | Frame::ConsumeOk(method_header, _)
472                            | Frame::CancelOk(method_header, _)
473                            | Frame::RecoverOk(method_header, _)
474                            | Frame::SelectOk(method_header, _)
475                            | Frame::TxSelectOk(method_header, _)
476                            | Frame::TxCommitOk(method_header, _)
477                            | Frame::TxRollbackOk(method_header, _) => {
478                                // handle synchronous response
479                                match self.responders.remove(method_header)
480                                {
481                                    Some(responder) => {
482                                        if let Err(response) = responder.send(frame) {
483                                            #[cfg(feature="traces")]
484                                            error!(
485                                                "failed to dispatch {} to channel {}",
486                                                response, self.channel
487                                            );
488                                        }
489                                    }
490                                    None => unreachable!(
491                                        "responder must be registered for {} on channel {}",
492                                        frame, self.channel
493                                    ),
494                                }
495                            }
496                            //////////////////////////////////////////////////////////
497                            // asynchronous request frames
498                            Frame::Flow(_, flow) => {
499                                // callback
500                                if let Some(ref mut cb) = self.callback {
501                                    match cb.flow(&self.channel, flow.active).await {
502                                      Err(err) => {
503                                        #[cfg(feature="traces")]
504                                        error!("flow callback error on channel {}, cause: '{}'.", self.channel, err);
505                                      }
506                                      Ok(active) => {
507                                         // respond to server that we have handled the request
508                                         self.channel.shared.outgoing_tx
509                                         .send((self.channel.channel_id(), FlowOk::new(active).into_frame()))
510                                         .await.unwrap();
511                                      }
512                                    };
513                                } else {
514                                    #[cfg(feature="traces")]
515                                    error!("callback not registered on channel {}", self.channel);
516                                }
517                            }
518                            Frame::Cancel(_, cancel) => {
519                                // callback
520                                if let Some(ref mut cb) = self.callback {
521                                    let consumer_tag = cancel.consumer_tag().clone();
522                                    let no_wait = cancel.no_wait();
523                                    match cb.cancel(&self.channel, cancel).await {
524                                      Err(err) => {
525                                        #[cfg(feature="traces")]
526                                        error!("cancel callback error on channel {}, cause: '{}'.", self.channel, err);
527                                      }
528                                      Ok(_) => {
529                                        self.remove_consumer_resource(&consumer_tag);
530
531                                        // respond to server that we have handled the request
532                                        if !no_wait  {
533                                            self.channel.shared.outgoing_tx
534                                            .send((self.channel.channel_id(), CancelOk::new(consumer_tag.try_into().unwrap()).into_frame()))
535                                            .await.unwrap();
536                                        }
537                                      }
538                                    };
539                                } else {
540                                    #[cfg(feature="traces")]
541                                    error!("callback not registered on channel {}", self.channel);
542                                }
543                            }
544                            // in confirmed mode
545                            Frame::Ack(_, ack) => {
546                                if let Some(ref mut cb) = self.callback {
547                                    cb.publish_ack(&self.channel, ack).await;
548                                } else {
549                                    #[cfg(feature="traces")]
550                                    error!("callback not registered on channel {}", self.channel);
551                                }
552                            }
553                            Frame::Nack(_, nack) => {
554                                if let Some(ref mut cb) = self.callback {
555                                    cb.publish_nack(&self.channel, nack).await;
556                                } else {
557                                    #[cfg(feature="traces")]
558                                    error!("callback not registered on channel {}", self.channel);
559                                }                            }
560                            _ => unreachable!("dispatcher of channel {} receive unexpected frame {}", self.channel, frame),
561                        }
562                    }
563                    // purge stale consumer resource
564                    _ = purge_timer.tick() => {
565                        self.purge_consumer_resource();
566                    }
567                    else => {
568                        break;
569                    }
570                }
571            }
572            self.channel.set_is_open(false);
573
574            #[cfg(feature = "traces")]
575            info!("exit dispatcher of channel {}", self.channel);
576        });
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use tokio::time;
583
584    use crate::{
585        channel::{
586            BasicCancelArguments, BasicConsumeArguments, BasicPublishArguments, QueueBindArguments,
587            QueueDeclareArguments,
588        },
589        connection::{Connection, OpenConnectionArguments},
590        consumer::DefaultConsumer,
591        test_utils::setup_logging,
592        BasicProperties,
593    };
594
595    use super::{CONSUMER_EXPIRY_PERIOD, CONSUMER_PURGE_INTERVAL};
596
597    #[tokio::test]
598    async fn test_purge_consumer_resource() {
599        setup_logging();
600
601        let args = OpenConnectionArguments::new("localhost", 5672, "user", "bitnami");
602        let connection = Connection::open(&args).await.unwrap();
603
604        let exchange_name = "amq.topic";
605        let routing_key = "test.purge.consumer";
606
607        let consumer_channel = connection.open_channel(None).await.unwrap();
608        let (queue_name, _, _) = consumer_channel
609            .queue_declare(QueueDeclareArguments::default())
610            .await
611            .unwrap()
612            .unwrap();
613        consumer_channel
614            .queue_bind(QueueBindArguments::new(
615                &queue_name,
616                exchange_name,
617                routing_key,
618            ))
619            .await
620            .unwrap();
621
622        // publish messages first so that messages
623        // are redelivered immediately once we start consumer
624        let pub_channel = connection.open_channel(None).await.unwrap();
625
626        for _ in 0..100 {
627            pub_channel
628                .basic_publish(
629                    BasicProperties::default(),
630                    String::from("stale message").into_bytes(),
631                    BasicPublishArguments::new(exchange_name, routing_key),
632                )
633                .await
634                .unwrap();
635        }
636        // wait for publish done
637        time::sleep(time::Duration::from_secs(1)).await;
638
639        // start consumer with no_wait = true
640        let consumer_tag = consumer_channel
641            .basic_consume(
642                DefaultConsumer::new(false),
643                BasicConsumeArguments::new(&queue_name, "purge-tester")
644                    .no_wait(true)
645                    .finish(),
646            )
647            .await
648            .unwrap();
649
650        // immediately cancel consumer with no_wait = true
651        consumer_channel
652            .basic_cancel(
653                BasicCancelArguments::new(&consumer_tag)
654                    .no_wait(true)
655                    .finish(),
656            )
657            .await
658            .unwrap();
659
660        // the consumer resource should be purged within `CONSUMER_PURGE_INTERVAL + CONSUMER_EXPIRY_PERIOD`
661        time::sleep(CONSUMER_PURGE_INTERVAL + CONSUMER_EXPIRY_PERIOD).await;
662    }
663}