use crate::{
BasicProperties, Error, Result,
channel_closer::ChannelCloser,
consumer_canceler::ConsumerCanceler,
consumer_status::ConsumerStatus,
error_holder::ErrorHolder,
internal_rpc::InternalRPCHandle,
message::{Delivery, DeliveryResult},
options::BasicConsumeOptions,
types::{ChannelId, PayloadSize},
types::{FieldTable, ShortString},
wakers::Wakers,
};
use flume::{Receiver, Sender};
use futures_core::stream::Stream;
use std::{
fmt,
future::{self, Future},
pin::Pin,
sync::{Arc, Mutex, MutexGuard},
task::{Context, Poll},
};
use tracing::{error, trace};
pub trait ConsumerDelegate: Send + Sync {
fn on_new_delivery(&self, delivery: DeliveryResult)
-> Pin<Box<dyn Future<Output = ()> + Send>>;
fn drop_prefetched_messages(&self) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(future::ready(()))
}
}
impl<
F: Future<Output = ()> + Send + 'static,
DeliveryHandler: Fn(DeliveryResult) -> F + Send + Sync + 'static,
> ConsumerDelegate for DeliveryHandler
{
fn on_new_delivery(
&self,
delivery: DeliveryResult,
) -> Pin<Box<dyn Future<Output = ()> + Send>> {
Box::pin(self(delivery))
}
}
#[derive(Clone)]
pub struct Consumer {
consumer_tag: ShortString,
inner: Arc<Mutex<Inner>>,
status: ConsumerStatus,
internal_rpc: InternalRPCHandle,
channel_closer: Option<Arc<ChannelCloser>>,
consumer_canceler: Option<Arc<ConsumerCanceler>>,
queue: ShortString,
options: BasicConsumeOptions,
arguments: FieldTable,
deliveries_in: Sender<DeliveryResult>,
wakers: Wakers,
error: ErrorHolder,
}
impl Consumer {
pub(crate) fn new(
consumer_tag: ShortString,
internal_rpc: InternalRPCHandle,
channel_closer: Option<Arc<ChannelCloser>>,
queue: ShortString,
options: BasicConsumeOptions,
arguments: FieldTable,
) -> Self {
let (sender, receiver) = flume::unbounded();
let status = ConsumerStatus::default();
Self {
consumer_tag: consumer_tag.clone(),
inner: Arc::new(Mutex::new(Inner::new(
consumer_tag,
receiver,
internal_rpc.clone(),
))),
status,
internal_rpc,
channel_closer,
consumer_canceler: None,
queue,
options,
arguments,
deliveries_in: sender,
wakers: Wakers::default(),
error: ErrorHolder::default(),
}
}
pub(crate) fn external(&self, channel_id: ChannelId) -> Self {
let mut consumer = self.clone();
consumer.consumer_canceler = Some(Arc::new(ConsumerCanceler::new(
channel_id,
self.consumer_tag.clone(),
self.status.clone(),
self.internal_rpc.clone(),
)));
consumer
}
pub(crate) fn error(&self) -> ErrorHolder {
self.error.clone()
}
pub fn tag(&self) -> ShortString {
self.consumer_tag.clone()
}
pub fn queue(&self) -> ShortString {
self.queue.clone()
}
pub(crate) fn options(&self) -> BasicConsumeOptions {
self.options
}
pub(crate) fn arguments(&self) -> FieldTable {
self.arguments.clone()
}
pub fn set_delegate<D: ConsumerDelegate + 'static>(&self, delegate: D) {
let mut inner = self.lock_inner();
let mut status = self.status.write();
while let Some(delivery) = inner.next_delivery() {
self.internal_rpc
.spawn_infallible(delegate.on_new_delivery(delivery));
}
status.set_delegate(Some(Arc::new(delegate)));
}
pub(crate) fn reset(&self) {
self.lock_inner()
.reset(self.options.no_ack, self.status.delegate());
}
pub(crate) fn start_new_delivery(&self, delivery: Delivery) {
self.lock_inner().current_message = Some(delivery);
}
pub(crate) fn handle_content_header_frame(
&self,
size: PayloadSize,
properties: BasicProperties,
) {
self.check_new_delivery(
self.lock_inner()
.handle_content_header_frame(size, properties),
);
}
pub(crate) fn handle_body_frame(&self, remaining_size: PayloadSize, payload: Vec<u8>) {
self.check_new_delivery(self.lock_inner().handle_body_frame(remaining_size, payload));
}
pub(crate) fn drop_prefetched_messages(&self) {
self.lock_inner()
.drop_prefetched_messages(self.status.delegate());
}
pub(crate) fn start_cancel(&self) {
self.status.write().start_cancel();
}
pub(crate) fn cancel(&self) {
trace!(consumer_tag=%self.consumer_tag, "cancel");
let mut status = self.status.write();
self.dispatch(
Ok(None),
"failed to send cancel to consumer",
status.delegate(),
);
status.cancel();
}
pub(crate) fn send_error(&self, error: Error) {
trace!(consumer_tag=%self.consumer_tag, "send_error");
self.dispatch(
Err(error),
"failed to send error to consumer",
self.status.delegate(),
);
}
pub(crate) fn set_error(&self, error: Error) {
trace!(consumer_tag=%self.consumer_tag, "set_error");
if let Err(cascading) = self.error.set(error.clone()) {
error!(%cascading, consumer_tag=%self.consumer_tag, "consumer already has an error");
}
self.send_error(error);
self.cancel();
}
fn lock_inner(&self) -> MutexGuard<'_, Inner> {
self.inner.lock().unwrap_or_else(|e| e.into_inner())
}
fn check_new_delivery(&self, delivery: Option<Delivery>) {
if let Some(delivery) = delivery {
self.dispatch(
Ok(Some(delivery)),
"failed to send delivery to consumer",
self.status.delegate(),
);
}
}
fn dispatch(
&self,
delivery: DeliveryResult,
error: &'static str,
delegate: Option<Arc<dyn ConsumerDelegate>>,
) {
if let Some(delegate) = delegate {
self.internal_rpc
.spawn_infallible(delegate.on_new_delivery(delivery));
} else if let Err(err) = self.deliveries_in.send(delivery) {
error!(?err, error);
}
self.wakers.wake();
}
}
impl fmt::Debug for Consumer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut debug = f.debug_struct("Consumer");
if let Ok(inner) = self.inner.try_lock() {
debug.field("tag", &inner.tag);
}
if let Some(status) = self.status.try_read() {
debug.field("state", &status.state());
}
debug.finish()
}
}
impl Drop for Consumer {
fn drop(&mut self) {
drop(self.consumer_canceler.take());
drop(self.channel_closer.take());
}
}
struct Inner {
current_message: Option<Delivery>,
deliveries_out: Receiver<DeliveryResult>,
internal_rpc: InternalRPCHandle,
tag: ShortString,
}
impl Inner {
fn new(
consumer_tag: ShortString,
deliveries_out: Receiver<DeliveryResult>,
internal_rpc: InternalRPCHandle,
) -> Self {
Self {
current_message: None,
deliveries_out,
internal_rpc,
tag: consumer_tag,
}
}
fn reset(&mut self, no_ack: bool, delegate: Option<Arc<dyn ConsumerDelegate>>) {
if !no_ack {
self.drop_prefetched_messages(delegate);
}
self.current_message = None;
}
fn next_delivery(&mut self) -> Option<DeliveryResult> {
self.deliveries_out.try_recv().ok()
}
fn handle_content_header_frame(
&mut self,
size: PayloadSize,
properties: BasicProperties,
) -> Option<Delivery> {
if let Some(delivery) = self.current_message.as_mut() {
delivery.properties = properties;
}
self.check_new_delivery_complete(size == 0)
}
fn handle_body_frame(
&mut self,
remaining_size: PayloadSize,
payload: Vec<u8>,
) -> Option<Delivery> {
if let Some(delivery) = self.current_message.as_mut() {
delivery.receive_content(payload);
}
self.check_new_delivery_complete(remaining_size == 0)
}
fn check_new_delivery_complete(&mut self, complete: bool) -> Option<Delivery> {
if !complete {
return None;
}
self.current_message
.take()
.inspect(|_| trace!(consumer_tag=%self.tag, "new_delivery"))
}
fn drop_prefetched_messages(&mut self, delegate: Option<Arc<dyn ConsumerDelegate>>) {
trace!(consumer_tag=%self.tag, "drop_prefetched_messages");
if let Some(delegate) = delegate {
self.internal_rpc
.spawn_infallible(delegate.drop_prefetched_messages());
}
while let Some(delivery) = self.next_delivery() {
if let Ok(Some(delivery)) = delivery {
delivery.acker.invalidate();
}
}
}
}
impl Stream for Consumer {
type Item = Result<Delivery>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
trace!("consumer poll_next");
self.wakers.register(cx.waker());
let mut inner = self.lock_inner();
trace!(
consumer_tag=%inner.tag,
"consumer poll; acquired inner lock"
);
if let Some(delivery) = inner.next_delivery() {
match delivery {
Ok(Some(delivery)) => {
trace!(
consumer_tag=%inner.tag,
delivery_tag=?delivery.delivery_tag,
"delivery"
);
Poll::Ready(Some(Ok(delivery)))
}
Ok(None) => {
trace!(consumer_tag=%inner.tag, "consumer canceled");
Poll::Ready(None)
}
Err(error) => Poll::Ready(Some(Err(error))),
}
} else {
trace!(consumer_tag=%inner.tag, "delivery; status=NotReady");
Poll::Pending
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
ConnectionStatus, ErrorKind, auth::DefaultAuthProvider, frames::Frames,
heartbeat::Heartbeat, internal_rpc::InternalRPC, runtime, secret_update::SecretUpdate,
socket_state::SocketState, uri::AMQPUri,
};
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
task::{Context, Poll, Wake, Waker},
};
use futures_lite::stream::StreamExt;
struct Counter(AtomicUsize);
impl Wake for Counter {
fn wake(self: Arc<Self>) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
fn create_consumer(tag: &str, queue: &str) -> Consumer {
let uri = AMQPUri::default();
let status = ConnectionStatus::new(&uri);
let runtime = runtime::default_runtime().unwrap();
let heartbeat = Heartbeat::new(status.clone(), runtime.clone());
let auth_provider = Arc::new(DefaultAuthProvider::new(&uri));
let secret_update = SecretUpdate::new(status.clone(), runtime.clone(), auth_provider);
let socket_state = SocketState::default();
let internal_rpc = InternalRPC::new(
runtime,
heartbeat,
secret_update,
Frames::default(),
socket_state.handle(),
);
Consumer::new(
ShortString::from(tag),
internal_rpc.handle(),
None,
queue.into(),
BasicConsumeOptions::default(),
FieldTable::default(),
)
}
#[test]
fn stream_on_cancel() {
let awoken_count = Arc::new(Counter(AtomicUsize::new(0)));
let waker = Waker::from(awoken_count.clone());
let mut cx = Context::from_waker(&waker);
let mut consumer = create_consumer("test-consumer", "test");
{
let mut next = consumer.next();
assert_eq!(awoken_count.0.load(Ordering::SeqCst), 0);
assert_eq!(Pin::new(&mut next).poll(&mut cx), Poll::Pending);
}
consumer.cancel();
{
let mut next = consumer.next();
assert_eq!(awoken_count.0.load(Ordering::SeqCst), 1);
assert_eq!(Pin::new(&mut next).poll(&mut cx), Poll::Ready(None));
}
}
#[test]
fn stream_on_error() {
let awoken_count = Arc::new(Counter(AtomicUsize::new(0)));
let waker = Waker::from(awoken_count.clone());
let mut cx = Context::from_waker(&waker);
let mut consumer = create_consumer("test-consumer", "test");
{
let mut next = consumer.next();
assert_eq!(awoken_count.0.load(Ordering::SeqCst), 0);
assert_eq!(Pin::new(&mut next).poll(&mut cx), Poll::Pending);
}
consumer.set_error(ErrorKind::ChannelsLimitReached.into());
{
let mut next = consumer.next();
assert_eq!(awoken_count.0.load(Ordering::SeqCst), 1);
assert_eq!(
Pin::new(&mut next).poll(&mut cx),
Poll::Ready(Some(Err(ErrorKind::ChannelsLimitReached.into())))
);
}
}
}