1use std::{
2 collections::HashMap,
3 pin::Pin,
4 sync::{
5 atomic::{
6 AtomicBool,
7 Ordering::{Relaxed, SeqCst},
8 },
9 Arc,
10 },
11 task::{Context, Poll},
12};
13
14use rabbitmq_stream_protocol::{
15 commands::subscribe::OffsetSpecification, message::Message, ResponseKind,
16};
17
18use core::option::Option::None;
19use futures::FutureExt;
20use std::future::Future;
21use tokio::sync::mpsc::{channel, Receiver, Sender};
22use tracing::trace;
23
24use crate::error::ConsumerStoreOffsetError;
25
26use crate::{
27 client::{MessageHandler, MessageResult},
28 error::{ConsumerCloseError, ConsumerCreateError, ConsumerDeliveryError},
29 Client, Environment, MetricsCollector,
30};
31use futures::{future::BoxFuture, task::AtomicWaker, Stream};
32
33type FilterPredicate = Option<Arc<dyn Fn(&Message) -> bool + Send + Sync>>;
34
35pub type ConsumerUpdateListener =
36 Arc<dyn Fn(u8, MessageContext) -> BoxFuture<'static, OffsetSpecification> + Send + Sync>;
37
38pub struct Consumer {
40 name: Option<String>,
42 receiver: Receiver<Result<Delivery, ConsumerDeliveryError>>,
43 internal: Arc<ConsumerInternal>,
44}
45
46struct ConsumerInternal {
47 name: Option<String>,
48 client: Client,
49 stream: String,
50 offset_specification: OffsetSpecification,
51 subscription_id: u8,
52 sender: Sender<Result<Delivery, ConsumerDeliveryError>>,
53 closed: Arc<AtomicBool>,
54 waker: AtomicWaker,
55 metrics_collector: Arc<dyn MetricsCollector>,
56 filter_configuration: Option<FilterConfiguration>,
57 consumer_update_listener: Option<ConsumerUpdateListener>,
58}
59
60impl ConsumerInternal {
61 fn is_closed(&self) -> bool {
62 self.closed.load(Relaxed)
63 }
64}
65
66#[derive(Clone)]
67pub struct FilterConfiguration {
68 filter_values: Vec<String>,
69 pub predicate: FilterPredicate,
70 match_unfiltered: bool,
71}
72
73impl FilterConfiguration {
74 pub fn new(filter_values: Vec<String>, match_unfiltered: bool) -> Self {
75 Self {
76 filter_values,
77 match_unfiltered,
78 predicate: None,
79 }
80 }
81
82 pub fn post_filter(
83 mut self,
84 predicate: impl Fn(&Message) -> bool + 'static + Send + Sync,
85 ) -> FilterConfiguration {
86 self.predicate = Some(Arc::new(predicate));
87 self
88 }
89}
90
91#[derive(Clone)]
92pub struct MessageContext {
93 name: String,
94 stream: String,
95 client: Client,
96}
97
98impl MessageContext {
99 pub fn name(&self) -> String {
100 self.name.clone()
101 }
102
103 pub fn stream(&self) -> String {
104 self.stream.clone()
105 }
106
107 pub fn client(&self) -> Client {
108 self.client.clone()
109 }
110}
111
112pub struct ConsumerBuilder {
114 pub(crate) consumer_name: Option<String>,
115 pub(crate) environment: Environment,
116 pub(crate) offset_specification: OffsetSpecification,
117 pub(crate) filter_configuration: Option<FilterConfiguration>,
118 pub(crate) consumer_update_listener: Option<ConsumerUpdateListener>,
119 pub(crate) client_provided_name: String,
120 pub(crate) properties: HashMap<String, String>,
121 pub(crate) is_single_active_consumer: bool,
122}
123
124impl ConsumerBuilder {
125 pub async fn build(mut self, stream: &str) -> Result<Consumer, ConsumerCreateError> {
126 if (self.is_single_active_consumer
127 || self.properties.contains_key("single-active-consumer"))
128 && self.consumer_name.is_none()
129 {
130 return Err(ConsumerCreateError::SingleActiveConsumerNotSupported);
131 }
132
133 let collector = self.environment.options.client_options.collector.clone();
134
135 let client = self
136 .environment
137 .create_consumer_client(stream, self.client_provided_name.clone())
138 .await?;
139
140 let subscription_id = 1;
141 let (tx, rx) = channel(10000);
142 let consumer = Arc::new(ConsumerInternal {
143 name: self.consumer_name.clone(),
144 subscription_id,
145 stream: stream.to_string(),
146 client: client.clone(),
147 offset_specification: self.offset_specification.clone(),
148 sender: tx,
149 closed: Arc::new(AtomicBool::new(false)),
150 waker: AtomicWaker::new(),
151 metrics_collector: collector,
152 filter_configuration: self.filter_configuration.clone(),
153 consumer_update_listener: self.consumer_update_listener.clone(),
154 });
155 let msg_handler = ConsumerMessageHandler(consumer.clone());
156 client.set_handler(msg_handler).await;
157
158 if let Some(filter_input) = self.filter_configuration {
159 if !client.filtering_supported() {
160 return Err(ConsumerCreateError::FilteringNotSupport);
161 }
162 for (index, item) in filter_input.filter_values.iter().enumerate() {
163 let key = format!("filter.{}", index);
164 self.properties.insert(key, item.to_owned());
165 }
166
167 let match_unfiltered_key = "match-unfiltered".to_string();
168 self.properties.insert(
169 match_unfiltered_key,
170 filter_input.match_unfiltered.to_string(),
171 );
172 }
173
174 if self.is_single_active_consumer {
175 self.properties
176 .insert("single-active-consumer".to_string(), "true".to_string());
177 self.properties
178 .insert("name".to_string(), self.consumer_name.clone().unwrap());
179 }
180
181 let response = client
182 .subscribe(
183 subscription_id,
184 stream,
185 self.offset_specification,
186 1,
187 self.properties.clone(),
188 )
189 .await?;
190
191 if response.is_ok() {
192 Ok(Consumer {
193 name: self.consumer_name.clone(),
194 receiver: rx,
195 internal: consumer,
196 })
197 } else {
198 Err(ConsumerCreateError::Create {
199 stream: stream.to_owned(),
200 status: response.code().clone(),
201 })
202 }
203 }
204
205 pub fn offset(mut self, offset_specification: OffsetSpecification) -> Self {
206 self.offset_specification = offset_specification;
207 self
208 }
209
210 pub fn client_provided_name(mut self, name: &str) -> Self {
211 self.client_provided_name = String::from(name);
212 self
213 }
214
215 pub fn name(mut self, consumer_name: &str) -> Self {
216 self.consumer_name = Some(String::from(consumer_name));
217 self
218 }
219
220 pub fn name_optional(mut self, consumer_name: Option<String>) -> Self {
221 self.consumer_name = consumer_name;
222 self
223 }
224
225 pub fn enable_single_active_consumer(mut self, is_single_active_consumer: bool) -> Self {
226 self.is_single_active_consumer = is_single_active_consumer;
227 self
228 }
229
230 pub fn filter_input(mut self, filter_configuration: Option<FilterConfiguration>) -> Self {
231 self.filter_configuration = filter_configuration;
232 self
233 }
234
235 pub fn consumer_update<Fut>(
236 mut self,
237 consumer_update_listener: impl Fn(u8, MessageContext) -> Fut + Send + Sync + 'static,
238 ) -> Self
239 where
240 Fut: Future<Output = OffsetSpecification> + Send + Sync + 'static,
241 {
242 let f = Arc::new(move |a, b| consumer_update_listener(a, b).boxed());
243 self.consumer_update_listener = Some(f);
244 self
245 }
246
247 pub fn consumer_update_arc(
248 mut self,
249 consumer_update_listener: Option<crate::consumer::ConsumerUpdateListener>,
250 ) -> Self {
251 self.consumer_update_listener = consumer_update_listener;
252 self
253 }
254
255 pub fn properties(mut self, properties: HashMap<String, String>) -> Self {
256 self.properties = properties;
257 self
258 }
259}
260
261impl Consumer {
262 pub fn handle(&self) -> ConsumerHandle {
264 ConsumerHandle(self.internal.clone())
265 }
266
267 pub fn is_closed(&self) -> bool {
269 self.internal.is_closed()
270 }
271
272 pub async fn store_offset(&self, offset: u64) -> Result<(), ConsumerStoreOffsetError> {
273 if let Some(name) = &self.name {
274 self.internal
275 .client
276 .store_offset(name.as_str(), self.internal.stream.as_str(), offset)
277 .await
278 .map(Ok)?
279 } else {
280 Err(ConsumerStoreOffsetError::NameMissing)
281 }
282 }
283
284 pub async fn query_offset(&self) -> Result<u64, ConsumerStoreOffsetError> {
285 if let Some(name) = &self.name {
286 self.internal
287 .client
288 .query_offset(name.clone(), self.internal.stream.as_str())
289 .await
290 .map(Ok)?
291 } else {
292 Err(ConsumerStoreOffsetError::NameMissing)
293 }
294 }
295}
296
297impl Stream for Consumer {
298 type Item = Result<Delivery, ConsumerDeliveryError>;
299
300 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
301 self.internal.waker.register(cx.waker());
302 let poll = Pin::new(&mut self.receiver).poll_recv(cx);
303 match (self.is_closed(), poll.is_ready()) {
304 (true, false) => Poll::Ready(None),
305 _ => poll,
306 }
307 }
308}
309
310pub struct ConsumerHandle(Arc<ConsumerInternal>);
312
313impl ConsumerHandle {
314 pub async fn close(self) -> Result<(), ConsumerCloseError> {
316 self.internal_close().await
317 }
318
319 pub(crate) async fn internal_close(&self) -> Result<(), ConsumerCloseError> {
320 match self.0.closed.compare_exchange(false, true, SeqCst, SeqCst) {
321 Ok(false) => {
322 let response = self.0.client.unsubscribe(self.0.subscription_id).await?;
323 if response.is_ok() {
324 self.0.waker.wake();
325 self.0.client.close().await?;
326 Ok(())
327 } else {
328 Err(ConsumerCloseError::Close {
329 stream: self.0.stream.clone(),
330 status: response.code().clone(),
331 })
332 }
333 }
334 _ => Err(ConsumerCloseError::AlreadyClosed),
335 }
336 }
337 pub async fn is_closed(&self) -> bool {
339 self.0.is_closed()
340 }
341}
342
343struct ConsumerMessageHandler(Arc<ConsumerInternal>);
344
345#[async_trait::async_trait]
346impl MessageHandler for ConsumerMessageHandler {
347 async fn handle_message(&self, item: MessageResult) -> crate::RabbitMQStreamResult<()> {
348 match item {
349 Some(Ok(response)) => {
350 if let ResponseKind::Deliver(delivery) = response.kind_ref() {
351 let mut offset = delivery.chunk_first_offset;
352
353 let len = delivery.messages.len();
354 let d = delivery.clone();
355 trace!("Got delivery with messages {}", len);
356
357 let messages = match &self.0.filter_configuration {
359 Some(filter_input) => {
360 if let Some(f) = &filter_input.predicate {
361 d.messages
362 .into_iter()
363 .filter(|message| f(message))
364 .collect::<Vec<Message>>()
365 } else {
366 d.messages
367 }
368 }
369
370 None => d.messages,
371 };
372
373 for message in messages {
374 if let OffsetSpecification::Offset(offset_) = self.0.offset_specification {
375 if offset_ > offset {
376 offset += 1;
377 continue;
378 }
379 }
380 let _ = self
381 .0
382 .sender
383 .send(Ok(Delivery {
384 name: self.0.name.clone(),
385 stream: self.0.stream.clone(),
386 subscription_id: self.0.subscription_id,
387 message,
388 offset,
389 }))
390 .await;
391 offset += 1;
392 }
393
394 let _ = self.0.client.credit(self.0.subscription_id, 1).await;
396 self.0.metrics_collector.consume(len as u64).await;
397 } else if let ResponseKind::ConsumerUpdate(consumer_update) = response.kind_ref() {
398 trace!("Received a ConsumerUpdate message");
399 if self.0.consumer_update_listener.is_none() {
402 trace!("User defined callback is not provided");
403 let offset_specification = OffsetSpecification::Next;
404 let _ = self
405 .0
406 .client
407 .consumer_update(
408 consumer_update.get_correlation_id(),
409 offset_specification,
410 )
411 .await;
412 } else {
413 let is_active = consumer_update.is_active();
415 let message_context = MessageContext {
416 name: self.0.name.clone().unwrap(),
417 stream: self.0.stream.clone(),
418 client: self.0.client.clone(),
419 };
420 let consumer_update_listener_callback =
421 self.0.consumer_update_listener.clone().unwrap();
422 let offset_specification =
423 consumer_update_listener_callback(is_active, message_context).await;
424 let _ = self
425 .0
426 .client
427 .consumer_update(
428 consumer_update.get_correlation_id(),
429 offset_specification,
430 )
431 .await;
432 }
433 }
434 }
435 Some(Err(err)) => {
436 let _ = self.0.sender.send(Err(err.into())).await;
437 }
438 None => {
439 trace!("Closing consumer");
440 self.0.closed.store(true, Relaxed);
441 self.0.waker.wake();
442 }
443 }
444 Ok(())
445 }
446}
447
448#[derive(Debug)]
450pub struct Delivery {
451 name: Option<String>,
452 stream: String,
453 subscription_id: u8,
454 message: Message,
455 offset: u64,
456}
457
458impl Delivery {
459 pub fn subscription_id(&self) -> u8 {
461 self.subscription_id
462 }
463
464 pub fn stream(&self) -> &String {
466 &self.stream
467 }
468
469 pub fn message(&self) -> &Message {
471 &self.message
472 }
473
474 pub fn offset(&self) -> u64 {
476 self.offset
477 }
478
479 pub fn consumer_name(&self) -> Option<String> {
480 self.name.clone()
481 }
482}