1use futures::executor::block_on;
2use std::future::Future;
3use std::time::Duration;
4use std::{
5 marker::PhantomData,
6 sync::{
7 atomic::{AtomicBool, AtomicU64, Ordering},
8 Arc,
9 },
10};
11
12use dashmap::DashMap;
13use futures::{future::BoxFuture, FutureExt};
14use tokio::sync::mpsc::channel;
15use tokio::sync::{mpsc, RwLock};
16use tokio::time::sleep;
17use tracing::{error, info, trace, warn};
18
19use rabbitmq_stream_protocol::{message::Message, ResponseCode, ResponseKind};
20
21use crate::client::ClientMessage;
22use crate::MetricsCollector;
23use crate::{client::MessageHandler, RabbitMQStreamResult};
24use crate::{
25 client::{Client, MessageResult},
26 environment::Environment,
27 error::{ClientError, ProducerCloseError, ProducerCreateError, ProducerPublishError},
28};
29
30type WaiterMap = Arc<DashMap<u64, (ClientMessage, ProducerMessageWaiter)>>;
31type FilterValueExtractor = Arc<dyn Fn(&Message) -> String + 'static + Send + Sync>;
32
33#[derive(Debug)]
34pub struct ConfirmationStatus {
35 publishing_id: u64,
36 confirmed: bool,
37 status: ResponseCode,
38 message: Message,
39}
40
41impl ConfirmationStatus {
42 pub fn confirmed(&self) -> bool {
44 self.confirmed
45 }
46
47 pub fn publishing_id(&self) -> u64 {
49 self.publishing_id
50 }
51
52 pub fn status(&self) -> &ResponseCode {
54 &self.status
55 }
56
57 pub fn message(&self) -> &Message {
59 &self.message
60 }
61}
62
63pub struct ProducerInternal {
64 client: Arc<Client>,
65 stream: String,
66 producer_id: u8,
67 publish_sequence: Arc<AtomicU64>,
68 waiting_confirmations: WaiterMap,
69 closed: Arc<AtomicBool>,
70 sender: mpsc::Sender<ClientMessage>,
71 filter_value_extractor: Option<FilterValueExtractor>,
72 on_closed: Arc<RwLock<Option<Box<dyn OnClosed + Send + Sync>>>>,
73}
74
75impl Drop for ProducerInternal {
76 fn drop(&mut self) {
77 block_on(async {
78 if let Err(e) = self.close().await {
79 error!(error = ?e, "Error closing producer");
80 }
81 });
82 }
83}
84
85impl ProducerInternal {
86 pub async fn close(&self) -> Result<(), ProducerCloseError> {
87 match self
88 .closed
89 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
90 {
91 Ok(false) => {
92 let response = self.client.delete_publisher(self.producer_id).await?;
93 if response.is_ok() {
94 self.client.close().await?;
95 Ok(())
96 } else {
97 Err(ProducerCloseError::Close {
98 status: response.code().clone(),
99 stream: self.stream.clone(),
100 })
101 }
102 }
103 _ => Ok(()), }
105 }
106}
107
108#[derive(Clone)]
110pub struct Producer<T>(Arc<ProducerInternal>, PhantomData<T>);
111
112pub struct ProducerBuilder<T> {
114 pub(crate) environment: Environment,
115 pub(crate) name: Option<String>,
116 pub batch_size: usize,
117 pub(crate) data: PhantomData<T>,
118 pub filter_value_extractor: Option<FilterValueExtractor>,
119 pub(crate) client_provided_name: String,
120 pub(crate) on_closed: Option<Box<dyn OnClosed + Send + Sync>>,
121 pub(crate) overwrite_heartbeat: Option<u32>,
122}
123
124#[derive(Clone)]
125pub struct NoDedup {}
126
127pub struct Dedup {}
128
129impl<T> ProducerBuilder<T> {
130 pub async fn build(self, stream: &str) -> Result<Producer<T>, ProducerCreateError> {
131 let metrics_collector = self.environment.options.client_options.collector.clone();
136
137 let client = self
138 .environment
139 .create_producer_client(stream, self.client_provided_name.clone())
140 .await?;
141
142 if let Some(heartbeat) = self.overwrite_heartbeat {
143 client.set_heartbeat(heartbeat).await;
144 }
145
146 let mut publish_version = 1;
147
148 if self.filter_value_extractor.is_some() {
149 if client.filtering_supported() {
150 publish_version = 2
151 } else {
152 return Err(ProducerCreateError::FilteringNotSupport);
153 }
154 }
155
156 let on_closed = Arc::new(RwLock::new(self.on_closed));
157
158 let waiting_confirmations: WaiterMap = Arc::new(DashMap::new());
159
160 let confirm_handler = ProducerConfirmHandler {
161 waiting_confirmations: waiting_confirmations.clone(),
162 metrics_collector,
163 on_closed: on_closed.clone(),
164 };
165
166 client.set_handler(confirm_handler).await;
167
168 let producer_id = 1;
169 let response = client
170 .declare_publisher(producer_id, self.name.clone(), stream)
171 .await?;
172
173 let publish_sequence = if let Some(name) = self.name {
174 let sequence = client.query_publisher_sequence(&name, stream).await?;
175
176 let first_sequence = if sequence == 0 { 0 } else { sequence + 1 };
177
178 Arc::new(AtomicU64::new(first_sequence))
179 } else {
180 Arc::new(AtomicU64::new(0))
181 };
182
183 if response.is_ok() {
184 let (sender, receiver) = mpsc::channel(self.batch_size);
185
186 let client = Arc::new(client);
187 let producer = ProducerInternal {
188 producer_id,
189 stream: stream.to_string(),
190 client,
191 publish_sequence,
192 waiting_confirmations,
193 closed: Arc::new(AtomicBool::new(false)),
194 sender,
195 filter_value_extractor: self.filter_value_extractor,
196 on_closed,
197 };
198
199 let internal_producer = Arc::new(producer);
200 schedule_batch_send(
201 self.batch_size,
202 receiver,
203 internal_producer.client.clone(),
204 producer_id,
205 publish_version,
206 );
207 let producer = Producer(internal_producer, PhantomData);
208
209 Ok(producer)
210 } else {
211 Err(ProducerCreateError::Create {
212 stream: stream.to_owned(),
213 status: response.code().clone(),
214 })
215 }
216 }
217
218 pub fn on_closed(mut self, on_closed: Box<dyn OnClosed + Send + Sync>) -> ProducerBuilder<T> {
219 self.on_closed = Some(on_closed);
220 self
221 }
222
223 pub fn batch_size(mut self, batch_size: usize) -> Self {
224 self.batch_size = batch_size;
225 self
226 }
227
228 pub fn overwrite_heartbeat(mut self, heartbeat: u32) -> ProducerBuilder<T> {
230 self.overwrite_heartbeat = Some(heartbeat);
231 self
232 }
233
234 pub fn client_provided_name(mut self, name: &str) -> Self {
235 self.client_provided_name = String::from(name);
236 self
237 }
238
239 pub fn name(mut self, name: &str) -> ProducerBuilder<Dedup> {
240 self.name = Some(name.to_owned());
241 ProducerBuilder {
242 environment: self.environment,
243 name: self.name,
244 batch_size: self.batch_size,
245 data: PhantomData,
246 filter_value_extractor: None,
247 client_provided_name: String::from("rust-stream-producer"),
248 on_closed: self.on_closed,
249 overwrite_heartbeat: None,
250 }
251 }
252
253 pub fn filter_value_extractor(
254 mut self,
255 filter_value_extractor: impl Fn(&Message) -> String + Send + Sync + 'static,
256 ) -> Self {
257 let f = Arc::new(filter_value_extractor);
258 self.filter_value_extractor = Some(f);
259 self
260 }
261
262 pub fn filter_value_extractor_arc(
263 mut self,
264 filter_value_extractor: Option<FilterValueExtractor>,
265 ) -> Self {
266 self.filter_value_extractor = filter_value_extractor;
267 self
268 }
269}
270
271fn schedule_batch_send(
272 batch_size: usize,
273 mut receiver: mpsc::Receiver<ClientMessage>,
274 client: Arc<Client>,
275 producer_id: u8,
276 publish_version: u16,
277) {
278 tokio::task::spawn(async move {
279 let mut buffer = Vec::with_capacity(batch_size);
280 loop {
281 let count = receiver.recv_many(&mut buffer, batch_size).await;
282
283 if count == 0 || buffer.is_empty() {
284 break;
286 }
287
288 let messages: Vec<_> = buffer.drain(..count).collect();
289 match client.publish(producer_id, messages, publish_version).await {
290 Ok(_) => {}
291 Err(e) => {
292 error!("Error publishing batch {:?}", e);
293
294 if matches!(e, ClientError::Io(e) if e.kind() == std::io::ErrorKind::BrokenPipe)
298 {
299 break;
301 }
302 }
303 };
304 }
305
306 info!("Batch send task finished");
307 });
308}
309
310impl Producer<NoDedup> {
311 pub async fn send_with_confirm(
312 &self,
313 message: Message,
314 ) -> Result<ConfirmationStatus, ProducerPublishError> {
315 self.do_send_with_confirm(message).await
316 }
317 pub async fn batch_send_with_confirm(
318 &self,
319 messages: Vec<Message>,
320 ) -> Result<Vec<ConfirmationStatus>, ProducerPublishError> {
321 self.do_batch_send_with_confirm(messages).await
322 }
323 pub async fn batch_send<Fut>(
324 &self,
325 messages: Vec<Message>,
326 cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
327 ) -> Result<(), ProducerPublishError>
328 where
329 Fut: Future<Output = ()> + Send + Sync + 'static,
330 {
331 self.do_batch_send(messages, cb).await
332 }
333
334 pub async fn send<Fut>(
335 &self,
336 message: Message,
337 cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
338 ) -> Result<(), ProducerPublishError>
339 where
340 Fut: Future<Output = ()> + Send + Sync + 'static,
341 {
342 self.do_send(message, cb).await
343 }
344}
345
346impl Producer<Dedup> {
347 pub async fn send_with_confirm(
348 &mut self,
349 message: Message,
350 ) -> Result<ConfirmationStatus, ProducerPublishError> {
351 self.do_send_with_confirm(message).await
352 }
353 pub async fn batch_send_with_confirm(
354 &mut self,
355 messages: Vec<Message>,
356 ) -> Result<Vec<ConfirmationStatus>, ProducerPublishError> {
357 self.do_batch_send_with_confirm(messages).await
358 }
359 pub async fn batch_send<Fut>(
360 &mut self,
361 messages: Vec<Message>,
362 cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
363 ) -> Result<(), ProducerPublishError>
364 where
365 Fut: Future<Output = ()> + Send + Sync + 'static,
366 {
367 self.do_batch_send(messages, cb).await
368 }
369
370 pub async fn send<Fut>(
371 &mut self,
372 message: Message,
373 cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
374 ) -> Result<(), ProducerPublishError>
375 where
376 Fut: Future<Output = ()> + Send + Sync + 'static,
377 {
378 self.do_send(message, cb).await
379 }
380}
381
382impl<T> Producer<T> {
383 async fn do_send_with_confirm(
384 &self,
385 message: Message,
386 ) -> Result<ConfirmationStatus, ProducerPublishError> {
387 let (tx, mut rx) = channel(1);
388 self.internal_send(message, move |status| {
389 let cloned = tx.clone();
390 async move {
391 let _ = cloned.send(status).await;
392 }
393 })
394 .await?;
395
396 let r = tokio::select! {
397 val = rx.recv() => {
398 Ok(val)
399 }
400 _ = sleep(Duration::from_secs(1)) => {
401 Err(ProducerPublishError::Timeout)
402 }
403 }?;
404 r.ok_or_else(|| ProducerPublishError::Confirmation {
405 stream: self.0.stream.clone(),
406 })?
407 .map_err(|err| ClientError::GenericError(Box::new(err)))
408 .map(Ok)?
409 }
410
411 async fn do_batch_send_with_confirm(
412 &self,
413 messages: Vec<Message>,
414 ) -> Result<Vec<ConfirmationStatus>, ProducerPublishError> {
415 let messages_len = messages.len();
416 let (tx, mut rx) = channel(messages_len);
417
418 self.internal_batch_send(messages, move |status| {
419 let cloned = tx.clone();
420 async move {
421 let _ = cloned.send(status).await;
422 }
423 })
424 .await?;
425
426 let mut confirmations = Vec::with_capacity(messages_len);
427
428 while let Some(confirmation) = rx.recv().await {
429 confirmations.push(confirmation?);
430 }
431
432 Ok(confirmations)
433 }
434 async fn do_batch_send<Fut>(
435 &self,
436 messages: Vec<Message>,
437 cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
438 ) -> Result<(), ProducerPublishError>
439 where
440 Fut: Future<Output = ()> + Send + Sync + 'static,
441 {
442 self.internal_batch_send(messages, cb).await?;
443
444 Ok(())
445 }
446
447 async fn do_send<Fut>(
448 &self,
449 message: Message,
450 cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
451 ) -> Result<(), ProducerPublishError>
452 where
453 Fut: Future<Output = ()> + Send + Sync + 'static,
454 {
455 self.internal_send(message, cb).await?;
456 Ok(())
457 }
458
459 async fn internal_send<Fut>(
460 &self,
461 message: Message,
462 cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
463 ) -> Result<(), ProducerPublishError>
464 where
465 Fut: Future<Output = ()> + Send + Sync + 'static,
466 {
467 if self.is_closed() {
468 return Err(ProducerPublishError::Closed);
469 }
470 let publishing_id = match message.publishing_id() {
471 Some(publishing_id) => *publishing_id,
472 None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed),
473 };
474 let mut msg = ClientMessage::new(publishing_id, message.clone(), None);
475
476 if let Some(f) = self.0.filter_value_extractor.as_ref() {
477 msg.filter_value_extract(f.as_ref())
478 }
479
480 let waiter = OnceProducerMessageWaiter::waiter_with_cb(cb, message);
481 self.0.waiting_confirmations.insert(
482 publishing_id,
483 (msg.clone(), ProducerMessageWaiter::Once(waiter)),
484 );
485
486 if let Err(e) = self.0.sender.send(msg).await {
487 if let Err(err) = self.0.close().await {
492 error!(error = ?err, "Failed to close producer after send error");
493 }
494 return Err(ClientError::GenericError(Box::new(e)))?;
495 }
496
497 Ok(())
498 }
499
500 async fn internal_batch_send<Fut>(
501 &self,
502 messages: Vec<Message>,
503 cb: impl Fn(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
504 ) -> Result<(), ProducerPublishError>
505 where
506 Fut: Future<Output = ()> + Send + Sync + 'static,
507 {
508 if self.is_closed() {
509 return Err(ProducerPublishError::Closed);
510 }
511
512 let arc_cb = Arc::new(move |status| cb(status).boxed());
513
514 for message in messages {
515 let waiter =
516 SharedProducerMessageWaiter::waiter_with_arc_cb(arc_cb.clone(), message.clone());
517
518 let publishing_id = match message.publishing_id() {
519 Some(publishing_id) => *publishing_id,
520 None => self.0.publish_sequence.fetch_add(1, Ordering::Relaxed),
521 };
522
523 let mut client_message = ClientMessage::new(publishing_id, message, None);
524 if let Some(f) = self.0.filter_value_extractor.as_ref() {
525 client_message.filter_value_extract(f.as_ref())
526 }
527
528 self.0.waiting_confirmations.insert(
529 publishing_id,
530 (
531 client_message.clone(),
532 ProducerMessageWaiter::Shared(waiter.clone()),
533 ),
534 );
535
536 if let Err(e) = self.0.sender.send(client_message).await {
538 return Err(ClientError::GenericError(Box::new(e)))?;
539 }
540 }
541
542 Ok(())
543 }
544
545 pub fn is_closed(&self) -> bool {
546 self.0.closed.load(Ordering::Relaxed)
547 }
548
549 pub async fn close(self) -> Result<(), ProducerCloseError> {
550 self.0.close().await
551 }
552
553 pub async fn set_on_closed(&self, on_closed: Box<dyn OnClosed + Send + Sync>) {
554 let mut on_closed_lock = self.0.on_closed.write().await;
555 *on_closed_lock = Some(on_closed);
556 }
557}
558
559#[async_trait::async_trait]
560pub trait OnClosed {
561 async fn on_closed(&self, unconfirmed: Vec<Message>);
562}
563
564struct ProducerConfirmHandler {
565 waiting_confirmations: WaiterMap,
566 metrics_collector: Arc<dyn MetricsCollector>,
567 on_closed: Arc<RwLock<Option<Box<dyn OnClosed + Send + Sync>>>>,
568}
569
570#[async_trait::async_trait]
571impl MessageHandler for ProducerConfirmHandler {
572 async fn handle_message(&self, item: MessageResult) -> RabbitMQStreamResult<()> {
573 match item {
574 Some(Ok(response)) => {
575 match response.kind() {
576 ResponseKind::PublishConfirm(confirm) => {
577 trace!("Got publish_confirm for {:?}", confirm.publishing_ids);
578 let confirm_len = confirm.publishing_ids.len();
579 for publishing_id in &confirm.publishing_ids {
580 let id = *publishing_id;
581
582 let (_, waiter) = match self.waiting_confirmations.remove(publishing_id)
583 {
584 Some((_, confirm_sender)) => confirm_sender,
585 None => todo!(),
586 };
587 match waiter {
588 ProducerMessageWaiter::Once(waiter) => {
589 invoke_handler_once(
590 waiter.cb,
591 id,
592 true,
593 ResponseCode::Ok,
594 waiter.msg,
595 )
596 .await;
597 }
598 ProducerMessageWaiter::Shared(waiter) => {
599 invoke_handler(
600 waiter.cb,
601 id,
602 true,
603 ResponseCode::Ok,
604 waiter.msg,
605 )
606 .await;
607 }
608 }
609 }
610 self.metrics_collector
611 .publish_confirm(confirm_len as u64)
612 .await;
613 }
614 ResponseKind::PublishError(error) => {
615 trace!("Got publish_error {:?}", error);
616 for err in &error.publishing_errors {
617 let code = err.error_code.clone();
618 let id = err.publishing_id;
619
620 let (_, waiter) = match self.waiting_confirmations.remove(&id) {
621 Some((_, confirm_sender)) => confirm_sender,
622 None => todo!(),
623 };
624 match waiter {
625 ProducerMessageWaiter::Once(waiter) => {
626 invoke_handler_once(waiter.cb, id, false, code, waiter.msg)
627 .await;
628 }
629 ProducerMessageWaiter::Shared(waiter) => {
630 invoke_handler(waiter.cb, id, false, code, waiter.msg).await;
631 }
632 }
633 }
634 }
635 _ => {}
636 };
637 }
638 Some(Err(error)) => {
639 trace!(?error);
640 }
642 None => {
643 info!("Connection closed");
644 let on_closed = self.on_closed.read().await;
645 if let Some(on_close) = &*on_closed {
646 let mut unconfirmed: Vec<(u64, Message)> = self
647 .waiting_confirmations
648 .iter()
649 .map(|entry| (*entry.key(), entry.value().0.clone().into_message()))
650 .collect();
651 unconfirmed.sort_by_key(|(id, _)| *id);
652
653 let unconfirmed: Vec<Message> =
654 unconfirmed.into_iter().map(|(_, msg)| msg).collect();
655
656 on_close.on_closed(unconfirmed).await;
657 } else {
658 warn!("No on_closed handler set, unconfirmed messages will be lost.");
659 }
660 }
661 }
662 Ok(())
663 }
664}
665
666async fn invoke_handler(
667 f: ArcConfirmCallback,
668 publishing_id: u64,
669 confirmed: bool,
670 status: ResponseCode,
671 message: Message,
672) {
673 f(Ok(ConfirmationStatus {
674 publishing_id,
675 confirmed,
676 status,
677 message,
678 }))
679 .await;
680}
681async fn invoke_handler_once(
682 f: ConfirmCallback,
683 publishing_id: u64,
684 confirmed: bool,
685 status: ResponseCode,
686 message: Message,
687) {
688 f(Ok(ConfirmationStatus {
689 publishing_id,
690 confirmed,
691 status,
692 message,
693 }))
694 .await;
695}
696
697type ConfirmCallback = Box<
698 dyn FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> BoxFuture<'static, ()>
699 + Send
700 + Sync,
701>;
702
703type ArcConfirmCallback = Arc<
704 dyn Fn(Result<ConfirmationStatus, ProducerPublishError>) -> BoxFuture<'static, ()>
705 + Send
706 + Sync,
707>;
708
709enum ProducerMessageWaiter {
710 Once(OnceProducerMessageWaiter),
711 Shared(SharedProducerMessageWaiter),
712}
713
714struct OnceProducerMessageWaiter {
715 cb: ConfirmCallback,
716 msg: Message,
717}
718impl OnceProducerMessageWaiter {
719 fn waiter_with_cb<Fut>(
720 cb: impl FnOnce(Result<ConfirmationStatus, ProducerPublishError>) -> Fut + Send + Sync + 'static,
721 msg: Message,
722 ) -> Self
723 where
724 Fut: Future<Output = ()> + Send + Sync + 'static,
725 {
726 Self {
727 cb: Box::new(move |confirm_status| cb(confirm_status).boxed()),
728 msg,
729 }
730 }
731}
732
733#[derive(Clone)]
734struct SharedProducerMessageWaiter {
735 cb: ArcConfirmCallback,
736 msg: Message,
737}
738
739impl SharedProducerMessageWaiter {
740 fn waiter_with_arc_cb(confirm_callback: ArcConfirmCallback, msg: Message) -> Self {
741 Self {
742 cb: confirm_callback,
743 msg,
744 }
745 }
746}