1use std::{
2 collections::{BTreeMap, VecDeque},
3 fmt,
4 sync::{
5 Arc, Condvar, Mutex, MutexGuard,
6 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
7 },
8 time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use smallvec::SmallVec;
13
14use crate::stream::{
15 BoxStream, Materializer, NotUsed, Sink, Source, SourceRuntimeHints, StreamCompletion,
16 TerminalSourceHookDyn, TerminalSourceStatus,
17};
18use crate::{StreamError, StreamResult};
19
20type Partitioner<T> = Arc<dyn Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync>;
21
22const MERGE_HUB_BATCH_LIMIT: usize = 256;
23const BROADCAST_HUB_BATCH_LIMIT: usize = 256;
24const BROADCAST_HUB_SINGLE_CONSUMER_BATCH_LIMIT: usize = 64;
25const PARTITION_HUB_BATCH_LIMIT: usize = 1024;
26const PARTITION_HUB_WIDE_BATCH_LIMIT: usize = 2048;
27const PARTITION_HUB_SINGLE_CONSUMER_BATCH_LIMIT: usize = 256;
28const FAN_OUT_CONSUMER_BATCH_LIMIT: usize = 256;
29
30fn fan_out_wait_timeout() -> Duration {
31 Duration::from_millis(1)
32}
33
34#[derive(Clone)]
37pub struct MergeHubDrainingControl {
38 state: Arc<MergeHubState>,
39 on_drain: Arc<dyn Fn() + Send + Sync>,
40}
41
42impl fmt::Debug for MergeHubDrainingControl {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.debug_struct("MergeHubDrainingControl").finish()
45 }
46}
47
48impl MergeHubDrainingControl {
49 pub fn drain_and_complete(&self) {
50 let mut state = self.state.lock();
51 state.draining = true;
52 self.state.condvar.notify_all();
53 drop(state);
54 (self.on_drain)();
55 }
56}
57
58pub struct MergeHub;
61
62impl MergeHub {
63 #[must_use]
65 pub fn source<T: Send + 'static>(
66 per_producer_buffer_size: usize,
67 ) -> Source<T, Sink<T, NotUsed>> {
68 Self::source_with_draining(per_producer_buffer_size)
69 .map_materialized_value(|(sink, _)| sink)
70 }
71
72 #[must_use]
75 pub fn source_with_draining<T: Send + 'static>(
76 per_producer_buffer_size: usize,
77 ) -> Source<T, (Sink<T, NotUsed>, MergeHubDrainingControl)> {
78 assert!(
79 per_producer_buffer_size > 0,
80 "MergeHub per_producer_buffer_size must be greater than zero"
81 );
82 Source::from_terminal_direct_materialized_factory(
83 move |_| {
84 let (state, sink, control) =
85 new_merge_hub_materialization(per_producer_buffer_size);
86 let source = Box::new(MergeHubSourceStream {
87 state: Arc::clone(&state),
88 local: VecDeque::new(),
89 prefer_direct: false,
90 }) as BoxStream<T>;
91 Ok((source, (sink, control)))
92 },
93 move |_| {
94 let (state, sink, control) =
95 new_merge_hub_materialization(per_producer_buffer_size);
96 let hook = Arc::new(MergeHubTerminalHook {
97 state,
98 prefer_direct: AtomicBool::new(false),
99 }) as Arc<dyn TerminalSourceHookDyn<T>>;
100 Ok((hook, (sink, control)))
101 },
102 )
103 }
104}
105
106fn new_merge_hub_materialization<T: Send + 'static>(
107 per_producer_buffer_size: usize,
108) -> (
109 Arc<MergeHubShared<T>>,
110 Sink<T, NotUsed>,
111 MergeHubDrainingControl,
112) {
113 let state = Arc::new(MergeHubShared::<T>::new(per_producer_buffer_size));
114 let sink = merge_hub_sink(Arc::clone(&state));
115 let control = MergeHubDrainingControl {
116 state: Arc::clone(&state.state),
117 on_drain: Arc::new({
118 let state = Arc::clone(&state);
119 move || state.finish_if_draining()
120 }),
121 };
122 (state, sink, control)
123}
124
125pub struct BroadcastHub;
128
129impl BroadcastHub {
130 #[must_use]
138 pub fn sink<T: Clone + Send + 'static>(
139 buffer_size: usize,
140 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
141 Self::sink_starting_after(0, buffer_size)
142 }
143
144 #[must_use]
146 pub fn sink_starting_after<T: Clone + Send + 'static>(
147 start_after_nr_of_consumers: usize,
148 buffer_size: usize,
149 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
150 assert!(
151 buffer_size > 0,
152 "BroadcastHub buffer_size must be greater than zero"
153 );
154 Sink::from_hinted_runner(move |input, materializer, hints| {
155 let state = Arc::new(FanOutHubShared::new(
156 FanOutMode::Broadcast,
157 start_after_nr_of_consumers,
158 buffer_size,
159 None::<Partitioner<T>>,
160 ));
161 let source = BroadcastHubConsumerSource {
162 state: Arc::clone(&state),
163 completion: Arc::new(Mutex::new(None)),
164 };
165 let completion = materializer.spawn_stream(move |cancelled| {
166 FanOutProducer::new(input, state).run(cancelled, hints)
167 });
168 source.attach_completion(completion);
169 Ok(source)
170 })
171 }
172}
173
174pub struct PartitionHub;
177
178impl PartitionHub {
179 #[must_use]
187 pub fn sink<T: Clone + Send + 'static, F>(
188 partitioner: F,
189 start_after_nr_of_consumers: usize,
190 buffer_size: usize,
191 ) -> Sink<T, PartitionHubConsumerSource<T>>
192 where
193 F: Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync + 'static,
194 {
195 assert!(
196 buffer_size > 0,
197 "PartitionHub buffer_size must be greater than zero"
198 );
199 let partitioner = Arc::new(partitioner);
200 Sink::from_hinted_runner(move |input, materializer, hints| {
201 let partitioner = Arc::clone(&partitioner);
202 let state = Arc::new(FanOutHubShared::new(
203 FanOutMode::Partition,
204 start_after_nr_of_consumers,
205 buffer_size,
206 Some(partitioner),
207 ));
208 let source = PartitionHubConsumerSource {
209 state: Arc::clone(&state),
210 completion: Arc::new(Mutex::new(None)),
211 };
212 let completion = materializer.spawn_stream(move |cancelled| {
213 FanOutProducer::new(input, state).run(cancelled, hints)
214 });
215 source.attach_completion(completion);
216 Ok(source)
217 })
218 }
219}
220
221#[derive(Clone)]
224pub struct BroadcastHubConsumerSource<T> {
225 state: Arc<FanOutHubShared<T>>,
226 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
227}
228
229impl<T: Clone + Send + 'static> BroadcastHubConsumerSource<T> {
230 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
231 *self
232 .completion
233 .lock()
234 .expect("broadcast hub completion poisoned") = Some(completion);
235 }
236
237 #[must_use]
238 pub fn source(&self) -> Source<T, NotUsed> {
239 let state = Arc::clone(&self.state);
240 Source::from_materialized_factory(move |_| {
241 let lane = state.register_consumer();
242 let stream = Box::new(FanOutConsumerStream {
243 state: Arc::clone(&state),
244 lane,
245 local: None,
246 detached: false,
247 }) as BoxStream<T>;
248 Ok((stream, NotUsed))
249 })
250 }
251}
252
253impl<T: Clone + Send + 'static> fmt::Debug for BroadcastHubConsumerSource<T> {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("BroadcastHubConsumerSource").finish()
256 }
257}
258
259#[derive(Clone)]
262pub struct PartitionHubConsumerSource<T> {
263 state: Arc<FanOutHubShared<T>>,
264 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
265}
266
267impl<T: Clone + Send + 'static> PartitionHubConsumerSource<T> {
268 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
269 *self
270 .completion
271 .lock()
272 .expect("partition hub completion poisoned") = Some(completion);
273 }
274
275 #[must_use]
276 pub fn source(&self) -> Source<T, NotUsed> {
277 let state = Arc::clone(&self.state);
278 Source::from_materialized_factory(move |_| {
279 let lane = state.register_consumer();
280 let stream = Box::new(FanOutConsumerStream {
281 state: Arc::clone(&state),
282 lane,
283 local: None,
284 detached: false,
285 }) as BoxStream<T>;
286 Ok((stream, NotUsed))
287 })
288 }
289}
290
291impl<T: Clone + Send + 'static> fmt::Debug for PartitionHubConsumerSource<T> {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 f.debug_struct("PartitionHubConsumerSource").finish()
294 }
295}
296
297#[derive(Clone, Debug)]
300pub struct PartitionConsumerInfo {
301 consumer_ids: SmallVec<[u64; 16]>,
302 queue_sizes: SmallVec<[(u64, usize); 16]>,
303}
304
305impl PartitionConsumerInfo {
306 #[must_use]
308 pub fn size(&self) -> usize {
309 self.consumer_ids.len()
310 }
311
312 #[must_use]
314 pub fn consumer_ids(&self) -> &[u64] {
315 &self.consumer_ids
316 }
317
318 #[must_use]
320 pub fn consumer_id_by_idx(&self, idx: usize) -> u64 {
321 self.consumer_ids[idx]
322 }
323
324 #[must_use]
326 pub fn queue_size(&self, consumer_id: u64) -> usize {
327 self.queue_sizes
328 .iter()
329 .find_map(|(id, size)| (*id == consumer_id).then_some(*size))
330 .unwrap_or(0)
331 }
332}
333
334fn merge_hub_sink<T: Send + 'static>(state: Arc<MergeHubShared<T>>) -> Sink<T, NotUsed> {
335 Sink::from_raw_hinted_runner(move |input, materializer, hints| {
336 if hints.inline_micro_max_success_items.is_some() {
337 state.register_direct_producer(input)?;
338 return Ok(NotUsed);
339 }
340
341 let input = materializer.checked_stream(input, None);
342 let producer_id = state.register_producer()?;
343 let hub = Arc::clone(&state);
344 let completion = materializer.spawn_stream(move |cancelled| {
345 let mut input = input;
346 loop {
347 if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
348 hub.fail(StreamError::Cancelled);
349 hub.deregister_producer(producer_id);
350 return Err(StreamError::Cancelled);
351 }
352
353 match input.next() {
354 Some(Ok(item)) => hub.push_item(producer_id, item)?,
355 Some(Err(error)) => {
356 hub.fail(error.clone());
357 hub.deregister_producer(producer_id);
358 return Err(error);
359 }
360 None => {
361 hub.deregister_producer(producer_id);
362 return Ok(NotUsed);
363 }
364 }
365 }
366 });
367 state.store_producer_completion(completion);
368 Ok(NotUsed)
369 })
370}
371
372struct MergeHubShared<T> {
373 state: Arc<MergeHubState>,
374 shared: Mutex<MergeHubInner<T>>,
375 condvar: Condvar,
376 failed: AtomicBool,
377}
378
379#[derive(Debug)]
380struct MergeHubState {
381 inner: Mutex<MergeHubFlags>,
382 condvar: Condvar,
383}
384
385#[derive(Debug, Default)]
386struct MergeHubFlags {
387 draining: bool,
388}
389
390impl MergeHubState {
391 fn lock(&self) -> MutexGuard<'_, MergeHubFlags> {
392 self.inner.lock().expect("merge hub flags poisoned")
393 }
394}
395
396struct MergeHubInner<T> {
397 queue: VecDeque<(u64, T)>,
398 direct_producers: VecDeque<MergeHubDirectProducer<T>>,
399 queued_per_producer: BTreeMap<u64, usize>,
400 producer_completions: Vec<StreamCompletion<NotUsed>>,
401 active_producers: usize,
402 next_producer_id: u64,
403 source_closed: bool,
404 completed: bool,
405 failed: Option<StreamError>,
406 per_producer_buffer_size: usize,
407}
408
409struct MergeHubDirectProducer<T> {
410 id: u64,
411 input: BoxStream<T>,
412}
413
414struct MergeHubTerminalHook<T> {
415 state: Arc<MergeHubShared<T>>,
416 prefer_direct: AtomicBool,
417}
418
419enum MergeHubProducerTerminal {
420 Active,
421 Completed,
422 Failed(StreamError),
423}
424
425impl<T: Send + 'static> TerminalSourceHookDyn<T> for MergeHubTerminalHook<T> {
426 fn drain_terminal_batch(
427 &self,
428 materializer: &Materializer,
429 cancelled: &Arc<AtomicBool>,
430 batch: &mut Vec<T>,
431 ) -> StreamResult<TerminalSourceStatus> {
432 self.state
433 .drain_terminal_batch(materializer, cancelled, &self.prefer_direct, batch)
434 }
435
436 fn cancel_terminal(&self) {
437 self.state.close_source();
438 }
439}
440
441impl<T> MergeHubShared<T> {
442 fn new(per_producer_buffer_size: usize) -> Self {
443 Self {
444 state: Arc::new(MergeHubState {
445 inner: Mutex::new(MergeHubFlags::default()),
446 condvar: Condvar::new(),
447 }),
448 shared: Mutex::new(MergeHubInner {
449 queue: VecDeque::new(),
450 direct_producers: VecDeque::new(),
451 queued_per_producer: BTreeMap::new(),
452 producer_completions: Vec::new(),
453 active_producers: 0,
454 next_producer_id: 0,
455 source_closed: false,
456 completed: false,
457 failed: None,
458 per_producer_buffer_size,
459 }),
460 condvar: Condvar::new(),
461 failed: AtomicBool::new(false),
462 }
463 }
464
465 fn register_direct_producer(&self, input: BoxStream<T>) -> StreamResult<()> {
466 let mut inner = self.shared.lock().expect("merge hub poisoned");
467 prune_finished_producer_completions(&mut inner.producer_completions);
468 let flags = self.state.lock();
469 if flags.draining || inner.source_closed || inner.completed {
470 return Err(StreamError::Failed(
471 "merge hub is draining or closed to new producers".to_owned(),
472 ));
473 }
474 if let Some(error) = inner.failed.clone() {
475 return Err(error);
476 }
477 let id = inner.next_producer_id;
478 inner.next_producer_id += 1;
479 inner.active_producers += 1;
480 inner
481 .direct_producers
482 .push_back(MergeHubDirectProducer { id, input });
483 drop(flags);
484 drop(inner);
485 self.condvar.notify_all();
486 Ok(())
487 }
488
489 fn register_producer(&self) -> StreamResult<u64> {
490 let mut inner = self.shared.lock().expect("merge hub poisoned");
491 prune_finished_producer_completions(&mut inner.producer_completions);
492 let flags = self.state.lock();
493 if flags.draining || inner.source_closed || inner.completed {
494 return Err(StreamError::Failed(
495 "merge hub is draining or closed to new producers".to_owned(),
496 ));
497 }
498 if let Some(error) = inner.failed.clone() {
499 return Err(error);
500 }
501 let id = inner.next_producer_id;
502 inner.next_producer_id += 1;
503 inner.active_producers += 1;
504 inner.queued_per_producer.insert(id, 0);
505 Ok(id)
506 }
507
508 fn store_producer_completion(&self, completion: StreamCompletion<NotUsed>) {
509 let mut inner = self.shared.lock().expect("merge hub poisoned");
510 prune_finished_producer_completions(&mut inner.producer_completions);
511 inner.producer_completions.push(completion);
512 }
513
514 fn push_item(&self, producer_id: u64, item: T) -> StreamResult<()> {
515 let mut inner = self.shared.lock().expect("merge hub poisoned");
516 loop {
517 if let Some(error) = inner.failed.clone() {
518 inner.queued_per_producer.remove(&producer_id);
519 return Err(error);
520 }
521 if inner.source_closed {
522 inner.queued_per_producer.remove(&producer_id);
523 return Err(StreamError::Cancelled);
524 }
525 let queued = inner
526 .queued_per_producer
527 .get(&producer_id)
528 .copied()
529 .unwrap_or(0);
530 if queued < inner.per_producer_buffer_size {
531 inner.queue.push_back((producer_id, item));
532 inner.queued_per_producer.insert(producer_id, queued + 1);
533 self.condvar.notify_all();
534 return Ok(());
535 }
536 inner = self
537 .condvar
538 .wait(inner)
539 .expect("merge hub poisoned while waiting");
540 }
541 }
542
543 fn deregister_producer(&self, producer_id: u64) {
544 let mut inner = self.shared.lock().expect("merge hub poisoned");
545 prune_finished_producer_completions(&mut inner.producer_completions);
546 inner.queued_per_producer.remove(&producer_id);
547 inner
548 .direct_producers
549 .retain(|producer| producer.id != producer_id);
550 inner.active_producers = inner.active_producers.saturating_sub(1);
551 if inner.active_producers == 0 {
552 let flags = self.state.lock();
553 if flags.draining {
554 inner.completed = true;
555 }
556 }
557 self.condvar.notify_all();
558 }
559
560 fn pop_direct_producer(
561 &self,
562 inner: &mut MergeHubInner<T>,
563 ) -> Option<MergeHubDirectProducer<T>> {
564 inner.direct_producers.pop_front()
565 }
566
567 fn restore_direct_producer(
568 &self,
569 producer: MergeHubDirectProducer<T>,
570 terminal: MergeHubProducerTerminal,
571 ) -> StreamResult<()> {
572 match terminal {
573 MergeHubProducerTerminal::Active => {
574 let mut inner = self.shared.lock().expect("merge hub poisoned");
575 if let Some(error) = inner.failed.clone() {
576 inner.active_producers = inner.active_producers.saturating_sub(1);
577 return Err(error);
578 }
579 if inner.source_closed {
580 inner.active_producers = inner.active_producers.saturating_sub(1);
581 return Err(StreamError::Cancelled);
582 }
583 inner.direct_producers.push_back(producer);
584 Ok(())
585 }
586 MergeHubProducerTerminal::Completed => {
587 self.deregister_producer(producer.id);
588 Ok(())
589 }
590 MergeHubProducerTerminal::Failed(error) => {
591 self.fail(error.clone());
592 self.deregister_producer(producer.id);
593 Err(error)
594 }
595 }
596 }
597
598 fn fail(&self, error: StreamError) {
599 let mut inner = self.shared.lock().expect("merge hub poisoned");
600 if inner.failed.is_none() {
601 inner.failed = Some(error);
602 self.failed.store(true, Ordering::SeqCst);
603 }
604 self.condvar.notify_all();
605 }
606
607 fn failed_error(&self) -> Option<StreamError> {
608 if !self.failed.load(Ordering::SeqCst) {
609 return None;
610 }
611 self.shared
612 .lock()
613 .expect("merge hub poisoned")
614 .failed
615 .clone()
616 }
617
618 fn finish_if_draining(&self) {
619 let flags = self.state.lock();
620 if !flags.draining {
621 return;
622 }
623 drop(flags);
624
625 let mut inner = self.shared.lock().expect("merge hub poisoned");
626 prune_finished_producer_completions(&mut inner.producer_completions);
627 if inner.active_producers == 0 {
628 inner.completed = true;
629 self.condvar.notify_all();
630 }
631 }
632
633 fn close_source(&self) {
634 let mut inner = self.shared.lock().expect("merge hub poisoned");
635 inner.source_closed = true;
636 self.condvar.notify_all();
637 }
638
639 fn terminal_status(
640 &self,
641 materializer: &Materializer,
642 cancelled: &Arc<AtomicBool>,
643 ) -> StreamResult<()> {
644 if materializer.is_shutdown() {
645 self.close_source();
646 Err(StreamError::AbruptTermination)
647 } else if cancelled.load(Ordering::SeqCst) {
648 self.close_source();
649 Err(StreamError::Cancelled)
650 } else {
651 Ok(())
652 }
653 }
654
655 fn drain_terminal_batch(
656 &self,
657 materializer: &Materializer,
658 cancelled: &Arc<AtomicBool>,
659 prefer_direct: &AtomicBool,
660 batch: &mut Vec<T>,
661 ) -> StreamResult<TerminalSourceStatus> {
662 batch.clear();
663 loop {
664 self.terminal_status(materializer, cancelled)?;
665 let mut inner = self.shared.lock().expect("merge hub poisoned");
666 loop {
667 if let Some(error) = inner.failed.clone() {
668 inner.source_closed = true;
669 return Err(error);
670 }
671
672 let should_drain_direct = !inner.direct_producers.is_empty()
673 && (prefer_direct.load(Ordering::Relaxed) || inner.queue.is_empty());
674 if should_drain_direct {
675 let Some(mut producer) = self.pop_direct_producer(&mut inner) else {
676 continue;
677 };
678 let drain_limit = inner
679 .per_producer_buffer_size
680 .clamp(1, MERGE_HUB_BATCH_LIMIT);
681 drop(inner);
682
683 batch.reserve(drain_limit.saturating_sub(batch.capacity()));
684 let mut terminal = MergeHubProducerTerminal::Active;
685 for _ in 0..drain_limit {
686 match producer.input.next() {
687 Some(Ok(item)) => batch.push(item),
688 Some(Err(error)) => {
689 terminal = MergeHubProducerTerminal::Failed(error);
690 break;
691 }
692 None => {
693 terminal = MergeHubProducerTerminal::Completed;
694 break;
695 }
696 }
697 }
698
699 let restore_result = self.restore_direct_producer(producer, terminal);
700 match restore_result {
701 Ok(()) => {
702 prefer_direct.store(false, Ordering::Relaxed);
703 }
704 Err(error) => {
705 self.close_source();
706 return Err(error);
707 }
708 }
709
710 let completed = self.terminal_completed_after_batch();
711 if !batch.is_empty() {
712 return Ok(if completed {
713 TerminalSourceStatus::Completed
714 } else {
715 TerminalSourceStatus::Active
716 });
717 }
718 if completed {
719 return Ok(TerminalSourceStatus::Completed);
720 }
721 break;
722 }
723
724 if !inner.queue.is_empty() {
725 let drain_n = inner.queue.len().min(MERGE_HUB_BATCH_LIMIT);
726 batch.reserve(drain_n.saturating_sub(batch.capacity()));
727 for _ in 0..drain_n {
728 if let Some((producer_id, item)) = inner.queue.pop_front() {
729 if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
730 *queued = queued.saturating_sub(1);
731 }
732 batch.push(item);
733 }
734 }
735 let completed = inner.completed
736 && inner.queue.is_empty()
737 && inner.direct_producers.is_empty();
738 self.condvar.notify_all();
739 drop(inner);
740
741 prefer_direct.store(true, Ordering::Relaxed);
742 return Ok(if completed {
743 TerminalSourceStatus::Completed
744 } else {
745 TerminalSourceStatus::Active
746 });
747 }
748
749 if inner.completed {
750 inner.source_closed = true;
751 return Ok(TerminalSourceStatus::Completed);
752 }
753
754 let (next_inner, _) = self
755 .condvar
756 .wait_timeout(inner, fan_out_wait_timeout())
757 .expect("merge hub poisoned while waiting");
758 inner = next_inner;
759 if materializer.is_shutdown() {
760 inner.source_closed = true;
761 self.condvar.notify_all();
762 return Err(StreamError::AbruptTermination);
763 }
764 if cancelled.load(Ordering::SeqCst) {
765 inner.source_closed = true;
766 self.condvar.notify_all();
767 return Err(StreamError::Cancelled);
768 }
769 }
770 }
771 }
772
773 fn terminal_completed_after_batch(&self) -> bool {
774 let inner = self.shared.lock().expect("merge hub poisoned");
775 inner.completed && inner.queue.is_empty() && inner.direct_producers.is_empty()
776 }
777}
778
779fn prune_finished_producer_completions(completions: &mut Vec<StreamCompletion<NotUsed>>) {
780 let mut index = 0;
781 while index < completions.len() {
782 if completions[index].try_wait().is_some() {
783 drop(completions.swap_remove(index));
784 } else {
785 index += 1;
786 }
787 }
788}
789
790struct MergeHubSourceStream<T> {
791 state: Arc<MergeHubShared<T>>,
792 local: VecDeque<T>,
793 prefer_direct: bool,
794}
795
796impl<T> Iterator for MergeHubSourceStream<T> {
797 type Item = StreamResult<T>;
798
799 fn next(&mut self) -> Option<Self::Item> {
800 if let Some(error) = self.state.failed_error() {
801 self.local.clear();
802 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
803 inner.source_closed = true;
804 return Some(Err(error));
805 }
806 if let Some(item) = self.local.pop_front() {
807 return Some(Ok(item));
808 }
809
810 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
811 loop {
812 if let Some(error) = inner.failed.clone() {
813 inner.source_closed = true;
814 return Some(Err(error));
815 }
816 let should_drain_direct = !inner.direct_producers.is_empty()
817 && (self.prefer_direct || inner.queue.is_empty());
818 if should_drain_direct {
819 let Some(mut producer) = self.state.pop_direct_producer(&mut inner) else {
820 continue;
821 };
822 let drain_limit = inner
823 .per_producer_buffer_size
824 .clamp(1, MERGE_HUB_BATCH_LIMIT);
825 drop(inner);
826
827 let mut batch = std::mem::take(&mut self.local);
828 batch.clear();
829 batch.reserve(drain_limit.saturating_sub(batch.capacity()));
830 let mut terminal = MergeHubProducerTerminal::Active;
831 for _ in 0..drain_limit {
832 match producer.input.next() {
833 Some(Ok(item)) => batch.push_back(item),
834 Some(Err(error)) => {
835 terminal = MergeHubProducerTerminal::Failed(error);
836 break;
837 }
838 None => {
839 terminal = MergeHubProducerTerminal::Completed;
840 break;
841 }
842 }
843 }
844
845 let restore_result = self.state.restore_direct_producer(producer, terminal);
846 match restore_result {
847 Ok(()) => {
848 self.prefer_direct = false;
849 if let Some(first) = batch.pop_front() {
850 self.local = batch;
851 return Some(Ok(first));
852 }
853 inner = self.state.shared.lock().expect("merge hub poisoned");
854 continue;
855 }
856 Err(error) => {
857 self.local.clear();
858 return Some(Err(error));
859 }
860 }
861 }
862 if !inner.queue.is_empty() {
863 let drain_n = inner.queue.len().min(MERGE_HUB_BATCH_LIMIT);
864 let mut batch = std::mem::take(&mut self.local);
865 batch.clear();
866 batch.reserve(drain_n.saturating_sub(batch.capacity()));
867 for _ in 0..drain_n {
868 if let Some((producer_id, item)) = inner.queue.pop_front() {
869 if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
870 *queued = queued.saturating_sub(1);
871 }
872 batch.push_back(item);
873 }
874 }
875 self.state.condvar.notify_all();
876 drop(inner);
877 let first = batch
878 .pop_front()
879 .expect("merge hub drained non-empty batch");
880 self.local = batch;
881 self.prefer_direct = true;
882 return Some(Ok(first));
883 }
884 if inner.completed {
885 inner.source_closed = true;
886 return None;
887 }
888 inner = self
889 .state
890 .condvar
891 .wait(inner)
892 .expect("merge hub poisoned while waiting");
893 }
894 }
895}
896
897impl<T> Drop for MergeHubSourceStream<T> {
898 fn drop(&mut self) {
899 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
900 inner.source_closed = true;
901 self.state.condvar.notify_all();
902 }
903}
904
905#[derive(Clone, Copy)]
906enum FanOutMode {
907 Broadcast,
908 Partition,
909}
910
911struct FanOutHubShared<T> {
912 registry: Mutex<FanOutRegistry<T>>,
913 snapshot: ArcSwap<FanOutSnapshot<T>>,
914 producer_wait: Mutex<()>,
915 producer_condvar: Condvar,
916 producer_epoch: AtomicU64,
917 topology_epoch: AtomicU64,
918 mode: FanOutMode,
919 start_after_nr_of_consumers: usize,
920 buffer_size: usize,
921 partitioner: Option<Partitioner<T>>,
922}
923
924struct FanOutRegistry<T> {
925 consumers: BTreeMap<u64, Arc<FanOutConsumerLane<T>>>,
926 next_consumer_id: u64,
927 terminal: Option<FanOutTerminal>,
928}
929
930struct FanOutSnapshot<T> {
931 consumers: Vec<Arc<FanOutConsumerLane<T>>>,
932 terminal: Option<FanOutTerminal>,
933}
934
935#[derive(Clone, Debug)]
936enum FanOutTerminal {
937 Completed,
938 Failed(StreamError),
939}
940
941impl FanOutTerminal {
942 fn producer_error(&self) -> StreamError {
943 match self {
944 Self::Completed => StreamError::Cancelled,
945 Self::Failed(error) => error.clone(),
946 }
947 }
948}
949
950struct FanOutConsumerLane<T> {
951 id: u64,
952 state: Mutex<FanOutLaneState<T>>,
953 condvar: Condvar,
954 queued: AtomicUsize,
955 active: AtomicBool,
956 failed: AtomicBool,
957}
958
959struct FanOutLaneState<T> {
960 chunks: VecDeque<Vec<T>>,
961 queued: usize,
962 terminal: Option<FanOutTerminal>,
963}
964
965impl<T> FanOutConsumerLane<T> {
966 fn new(id: u64, buffer_size: usize) -> Self {
967 Self {
968 id,
969 state: Mutex::new(FanOutLaneState {
970 chunks: VecDeque::with_capacity((buffer_size / FAN_OUT_CONSUMER_BATCH_LIMIT) + 1),
971 queued: 0,
972 terminal: None,
973 }),
974 condvar: Condvar::new(),
975 queued: AtomicUsize::new(0),
976 active: AtomicBool::new(true),
977 failed: AtomicBool::new(false),
978 }
979 }
980
981 fn terminal(id: u64, terminal: FanOutTerminal) -> Self {
982 let failed = matches!(terminal, FanOutTerminal::Failed(_));
983 Self {
984 id,
985 state: Mutex::new(FanOutLaneState {
986 chunks: VecDeque::new(),
987 queued: 0,
988 terminal: Some(terminal),
989 }),
990 condvar: Condvar::new(),
991 queued: AtomicUsize::new(0),
992 active: AtomicBool::new(false),
993 failed: AtomicBool::new(failed),
994 }
995 }
996
997 fn id(&self) -> u64 {
998 self.id
999 }
1000
1001 fn queued_len(&self) -> usize {
1002 self.queued.load(Ordering::Acquire)
1003 }
1004
1005 fn is_active(&self) -> bool {
1006 self.active.load(Ordering::Acquire)
1007 }
1008
1009 fn deactivate(&self) {
1010 self.active.store(false, Ordering::Release);
1011 self.condvar.notify_all();
1012 }
1013
1014 fn set_terminal(&self, terminal: FanOutTerminal) {
1015 let mut state = self.state.lock().expect("fan-out lane poisoned");
1016 if matches!(terminal, FanOutTerminal::Failed(_)) {
1017 state.chunks.clear();
1018 state.queued = 0;
1019 self.queued.store(0, Ordering::Release);
1020 self.failed.store(true, Ordering::Release);
1021 }
1022 state.terminal = Some(terminal);
1023 drop(state);
1024 self.condvar.notify_all();
1025 }
1026}
1027
1028impl PartitionConsumerInfo {
1029 fn from_cached_lanes<T>(consumer_ids: &[u64], lanes: &[Arc<FanOutConsumerLane<T>>]) -> Self {
1030 Self {
1031 consumer_ids: consumer_ids.iter().copied().collect(),
1032 queue_sizes: lanes
1033 .iter()
1034 .map(|lane| (lane.id(), lane.queued_len()))
1035 .collect(),
1036 }
1037 }
1038}
1039
1040struct PartitionTopologyCache<T> {
1041 epoch: u64,
1042 consumer_ids: SmallVec<[u64; 16]>,
1043 lanes: SmallVec<[Arc<FanOutConsumerLane<T>>; 16]>,
1044}
1045
1046type PartitionRoutedBatches<T> = SmallVec<[(Arc<FanOutConsumerLane<T>>, VecDeque<T>); 16]>;
1047
1048impl<T> PartitionTopologyCache<T> {
1049 fn new() -> Self {
1050 Self {
1051 epoch: u64::MAX,
1052 consumer_ids: SmallVec::new(),
1053 lanes: SmallVec::new(),
1054 }
1055 }
1056
1057 fn clear(&mut self) {
1058 self.epoch = u64::MAX;
1059 self.consumer_ids.clear();
1060 self.lanes.clear();
1061 }
1062}
1063
1064impl<T> FanOutHubShared<T> {
1065 fn new(
1066 mode: FanOutMode,
1067 start_after_nr_of_consumers: usize,
1068 buffer_size: usize,
1069 partitioner: Option<Partitioner<T>>,
1070 ) -> Self {
1071 Self {
1072 registry: Mutex::new(FanOutRegistry {
1073 consumers: BTreeMap::new(),
1074 next_consumer_id: 0,
1075 terminal: None,
1076 }),
1077 snapshot: ArcSwap::from_pointee(FanOutSnapshot {
1078 consumers: Vec::new(),
1079 terminal: None,
1080 }),
1081 producer_wait: Mutex::new(()),
1082 producer_condvar: Condvar::new(),
1083 producer_epoch: AtomicU64::new(0),
1084 topology_epoch: AtomicU64::new(0),
1085 mode,
1086 start_after_nr_of_consumers,
1087 buffer_size,
1088 partitioner,
1089 }
1090 }
1091
1092 fn register_consumer(&self) -> Arc<FanOutConsumerLane<T>> {
1093 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1094 let id = registry.next_consumer_id;
1095 registry.next_consumer_id += 1;
1096
1097 if let Some(terminal) = registry.terminal.clone() {
1098 return Arc::new(FanOutConsumerLane::terminal(id, terminal));
1099 }
1100
1101 let lane = Arc::new(FanOutConsumerLane::new(id, self.buffer_size));
1102 registry.consumers.insert(id, Arc::clone(&lane));
1103 self.publish_snapshot_locked(®istry);
1104 drop(registry);
1105 self.notify_topology_transition();
1106 lane
1107 }
1108
1109 fn remove_consumer(&self, consumer_id: u64) {
1110 let lane = {
1111 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1112 let lane = registry.consumers.remove(&consumer_id);
1113 if let Some(lane) = &lane {
1114 lane.deactivate();
1115 self.publish_snapshot_locked(®istry);
1116 }
1117 lane
1118 };
1119
1120 if let Some(lane) = lane {
1121 lane.condvar.notify_all();
1122 self.notify_topology_transition();
1123 }
1124 }
1125
1126 fn wait_for_broadcast_capacity(&self, max_items: usize) -> StreamResult<usize> {
1127 loop {
1128 let observed = self.producer_epoch.load(Ordering::Acquire);
1129 let snapshot = self.snapshot.load();
1130 if let Some(terminal) = &snapshot.terminal {
1131 return Err(terminal.producer_error());
1132 }
1133
1134 let lanes = self.active_lanes(&snapshot);
1135 if lanes.len() < self.start_after_nr_of_consumers || lanes.is_empty() {
1136 self.wait_for_producer_transition(observed);
1137 continue;
1138 }
1139
1140 let free = lanes
1141 .iter()
1142 .map(|lane| self.buffer_size.saturating_sub(lane.queued_len()))
1143 .min()
1144 .unwrap_or(0);
1145 if free > 0 {
1146 let max_items = if lanes.len() == 1 {
1147 max_items.min(BROADCAST_HUB_SINGLE_CONSUMER_BATCH_LIMIT)
1148 } else {
1149 max_items
1150 };
1151 return Ok(free.min(max_items).max(1));
1152 }
1153 self.wait_for_producer_transition(observed);
1154 }
1155 }
1156
1157 fn push_broadcast_item(&self, item: T) -> StreamResult<()>
1158 where
1159 T: Clone,
1160 {
1161 let mut batch = VecDeque::new();
1162 batch.push_back(item);
1163 self.push_broadcast_batch(&mut batch)
1164 }
1165
1166 fn push_broadcast_batch(&self, batch: &mut VecDeque<T>) -> StreamResult<()>
1167 where
1168 T: Clone,
1169 {
1170 if batch.is_empty() {
1171 return Ok(());
1172 }
1173
1174 while !batch.is_empty() {
1175 let observed = self.producer_epoch.load(Ordering::Acquire);
1176 let snapshot = self.snapshot.load();
1177 if let Some(terminal) = &snapshot.terminal {
1178 batch.clear();
1179 return Err(terminal.producer_error());
1180 }
1181
1182 let lanes = self.active_lanes(&snapshot);
1183 if lanes.len() < self.start_after_nr_of_consumers || lanes.is_empty() {
1184 self.wait_for_producer_transition(observed);
1185 continue;
1186 }
1187
1188 let free = lanes
1189 .iter()
1190 .map(|lane| self.buffer_size.saturating_sub(lane.queued_len()))
1191 .min()
1192 .unwrap_or(0);
1193 if free == 0 {
1194 self.wait_for_producer_transition(observed);
1195 continue;
1196 }
1197
1198 let take_n = free.min(batch.len());
1199 if !self.try_push_broadcast_batch(&lanes, batch, take_n)? {
1200 self.wait_for_producer_transition(observed);
1201 }
1202 }
1203 Ok(())
1204 }
1205
1206 fn try_push_broadcast_batch(
1207 &self,
1208 lanes: &[Arc<FanOutConsumerLane<T>>],
1209 batch: &mut VecDeque<T>,
1210 take_n: usize,
1211 ) -> StreamResult<bool>
1212 where
1213 T: Clone,
1214 {
1215 if take_n == 0 || lanes.is_empty() {
1216 return Ok(false);
1217 }
1218
1219 let mut guards = SmallVec::<
1220 [(
1221 Arc<FanOutConsumerLane<T>>,
1222 MutexGuard<'_, FanOutLaneState<T>>,
1223 ); 16],
1224 >::new();
1225 for lane in lanes {
1226 if !lane.is_active() {
1227 return Ok(false);
1228 }
1229 let guard = lane.state.lock().expect("fan-out lane poisoned");
1230 if !lane.is_active() {
1231 return Ok(false);
1232 }
1233 if let Some(terminal) = &guard.terminal {
1234 return Err(terminal.producer_error());
1235 }
1236 if self.buffer_size.saturating_sub(guard.queued) < take_n {
1237 return Ok(false);
1238 }
1239 guards.push((Arc::clone(lane), guard));
1240 }
1241
1242 let lane_count = guards.len();
1243 if lane_count == 0 {
1244 return Ok(false);
1245 }
1246
1247 let mut notify_lanes = SmallVec::<[Arc<FanOutConsumerLane<T>>; 16]>::new();
1248 for (index, (lane, guard)) in guards.iter_mut().enumerate() {
1249 let chunk = if index + 1 == lane_count {
1250 batch.drain(..take_n).collect()
1251 } else {
1252 batch.iter().take(take_n).cloned().collect()
1253 };
1254 guard.chunks.push_back(chunk);
1255 guard.queued += take_n;
1256 lane.queued.store(guard.queued, Ordering::Release);
1257 notify_lanes.push(Arc::clone(lane));
1258 }
1259 drop(guards);
1260
1261 for lane in notify_lanes {
1262 lane.condvar.notify_one();
1263 }
1264 Ok(true)
1265 }
1266
1267 fn select_partition(
1268 &self,
1269 item: &T,
1270 cache: &mut PartitionTopologyCache<T>,
1271 ) -> StreamResult<Option<Arc<FanOutConsumerLane<T>>>> {
1272 let Some(partitioner) = &self.partitioner else {
1273 return Err(StreamError::Failed(
1274 "partition hub partitioner missing".to_owned(),
1275 ));
1276 };
1277 let topology_epoch = self.topology_epoch.load(Ordering::Acquire);
1278 if cache.epoch != topology_epoch || cache.lanes.is_empty() {
1279 self.refresh_partition_topology(cache)?;
1280 }
1281
1282 let info = PartitionConsumerInfo::from_cached_lanes(&cache.consumer_ids, &cache.lanes);
1283 let selected = partitioner(&info, item);
1284 if selected < 0 {
1285 return Ok(None);
1286 }
1287 let selected = selected as u64;
1288 let selected_idx = selected as usize;
1289 if cache.consumer_ids.get(selected_idx).copied() == Some(selected) {
1290 return Ok(Some(Arc::clone(&cache.lanes[selected_idx])));
1291 }
1292 cache
1293 .consumer_ids
1294 .iter()
1295 .position(|id| *id == selected)
1296 .map(|idx| Some(Arc::clone(&cache.lanes[idx])))
1297 .ok_or_else(|| {
1298 StreamError::Failed("partition hub selected unknown consumer".to_owned())
1299 })
1300 }
1301
1302 fn refresh_partition_topology(
1303 &self,
1304 cache: &mut PartitionTopologyCache<T>,
1305 ) -> StreamResult<()> {
1306 loop {
1307 let observed = self.producer_epoch.load(Ordering::Acquire);
1308 let topology_epoch = self.topology_epoch.load(Ordering::Acquire);
1309 let snapshot = self.snapshot.load();
1310 if let Some(terminal) = &snapshot.terminal {
1311 cache.clear();
1312 return Err(terminal.producer_error());
1313 }
1314
1315 let lanes = self.active_lanes(&snapshot);
1316 if lanes.len() >= self.start_after_nr_of_consumers && !lanes.is_empty() {
1317 cache.epoch = topology_epoch;
1318 cache.consumer_ids.clear();
1319 cache
1320 .consumer_ids
1321 .extend(lanes.iter().map(|lane| lane.id()));
1322 cache.lanes = lanes;
1323 return Ok(());
1324 }
1325 self.wait_for_producer_transition(observed);
1326 }
1327 }
1328
1329 fn enqueue_partition(&self, selected: Arc<FanOutConsumerLane<T>>, item: T) -> StreamResult<()> {
1330 let mut batch = VecDeque::new();
1331 batch.push_back(item);
1332 self.enqueue_partition_batch(selected, &mut batch)
1333 }
1334
1335 fn enqueue_partition_batch(
1336 &self,
1337 selected: Arc<FanOutConsumerLane<T>>,
1338 batch: &mut VecDeque<T>,
1339 ) -> StreamResult<()> {
1340 if batch.is_empty() {
1341 return Ok(());
1342 }
1343
1344 let mut state = selected.state.lock().expect("fan-out lane poisoned");
1345 loop {
1346 if !selected.is_active() {
1347 batch.clear();
1348 return Err(StreamError::Failed(
1349 "partition hub selected unknown consumer".to_owned(),
1350 ));
1351 }
1352 if let Some(terminal) = &state.terminal {
1353 batch.clear();
1354 return Err(terminal.producer_error());
1355 }
1356 let free = self.buffer_size.saturating_sub(state.queued);
1357 if free > 0 {
1358 let take_n = free.min(batch.len());
1359 let chunk = batch.drain(..take_n).collect();
1360 state.chunks.push_back(chunk);
1361 state.queued += take_n;
1362 selected.queued.store(state.queued, Ordering::Release);
1363 selected.condvar.notify_one();
1364 if batch.is_empty() {
1365 return Ok(());
1366 }
1367 }
1368 let (guard, _) = selected
1369 .condvar
1370 .wait_timeout(state, fan_out_wait_timeout())
1371 .expect("fan-out lane poisoned while waiting");
1372 state = guard;
1373 }
1374 }
1375
1376 fn complete(&self) {
1377 let lanes = {
1378 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1379 if registry.terminal.is_none() {
1380 registry.terminal = Some(FanOutTerminal::Completed);
1381 self.publish_snapshot_locked(®istry);
1382 }
1383 registry.consumers.values().cloned().collect::<Vec<_>>()
1384 };
1385
1386 for lane in lanes {
1387 lane.set_terminal(FanOutTerminal::Completed);
1388 }
1389 self.notify_topology_transition();
1390 }
1391
1392 fn fail(&self, error: StreamError) {
1393 let terminal = FanOutTerminal::Failed(error);
1394 let lanes = {
1395 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1396 if registry.terminal.is_none() {
1397 registry.terminal = Some(terminal.clone());
1398 self.publish_snapshot_locked(®istry);
1399 }
1400 registry.consumers.values().cloned().collect::<Vec<_>>()
1401 };
1402
1403 for lane in lanes {
1404 lane.set_terminal(terminal.clone());
1405 }
1406 self.notify_topology_transition();
1407 }
1408
1409 fn publish_snapshot_locked(&self, registry: &FanOutRegistry<T>) {
1410 self.snapshot.store(Arc::new(FanOutSnapshot {
1411 consumers: registry.consumers.values().cloned().collect(),
1412 terminal: registry.terminal.clone(),
1413 }));
1414 }
1415
1416 fn active_lanes(
1417 &self,
1418 snapshot: &FanOutSnapshot<T>,
1419 ) -> SmallVec<[Arc<FanOutConsumerLane<T>>; 16]> {
1420 snapshot
1421 .consumers
1422 .iter()
1423 .filter(|lane| lane.is_active())
1424 .cloned()
1425 .collect()
1426 }
1427
1428 fn partition_batch_limit(&self) -> usize {
1429 let snapshot = self.snapshot.load();
1430 let active = snapshot
1431 .consumers
1432 .iter()
1433 .filter(|lane| lane.is_active())
1434 .count();
1435 match active {
1436 0 | 1 => PARTITION_HUB_SINGLE_CONSUMER_BATCH_LIMIT,
1437 2..=7 => PARTITION_HUB_BATCH_LIMIT,
1438 _ => PARTITION_HUB_WIDE_BATCH_LIMIT,
1439 }
1440 }
1441
1442 fn notify_producer_transition(&self) {
1443 self.producer_epoch.fetch_add(1, Ordering::Release);
1444 self.producer_condvar.notify_one();
1445 }
1446
1447 fn notify_topology_transition(&self) {
1448 self.topology_epoch.fetch_add(1, Ordering::Release);
1449 self.notify_producer_transition();
1450 }
1451
1452 fn wait_for_producer_transition(&self, observed_epoch: u64) {
1453 let guard = self
1454 .producer_wait
1455 .lock()
1456 .expect("fan-out producer wait poisoned");
1457 if self.producer_epoch.load(Ordering::Acquire) == observed_epoch {
1458 let (_guard, _) = self
1459 .producer_condvar
1460 .wait_timeout(guard, fan_out_wait_timeout())
1461 .expect("fan-out producer wait poisoned while waiting");
1462 }
1463 }
1464}
1465
1466fn push_partition_routed<T>(
1467 routed: &mut PartitionRoutedBatches<T>,
1468 lane: Arc<FanOutConsumerLane<T>>,
1469 item: T,
1470) {
1471 for (existing, batch) in routed.iter_mut() {
1472 if existing.id() == lane.id() {
1473 batch.push_back(item);
1474 return;
1475 }
1476 }
1477
1478 let mut batch = VecDeque::new();
1479 batch.push_back(item);
1480 routed.push((lane, batch));
1481}
1482
1483struct FanOutProducer<T> {
1484 input: BoxStream<T>,
1485 state: Arc<FanOutHubShared<T>>,
1486}
1487
1488impl<T> FanOutProducer<T> {
1489 fn new(input: BoxStream<T>, state: Arc<FanOutHubShared<T>>) -> Self {
1490 Self { input, state }
1491 }
1492}
1493
1494impl<T: Send + 'static + Clone> FanOutProducer<T> {
1495 fn run(
1496 mut self,
1497 cancelled: Arc<std::sync::atomic::AtomicBool>,
1498 hints: SourceRuntimeHints,
1499 ) -> StreamResult<NotUsed> {
1500 struct ProducerDropGuard<T> {
1501 state: Arc<FanOutHubShared<T>>,
1502 disarmed: bool,
1503 }
1504
1505 impl<T> ProducerDropGuard<T> {
1506 fn new(state: Arc<FanOutHubShared<T>>) -> Self {
1507 Self {
1508 state,
1509 disarmed: false,
1510 }
1511 }
1512
1513 fn disarm(&mut self) {
1514 self.disarmed = true;
1515 }
1516 }
1517
1518 impl<T> Drop for ProducerDropGuard<T> {
1519 fn drop(&mut self) {
1520 if !self.disarmed && std::thread::panicking() {
1521 self.state.fail(StreamError::Failed(
1522 "fan-out hub producer panicked".to_owned(),
1523 ));
1524 }
1525 }
1526 }
1527
1528 let mut guard = ProducerDropGuard::new(Arc::clone(&self.state));
1529 match self.state.mode {
1530 FanOutMode::Broadcast => {
1531 let batch_producer = hints.inline_micro_max_success_items.is_some();
1532 let mut batch = VecDeque::new();
1533 loop {
1534 if cancelled.load(Ordering::SeqCst) {
1535 self.state.fail(StreamError::Cancelled);
1536 guard.disarm();
1537 return Err(StreamError::Cancelled);
1538 }
1539
1540 if batch_producer {
1541 let capacity = self
1542 .state
1543 .wait_for_broadcast_capacity(BROADCAST_HUB_BATCH_LIMIT)?;
1544 batch.clear();
1545 batch.reserve(capacity.saturating_sub(batch.capacity()));
1546 let mut terminal = None;
1547 for _ in 0..capacity {
1548 if cancelled.load(Ordering::SeqCst) {
1549 terminal = Some(Err(StreamError::Cancelled));
1550 break;
1551 }
1552 match self.input.next() {
1553 Some(Ok(item)) => batch.push_back(item),
1554 Some(Err(error)) => {
1555 terminal = Some(Err(error));
1556 break;
1557 }
1558 None => {
1559 terminal = Some(Ok(()));
1560 break;
1561 }
1562 }
1563 }
1564 if !batch.is_empty() {
1565 self.state.push_broadcast_batch(&mut batch)?;
1566 }
1567 match terminal {
1568 Some(Err(StreamError::Cancelled)) => {
1569 self.state.fail(StreamError::Cancelled);
1570 guard.disarm();
1571 return Err(StreamError::Cancelled);
1572 }
1573 Some(Err(error)) => {
1574 self.state.fail(error.clone());
1575 guard.disarm();
1576 return Err(error);
1577 }
1578 Some(Ok(())) => {
1579 self.state.complete();
1580 guard.disarm();
1581 return Ok(NotUsed);
1582 }
1583 None => continue,
1584 }
1585 }
1586
1587 match self.input.next() {
1588 Some(Ok(item)) => self.state.push_broadcast_item(item)?,
1589 Some(Err(error)) => {
1590 self.state.fail(error.clone());
1591 guard.disarm();
1592 return Err(error);
1593 }
1594 None => {
1595 self.state.complete();
1596 guard.disarm();
1597 return Ok(NotUsed);
1598 }
1599 }
1600 }
1601 }
1602 FanOutMode::Partition => {
1603 let batch_producer = hints.inline_micro_max_success_items.is_some();
1604 let mut topology_cache = PartitionTopologyCache::new();
1605 let mut routed = PartitionRoutedBatches::new();
1606 loop {
1607 if cancelled.load(Ordering::SeqCst) {
1608 self.state.fail(StreamError::Cancelled);
1609 guard.disarm();
1610 return Err(StreamError::Cancelled);
1611 }
1612
1613 if batch_producer {
1614 for (_, batch) in routed.iter_mut() {
1615 batch.clear();
1616 }
1617 let partition_batch_limit = self.state.partition_batch_limit();
1618 let mut terminal = None;
1619 for _ in 0..partition_batch_limit {
1620 if cancelled.load(Ordering::SeqCst) {
1621 terminal = Some(Err(StreamError::Cancelled));
1622 break;
1623 }
1624 match self.input.next() {
1625 Some(Ok(item)) => {
1626 if let Some(selected) =
1627 self.state.select_partition(&item, &mut topology_cache)?
1628 {
1629 push_partition_routed(&mut routed, selected, item);
1630 }
1631 }
1632 Some(Err(error)) => {
1633 terminal = Some(Err(error));
1634 break;
1635 }
1636 None => {
1637 terminal = Some(Ok(()));
1638 break;
1639 }
1640 }
1641 }
1642 for (selected, batch) in routed.iter_mut() {
1643 if !batch.is_empty() {
1644 self.state
1645 .enqueue_partition_batch(Arc::clone(selected), batch)?;
1646 }
1647 }
1648 match terminal {
1649 Some(Err(StreamError::Cancelled)) => {
1650 self.state.fail(StreamError::Cancelled);
1651 guard.disarm();
1652 return Err(StreamError::Cancelled);
1653 }
1654 Some(Err(error)) => {
1655 self.state.fail(error.clone());
1656 guard.disarm();
1657 return Err(error);
1658 }
1659 Some(Ok(())) => {
1660 self.state.complete();
1661 guard.disarm();
1662 return Ok(NotUsed);
1663 }
1664 None => continue,
1665 }
1666 }
1667
1668 match self.input.next() {
1669 Some(Ok(item)) => {
1670 if let Some(selected) =
1671 self.state.select_partition(&item, &mut topology_cache)?
1672 {
1673 self.state.enqueue_partition(selected, item)?;
1674 }
1675 }
1676 Some(Err(error)) => {
1677 self.state.fail(error.clone());
1678 guard.disarm();
1679 return Err(error);
1680 }
1681 None => {
1682 self.state.complete();
1683 guard.disarm();
1684 return Ok(NotUsed);
1685 }
1686 }
1687 }
1688 }
1689 }
1690 }
1691}
1692
1693struct FanOutConsumerStream<T> {
1694 state: Arc<FanOutHubShared<T>>,
1695 lane: Arc<FanOutConsumerLane<T>>,
1696 local: Option<std::vec::IntoIter<T>>,
1697 detached: bool,
1698}
1699
1700impl<T: Clone + Send + 'static> Iterator for FanOutConsumerStream<T> {
1701 type Item = StreamResult<T>;
1702
1703 fn next(&mut self) -> Option<Self::Item> {
1704 if self.lane.failed.load(Ordering::Acquire) {
1705 self.local = None;
1706 } else if let Some(local) = &mut self.local {
1707 if let Some(item) = local.next() {
1708 return Some(Ok(item));
1709 }
1710 self.local = None;
1711 }
1712
1713 let mut state = self.lane.state.lock().expect("fan-out lane poisoned");
1714 loop {
1715 if let Some(FanOutTerminal::Failed(error)) = state.terminal.clone() {
1716 return Some(Err(error));
1717 }
1718 if let Some(batch) = state.chunks.pop_front() {
1719 state.queued = state.queued.saturating_sub(batch.len());
1720 self.lane.queued.store(state.queued, Ordering::Release);
1721 if matches!(self.state.mode, FanOutMode::Partition) {
1722 self.lane.condvar.notify_one();
1723 }
1724 self.state.notify_producer_transition();
1725 drop(state);
1726 let mut batch = batch.into_iter();
1727 let first = batch.next().expect("fan-out drained non-empty lane batch");
1728 if batch.len() > 0 {
1729 self.local = Some(batch);
1730 }
1731 return Some(Ok(first));
1732 }
1733 if matches!(state.terminal, Some(FanOutTerminal::Completed)) {
1734 return None;
1735 }
1736 let (guard, _) = self
1737 .lane
1738 .condvar
1739 .wait_timeout(state, fan_out_wait_timeout())
1740 .expect("fan-out lane poisoned while waiting");
1741 state = guard;
1742 }
1743 }
1744}
1745
1746impl<T> Drop for FanOutConsumerStream<T> {
1747 fn drop(&mut self) {
1748 if !self.detached {
1749 self.state.remove_consumer(self.lane.id());
1750 self.detached = true;
1751 }
1752 }
1753}
1754
1755#[cfg(test)]
1756mod tests {
1757 use super::*;
1758 use crate::testkit::{TestSink, TestSource};
1759 use crate::{Keep, Materializer, Sink, Source};
1760 use std::{
1761 panic::{self, AssertUnwindSafe},
1762 sync::{
1763 Arc,
1764 atomic::{AtomicUsize, Ordering},
1765 },
1766 thread,
1767 time::{Duration, Instant},
1768 };
1769
1770 #[test]
1771 fn merge_hub_accepts_dynamic_producers_and_drains() {
1772 let materializer = Materializer::new();
1773 let ((hub_sink, control), completion) = MergeHub::source_with_draining::<i32>(4)
1774 .to_mat(Sink::collect(), Keep::both)
1775 .run_with_materializer(&materializer)
1776 .expect("merge hub materializes");
1777
1778 hub_sink
1779 .clone()
1780 .run_with(Source::from_iter([1, 2, 3]))
1781 .expect("first producer attaches");
1782 hub_sink
1783 .run_with(Source::from_iter([4, 5]))
1784 .expect("second producer attaches");
1785 control.drain_and_complete();
1786 let mut result = completion.wait().expect("merge hub completes");
1787 result.sort_unstable();
1788 assert_eq!(result, vec![1, 2, 3, 4, 5]);
1789 }
1790
1791 #[test]
1792 fn merge_hub_direct_terminal_fold_counts_single_finite_producer() {
1793 let materializer = Materializer::new();
1794 let ((hub_sink, control), completion) = MergeHub::source_with_draining::<u64>(16)
1795 .to_mat(Sink::fold(0_u64, |acc, _| acc + 1), Keep::both)
1796 .run_with_materializer(&materializer)
1797 .expect("merge hub materializes");
1798
1799 hub_sink
1800 .run_with(Source::from_iter(0_u64..1024))
1801 .expect("producer attaches");
1802 control.drain_and_complete();
1803
1804 assert_eq!(completion.wait().expect("merge hub completes"), 1024);
1805 }
1806
1807 #[test]
1808 fn merge_hub_direct_terminal_fold_result_error_closes_source() {
1809 let materializer = Materializer::new();
1810 let ((hub_sink, _control), completion) = MergeHub::source_with_draining::<i32>(8)
1811 .to_mat(
1812 Sink::fold_result(0_i32, |acc, item| {
1813 if item == 3 {
1814 Err(StreamError::Failed("terminal failed".to_owned()))
1815 } else {
1816 Ok(acc + item)
1817 }
1818 }),
1819 Keep::both,
1820 )
1821 .run_with_materializer(&materializer)
1822 .expect("merge hub materializes");
1823
1824 hub_sink
1825 .clone()
1826 .run_with(Source::from_iter([1, 2, 3, 4]))
1827 .expect("producer attaches");
1828
1829 assert_eq!(
1830 completion.wait(),
1831 Err(StreamError::Failed("terminal failed".to_owned()))
1832 );
1833 assert!(hub_sink.run_with(Source::single(5)).is_err());
1834 }
1835
1836 #[test]
1837 fn merge_hub_direct_terminal_panic_closes_source() {
1838 let materializer = Materializer::new();
1839 let ((hub_sink, _control), completion) = MergeHub::source_with_draining::<i32>(8)
1840 .to_mat(
1841 Sink::fold(0_i32, |acc, item| {
1842 if item == 3 {
1843 panic!("terminal failed");
1844 }
1845 acc + item
1846 }),
1847 Keep::both,
1848 )
1849 .run_with_materializer(&materializer)
1850 .expect("merge hub materializes");
1851
1852 hub_sink
1853 .clone()
1854 .run_with(Source::from_iter([1, 2, 3, 4]))
1855 .expect("producer attaches");
1856
1857 assert_eq!(completion.wait(), Err(StreamError::AbruptTermination));
1858 assert!(hub_sink.run_with(Source::single(5)).is_err());
1859 }
1860
1861 #[test]
1862 fn merge_hub_producer_error_fails_downstream_consumer() {
1863 let materializer = Materializer::new();
1864 let (hub_sink, sink) = MergeHub::source::<i32>(4)
1865 .to_mat(TestSink::probe(), Keep::both)
1866 .run_with_materializer(&materializer)
1867 .expect("merge hub materializes");
1868
1869 let producer_ok = TestSource::probe::<i32>()
1870 .to_mat(hub_sink.clone(), Keep::left)
1871 .run_with_materializer(&materializer)
1872 .expect("successful producer attaches");
1873 let producer_fail = TestSource::probe::<i32>()
1874 .to_mat(hub_sink, Keep::left)
1875 .run_with_materializer(&materializer)
1876 .expect("failing producer attaches");
1877
1878 sink.request(1);
1879 assert_eq!(producer_ok.expect_request(), 1);
1880 producer_ok.send_next(1);
1881 sink.assert_next(1);
1882
1883 producer_fail.send_error(StreamError::Failed("producer failed".to_owned()));
1884 sink.request(1);
1885 assert_eq!(
1886 sink.expect_error(),
1887 StreamError::Failed("producer failed".to_owned())
1888 );
1889 }
1890
1891 #[test]
1892 fn broadcast_hub_backpressures_slowest_consumer() {
1893 let materializer = Materializer::new();
1894 let (publisher, hub_source) = TestSource::probe::<i32>()
1895 .to_mat(BroadcastHub::sink(1), Keep::both)
1896 .run_with_materializer(&materializer)
1897 .expect("broadcast hub materializes");
1898
1899 let sink_a = hub_source
1900 .source()
1901 .run_with(TestSink::probe())
1902 .expect("first consumer materializes");
1903 let sink_b = hub_source
1904 .source()
1905 .run_with(TestSink::probe())
1906 .expect("second consumer materializes");
1907
1908 sink_a.request(1);
1909 sink_b.request(1);
1910 assert_eq!(publisher.expect_request(), 1);
1911 publisher.send_next(1);
1912 sink_a.assert_next(1);
1913 sink_b.assert_next(1);
1914
1915 sink_a.request(1);
1916 assert_eq!(publisher.expect_request(), 1);
1917 publisher.send_next(2);
1918 sink_a.assert_next(2);
1919 sink_b.expect_no_message(Duration::from_millis(250));
1920
1921 sink_a.request(1);
1922 sink_a.expect_no_message(Duration::from_millis(250));
1923
1924 sink_b.request(1);
1925 sink_b.assert_next(2);
1926 assert_eq!(publisher.expect_request(), 1);
1927 }
1928
1929 #[test]
1930 fn broadcast_hub_late_consumer_sees_only_late_elements() {
1931 let materializer = Materializer::new();
1932 let (publisher, hub_source) = TestSource::probe::<i32>()
1933 .to_mat(BroadcastHub::sink(2), Keep::both)
1934 .run_with_materializer(&materializer)
1935 .expect("broadcast hub materializes");
1936
1937 let sink_a = hub_source
1938 .source()
1939 .run_with(TestSink::probe())
1940 .expect("first consumer materializes");
1941 sink_a.request(1);
1942 assert_eq!(publisher.expect_request(), 1);
1943 publisher.send_next(1);
1944 sink_a.assert_next(1);
1945
1946 let sink_b = hub_source
1947 .source()
1948 .run_with(TestSink::probe())
1949 .expect("late consumer materializes");
1950 sink_a.request(1);
1951 sink_b.request(1);
1952 assert_eq!(publisher.expect_request(), 1);
1953 publisher.send_next(2);
1954 sink_a.assert_next(2);
1955 sink_b.assert_next(2);
1956
1957 publisher.send_complete();
1958 sink_a.request(1);
1959 sink_b.request(1);
1960 sink_a.expect_complete();
1961 sink_b.expect_complete();
1962 }
1963
1964 #[test]
1965 fn partition_hub_routes_elements_to_selected_consumers() {
1966 let materializer = Materializer::new();
1967 let hub = Source::from_iter([0, 1, 2, 3])
1968 .run_with_materializer(
1969 PartitionHub::sink(
1970 |info, item| {
1971 let idx = (*item as usize) % info.size();
1972 info.consumer_id_by_idx(idx) as isize
1973 },
1974 2,
1975 8,
1976 ),
1977 &materializer,
1978 )
1979 .expect("partition hub materializes");
1980
1981 let sink_a = hub
1982 .source()
1983 .run_with(TestSink::probe())
1984 .expect("first consumer materializes");
1985 let sink_b = hub
1986 .source()
1987 .run_with(TestSink::probe())
1988 .expect("second consumer materializes");
1989
1990 sink_a.request(2);
1991 sink_b.request(2);
1992 sink_a.assert_next_n([0, 2]);
1993 sink_b.assert_next_n([1, 3]);
1994 }
1995
1996 #[test]
1997 fn partition_hub_evaluates_stateful_partitioner_once_per_blocked_element() {
1998 let materializer = Materializer::new();
1999 let partition_calls = Arc::new(AtomicUsize::new(0));
2000 let partition_calls_for_hub = Arc::clone(&partition_calls);
2001
2002 let (publisher, hub) = TestSource::probe::<i32>()
2003 .to_mat(
2004 PartitionHub::sink(
2005 move |info, _item| {
2006 partition_calls_for_hub.fetch_add(1, Ordering::SeqCst);
2007 info.consumer_id_by_idx(0) as isize
2008 },
2009 1,
2010 1,
2011 ),
2012 Keep::both,
2013 )
2014 .run_with_materializer(&materializer)
2015 .expect("partition hub materializes");
2016
2017 let sink = hub
2018 .source()
2019 .run_with(TestSink::probe())
2020 .expect("consumer materializes");
2021
2022 assert_eq!(publisher.expect_request(), 1);
2023 publisher.send_next(1);
2024 wait_for_partition_calls(&partition_calls, 1);
2025
2026 assert_eq!(publisher.expect_request(), 1);
2027 publisher.send_next(2);
2028 sink.expect_no_message(Duration::from_millis(250));
2029 wait_for_partition_calls(&partition_calls, 2);
2030
2031 sink.request(1);
2032 sink.assert_next(1);
2033 sink.request(1);
2034 sink.assert_next(2);
2035 assert_eq!(partition_calls.load(Ordering::SeqCst), 2);
2036 }
2037
2038 #[test]
2039 fn partition_hub_invalidates_cached_topology_on_consumer_churn() {
2040 let materializer = Materializer::new();
2041 let (publisher, hub) = TestSource::probe::<i32>()
2042 .to_mat(
2043 PartitionHub::sink(
2044 |info, _item| info.consumer_id_by_idx(info.size() - 1) as isize,
2045 1,
2046 8,
2047 ),
2048 Keep::both,
2049 )
2050 .run_with_materializer(&materializer)
2051 .expect("partition hub materializes");
2052
2053 let sink_a = hub
2054 .source()
2055 .run_with(TestSink::probe())
2056 .expect("first consumer materializes");
2057 sink_a.request(1);
2058 assert_eq!(publisher.expect_request(), 1);
2059 publisher.send_next(1);
2060 sink_a.assert_next(1);
2061 drop(sink_a);
2062
2063 let sink_b = hub
2064 .source()
2065 .run_with(TestSink::probe())
2066 .expect("replacement consumer materializes");
2067 sink_b.request(1);
2068 assert_eq!(publisher.expect_request(), 1);
2069 publisher.send_next(2);
2070 sink_b.assert_next(2);
2071
2072 publisher.send_complete();
2073 sink_b.request(1);
2074 sink_b.expect_complete();
2075 }
2076
2077 #[test]
2078 fn broadcast_hub_drains_local_chunk_before_completion() {
2079 let materializer = Materializer::new();
2080 let hub = Source::from_iter([1, 2, 3])
2081 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
2082 .expect("broadcast hub materializes");
2083
2084 let sink = hub
2085 .source()
2086 .run_with(TestSink::probe())
2087 .expect("consumer materializes");
2088
2089 sink.request(1);
2090 sink.assert_next(1);
2091 sink.request(3);
2092 sink.assert_next_n([2, 3]);
2093 sink.expect_complete();
2094 }
2095
2096 #[test]
2097 fn broadcast_hub_panicking_upstream_fails_consumers() {
2098 let materializer = Materializer::new();
2099 let hub = Source::from_fn_iter(|| {
2100 let mut yielded = false;
2101 std::iter::from_fn(move || {
2102 if !yielded {
2103 yielded = true;
2104 Some(1)
2105 } else {
2106 panic!("boom");
2107 }
2108 })
2109 })
2110 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
2111 .expect("broadcast hub materializes");
2112
2113 let sink = hub
2114 .source()
2115 .run_with(TestSink::probe())
2116 .expect("consumer materializes");
2117
2118 sink.request(2);
2119 match panic::catch_unwind(AssertUnwindSafe(|| sink.expect_error())) {
2120 Ok(error) => assert_eq!(
2121 error,
2122 StreamError::Failed("fan-out hub producer panicked".to_owned())
2123 ),
2124 Err(payload) => {
2125 assert_eq!(
2126 panic_message(payload),
2127 "expected stream error, got next element"
2128 );
2129 sink.request(1);
2130 assert_eq!(
2131 sink.expect_error(),
2132 StreamError::Failed("fan-out hub producer panicked".to_owned())
2133 );
2134 }
2135 }
2136 }
2137
2138 fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
2139 match payload.downcast::<String>() {
2140 Ok(message) => *message,
2141 Err(payload) => match payload.downcast::<&'static str>() {
2142 Ok(message) => (*message).to_owned(),
2143 Err(_) => "<non-string panic payload>".to_owned(),
2144 },
2145 }
2146 }
2147
2148 fn wait_for_partition_calls(counter: &AtomicUsize, expected: usize) {
2149 let deadline = Instant::now() + Duration::from_secs(1);
2150 while Instant::now() < deadline {
2151 if counter.load(Ordering::SeqCst) == expected {
2152 return;
2153 }
2154 thread::sleep(Duration::from_millis(5));
2155 }
2156 assert_eq!(counter.load(Ordering::SeqCst), expected);
2157 }
2158}