use std::{
collections::HashMap,
pin::Pin,
sync::{
atomic::{
AtomicBool,
Ordering::{Relaxed, SeqCst},
},
Arc,
},
task::{Context, Poll},
};
use rabbitmq_stream_protocol::{
commands::subscribe::OffsetSpecification, message::Message, ResponseKind,
};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tracing::trace;
use crate::error::ConsumerStoreOffsetError;
use crate::{
client::{MessageHandler, MessageResult},
error::{ConsumerCloseError, ConsumerCreateError, ConsumerDeliveryError},
Client, ClientOptions, Environment, MetricsCollector,
};
use futures::{task::AtomicWaker, Stream};
use rand::rngs::StdRng;
use rand::{seq::SliceRandom, SeedableRng};
type FilterPredicate = Option<Arc<dyn Fn(&Message) -> bool + Send + Sync>>;
pub struct Consumer {
name: Option<String>,
receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
internal: Arc<ConsumerInternal>,
}
struct ConsumerInternal {
client: Client,
stream: String,
offset_specification: OffsetSpecification,
subscription_id: u8,
sender: Sender<Result<Delivery, ConsumerDeliveryError>>,
closed: Arc<AtomicBool>,
waker: AtomicWaker,
metrics_collector: Arc<dyn MetricsCollector>,
filter_configuration: Option<FilterConfiguration>,
}
impl ConsumerInternal {
fn is_closed(&self) -> bool {
self.closed.load(Relaxed)
}
}
#[derive(Clone)]
pub struct FilterConfiguration {
filter_values: Vec<String>,
pub predicate: FilterPredicate,
match_unfiltered: bool,
}
impl FilterConfiguration {
pub fn new(filter_values: Vec<String>, match_unfiltered: bool) -> Self {
Self {
filter_values,
match_unfiltered,
predicate: None,
}
}
pub fn post_filter(
mut self,
predicate: impl Fn(&Message) -> bool + 'static + Send + Sync,
) -> FilterConfiguration {
self.predicate = Some(Arc::new(predicate));
self
}
}
pub struct ConsumerBuilder {
pub(crate) consumer_name: Option<String>,
pub(crate) environment: Environment,
pub(crate) offset_specification: OffsetSpecification,
pub(crate) filter_configuration: Option<FilterConfiguration>,
}
impl ConsumerBuilder {
pub async fn build(self, stream: &str) -> Result<Consumer, ConsumerCreateError> {
let mut client = self.environment.create_client().await?;
let collector = self.environment.options.client_options.collector.clone();
if let Some(metadata) = client.metadata(vec![stream.to_string()]).await?.get(stream) {
if let Some(replica) = metadata.replicas.choose(&mut StdRng::from_entropy()) {
tracing::debug!(
"Picked replica {:?} out of possible candidates {:?} for stream {}",
replica,
metadata.replicas,
stream
);
let load_balancer_mode = self.environment.options.client_options.load_balancer_mode;
if load_balancer_mode {
let options = self.environment.options.client_options.clone();
loop {
let temp_client = Client::connect(options.clone()).await?;
let mapping = temp_client.connection_properties().await;
if let Some(advertised_host) = mapping.get("advertised_host") {
if *advertised_host == replica.host.clone() {
client.close().await?;
client = temp_client;
break;
}
}
temp_client.close().await?;
}
} else {
client.close().await?;
client = Client::connect(ClientOptions {
host: replica.host.clone(),
port: replica.port as u16,
..self.environment.options.client_options
})
.await?;
}
}
} else {
return Err(ConsumerCreateError::StreamDoesNotExist {
stream: stream.into(),
});
}
let subscription_id = 1;
let (tx, rx) = channel(10000);
let consumer = Arc::new(ConsumerInternal {
subscription_id,
stream: stream.to_string(),
client: client.clone(),
offset_specification: self.offset_specification.clone(),
sender: tx,
closed: Arc::new(AtomicBool::new(false)),
waker: AtomicWaker::new(),
metrics_collector: collector,
filter_configuration: self.filter_configuration.clone(),
});
let msg_handler = ConsumerMessageHandler(consumer.clone());
client.set_handler(msg_handler).await;
let mut properties = HashMap::new();
if let Some(filter_input) = self.filter_configuration {
if !client.filtering_supported() {
return Err(ConsumerCreateError::FilteringNotSupport);
}
for (index, item) in filter_input.filter_values.iter().enumerate() {
let key = format!("filter.{}", index);
properties.insert(key, item.to_owned());
}
let match_unfiltered_key = "match-unfiltered".to_string();
properties.insert(
match_unfiltered_key,
filter_input.match_unfiltered.to_string(),
);
}
let response = client
.subscribe(
subscription_id,
stream,
self.offset_specification,
1,
properties,
)
.await?;
if response.is_ok() {
Ok(Consumer {
name: self.consumer_name,
receiver: rx,
internal: consumer,
})
} else {
Err(ConsumerCreateError::Create {
stream: stream.to_owned(),
status: response.code().clone(),
})
}
}
pub fn offset(mut self, offset_specification: OffsetSpecification) -> Self {
self.offset_specification = offset_specification;
self
}
pub fn name(mut self, consumer_name: &str) -> Self {
self.consumer_name = Some(String::from(consumer_name));
self
}
pub fn filter_input(mut self, filter_configuration: Option<FilterConfiguration>) -> Self {
self.filter_configuration = filter_configuration;
self
}
}
impl Consumer {
pub fn handle(&self) -> ConsumerHandle {
ConsumerHandle(self.internal.clone())
}
pub fn is_closed(&self) -> bool {
self.internal.is_closed()
}
pub async fn store_offset(&self, offset: u64) -> Result<(), ConsumerStoreOffsetError> {
if let Some(name) = &self.name {
self.internal
.client
.store_offset(name.as_str(), self.internal.stream.as_str(), offset)
.await
.map(Ok)?
} else {
Err(ConsumerStoreOffsetError::NameMissing)
}
}
pub async fn query_offset(&self) -> Result<u64, ConsumerStoreOffsetError> {
if let Some(name) = &self.name {
self.internal
.client
.query_offset(name.clone(), self.internal.stream.as_str())
.await
.map(Ok)?
} else {
Err(ConsumerStoreOffsetError::NameMissing)
}
}
}
impl Stream for Consumer {
type Item = Result<Delivery, ConsumerDeliveryError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.internal.waker.register(cx.waker());
let poll = Pin::new(&mut self.receiver).poll_recv(cx);
match (self.is_closed(), poll.is_ready()) {
(true, false) => Poll::Ready(None),
_ => poll,
}
}
}
pub struct ConsumerHandle(Arc<ConsumerInternal>);
impl ConsumerHandle {
pub async fn close(self) -> Result<(), ConsumerCloseError> {
self.internal_close().await
}
pub(crate) async fn internal_close(&self) -> Result<(), ConsumerCloseError> {
match self.0.closed.compare_exchange(false, true, SeqCst, SeqCst) {
Ok(false) => {
let response = self.0.client.unsubscribe(self.0.subscription_id).await?;
if response.is_ok() {
self.0.waker.wake();
self.0.client.close().await?;
Ok(())
} else {
Err(ConsumerCloseError::Close {
stream: self.0.stream.clone(),
status: response.code().clone(),
})
}
}
_ => Err(ConsumerCloseError::AlreadyClosed),
}
}
pub async fn is_closed(&self) -> bool {
self.0.is_closed()
}
}
struct ConsumerMessageHandler(Arc<ConsumerInternal>);
#[async_trait::async_trait]
impl MessageHandler for ConsumerMessageHandler {
async fn handle_message(&self, item: MessageResult) -> crate::RabbitMQStreamResult<()> {
match item {
Some(Ok(response)) => {
if let ResponseKind::Deliver(delivery) = response.kind() {
let mut offset = delivery.chunk_first_offset;
let len = delivery.messages.len();
trace!("Got delivery with messages {}", len);
let messages = match &self.0.filter_configuration {
Some(filter_input) => {
if let Some(f) = &filter_input.predicate {
delivery
.messages
.into_iter()
.filter(|message| f(message))
.collect::<Vec<Message>>()
} else {
delivery.messages
}
}
None => delivery.messages,
};
for message in messages {
if let OffsetSpecification::Offset(offset_) = self.0.offset_specification {
if offset_ > offset {
offset += 1;
continue;
}
}
let _ = self
.0
.sender
.send(Ok(Delivery {
stream: self.0.stream.clone(),
subscription_id: self.0.subscription_id,
message,
offset,
}))
.await;
offset += 1;
}
let _ = self.0.client.credit(self.0.subscription_id, 1).await;
self.0.metrics_collector.consume(len as u64).await;
}
}
Some(Err(err)) => {
let _ = self.0.sender.send(Err(err.into())).await;
}
None => {
trace!("Closing consumer");
self.0.closed.store(true, Relaxed);
self.0.waker.wake();
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct Delivery {
stream: String,
subscription_id: u8,
message: Message,
offset: u64,
}
impl Delivery {
pub fn subscription_id(&self) -> u8 {
self.subscription_id
}
pub fn stream(&self) -> &String {
&self.stream
}
pub fn message(&self) -> &Message {
&self.message
}
pub fn offset(&self) -> u64 {
self.offset
}
}