rabbitmq_stream_client/
producer.rs

1use futures::executor::block_on;
2use std::future::Future;
3use std::time::Duration;
4use std::{
5    marker::PhantomData,
6    sync::{
7        atomic::{AtomicBool, AtomicU64, Ordering},
8        Arc,
9    },
10};
11
12use dashmap::DashMap;
13use futures::{future::BoxFuture, FutureExt};
14use tokio::sync::mpsc::channel;
15use tokio::sync::{mpsc, RwLock};
16use tokio::time::sleep;
17use tracing::{error, info, trace, warn};
18
19use rabbitmq_stream_protocol::{message::Message, ResponseCode, ResponseKind};
20
21use crate::client::ClientMessage;
22use crate::MetricsCollector;
23use crate::{client::MessageHandler, RabbitMQStreamResult};
24use crate::{
25    client::{Client, MessageResult},
26    environment::Environment,
27    error::{ClientError, ProducerCloseError, ProducerCreateError, ProducerPublishError},
28};
29
30type WaiterMap = Arc<DashMap<u64, (ClientMessage, ProducerMessageWaiter)>>;
31type FilterValueExtractor = Arc<dyn Fn(&Message) -> String + 'static + Send + Sync>;
32
33#[derive(Debug)]
34pub struct ConfirmationStatus {
35    publishing_id: u64,
36    confirmed: bool,
37    status: ResponseCode,
38    message: Message,
39}
40
41impl ConfirmationStatus {
42    /// Get a reference to the confirmation status's confirmed.
43    pub fn confirmed(&self) -> bool {
44        self.confirmed
45    }
46
47    /// Get a reference to the confirmation status's publishing id.
48    pub fn publishing_id(&self) -> u64 {
49        self.publishing_id
50    }
51
52    /// Get a reference to the confirmation status's status.
53    pub fn status(&self) -> &ResponseCode {
54        &self.status
55    }
56
57    /// Get a reference to the confirmation status's message.
58    pub fn message(&self) -> &Message {
59        &self.message
60    }
61}
62
63pub struct ProducerInternal {
64    client: Arc<Client>,
65    stream: String,
66    producer_id: u8,
67    publish_sequence: Arc<AtomicU64>,
68    waiting_confirmations: WaiterMap,
69    closed: Arc<AtomicBool>,
70    sender: mpsc::Sender<ClientMessage>,
71    filter_value_extractor: Option<FilterValueExtractor>,
72    on_closed: Arc<RwLock<Option<Box<dyn OnClosed + Send + Sync>>>>,
73}
74
75impl Drop for ProducerInternal {
76    fn drop(&mut self) {
77        block_on(async {
78            if let Err(e) = self.close().await {
79                error!(error = ?e, "Error closing producer");
80            }
81        });
82    }
83}
84
85impl ProducerInternal {
86    pub async fn close(&self) -> Result<(), ProducerCloseError> {
87        match self
88            .closed
89            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
90        {
91            Ok(false) => {
92                let response = self.client.delete_publisher(self.producer_id).await?;
93                if response.is_ok() {
94                    self.client.close().await?;
95                    Ok(())
96                } else {
97                    Err(ProducerCloseError::Close {
98                        status: response.code().clone(),
99                        stream: self.stream.clone(),
100                    })
101                }
102            }
103            _ => Ok(()), // Already closed
104        }
105    }
106}
107
108/// API for publising messages to RabbitMQ stream
109#[derive(Clone)]
110pub struct Producer<T>(Arc<ProducerInternal>, PhantomData<T>);
111
112/// Builder for [`Producer`]
113pub struct ProducerBuilder<T> {
114    pub(crate) environment: Environment,
115    pub(crate) name: Option<String>,
116    pub batch_size: usize,
117    pub(crate) data: PhantomData<T>,
118    pub filter_value_extractor: Option<FilterValueExtractor>,
119    pub(crate) client_provided_name: String,
120    pub(crate) on_closed: Option<Box<dyn OnClosed + Send + Sync>>,
121    pub(crate) overwrite_heartbeat: Option<u32>,
122}
123
124#[derive(Clone)]
125pub struct NoDedup {}
126
127pub struct Dedup {}
128
129impl<T> ProducerBuilder<T> {
130    pub async fn build(self, stream: &str) -> Result<Producer<T>, ProducerCreateError> {
131        // Connect to the user specified node first, then look for the stream leader.
132        // The leader is the recommended node for writing, because writing to a replica will redundantly pass these messages
133        // to the leader anyway - it is the only one capable of writing.
134
135        let metrics_collector = self.environment.options.client_options.collector.clone();
136
137        let client = self
138            .environment
139            .create_producer_client(stream, self.client_provided_name.clone())
140            .await?;
141
142        if let Some(heartbeat) = self.overwrite_heartbeat {
143            client.set_heartbeat(heartbeat).await;
144        }
145
146        let mut publish_version = 1;
147
148        if self.filter_value_extractor.is_some() {
149            if client.filtering_supported() {
150                publish_version = 2
151            } else {
152                return Err(ProducerCreateError::FilteringNotSupport);
153            }
154        }
155
156        let on_closed = Arc::new(RwLock::new(self.on_closed));
157
158        let waiting_confirmations: WaiterMap = Arc::new(DashMap::new());
159
160        let confirm_handler = ProducerConfirmHandler {
161            waiting_confirmations: waiting_confirmations.clone(),
162            metrics_collector,
163            on_closed: on_closed.clone(),
164        };
165
166        client.set_handler(confirm_handler).await;
167
168        let producer_id = 1;
169        let response = client
170            .declare_publisher(producer_id, self.name.clone(), stream)
171            .await?;
172
173        let publish_sequence = if let Some(name) = self.name {
174            let sequence = client.query_publisher_sequence(&name, stream).await?;
175
176            let first_sequence = if sequence == 0 { 0 } else { sequence + 1 };
177
178            Arc::new(AtomicU64::new(first_sequence))
179        } else {
180            Arc::new(AtomicU64::new(0))
181        };
182
183        if response.is_ok() {
184            let (sender, receiver) = mpsc::channel(self.batch_size);
185
186            let client = Arc::new(client);
187            let producer = ProducerInternal {
188                producer_id,
189                stream: stream.to_string(),
190                client,
191                publish_sequence,
192                waiting_confirmations,
193                closed: Arc::new(AtomicBool::new(false)),
194                sender,
195                filter_value_extractor: self.filter_value_extractor,
196                on_closed,
197            };
198
199            let internal_producer = Arc::new(producer);
200            schedule_batch_send(
201                self.batch_size,
202                receiver,
203                internal_producer.client.clone(),
204                producer_id,
205                publish_version,
206            );
207            let producer = Producer(internal_producer, PhantomData);
208
209            Ok(producer)
210        } else {
211            Err(ProducerCreateError::Create {
212                stream: stream.to_owned(),
213                status: response.code().clone(),
214            })
215        }
216    }
217
218    pub fn on_closed(mut self, on_closed: Box<dyn OnClosed + Send + Sync>) -> ProducerBuilder<T> {
219        self.on_closed = Some(on_closed);
220        self
221    }
222
223    pub fn batch_size(mut self, batch_size: usize) -> Self {
224        self.batch_size = batch_size;
225        self
226    }
227
228    /// Don't use this in production, it is only for testing purposes.
229    pub fn overwrite_heartbeat(mut self, heartbeat: u32) -> ProducerBuilder<T> {
230        self.overwrite_heartbeat = Some(heartbeat);
231        self
232    }
233
234    pub fn client_provided_name(mut self, name: &str) -> Self {
235        self.client_provided_name = String::from(name);
236        self
237    }
238
239    pub fn name(mut self, name: &str) -> ProducerBuilder<Dedup> {
240        self.name = Some(name.to_owned());
241        ProducerBuilder {
242            environment: self.environment,
243            name: self.name,
244            batch_size: self.batch_size,
245            data: PhantomData,
246            filter_value_extractor: None,
247            client_provided_name: String::from("rust-stream-producer"),
248            on_closed: self.on_closed,
249            overwrite_heartbeat: None,
250        }
251    }
252
253    pub fn filter_value_extractor(
254        mut self,
255        filter_value_extractor: impl Fn(&Message) -> String + Send + Sync + 'static,
256    ) -> Self {
257        let f = Arc::new(filter_value_extractor);
258        self.filter_value_extractor = Some(f);
259        self
260    }
261
262    pub fn filter_value_extractor_arc(
263        mut self,
264        filter_value_extractor: Option<FilterValueExtractor>,
265    ) -> Self {
266        self.filter_value_extractor = filter_value_extractor;
267        self
268    }
269}
270
271fn schedule_batch_send(
272    batch_size: usize,
273    mut receiver: mpsc::Receiver<ClientMessage>,
274    client: Arc<Client>,
275    producer_id: u8,
276    publish_version: u16,
277) {
278    tokio::task::spawn(async move {
279        let mut buffer = Vec::with_capacity(batch_size);
280        loop {
281            let count = receiver.recv_many(&mut buffer, batch_size).await;
282
283            if count == 0 || buffer.is_empty() {
284                // Channel is closed, exit the loop
285                break;
286            }
287
288            let messages: Vec<_> = buffer.drain(..count).collect();
289            match client.publish(producer_id, messages, publish_version).await {
290                Ok(_) => {}
291                Err(e) => {
292                    error!("Error publishing batch {:?}", e);
293
294                    // If the underlying error is a broken pipe, we can assume the connection is closed
295                    // In fact, BorkenPipe is not recoverable, so we can exit the loop.
296                    // This will close the receiver, so, the next time a send is called, it will return an error.
297                    if matches!(e, ClientError::Io(e) if e.kind() == std::io::ErrorKind::BrokenPipe)
298                    {
299                        // If the error is a broken pipe, we can assume the connection is closed
300                        break;
301                    }
302                }
303            };
304        }
305
306        info!("Batch send task finished");
307    });
308}
309
310impl Producer<NoDedup> {
311    pub async fn send_with_confirm(
312        &self,
313        message: Message,
314    ) -> Result<ConfirmationStatus, ProducerPublishError> {
315        self.do_send_with_confirm(message).await
316    }
317    pub async fn batch_send_with_confirm(
318        &self,
319        messages: Vec<Message>,
320    ) -> Result<Vec<ConfirmationStatus>, ProducerPublishError> {
321        self.do_batch_send_with_confirm(messages).await
322    }
323    pub async fn batch_send<Fut>(
324        &self,
325        messages: Vec<Message>,
326        cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
327    ) -> Result<(), ProducerPublishError>
328    where
329        Fut: Future<Output = ()> + Send + Sync + 'static,
330    {
331        self.do_batch_send(messages, cb).await
332    }
333
334    pub async fn send<Fut>(
335        &self,
336        message: Message,
337        cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
338    ) -> Result<(), ProducerPublishError>
339    where
340        Fut: Future<Output = ()> + Send + Sync + 'static,
341    {
342        self.do_send(message, cb).await
343    }
344}
345
346impl Producer<Dedup> {
347    pub async fn send_with_confirm(
348        &mut self,
349        message: Message,
350    ) -> Result<ConfirmationStatus, ProducerPublishError> {
351        self.do_send_with_confirm(message).await
352    }
353    pub async fn batch_send_with_confirm(
354        &mut self,
355        messages: Vec<Message>,
356    ) -> Result<Vec<ConfirmationStatus>, ProducerPublishError> {
357        self.do_batch_send_with_confirm(messages).await
358    }
359    pub async fn batch_send<Fut>(
360        &mut self,
361        messages: Vec<Message>,
362        cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
363    ) -> Result<(), ProducerPublishError>
364    where
365        Fut: Future<Output = ()> + Send + Sync + 'static,
366    {
367        self.do_batch_send(messages, cb).await
368    }
369
370    pub async fn send<Fut>(
371        &mut self,
372        message: Message,
373        cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
374    ) -> Result<(), ProducerPublishError>
375    where
376        Fut: Future<Output = ()> + Send + Sync + 'static,
377    {
378        self.do_send(message, cb).await
379    }
380}
381
382impl<T> Producer<T> {
383    async fn do_send_with_confirm(
384        &self,
385        message: Message,
386    ) -> Result<ConfirmationStatus, ProducerPublishError> {
387        let (tx, mut rx) = channel(1);
388        self.internal_send(message, move |status| {
389            let cloned = tx.clone();
390            async move {
391                let _ = cloned.send(status).await;
392            }
393        })
394        .await?;
395
396        let r = tokio::select! {
397            val = rx.recv() => {
398                Ok(val)
399            }
400            _ = sleep(Duration::from_secs(1)) => {
401                Err(ProducerPublishError::Timeout)
402            }
403        }?;
404        r.ok_or_else(|| ProducerPublishError::Confirmation {
405            stream: self.0.stream.clone(),
406        })?
407        .map_err(|err| ClientError::GenericError(Box::new(err)))
408        .map(Ok)?
409    }
410
411    async fn do_batch_send_with_confirm(
412        &self,
413        messages: Vec<Message>,
414    ) -> Result<Vec<ConfirmationStatus>, ProducerPublishError> {
415        let messages_len = messages.len();
416        let (tx, mut rx) = channel(messages_len);
417
418        self.internal_batch_send(messages, move |status| {
419            let cloned = tx.clone();
420            async move {
421                let _ = cloned.send(status).await;
422            }
423        })
424        .await?;
425
426        let mut confirmations = Vec::with_capacity(messages_len);
427
428        while let Some(confirmation) = rx.recv().await {
429            confirmations.push(confirmation?);
430        }
431
432        Ok(confirmations)
433    }
434    async fn do_batch_send<Fut>(
435        &self,
436        messages: Vec<Message>,
437        cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
438    ) -> Result<(), ProducerPublishError>
439    where
440        Fut: Future<Output = ()> + Send + Sync + 'static,
441    {
442        self.internal_batch_send(messages, cb).await?;
443
444        Ok(())
445    }
446
447    async fn do_send<Fut>(
448        &self,
449        message: Message,
450        cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
451    ) -> Result<(), ProducerPublishError>
452    where
453        Fut: Future<Output = ()> + Send + Sync + 'static,
454    {
455        self.internal_send(message, cb).await?;
456        Ok(())
457    }
458
459    async fn internal_send<Fut>(
460        &self,
461        message: Message,
462        cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
463    ) -> Result<(), ProducerPublishError>
464    where
465        Fut: Future<Output = ()> + Send + Sync + 'static,
466    {
467        if self.is_closed() {
468            return Err(ProducerPublishError::Closed);
469        }
470        let publishing_id = match message.publishing_id() {
471            Some(publishing_id) => *publishing_id,
472            None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed),
473        };
474        let mut msg = ClientMessage::new(publishing_id, message.clone(), None);
475
476        if let Some(f) = self.0.filter_value_extractor.as_ref() {
477            msg.filter_value_extract(f.as_ref())
478        }
479
480        let waiter = OnceProducerMessageWaiter::waiter_with_cb(cb, message);
481        self.0.waiting_confirmations.insert(
482            publishing_id,
483            (msg.clone(), ProducerMessageWaiter::Once(waiter)),
484        );
485
486        if let Err(e) = self.0.sender.send(msg).await {
487            // `send` fails only when the receiver is closed, which means the TCP connection is broken.
488            // In this case, we forcefully close the producer and return an error.
489            // The current message will not be sent, but it is not lost:
490            // `on_closed` handler will be called, and it can resend the message if needed.
491            if let Err(err) = self.0.close().await {
492                error!(error = ?err, "Failed to close producer after send error");
493            }
494            return Err(ClientError::GenericError(Box::new(e)))?;
495        }
496
497        Ok(())
498    }
499
500    async fn internal_batch_send<Fut>(
501        &self,
502        messages: Vec<Message>,
503        cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
504    ) -> Result<(), ProducerPublishError>
505    where
506        Fut: Future<Output = ()> + Send + Sync + 'static,
507    {
508        if self.is_closed() {
509            return Err(ProducerPublishError::Closed);
510        }
511
512        let arc_cb = Arc::new(move |status| cb(status).boxed());
513
514        for message in messages {
515            let waiter =
516                SharedProducerMessageWaiter::waiter_with_arc_cb(arc_cb.clone(), message.clone());
517
518            let publishing_id = match message.publishing_id() {
519                Some(publishing_id) => *publishing_id,
520                None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed),
521            };
522
523            let mut client_message = ClientMessage::new(publishing_id, message, None);
524            if let Some(f) = self.0.filter_value_extractor.as_ref() {
525                client_message.filter_value_extract(f.as_ref())
526            }
527
528            self.0.waiting_confirmations.insert(
529                publishing_id,
530                (
531                    client_message.clone(),
532                    ProducerMessageWaiter::Shared(waiter.clone()),
533                ),
534            );
535
536            // Queue the message for sending
537            if let Err(e) = self.0.sender.send(client_message).await {
538                return Err(ClientError::GenericError(Box::new(e)))?;
539            }
540        }
541
542        Ok(())
543    }
544
545    pub fn is_closed(&self) -> bool {
546        self.0.closed.load(Ordering::Relaxed)
547    }
548
549    pub async fn close(self) -> Result<(), ProducerCloseError> {
550        self.0.close().await
551    }
552
553    pub async fn set_on_closed(&self, on_closed: Box<dyn OnClosed + Send + Sync>) {
554        let mut on_closed_lock = self.0.on_closed.write().await;
555        *on_closed_lock = Some(on_closed);
556    }
557}
558
559#[async_trait::async_trait]
560pub trait OnClosed {
561    async fn on_closed(&self, unconfirmed: Vec<Message>);
562}
563
564struct ProducerConfirmHandler {
565    waiting_confirmations: WaiterMap,
566    metrics_collector: Arc<dyn MetricsCollector>,
567    on_closed: Arc<RwLock<Option<Box<dyn OnClosed + Send + Sync>>>>,
568}
569
570#[async_trait::async_trait]
571impl MessageHandler for ProducerConfirmHandler {
572    async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
573        match item {
574            Some(Ok(response)) => {
575                match response.kind() {
576                    ResponseKind::PublishConfirm(confirm) => {
577                        trace!("Got publish_confirm for {:?}", confirm.publishing_ids);
578                        let confirm_len = confirm.publishing_ids.len();
579                        for publishing_id in &confirm.publishing_ids {
580                            let id = *publishing_id;
581
582                            let (_, waiter) = match self.waiting_confirmations.remove(publishing_id)
583                            {
584                                Some((_, confirm_sender)) => confirm_sender,
585                                None => todo!(),
586                            };
587                            match waiter {
588                                ProducerMessageWaiter::Once(waiter) => {
589                                    invoke_handler_once(
590                                        waiter.cb,
591                                        id,
592                                        true,
593                                        ResponseCode::Ok,
594                                        waiter.msg,
595                                    )
596                                    .await;
597                                }
598                                ProducerMessageWaiter::Shared(waiter) => {
599                                    invoke_handler(
600                                        waiter.cb,
601                                        id,
602                                        true,
603                                        ResponseCode::Ok,
604                                        waiter.msg,
605                                    )
606                                    .await;
607                                }
608                            }
609                        }
610                        self.metrics_collector
611                            .publish_confirm(confirm_len as u64)
612                            .await;
613                    }
614                    ResponseKind::PublishError(error) => {
615                        trace!("Got publish_error  {:?}", error);
616                        for err in &error.publishing_errors {
617                            let code = err.error_code.clone();
618                            let id = err.publishing_id;
619
620                            let (_, waiter) = match self.waiting_confirmations.remove(&id) {
621                                Some((_, confirm_sender)) => confirm_sender,
622                                None => todo!(),
623                            };
624                            match waiter {
625                                ProducerMessageWaiter::Once(waiter) => {
626                                    invoke_handler_once(waiter.cb, id, false, code, waiter.msg)
627                                        .await;
628                                }
629                                ProducerMessageWaiter::Shared(waiter) => {
630                                    invoke_handler(waiter.cb, id, false, code, waiter.msg).await;
631                                }
632                            }
633                        }
634                    }
635                    _ => {}
636                };
637            }
638            Some(Err(error)) => {
639                trace!(?error);
640                // TODO clean all waiting for confirm
641            }
642            None => {
643                info!("Connection closed");
644                let on_closed = self.on_closed.read().await;
645                if let Some(on_close) = &*on_closed {
646                    let mut unconfirmed: Vec<(u64, Message)> = self
647                        .waiting_confirmations
648                        .iter()
649                        .map(|entry| (*entry.key(), entry.value().0.clone().into_message()))
650                        .collect();
651                    unconfirmed.sort_by_key(|(id, _)| *id);
652
653                    let unconfirmed: Vec<Message> =
654                        unconfirmed.into_iter().map(|(_, msg)| msg).collect();
655
656                    on_close.on_closed(unconfirmed).await;
657                } else {
658                    warn!("No on_closed handler set, unconfirmed messages will be lost.");
659                }
660            }
661        }
662        Ok(())
663    }
664}
665
666async fn invoke_handler(
667    f: ArcConfirmCallback,
668    publishing_id: u64,
669    confirmed: bool,
670    status: ResponseCode,
671    message: Message,
672) {
673    f(Ok(ConfirmationStatus {
674        publishing_id,
675        confirmed,
676        status,
677        message,
678    }))
679    .await;
680}
681async fn invoke_handler_once(
682    f: ConfirmCallback,
683    publishing_id: u64,
684    confirmed: bool,
685    status: ResponseCode,
686    message: Message,
687) {
688    f(Ok(ConfirmationStatus {
689        publishing_id,
690        confirmed,
691        status,
692        message,
693    }))
694    .await;
695}
696
697type ConfirmCallback = Box<
698    dyn FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> BoxFuture<'static, ()>
699        + Send
700        + Sync,
701>;
702
703type ArcConfirmCallback = Arc<
704    dyn Fn(Result<ConfirmationStatus, ProducerPublishError>) -> BoxFuture<'static, ()>
705        + Send
706        + Sync,
707>;
708
709enum ProducerMessageWaiter {
710    Once(OnceProducerMessageWaiter),
711    Shared(SharedProducerMessageWaiter),
712}
713
714struct OnceProducerMessageWaiter {
715    cb: ConfirmCallback,
716    msg: Message,
717}
718impl OnceProducerMessageWaiter {
719    fn waiter_with_cb<Fut>(
720        cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
721        msg: Message,
722    ) -> Self
723    where
724        Fut: Future<Output = ()> + Send + Sync + 'static,
725    {
726        Self {
727            cb: Box::new(move |confirm_status| cb(confirm_status).boxed()),
728            msg,
729        }
730    }
731}
732
733#[derive(Clone)]
734struct SharedProducerMessageWaiter {
735    cb: ArcConfirmCallback,
736    msg: Message,
737}
738
739impl SharedProducerMessageWaiter {
740    fn waiter_with_arc_cb(confirm_callback: ArcConfirmCallback, msg: Message) -> Self {
741        Self {
742            cb: confirm_callback,
743            msg,
744        }
745    }
746}