rabbitmq_stream_client/
consumer.rs

1use std::{
2    collections::HashMap,
3    pin::Pin,
4    sync::{
5        atomic::{
6            AtomicBool,
7            Ordering::{Relaxed, SeqCst},
8        },
9        Arc,
10    },
11    task::{Context, Poll},
12};
13
14use rabbitmq_stream_protocol::{
15    commands::subscribe::OffsetSpecification, message::Message, ResponseKind,
16};
17
18use core::option::Option::None;
19use futures::FutureExt;
20use std::future::Future;
21use tokio::sync::mpsc::{channel, Receiver, Sender};
22use tracing::trace;
23
24use crate::error::ConsumerStoreOffsetError;
25
26use crate::{
27    client::{MessageHandler, MessageResult},
28    error::{ConsumerCloseError, ConsumerCreateError, ConsumerDeliveryError},
29    Client, Environment, MetricsCollector,
30};
31use futures::{future::BoxFuture, task::AtomicWaker, Stream};
32
33type FilterPredicate = Option<Arc<dyn Fn(&Message) -> bool + Send + Sync>>;
34
35pub type ConsumerUpdateListener =
36    Arc<dyn Fn(u8, MessageContext) -> BoxFuture<'static, OffsetSpecification> + Send + Sync>;
37
38/// API for consuming RabbitMQ stream messages
39pub struct Consumer {
40    // Mandatory in case of manual offset tracking
41    name: Option<String>,
42    receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
43    internal: Arc<ConsumerInternal>,
44}
45
46struct ConsumerInternal {
47    name: Option<String>,
48    client: Client,
49    stream: String,
50    offset_specification: OffsetSpecification,
51    subscription_id: u8,
52    sender: Sender<Result<Delivery, ConsumerDeliveryError>>,
53    closed: Arc<AtomicBool>,
54    waker: AtomicWaker,
55    metrics_collector: Arc<dyn MetricsCollector>,
56    filter_configuration: Option<FilterConfiguration>,
57    consumer_update_listener: Option<ConsumerUpdateListener>,
58}
59
60impl ConsumerInternal {
61    fn is_closed(&self) -> bool {
62        self.closed.load(Relaxed)
63    }
64}
65
66#[derive(Clone)]
67pub struct FilterConfiguration {
68    filter_values: Vec<String>,
69    pub predicate: FilterPredicate,
70    match_unfiltered: bool,
71}
72
73impl FilterConfiguration {
74    pub fn new(filter_values: Vec<String>, match_unfiltered: bool) -> Self {
75        Self {
76            filter_values,
77            match_unfiltered,
78            predicate: None,
79        }
80    }
81
82    pub fn post_filter(
83        mut self,
84        predicate: impl Fn(&Message) -> bool + 'static + Send + Sync,
85    ) -> FilterConfiguration {
86        self.predicate = Some(Arc::new(predicate));
87        self
88    }
89}
90
91#[derive(Clone)]
92pub struct MessageContext {
93    name: String,
94    stream: String,
95    client: Client,
96}
97
98impl MessageContext {
99    pub fn name(&self) -> String {
100        self.name.clone()
101    }
102
103    pub fn stream(&self) -> String {
104        self.stream.clone()
105    }
106
107    pub fn client(&self) -> Client {
108        self.client.clone()
109    }
110}
111
112/// Builder for [`Consumer`]
113pub struct ConsumerBuilder {
114    pub(crate) consumer_name: Option<String>,
115    pub(crate) environment: Environment,
116    pub(crate) offset_specification: OffsetSpecification,
117    pub(crate) filter_configuration: Option<FilterConfiguration>,
118    pub(crate) consumer_update_listener: Option<ConsumerUpdateListener>,
119    pub(crate) client_provided_name: String,
120    pub(crate) properties: HashMap<String, String>,
121    pub(crate) is_single_active_consumer: bool,
122}
123
124impl ConsumerBuilder {
125    pub async fn build(mut self, stream: &str) -> Result<Consumer, ConsumerCreateError> {
126        if (self.is_single_active_consumer
127            || self.properties.contains_key("single-active-consumer"))
128            && self.consumer_name.is_none()
129        {
130            return Err(ConsumerCreateError::SingleActiveConsumerNotSupported);
131        }
132
133        let collector = self.environment.options.client_options.collector.clone();
134
135        let client = self
136            .environment
137            .create_consumer_client(stream, self.client_provided_name.clone())
138            .await?;
139
140        let subscription_id = 1;
141        let (tx, rx) = channel(10000);
142        let consumer = Arc::new(ConsumerInternal {
143            name: self.consumer_name.clone(),
144            subscription_id,
145            stream: stream.to_string(),
146            client: client.clone(),
147            offset_specification: self.offset_specification.clone(),
148            sender: tx,
149            closed: Arc::new(AtomicBool::new(false)),
150            waker: AtomicWaker::new(),
151            metrics_collector: collector,
152            filter_configuration: self.filter_configuration.clone(),
153            consumer_update_listener: self.consumer_update_listener.clone(),
154        });
155        let msg_handler = ConsumerMessageHandler(consumer.clone());
156        client.set_handler(msg_handler).await;
157
158        if let Some(filter_input) = self.filter_configuration {
159            if !client.filtering_supported() {
160                return Err(ConsumerCreateError::FilteringNotSupport);
161            }
162            for (index, item) in filter_input.filter_values.iter().enumerate() {
163                let key = format!("filter.{}", index);
164                self.properties.insert(key, item.to_owned());
165            }
166
167            let match_unfiltered_key = "match-unfiltered".to_string();
168            self.properties.insert(
169                match_unfiltered_key,
170                filter_input.match_unfiltered.to_string(),
171            );
172        }
173
174        if self.is_single_active_consumer {
175            self.properties
176                .insert("single-active-consumer".to_string(), "true".to_string());
177            self.properties
178                .insert("name".to_string(), self.consumer_name.clone().unwrap());
179        }
180
181        let response = client
182            .subscribe(
183                subscription_id,
184                stream,
185                self.offset_specification,
186                1,
187                self.properties.clone(),
188            )
189            .await?;
190
191        if response.is_ok() {
192            Ok(Consumer {
193                name: self.consumer_name.clone(),
194                receiver: rx,
195                internal: consumer,
196            })
197        } else {
198            Err(ConsumerCreateError::Create {
199                stream: stream.to_owned(),
200                status: response.code().clone(),
201            })
202        }
203    }
204
205    pub fn offset(mut self, offset_specification: OffsetSpecification) -> Self {
206        self.offset_specification = offset_specification;
207        self
208    }
209
210    pub fn client_provided_name(mut self, name: &str) -> Self {
211        self.client_provided_name = String::from(name);
212        self
213    }
214
215    pub fn name(mut self, consumer_name: &str) -> Self {
216        self.consumer_name = Some(String::from(consumer_name));
217        self
218    }
219
220    pub fn name_optional(mut self, consumer_name: Option<String>) -> Self {
221        self.consumer_name = consumer_name;
222        self
223    }
224
225    pub fn enable_single_active_consumer(mut self, is_single_active_consumer: bool) -> Self {
226        self.is_single_active_consumer = is_single_active_consumer;
227        self
228    }
229
230    pub fn filter_input(mut self, filter_configuration: Option<FilterConfiguration>) -> Self {
231        self.filter_configuration = filter_configuration;
232        self
233    }
234
235    pub fn consumer_update<Fut>(
236        mut self,
237        consumer_update_listener: impl Fn(u8, MessageContext) -> Fut + Send + Sync + 'static,
238    ) -> Self
239    where
240        Fut: Future<Output = OffsetSpecification> + Send + Sync + 'static,
241    {
242        let f = Arc::new(move |a, b| consumer_update_listener(a, b).boxed());
243        self.consumer_update_listener = Some(f);
244        self
245    }
246
247    pub fn consumer_update_arc(
248        mut self,
249        consumer_update_listener: Option<crate::consumer::ConsumerUpdateListener>,
250    ) -> Self {
251        self.consumer_update_listener = consumer_update_listener;
252        self
253    }
254
255    pub fn properties(mut self, properties: HashMap<String, String>) -> Self {
256        self.properties = properties;
257        self
258    }
259}
260
261impl Consumer {
262    /// Return an handle for current [`Consumer`]
263    pub fn handle(&self) -> ConsumerHandle {
264        ConsumerHandle(self.internal.clone())
265    }
266
267    /// Check if the consumer is closed
268    pub fn is_closed(&self) -> bool {
269        self.internal.is_closed()
270    }
271
272    pub async fn store_offset(&self, offset: u64) -> Result<(), ConsumerStoreOffsetError> {
273        if let Some(name) = &self.name {
274            self.internal
275                .client
276                .store_offset(name.as_str(), self.internal.stream.as_str(), offset)
277                .await
278                .map(Ok)?
279        } else {
280            Err(ConsumerStoreOffsetError::NameMissing)
281        }
282    }
283
284    pub async fn query_offset(&self) -> Result<u64, ConsumerStoreOffsetError> {
285        if let Some(name) = &self.name {
286            self.internal
287                .client
288                .query_offset(name.clone(), self.internal.stream.as_str())
289                .await
290                .map(Ok)?
291        } else {
292            Err(ConsumerStoreOffsetError::NameMissing)
293        }
294    }
295}
296
297impl Stream for Consumer {
298    type Item = Result<Delivery, ConsumerDeliveryError>;
299
300    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301        self.internal.waker.register(cx.waker());
302        let poll = Pin::new(&mut self.receiver).poll_recv(cx);
303        match (self.is_closed(), poll.is_ready()) {
304            (true, false) => Poll::Ready(None),
305            _ => poll,
306        }
307    }
308}
309
310/// Handler API for [`Consumer`]
311pub struct ConsumerHandle(Arc<ConsumerInternal>);
312
313impl ConsumerHandle {
314    /// Close the [`Consumer`] associated to this handle
315    pub async fn close(self) -> Result<(), ConsumerCloseError> {
316        self.internal_close().await
317    }
318
319    pub(crate) async fn internal_close(&self) -> Result<(), ConsumerCloseError> {
320        match self.0.closed.compare_exchange(false, true, SeqCst, SeqCst) {
321            Ok(false) => {
322                let response = self.0.client.unsubscribe(self.0.subscription_id).await?;
323                if response.is_ok() {
324                    self.0.waker.wake();
325                    self.0.client.close().await?;
326                    Ok(())
327                } else {
328                    Err(ConsumerCloseError::Close {
329                        stream: self.0.stream.clone(),
330                        status: response.code().clone(),
331                    })
332                }
333            }
334            _ => Err(ConsumerCloseError::AlreadyClosed),
335        }
336    }
337    /// Check if the consumer is closed
338    pub async fn is_closed(&self) -> bool {
339        self.0.is_closed()
340    }
341}
342
343struct ConsumerMessageHandler(Arc<ConsumerInternal>);
344
345#[async_trait::async_trait]
346impl MessageHandler for ConsumerMessageHandler {
347    async fn handle_message(&self, item: MessageResult) -> crate::RabbitMQStreamResult<()> {
348        match item {
349            Some(Ok(response)) => {
350                if let ResponseKind::Deliver(delivery) = response.kind_ref() {
351                    let mut offset = delivery.chunk_first_offset;
352
353                    let len = delivery.messages.len();
354                    let d = delivery.clone();
355                    trace!("Got delivery with messages {}", len);
356
357                    // // client filter
358                    let messages = match &self.0.filter_configuration {
359                        Some(filter_input) => {
360                            if let Some(f) = &filter_input.predicate {
361                                d.messages
362                                    .into_iter()
363                                    .filter(|message| f(message))
364                                    .collect::<Vec<Message>>()
365                            } else {
366                                d.messages
367                            }
368                        }
369
370                        None => d.messages,
371                    };
372
373                    for message in messages {
374                        if let OffsetSpecification::Offset(offset_) = self.0.offset_specification {
375                            if offset_ > offset {
376                                offset += 1;
377                                continue;
378                            }
379                        }
380                        let _ = self
381                            .0
382                            .sender
383                            .send(Ok(Delivery {
384                                name: self.0.name.clone(),
385                                stream: self.0.stream.clone(),
386                                subscription_id: self.0.subscription_id,
387                                message,
388                                offset,
389                            }))
390                            .await;
391                        offset += 1;
392                    }
393
394                    // TODO handle credit fail
395                    let _ = self.0.client.credit(self.0.subscription_id, 1).await;
396                    self.0.metrics_collector.consume(len as u64).await;
397                } else if let ResponseKind::ConsumerUpdate(consumer_update) = response.kind_ref() {
398                    trace!("Received a ConsumerUpdate message");
399                    // If no callback is provided by the user we will restart from Next by protocol
400                    // We need to respond to the server too
401                    if self.0.consumer_update_listener.is_none() {
402                        trace!("User defined callback is not provided");
403                        let offset_specification = OffsetSpecification::Next;
404                        let _ = self
405                            .0
406                            .client
407                            .consumer_update(
408                                consumer_update.get_correlation_id(),
409                                offset_specification,
410                            )
411                            .await;
412                    } else {
413                        // Otherwise the Offset specification is returned by the user callback
414                        let is_active = consumer_update.is_active();
415                        let message_context = MessageContext {
416                            name: self.0.name.clone().unwrap(),
417                            stream: self.0.stream.clone(),
418                            client: self.0.client.clone(),
419                        };
420                        let consumer_update_listener_callback =
421                            self.0.consumer_update_listener.clone().unwrap();
422                        let offset_specification =
423                            consumer_update_listener_callback(is_active, message_context).await;
424                        let _ = self
425                            .0
426                            .client
427                            .consumer_update(
428                                consumer_update.get_correlation_id(),
429                                offset_specification,
430                            )
431                            .await;
432                    }
433                }
434            }
435            Some(Err(err)) => {
436                let _ = self.0.sender.send(Err(err.into())).await;
437            }
438            None => {
439                trace!("Closing consumer");
440                self.0.closed.store(true, Relaxed);
441                self.0.waker.wake();
442            }
443        }
444        Ok(())
445    }
446}
447
448/// Envelope from incoming message
449#[derive(Debug)]
450pub struct Delivery {
451    name: Option<String>,
452    stream: String,
453    subscription_id: u8,
454    message: Message,
455    offset: u64,
456}
457
458impl Delivery {
459    /// Get a reference to the delivery's subscription id.
460    pub fn subscription_id(&self) -> u8 {
461        self.subscription_id
462    }
463
464    /// Get a reference to the delivery's stream name.
465    pub fn stream(&self) -> &String {
466        &self.stream
467    }
468
469    /// Get a reference to the delivery's message.
470    pub fn message(&self) -> &Message {
471        &self.message
472    }
473
474    /// Get a reference to the delivery's offset.
475    pub fn offset(&self) -> u64 {
476        self.offset
477    }
478
479    pub fn consumer_name(&self) -> Option<String> {
480        self.name.clone()
481    }
482}