1use std::{sync::Arc, time::Duration};
2
3use chrono::Utc;
4use rand::Rng;
5use tokio::{
6 sync::{mpsc, watch, OwnedSemaphorePermit, Semaphore},
7 task::{JoinHandle, JoinSet},
8};
9
10use crate::{
11 AckMode, BatchOutcome, BoxedError, DeliveryControl, DeliveryHandle, EventBusError, Handler,
12 Message, MessageId, PublishOptions, Publisher, Subscriber, SubscriptionConfig,
13};
14
15use super::{
16 ack_flusher::{self, AckRequest},
17 backend::{ClaimedMessage, FetchedEntry, SharedBackend, StreamBackend},
18 delivery::StreamDelivery,
19 observer::{ErrorObserver, ErrorScope},
20 subscription::StreamSubscription,
21};
22
23use crate::HEADER_DEAD_LETTER_REASON;
24
25const DEFAULT_PUBLISH_BATCH_PARALLELISM: usize = 32;
26const MAX_BACKOFF_CEILING: Duration = Duration::from_secs(5);
27const DEFAULT_MAX_PAYLOAD_BYTES: usize = 4 * 1024 * 1024;
31
32type DeliveryTaskResult = Result<(), EventBusError>;
33
34struct RuntimeState {
41 handler: Arc<dyn Handler>,
42 config: Arc<SubscriptionConfig>,
43 limiter: Arc<Semaphore>,
44 ack_tx: mpsc::Sender<AckRequest>,
45}
46
47impl Clone for RuntimeState {
48 fn clone(&self) -> Self {
49 Self {
50 handler: Arc::clone(&self.handler),
51 config: Arc::clone(&self.config),
52 limiter: Arc::clone(&self.limiter),
53 ack_tx: self.ack_tx.clone(),
54 }
55 }
56}
57
58#[derive(Clone)]
59pub struct StreamBusOptions {
60 pub block_timeout: Duration,
61 pub claim_idle_timeout: Duration,
62 pub claim_scan_batch_size: usize,
63 pub group_start_id: String,
64 pub publish_batch_parallelism: usize,
68 pub ack_batch_size: usize,
70 pub ack_flush_interval: Duration,
74 pub reclaim_interval: Duration,
78 pub max_payload_bytes: usize,
83 pub error_observer: Option<Arc<dyn ErrorObserver>>,
88}
89
90impl std::fmt::Debug for StreamBusOptions {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("StreamBusOptions")
93 .field("block_timeout", &self.block_timeout)
94 .field("claim_idle_timeout", &self.claim_idle_timeout)
95 .field("claim_scan_batch_size", &self.claim_scan_batch_size)
96 .field("group_start_id", &self.group_start_id)
97 .field("publish_batch_parallelism", &self.publish_batch_parallelism)
98 .field("ack_batch_size", &self.ack_batch_size)
99 .field("ack_flush_interval", &self.ack_flush_interval)
100 .field("reclaim_interval", &self.reclaim_interval)
101 .field("max_payload_bytes", &self.max_payload_bytes)
102 .field(
103 "error_observer",
104 &self.error_observer.as_ref().map(|_| "<observer>"),
105 )
106 .finish()
107 }
108}
109
110impl Default for StreamBusOptions {
111 fn default() -> Self {
112 Self {
113 block_timeout: Duration::from_secs(2),
114 claim_idle_timeout: Duration::from_secs(60),
115 claim_scan_batch_size: 64,
116 group_start_id: "$".to_string(),
117 publish_batch_parallelism: DEFAULT_PUBLISH_BATCH_PARALLELISM,
118 ack_batch_size: 64,
119 ack_flush_interval: Duration::from_millis(2),
120 reclaim_interval: Duration::from_millis(500),
121 max_payload_bytes: DEFAULT_MAX_PAYLOAD_BYTES,
122 error_observer: None,
123 }
124 }
125}
126
127impl StreamBusOptions {
128 #[must_use]
131 pub fn new() -> Self {
132 Self::default()
133 }
134
135 #[must_use]
137 pub fn with_block_timeout(mut self, v: Duration) -> Self {
138 self.block_timeout = v;
139 self
140 }
141
142 #[must_use]
144 pub fn with_claim_idle_timeout(mut self, v: Duration) -> Self {
145 self.claim_idle_timeout = v;
146 self
147 }
148
149 #[must_use]
151 pub fn with_claim_scan_batch_size(mut self, v: usize) -> Self {
152 self.claim_scan_batch_size = v;
153 self
154 }
155
156 #[must_use]
158 pub fn with_group_start_id(mut self, v: impl Into<String>) -> Self {
159 self.group_start_id = v.into();
160 self
161 }
162
163 #[must_use]
165 pub fn with_publish_batch_parallelism(mut self, v: usize) -> Self {
166 self.publish_batch_parallelism = v;
167 self
168 }
169
170 #[must_use]
172 pub fn with_ack_batch_size(mut self, v: usize) -> Self {
173 self.ack_batch_size = v;
174 self
175 }
176
177 #[must_use]
179 pub fn with_ack_flush_interval(mut self, v: Duration) -> Self {
180 self.ack_flush_interval = v;
181 self
182 }
183
184 #[must_use]
186 pub fn with_reclaim_interval(mut self, v: Duration) -> Self {
187 self.reclaim_interval = v;
188 self
189 }
190
191 #[must_use]
194 pub fn with_max_payload_bytes(mut self, v: usize) -> Self {
195 self.max_payload_bytes = v;
196 self
197 }
198
199 #[must_use]
201 pub fn with_error_observer(mut self, observer: Arc<dyn ErrorObserver>) -> Self {
202 self.error_observer = Some(observer);
203 self
204 }
205
206 fn normalize(mut self) -> Result<Self, EventBusError> {
207 if self.block_timeout.is_zero() {
208 self.block_timeout = Duration::from_secs(2);
209 }
210
211 if self.claim_idle_timeout.is_zero() {
212 self.claim_idle_timeout = Duration::from_secs(60);
213 }
214
215 if self.claim_scan_batch_size == 0 {
216 self.claim_scan_batch_size = 64;
217 }
218
219 if self.group_start_id.trim().is_empty() {
220 self.group_start_id = "$".to_string();
221 }
222
223 if self.publish_batch_parallelism == 0 {
224 self.publish_batch_parallelism = DEFAULT_PUBLISH_BATCH_PARALLELISM;
225 }
226
227 if self.ack_batch_size == 0 {
228 self.ack_batch_size = 64;
229 }
230
231 if self.ack_flush_interval.is_zero() {
232 self.ack_flush_interval = Duration::from_millis(2);
233 }
234
235 if self.reclaim_interval.is_zero() {
236 self.reclaim_interval = Duration::from_millis(500);
237 }
238
239 Ok(self)
240 }
241}
242
243pub struct StreamBus<B: StreamBackend> {
244 backend: SharedBackend<B>,
245 options: StreamBusOptions,
246}
247
248impl<B: StreamBackend> Clone for StreamBus<B> {
249 fn clone(&self) -> Self {
250 Self {
251 backend: Arc::clone(&self.backend),
252 options: self.options.clone(),
253 }
254 }
255}
256
257impl<B: StreamBackend> StreamBus<B> {
258 pub fn new(
259 backend: SharedBackend<B>,
260 options: StreamBusOptions,
261 ) -> Result<Self, EventBusError> {
262 Ok(Self {
263 backend,
264 options: options.normalize()?,
265 })
266 }
267
268 pub async fn publish(
269 &self,
270 msg: Message,
271 opts: PublishOptions,
272 ) -> Result<MessageId, EventBusError> {
273 <Self as Publisher>::publish(self, msg, opts).await
274 }
275
276 pub async fn publish_batch(
277 &self,
278 msgs: Vec<Message>,
279 opts: PublishOptions,
280 ) -> Result<BatchOutcome, EventBusError> {
281 <Self as Publisher>::publish_batch(self, msgs, opts).await
282 }
283
284 pub async fn subscribe<H>(
288 &self,
289 cfg: SubscriptionConfig,
290 handler: H,
291 ) -> Result<StreamSubscription, EventBusError>
292 where
293 H: Handler + 'static,
294 {
295 self.subscribe_inner(cfg, Arc::new(handler)).await
296 }
297
298 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, message, options), fields(topic = %message.topic)))]
299 async fn publish_inner(
300 &self,
301 message: Message,
302 options: &PublishOptions,
303 ) -> Result<MessageId, EventBusError> {
304 options.validate()?;
305
306 if let Some(delay) = options.delay {
307 tokio::time::sleep(delay).await;
308 }
309
310 let message = Self::prepare_message(message, options, self.options.max_payload_bytes)?;
311
312 let topic = message.topic.clone();
313 let id = self.backend.publish(topic.as_str(), message).await?;
314 Ok(MessageId::new(id))
315 }
316
317 #[cfg_attr(feature = "tracing", tracing::instrument(skip(self, msgs, opts), fields(count = msgs.len())))]
318 async fn publish_batch_impl(
319 &self,
320 msgs: Vec<Message>,
321 opts: PublishOptions,
322 ) -> Result<BatchOutcome, EventBusError> {
323 opts.validate()?;
324
325 if let Some(delay) = opts.delay {
326 tokio::time::sleep(delay).await;
327 }
328
329 let max_payload_bytes = self.options.max_payload_bytes;
330 let prepared: Vec<(usize, Result<Message, EventBusError>)> = msgs
331 .into_iter()
332 .enumerate()
333 .map(|(idx, m)| (idx, Self::prepare_message(m, &opts, max_payload_bytes)))
334 .collect();
335
336 let total = prepared.len();
337 let parallelism = total.clamp(1, self.options.publish_batch_parallelism);
338 let mut iter = prepared.into_iter();
339 let mut tasks: JoinSet<(usize, Result<MessageId, EventBusError>)> = JoinSet::new();
340 let mut results: Vec<Option<Result<MessageId, EventBusError>>> =
341 std::iter::repeat_with(|| None).take(total).collect();
342
343 for _ in 0..parallelism {
347 if let Some((idx, prep)) = iter.next() {
348 results[idx] = Some(Err(EventBusError::Internal(
349 "publish task did not complete".into(),
350 )));
351 let backend = Arc::clone(&self.backend);
352 tasks.spawn(async move {
353 let r = match prep {
354 Err(e) => Err(e),
355 Ok(m) => {
356 let topic = m.topic.clone();
357 backend.publish(topic.as_str(), m).await.map(MessageId::new)
358 }
359 };
360 (idx, r)
361 });
362 }
363 }
364
365 while let Some(joined) = tasks.join_next().await {
366 match joined {
367 Ok((idx, r)) => {
368 if idx < results.len() {
369 results[idx] = Some(r);
370 }
371 }
372 Err(je) => {
373 if let Some(obs) = self.options.error_observer.as_ref() {
377 obs.on_panic(ErrorScope::HandlerPanic, &je.to_string());
378 }
379 }
380 }
381 if let Some((next_idx, prep)) = iter.next() {
382 results[next_idx] = Some(Err(EventBusError::Internal(
383 "publish task did not complete".into(),
384 )));
385 let backend = Arc::clone(&self.backend);
386 tasks.spawn(async move {
387 let r = match prep {
388 Err(e) => Err(e),
389 Ok(m) => {
390 let topic = m.topic.clone();
391 backend.publish(topic.as_str(), m).await.map(MessageId::new)
392 }
393 };
394 (next_idx, r)
395 });
396 }
397 }
398
399 Ok(BatchOutcome {
400 results: results
401 .into_iter()
402 .map(|o| {
403 o.unwrap_or_else(|| {
404 Err(EventBusError::Internal(
405 "publish_batch slot never filled".into(),
406 ))
407 })
408 })
409 .collect(),
410 })
411 }
412
413 fn prepare_message(
414 mut message: Message,
415 options: &PublishOptions,
416 max_payload_bytes: usize,
417 ) -> Result<Message, EventBusError> {
418 if max_payload_bytes > 0 && message.payload.len() > max_payload_bytes {
420 return Err(EventBusError::Validation(format!(
421 "message payload {} bytes exceeds max_payload_bytes {}",
422 message.payload.len(),
423 max_payload_bytes,
424 )));
425 }
426
427 for (key, value) in &options.metadata {
428 message.headers.insert(key.clone(), value.clone());
429 }
430
431 if let Some(idempotency_key) = options.idempotency_key.as_deref() {
432 message.set_idempotency_key(idempotency_key);
433 }
434
435 Ok(message)
436 }
437
438 #[cfg_attr(
439 feature = "tracing",
440 tracing::instrument(
441 skip_all,
442 fields(
443 topic = %runtime.config.topic.as_str(),
444 group = %runtime.config.consumer_group.as_str()
445 )
446 )
447 )]
448 async fn consume_loop(
449 self,
450 mut close_rx: watch::Receiver<bool>,
451 runtime: RuntimeState,
452 mut reclaim_rx: mpsc::Receiver<Vec<FetchedEntry>>,
453 flusher_handle: JoinHandle<()>,
454 reclaim_handle: JoinHandle<()>,
455 ) -> Result<(), EventBusError> {
456 let mut tasks = JoinSet::new();
457 let mut first_delivery_error: Option<EventBusError> = None;
458 let mut backoff = BackoffState::new(runtime.config.retry_backoff);
459 let observer = self.options.error_observer.clone();
460
461 loop {
462 if *close_rx.borrow() {
463 break;
464 }
465
466 drain_completed_tasks(&mut tasks, observer.as_ref(), &mut first_delivery_error)?;
467
468 let max_batch = runtime
472 .config
473 .backpressure
474 .as_ref()
475 .map_or(usize::MAX, |p| p.max_batch_size.max(1));
476 let mut permits: Vec<OwnedSemaphorePermit> = Vec::new();
477 while permits.len() < max_batch {
478 match Arc::clone(&runtime.limiter).try_acquire_owned() {
479 Ok(p) => permits.push(p),
480 Err(_) => break,
481 }
482 }
483
484 if permits.is_empty() {
485 if !wait_for_task_or_close(
486 &mut tasks,
487 &mut close_rx,
488 backoff.peek(),
489 observer.as_ref(),
490 &mut first_delivery_error,
491 )
492 .await
493 {
494 break;
495 }
496 continue;
497 }
498
499 let read_limit = permits.len();
500 let read_future = self.backend.read_new(
501 runtime.config.topic.as_str(),
502 runtime.config.consumer_group.as_str(),
503 runtime.config.consumer_name.as_str(),
504 read_limit,
505 self.options.block_timeout,
506 );
507 tokio::pin!(read_future);
508
509 let mut any_work = false;
510
511 tokio::select! {
515 biased;
516 changed = close_rx.changed() => {
517 if changed.is_ok() && *close_rx.borrow() {
518 break;
519 }
520 continue;
522 }
523 Some(reclaimed) = reclaim_rx.recv() => {
524 if !reclaimed.is_empty() {
525 any_work = true;
526 self.spawn_messages(&mut tasks, reclaimed, &mut permits, &runtime).await?;
527 }
528 }
529 result = &mut read_future => {
530 match result {
531 Ok(messages) if !messages.is_empty() => {
532 any_work = true;
533 self.spawn_messages(&mut tasks, messages, &mut permits, &runtime).await?;
534 }
535 Ok(_) => {}
536 Err(err) => {
537 if let Some(obs) = observer.as_ref() {
538 obs.on_error(ErrorScope::Read, &err);
539 }
540 let sleep_dur = backoff.next();
541 if !sleep_or_close(&mut close_rx, sleep_dur).await {
543 break;
544 }
545 continue;
546 }
547 }
548 }
549 }
550
551 if any_work {
554 backoff.reset();
555 }
556 }
557
558 while let Some(result) = tasks.join_next().await {
560 match result {
561 Ok(Ok(())) => {}
562 Ok(Err(err)) => {
563 first_delivery_error.get_or_insert(err);
564 }
565 Err(err) => {
566 if let Some(obs) = observer.as_ref() {
567 obs.on_panic(ErrorScope::HandlerPanic, &err.to_string());
568 }
569 first_delivery_error.get_or_insert_with(|| {
570 EventBusError::source("delivery task panicked", err)
571 });
572 }
573 }
574 }
575
576 let topic = runtime.config.topic.clone();
579 let group = runtime.config.consumer_group.clone();
580 let consumer = runtime.config.consumer_name.clone();
581
582 drop(runtime);
584 drop(reclaim_rx);
585 let _ = reclaim_handle.await;
586 let _ = flusher_handle.await;
587
588 self.backend
589 .forget_consumer(topic.as_str(), group.as_str(), consumer.as_str())
590 .await;
591
592 if let Some(err) = first_delivery_error {
593 return Err(err);
594 }
595
596 Ok(())
597 }
598
599 async fn spawn_messages(
600 &self,
601 tasks: &mut JoinSet<DeliveryTaskResult>,
602 entries: Vec<FetchedEntry>,
603 permits: &mut Vec<OwnedSemaphorePermit>,
604 runtime: &RuntimeState,
605 ) -> Result<(), EventBusError> {
606 for entry in entries {
612 match entry {
613 FetchedEntry::Decoded(claimed) => {
614 let Some(permit) = permits.pop() else {
615 break;
617 };
618 let bus = self.clone();
619 let config = Arc::clone(&runtime.config);
620 let handler = Arc::clone(&runtime.handler);
621 let ack_tx = runtime.ack_tx.clone();
622 tasks.spawn(async move {
623 bus.process_single_message(config, handler, claimed, permit, ack_tx)
624 .await
625 });
626 }
627 FetchedEntry::Malformed { id, error } => {
628 self.handle_malformed_entry(&runtime.config, id, error)
629 .await;
630 }
631 }
632 }
633 Ok(())
634 }
635
636 async fn handle_malformed_entry(
641 &self,
642 config: &SubscriptionConfig,
643 id: String,
644 error: EventBusError,
645 ) {
646 if let Some(obs) = self.options.error_observer.as_ref() {
647 obs.on_error(ErrorScope::Read, &error);
648 }
649
650 if let Some(dlq) = config.dead_letter_topic.as_ref() {
651 let mut headers = std::collections::HashMap::new();
652 headers.insert(HEADER_DEAD_LETTER_REASON.to_string(), error.to_string());
653 let envelope = Message {
654 uid: format!("malformed-{id}"),
655 topic: dlq.clone(),
656 key: id.clone(),
657 kind: "eventbus.malformed".into(),
658 source: config.topic.as_str().to_string(),
659 occurred_at: Utc::now(),
660 headers,
661 payload: bytes::Bytes::new(),
662 content_type: None,
663 event_version: None,
664 idempotency_key: None,
665 expires_at: None,
666 trace_uid: None,
667 correlation_uid: None,
668 };
669 if let Err(err) = self.backend.publish(dlq.as_str(), envelope).await {
670 if let Some(obs) = self.options.error_observer.as_ref() {
671 obs.on_error(ErrorScope::Read, &err);
672 }
673 }
676 }
677
678 if let Err(err) = self
679 .backend
680 .ack(config.topic.as_str(), config.consumer_group.as_str(), &id)
681 .await
682 {
683 if let Some(obs) = self.options.error_observer.as_ref() {
684 obs.on_error(ErrorScope::AckFlush, &err);
685 }
686 }
687 }
688
689 #[cfg_attr(
690 feature = "tracing",
691 tracing::instrument(
692 skip(self, config, handler, claimed, permit, ack_tx),
693 fields(message_id = %claimed.id)
694 )
695 )]
696 async fn process_single_message(
697 &self,
698 config: Arc<SubscriptionConfig>,
699 handler: Arc<dyn Handler>,
700 claimed: ClaimedMessage,
701 permit: OwnedSemaphorePermit,
702 ack_tx: mpsc::Sender<AckRequest>,
703 ) -> Result<(), EventBusError> {
704 use super::auto_finalize::AutoFinalizeTracker;
705
706 let max_attempt = (config.max_retry as u32).saturating_add(1);
710 let state = claimed.state.with_max_attempt(max_attempt);
711 let max_payload_bytes = self.options.max_payload_bytes;
712 let ack_mode = config.ack_mode;
713
714 if max_payload_bytes > 0 && claimed.message.payload.len() > max_payload_bytes {
717 let oversize_err = EventBusError::Validation(format!(
718 "received payload {} bytes exceeds max_payload_bytes {}",
719 claimed.message.payload.len(),
720 max_payload_bytes,
721 ));
722 let delivery = Box::new(StreamDelivery::new(
723 Arc::clone(&self.backend),
724 ack_tx,
725 claimed.id,
726 claimed.message,
727 state,
728 Arc::clone(&config),
729 permit,
730 ));
731 if config.dead_letter_topic.is_some() {
732 let reason: BoxedError = Box::new(SimpleError(oversize_err.to_string()));
733 return delivery.nack(reason).await;
734 }
735 return Err(oversize_err);
736 }
737
738 let delivery = Box::new(StreamDelivery::new(
739 Arc::clone(&self.backend),
740 ack_tx,
741 claimed.id,
742 claimed.message,
743 state,
744 Arc::clone(&config),
745 permit,
746 ));
747
748 match ack_mode {
749 AckMode::AutoOnReceive => {
750 delivery.pre_ack().await?;
756 let boxed: Box<dyn DeliveryHandle> = delivery;
757 let _ = handler.handle(boxed).await;
762 Ok(())
763 }
764 AckMode::Manual => {
765 let boxed: Box<dyn DeliveryHandle> = delivery;
766 let _ = handler.handle(boxed).await;
767 Ok(())
768 }
769 AckMode::AutoOnHandlerSuccess => {
770 let real: Box<dyn DeliveryHandle> = delivery;
771 let (tracker, proxy) = AutoFinalizeTracker::new(real).await?;
772 let proxy_boxed: Box<dyn DeliveryHandle> = Box::new(proxy);
773 let result = handler.handle(proxy_boxed).await;
774 if let Some(remaining) = tracker.take_remaining() {
776 match &result {
777 Ok(()) => remaining.ack().await?,
778 Err(err) => {
779 let reason: BoxedError = Box::new(SimpleError(err.to_string()));
780 remaining.retry(reason).await?;
781 }
782 }
783 }
784 Ok(())
785 }
786 }
787 }
788}
789
790#[derive(Debug)]
793struct SimpleError(String);
794
795impl std::fmt::Display for SimpleError {
796 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
797 f.write_str(&self.0)
798 }
799}
800
801impl std::error::Error for SimpleError {}
802
803impl<B: StreamBackend> Publisher for StreamBus<B> {
804 fn publish(
805 &self,
806 msg: Message,
807 opts: PublishOptions,
808 ) -> crate::BoxFuture<'_, Result<MessageId, EventBusError>> {
809 Box::pin(async move { self.publish_inner(msg, &opts).await })
810 }
811
812 fn publish_batch(
813 &self,
814 msgs: Vec<Message>,
815 opts: PublishOptions,
816 ) -> crate::BoxFuture<'_, Result<BatchOutcome, EventBusError>> {
817 Box::pin(async move { self.publish_batch_impl(msgs, opts).await })
818 }
819}
820
821impl<B: StreamBackend> StreamBus<B> {
822 #[cfg_attr(
823 feature = "tracing",
824 tracing::instrument(
825 skip(self, cfg, handler),
826 fields(topic = %cfg.topic, group = %cfg.consumer_group)
827 )
828 )]
829 async fn subscribe_inner(
830 &self,
831 mut cfg: SubscriptionConfig,
832 handler: Arc<dyn Handler>,
833 ) -> Result<StreamSubscription, EventBusError> {
834 if cfg.consumer_name.as_str().trim().is_empty() {
838 cfg.consumer_name = crate::ConsumerName::auto();
840 }
841
842 cfg.normalize_and_validate()?;
843
844 if cfg.balance_mode == Some(crate::ConsumerBalanceMode::FanOut) {
845 return Err(EventBusError::Validation(
846 "FanOut balance mode is not yet supported by StreamBus".into(),
847 ));
848 }
849
850 self.backend
851 .create_group(
852 cfg.topic.as_str(),
853 cfg.consumer_group.as_str(),
854 &self.options.group_start_id,
855 )
856 .await?;
857
858 let (close_tx, close_rx) = watch::channel(false);
859 let limit = cfg.max_in_flight.max(1);
868 let consumer_name = cfg.consumer_name.as_str().to_string();
869
870 let stream = cfg.topic.as_str().to_string();
871 let group = cfg.consumer_group.as_str().to_string();
872 let (ack_tx, flusher_handle) = ack_flusher::spawn(
873 Arc::clone(&self.backend),
874 stream,
875 group,
876 self.options.ack_batch_size,
877 self.options.ack_flush_interval,
878 self.options.error_observer.clone(),
879 );
880
881 let limiter = Arc::new(Semaphore::new(limit));
882
883 let (reclaim_tx, reclaim_rx) = mpsc::channel::<Vec<FetchedEntry>>(4);
885 let reclaim_handle = tokio::spawn({
886 let args = ReclaimLoopArgs {
887 backend: Arc::clone(&self.backend),
888 close_rx: close_rx.clone(),
889 reclaim_tx,
890 topic: cfg.topic.as_str().to_string(),
891 group: cfg.consumer_group.as_str().to_string(),
892 consumer: cfg.consumer_name.as_str().to_string(),
893 claim_idle_timeout: self.options.claim_idle_timeout,
894 claim_scan_batch_size: self.options.claim_scan_batch_size,
895 reclaim_interval: self.options.reclaim_interval,
896 error_observer: self.options.error_observer.clone(),
897 };
898 async move { reclaim_loop(args).await }
899 });
900
901 let runtime = RuntimeState {
902 handler,
903 config: Arc::new(cfg),
904 limiter,
905 ack_tx,
906 };
907
908 let task = tokio::spawn({
909 let bus = self.clone();
910 let runtime = runtime.clone();
911 async move {
912 bus.consume_loop(
913 close_rx,
914 runtime,
915 reclaim_rx,
916 flusher_handle,
917 reclaim_handle,
918 )
919 .await
920 }
921 });
922
923 drop(runtime);
924
925 Ok(StreamSubscription::new(
926 consumer_name,
927 close_tx,
928 task,
929 self.options.error_observer.clone(),
930 ))
931 }
932}
933
934impl<B: StreamBackend> Subscriber for StreamBus<B> {
935 fn subscribe(
936 &self,
937 cfg: SubscriptionConfig,
938 handler: Arc<dyn Handler>,
939 ) -> crate::BoxFuture<'_, Result<Arc<dyn crate::Subscription>, EventBusError>> {
940 Box::pin(async move {
941 let sub = self.subscribe_inner(cfg, handler).await?;
942 Ok(Arc::new(sub) as Arc<dyn crate::Subscription>)
943 })
944 }
945}
946
947struct ReclaimLoopArgs<B: StreamBackend> {
952 backend: SharedBackend<B>,
953 close_rx: watch::Receiver<bool>,
954 reclaim_tx: mpsc::Sender<Vec<FetchedEntry>>,
955 topic: String,
956 group: String,
957 consumer: String,
958 claim_idle_timeout: Duration,
959 claim_scan_batch_size: usize,
960 reclaim_interval: Duration,
961 error_observer: Option<Arc<dyn ErrorObserver>>,
962}
963
964async fn reclaim_loop<B: StreamBackend>(args: ReclaimLoopArgs<B>) {
965 let ReclaimLoopArgs {
966 backend,
967 mut close_rx,
968 reclaim_tx,
969 topic,
970 group,
971 consumer,
972 claim_idle_timeout,
973 claim_scan_batch_size,
974 reclaim_interval,
975 error_observer,
976 } = args;
977
978 let mut backoff = BackoffState::new(Duration::from_millis(100));
979
980 loop {
981 if !sleep_or_close(&mut close_rx, reclaim_interval).await {
982 break;
983 }
984
985 let count = claim_scan_batch_size;
990 match backend
991 .reclaim_idle(&topic, &group, &consumer, claim_idle_timeout, count)
992 .await
993 {
994 Ok(messages) => {
995 if !messages.is_empty() && reclaim_tx.send(messages).await.is_err() {
996 break;
997 }
998 backoff.reset();
999 }
1000 Err(err) => {
1001 if let Some(obs) = error_observer.as_ref() {
1002 obs.on_error(ErrorScope::Reclaim, &err);
1003 }
1004 let dur = backoff.next();
1005 if !sleep_or_close(&mut close_rx, dur).await {
1006 break;
1007 }
1008 }
1009 }
1010 }
1011}
1012
1013fn drain_completed_tasks(
1019 tasks: &mut JoinSet<DeliveryTaskResult>,
1020 observer: Option<&Arc<dyn ErrorObserver>>,
1021 first_delivery_error: &mut Option<EventBusError>,
1022) -> Result<(), EventBusError> {
1023 while let Some(result) = tasks.try_join_next() {
1024 match result {
1025 Ok(Ok(())) => {}
1026 Ok(Err(err)) => {
1027 first_delivery_error.get_or_insert(err);
1028 }
1029 Err(err) => {
1030 if let Some(obs) = observer {
1031 obs.on_panic(ErrorScope::HandlerPanic, &err.to_string());
1032 }
1033 return Err(EventBusError::source("delivery task panicked", err));
1034 }
1035 }
1036 }
1037 Ok(())
1038}
1039
1040struct BackoffState {
1042 base: Duration,
1043 current: Duration,
1044}
1045
1046impl BackoffState {
1047 fn new(base: Duration) -> Self {
1048 let base = if base.is_zero() {
1049 Duration::from_millis(100)
1050 } else {
1051 base
1052 };
1053 Self {
1054 base,
1055 current: base,
1056 }
1057 }
1058
1059 fn peek(&self) -> Duration {
1060 self.base
1061 }
1062
1063 fn next(&mut self) -> Duration {
1064 let dur = self.current;
1065 let next_raw = dur.saturating_mul(2).min(MAX_BACKOFF_CEILING);
1066 self.current = next_raw;
1067
1068 let jitter_nanos = rand::thread_rng().gen_range(0..=dur.as_nanos() as u64);
1069 Duration::from_nanos(jitter_nanos)
1070 .saturating_add(dur / 2)
1071 .min(MAX_BACKOFF_CEILING)
1072 }
1073
1074 fn reset(&mut self) {
1075 self.current = self.base;
1076 }
1077}
1078
1079async fn sleep_or_close(close_rx: &mut watch::Receiver<bool>, duration: Duration) -> bool {
1080 tokio::select! {
1081 changed = close_rx.changed() => {
1082 if changed.is_err() {
1083 false
1084 } else {
1085 !*close_rx.borrow()
1086 }
1087 }
1088 _ = tokio::time::sleep(duration) => true,
1089 }
1090}
1091
1092async fn wait_for_task_or_close(
1093 tasks: &mut JoinSet<DeliveryTaskResult>,
1094 close_rx: &mut watch::Receiver<bool>,
1095 duration: Duration,
1096 observer: Option<&Arc<dyn ErrorObserver>>,
1097 first_delivery_error: &mut Option<EventBusError>,
1098) -> bool {
1099 if tasks.is_empty() {
1100 return sleep_or_close(close_rx, duration).await;
1101 }
1102
1103 tokio::select! {
1104 changed = close_rx.changed() => {
1105 if changed.is_err() {
1106 false
1107 } else {
1108 !*close_rx.borrow()
1109 }
1110 }
1111 result = tasks.join_next() => match result {
1112 Some(Ok(Ok(()))) | None => true,
1113 Some(Ok(Err(err))) => {
1114 first_delivery_error.get_or_insert(err);
1115 true
1116 }
1117 Some(Err(err)) => {
1118 if let Some(obs) = observer {
1119 obs.on_panic(ErrorScope::HandlerPanic, &err.to_string());
1120 }
1121 first_delivery_error.get_or_insert_with(|| {
1122 EventBusError::source("delivery task failed", err)
1123 });
1124 true
1125 }
1126 },
1127 }
1128}
1129
1130#[cfg(test)]
1131mod tests {
1132 use std::time::Duration;
1133
1134 use super::{BackoffState, StreamBusOptions, MAX_BACKOFF_CEILING};
1135
1136 #[test]
1137 fn zero_duration_options_normalize_to_defaults() {
1138 let normalized = StreamBusOptions {
1139 block_timeout: Duration::ZERO,
1140 claim_idle_timeout: Duration::ZERO,
1141 claim_scan_batch_size: 0,
1142 group_start_id: String::new(),
1143 publish_batch_parallelism: 0,
1144 ack_batch_size: 0,
1145 ack_flush_interval: Duration::ZERO,
1146 reclaim_interval: Duration::ZERO,
1147 max_payload_bytes: 0,
1148 error_observer: None,
1149 }
1150 .normalize()
1151 .expect("normalize options");
1152
1153 assert_eq!(normalized.block_timeout, Duration::from_secs(2));
1154 assert_eq!(normalized.claim_idle_timeout, Duration::from_secs(60));
1155 assert_eq!(normalized.claim_scan_batch_size, 64);
1156 assert_eq!(normalized.group_start_id, "$".to_string());
1157 assert_eq!(normalized.publish_batch_parallelism, 32);
1158 assert_eq!(normalized.ack_batch_size, 64);
1159 assert_eq!(normalized.ack_flush_interval, Duration::from_millis(2));
1160 assert_eq!(normalized.reclaim_interval, Duration::from_millis(500));
1161 }
1162
1163 #[test]
1164 fn backoff_grows_exponentially_and_caps() {
1165 let mut backoff = BackoffState::new(Duration::from_millis(100));
1166 for _ in 0..20 {
1167 let dur = backoff.next();
1168 assert!(dur <= MAX_BACKOFF_CEILING);
1169 }
1170 backoff.reset();
1171 let first = backoff.next();
1172 assert!(first <= MAX_BACKOFF_CEILING);
1173 }
1174}