1use std::{
2 collections::{BTreeMap, VecDeque},
3 fmt,
4 sync::{
5 Arc, Condvar, Mutex, MutexGuard,
6 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
7 },
8 time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use smallvec::SmallVec;
13
14use crate::stream::{
15 BoxStream, Materializer, NotUsed, Sink, Source, SourceRuntimeHints, StreamCompletion,
16 TerminalSourceHookDyn, TerminalSourceStatus,
17};
18use crate::{StreamError, StreamResult};
19
20type Partitioner<T> = Arc<dyn Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync>;
21
22const MERGE_HUB_BATCH_LIMIT: usize = 256;
23const BROADCAST_HUB_BATCH_LIMIT: usize = 256;
24const BROADCAST_HUB_SINGLE_CONSUMER_BATCH_LIMIT: usize = 64;
25const PARTITION_HUB_BATCH_LIMIT: usize = 1024;
26const PARTITION_HUB_WIDE_BATCH_LIMIT: usize = 2048;
27const PARTITION_HUB_SINGLE_CONSUMER_BATCH_LIMIT: usize = 256;
28const FAN_OUT_CONSUMER_BATCH_LIMIT: usize = 256;
29
30fn fan_out_wait_timeout() -> Duration {
31 Duration::from_millis(1)
32}
33
34#[derive(Clone)]
35pub struct MergeHubDrainingControl {
36 state: Arc<MergeHubState>,
37 on_drain: Arc<dyn Fn() + Send + Sync>,
38}
39
40impl fmt::Debug for MergeHubDrainingControl {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("MergeHubDrainingControl").finish()
43 }
44}
45
46impl MergeHubDrainingControl {
47 pub fn drain_and_complete(&self) {
48 let mut state = self.state.lock();
49 state.draining = true;
50 self.state.condvar.notify_all();
51 drop(state);
52 (self.on_drain)();
53 }
54}
55
56pub struct MergeHub;
57
58impl MergeHub {
59 #[must_use]
61 pub fn source<T: Send + 'static>(
62 per_producer_buffer_size: usize,
63 ) -> Source<T, Sink<T, NotUsed>> {
64 Self::source_with_draining(per_producer_buffer_size)
65 .map_materialized_value(|(sink, _)| sink)
66 }
67
68 #[must_use]
71 pub fn source_with_draining<T: Send + 'static>(
72 per_producer_buffer_size: usize,
73 ) -> Source<T, (Sink<T, NotUsed>, MergeHubDrainingControl)> {
74 assert!(
75 per_producer_buffer_size > 0,
76 "MergeHub per_producer_buffer_size must be greater than zero"
77 );
78 Source::from_terminal_direct_materialized_factory(
79 move |_| {
80 let (state, sink, control) =
81 new_merge_hub_materialization(per_producer_buffer_size);
82 let source = Box::new(MergeHubSourceStream {
83 state: Arc::clone(&state),
84 local: VecDeque::new(),
85 prefer_direct: false,
86 }) as BoxStream<T>;
87 Ok((source, (sink, control)))
88 },
89 move |_| {
90 let (state, sink, control) =
91 new_merge_hub_materialization(per_producer_buffer_size);
92 let hook = Arc::new(MergeHubTerminalHook {
93 state,
94 prefer_direct: AtomicBool::new(false),
95 }) as Arc<dyn TerminalSourceHookDyn<T>>;
96 Ok((hook, (sink, control)))
97 },
98 )
99 }
100}
101
102fn new_merge_hub_materialization<T: Send + 'static>(
103 per_producer_buffer_size: usize,
104) -> (
105 Arc<MergeHubShared<T>>,
106 Sink<T, NotUsed>,
107 MergeHubDrainingControl,
108) {
109 let state = Arc::new(MergeHubShared::<T>::new(per_producer_buffer_size));
110 let sink = merge_hub_sink(Arc::clone(&state));
111 let control = MergeHubDrainingControl {
112 state: Arc::clone(&state.state),
113 on_drain: Arc::new({
114 let state = Arc::clone(&state);
115 move || state.finish_if_draining()
116 }),
117 };
118 (state, sink, control)
119}
120
121pub struct BroadcastHub;
122
123impl BroadcastHub {
124 #[must_use]
132 pub fn sink<T: Clone + Send + 'static>(
133 buffer_size: usize,
134 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
135 Self::sink_starting_after(0, buffer_size)
136 }
137
138 #[must_use]
140 pub fn sink_starting_after<T: Clone + Send + 'static>(
141 start_after_nr_of_consumers: usize,
142 buffer_size: usize,
143 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
144 assert!(
145 buffer_size > 0,
146 "BroadcastHub buffer_size must be greater than zero"
147 );
148 Sink::from_hinted_runner(move |input, materializer, hints| {
149 let state = Arc::new(FanOutHubShared::new(
150 FanOutMode::Broadcast,
151 start_after_nr_of_consumers,
152 buffer_size,
153 None::<Partitioner<T>>,
154 ));
155 let source = BroadcastHubConsumerSource {
156 state: Arc::clone(&state),
157 completion: Arc::new(Mutex::new(None)),
158 };
159 let completion = materializer.spawn_stream(move |cancelled| {
160 FanOutProducer::new(input, state).run(cancelled, hints)
161 });
162 source.attach_completion(completion);
163 Ok(source)
164 })
165 }
166}
167
168pub struct PartitionHub;
169
170impl PartitionHub {
171 #[must_use]
179 pub fn sink<T: Clone + Send + 'static, F>(
180 partitioner: F,
181 start_after_nr_of_consumers: usize,
182 buffer_size: usize,
183 ) -> Sink<T, PartitionHubConsumerSource<T>>
184 where
185 F: Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync + 'static,
186 {
187 assert!(
188 buffer_size > 0,
189 "PartitionHub buffer_size must be greater than zero"
190 );
191 let partitioner = Arc::new(partitioner);
192 Sink::from_hinted_runner(move |input, materializer, hints| {
193 let partitioner = Arc::clone(&partitioner);
194 let state = Arc::new(FanOutHubShared::new(
195 FanOutMode::Partition,
196 start_after_nr_of_consumers,
197 buffer_size,
198 Some(partitioner),
199 ));
200 let source = PartitionHubConsumerSource {
201 state: Arc::clone(&state),
202 completion: Arc::new(Mutex::new(None)),
203 };
204 let completion = materializer.spawn_stream(move |cancelled| {
205 FanOutProducer::new(input, state).run(cancelled, hints)
206 });
207 source.attach_completion(completion);
208 Ok(source)
209 })
210 }
211}
212
213#[derive(Clone)]
214pub struct BroadcastHubConsumerSource<T> {
215 state: Arc<FanOutHubShared<T>>,
216 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
217}
218
219impl<T: Clone + Send + 'static> BroadcastHubConsumerSource<T> {
220 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
221 *self
222 .completion
223 .lock()
224 .expect("broadcast hub completion poisoned") = Some(completion);
225 }
226
227 #[must_use]
228 pub fn source(&self) -> Source<T, NotUsed> {
229 let state = Arc::clone(&self.state);
230 Source::from_materialized_factory(move |_| {
231 let lane = state.register_consumer();
232 let stream = Box::new(FanOutConsumerStream {
233 state: Arc::clone(&state),
234 lane,
235 local: None,
236 detached: false,
237 }) as BoxStream<T>;
238 Ok((stream, NotUsed))
239 })
240 }
241}
242
243impl<T: Clone + Send + 'static> fmt::Debug for BroadcastHubConsumerSource<T> {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 f.debug_struct("BroadcastHubConsumerSource").finish()
246 }
247}
248
249#[derive(Clone)]
250pub struct PartitionHubConsumerSource<T> {
251 state: Arc<FanOutHubShared<T>>,
252 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
253}
254
255impl<T: Clone + Send + 'static> PartitionHubConsumerSource<T> {
256 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
257 *self
258 .completion
259 .lock()
260 .expect("partition hub completion poisoned") = Some(completion);
261 }
262
263 #[must_use]
264 pub fn source(&self) -> Source<T, NotUsed> {
265 let state = Arc::clone(&self.state);
266 Source::from_materialized_factory(move |_| {
267 let lane = state.register_consumer();
268 let stream = Box::new(FanOutConsumerStream {
269 state: Arc::clone(&state),
270 lane,
271 local: None,
272 detached: false,
273 }) as BoxStream<T>;
274 Ok((stream, NotUsed))
275 })
276 }
277}
278
279impl<T: Clone + Send + 'static> fmt::Debug for PartitionHubConsumerSource<T> {
280 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
281 f.debug_struct("PartitionHubConsumerSource").finish()
282 }
283}
284
285#[derive(Clone, Debug)]
286pub struct PartitionConsumerInfo {
287 consumer_ids: SmallVec<[u64; 16]>,
288 queue_sizes: SmallVec<[(u64, usize); 16]>,
289}
290
291impl PartitionConsumerInfo {
292 #[must_use]
293 pub fn size(&self) -> usize {
294 self.consumer_ids.len()
295 }
296
297 #[must_use]
298 pub fn consumer_ids(&self) -> &[u64] {
299 &self.consumer_ids
300 }
301
302 #[must_use]
303 pub fn consumer_id_by_idx(&self, idx: usize) -> u64 {
304 self.consumer_ids[idx]
305 }
306
307 #[must_use]
308 pub fn queue_size(&self, consumer_id: u64) -> usize {
309 self.queue_sizes
310 .iter()
311 .find_map(|(id, size)| (*id == consumer_id).then_some(*size))
312 .unwrap_or(0)
313 }
314}
315
316fn merge_hub_sink<T: Send + 'static>(state: Arc<MergeHubShared<T>>) -> Sink<T, NotUsed> {
317 Sink::from_raw_hinted_runner(move |input, materializer, hints| {
318 if hints.inline_micro_max_success_items.is_some() {
319 state.register_direct_producer(input)?;
320 return Ok(NotUsed);
321 }
322
323 let input = materializer.checked_stream(input, None);
324 let producer_id = state.register_producer()?;
325 let hub = Arc::clone(&state);
326 let completion = materializer.spawn_stream(move |cancelled| {
327 let mut input = input;
328 loop {
329 if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
330 hub.fail(StreamError::Cancelled);
331 hub.deregister_producer(producer_id);
332 return Err(StreamError::Cancelled);
333 }
334
335 match input.next() {
336 Some(Ok(item)) => hub.push_item(producer_id, item)?,
337 Some(Err(error)) => {
338 hub.fail(error.clone());
339 hub.deregister_producer(producer_id);
340 return Err(error);
341 }
342 None => {
343 hub.deregister_producer(producer_id);
344 return Ok(NotUsed);
345 }
346 }
347 }
348 });
349 state.store_producer_completion(completion);
350 Ok(NotUsed)
351 })
352}
353
354struct MergeHubShared<T> {
355 state: Arc<MergeHubState>,
356 shared: Mutex<MergeHubInner<T>>,
357 condvar: Condvar,
358 failed: AtomicBool,
359}
360
361#[derive(Debug)]
362struct MergeHubState {
363 inner: Mutex<MergeHubFlags>,
364 condvar: Condvar,
365}
366
367#[derive(Debug, Default)]
368struct MergeHubFlags {
369 draining: bool,
370}
371
372impl MergeHubState {
373 fn lock(&self) -> MutexGuard<'_, MergeHubFlags> {
374 self.inner.lock().expect("merge hub flags poisoned")
375 }
376}
377
378struct MergeHubInner<T> {
379 queue: VecDeque<(u64, T)>,
380 direct_producers: VecDeque<MergeHubDirectProducer<T>>,
381 queued_per_producer: BTreeMap<u64, usize>,
382 producer_completions: Vec<StreamCompletion<NotUsed>>,
383 active_producers: usize,
384 next_producer_id: u64,
385 source_closed: bool,
386 completed: bool,
387 failed: Option<StreamError>,
388 per_producer_buffer_size: usize,
389}
390
391struct MergeHubDirectProducer<T> {
392 id: u64,
393 input: BoxStream<T>,
394}
395
396struct MergeHubTerminalHook<T> {
397 state: Arc<MergeHubShared<T>>,
398 prefer_direct: AtomicBool,
399}
400
401enum MergeHubProducerTerminal {
402 Active,
403 Completed,
404 Failed(StreamError),
405}
406
407impl<T: Send + 'static> TerminalSourceHookDyn<T> for MergeHubTerminalHook<T> {
408 fn drain_terminal_batch(
409 &self,
410 materializer: &Materializer,
411 cancelled: &Arc<AtomicBool>,
412 batch: &mut Vec<T>,
413 ) -> StreamResult<TerminalSourceStatus> {
414 self.state
415 .drain_terminal_batch(materializer, cancelled, &self.prefer_direct, batch)
416 }
417
418 fn cancel_terminal(&self) {
419 self.state.close_source();
420 }
421}
422
423impl<T> MergeHubShared<T> {
424 fn new(per_producer_buffer_size: usize) -> Self {
425 Self {
426 state: Arc::new(MergeHubState {
427 inner: Mutex::new(MergeHubFlags::default()),
428 condvar: Condvar::new(),
429 }),
430 shared: Mutex::new(MergeHubInner {
431 queue: VecDeque::new(),
432 direct_producers: VecDeque::new(),
433 queued_per_producer: BTreeMap::new(),
434 producer_completions: Vec::new(),
435 active_producers: 0,
436 next_producer_id: 0,
437 source_closed: false,
438 completed: false,
439 failed: None,
440 per_producer_buffer_size,
441 }),
442 condvar: Condvar::new(),
443 failed: AtomicBool::new(false),
444 }
445 }
446
447 fn register_direct_producer(&self, input: BoxStream<T>) -> StreamResult<()> {
448 let mut inner = self.shared.lock().expect("merge hub poisoned");
449 prune_finished_producer_completions(&mut inner.producer_completions);
450 let flags = self.state.lock();
451 if flags.draining || inner.source_closed || inner.completed {
452 return Err(StreamError::Failed(
453 "merge hub is draining or closed to new producers".to_owned(),
454 ));
455 }
456 if let Some(error) = inner.failed.clone() {
457 return Err(error);
458 }
459 let id = inner.next_producer_id;
460 inner.next_producer_id += 1;
461 inner.active_producers += 1;
462 inner
463 .direct_producers
464 .push_back(MergeHubDirectProducer { id, input });
465 drop(flags);
466 drop(inner);
467 self.condvar.notify_all();
468 Ok(())
469 }
470
471 fn register_producer(&self) -> StreamResult<u64> {
472 let mut inner = self.shared.lock().expect("merge hub poisoned");
473 prune_finished_producer_completions(&mut inner.producer_completions);
474 let flags = self.state.lock();
475 if flags.draining || inner.source_closed || inner.completed {
476 return Err(StreamError::Failed(
477 "merge hub is draining or closed to new producers".to_owned(),
478 ));
479 }
480 if let Some(error) = inner.failed.clone() {
481 return Err(error);
482 }
483 let id = inner.next_producer_id;
484 inner.next_producer_id += 1;
485 inner.active_producers += 1;
486 inner.queued_per_producer.insert(id, 0);
487 Ok(id)
488 }
489
490 fn store_producer_completion(&self, completion: StreamCompletion<NotUsed>) {
491 let mut inner = self.shared.lock().expect("merge hub poisoned");
492 prune_finished_producer_completions(&mut inner.producer_completions);
493 inner.producer_completions.push(completion);
494 }
495
496 fn push_item(&self, producer_id: u64, item: T) -> StreamResult<()> {
497 let mut inner = self.shared.lock().expect("merge hub poisoned");
498 loop {
499 if let Some(error) = inner.failed.clone() {
500 inner.queued_per_producer.remove(&producer_id);
501 return Err(error);
502 }
503 if inner.source_closed {
504 inner.queued_per_producer.remove(&producer_id);
505 return Err(StreamError::Cancelled);
506 }
507 let queued = inner
508 .queued_per_producer
509 .get(&producer_id)
510 .copied()
511 .unwrap_or(0);
512 if queued < inner.per_producer_buffer_size {
513 inner.queue.push_back((producer_id, item));
514 inner.queued_per_producer.insert(producer_id, queued + 1);
515 self.condvar.notify_all();
516 return Ok(());
517 }
518 inner = self
519 .condvar
520 .wait(inner)
521 .expect("merge hub poisoned while waiting");
522 }
523 }
524
525 fn deregister_producer(&self, producer_id: u64) {
526 let mut inner = self.shared.lock().expect("merge hub poisoned");
527 prune_finished_producer_completions(&mut inner.producer_completions);
528 inner.queued_per_producer.remove(&producer_id);
529 inner
530 .direct_producers
531 .retain(|producer| producer.id != producer_id);
532 inner.active_producers = inner.active_producers.saturating_sub(1);
533 if inner.active_producers == 0 {
534 let flags = self.state.lock();
535 if flags.draining {
536 inner.completed = true;
537 }
538 }
539 self.condvar.notify_all();
540 }
541
542 fn pop_direct_producer(
543 &self,
544 inner: &mut MergeHubInner<T>,
545 ) -> Option<MergeHubDirectProducer<T>> {
546 inner.direct_producers.pop_front()
547 }
548
549 fn restore_direct_producer(
550 &self,
551 producer: MergeHubDirectProducer<T>,
552 terminal: MergeHubProducerTerminal,
553 ) -> StreamResult<()> {
554 match terminal {
555 MergeHubProducerTerminal::Active => {
556 let mut inner = self.shared.lock().expect("merge hub poisoned");
557 if let Some(error) = inner.failed.clone() {
558 inner.active_producers = inner.active_producers.saturating_sub(1);
559 return Err(error);
560 }
561 if inner.source_closed {
562 inner.active_producers = inner.active_producers.saturating_sub(1);
563 return Err(StreamError::Cancelled);
564 }
565 inner.direct_producers.push_back(producer);
566 Ok(())
567 }
568 MergeHubProducerTerminal::Completed => {
569 self.deregister_producer(producer.id);
570 Ok(())
571 }
572 MergeHubProducerTerminal::Failed(error) => {
573 self.fail(error.clone());
574 self.deregister_producer(producer.id);
575 Err(error)
576 }
577 }
578 }
579
580 fn fail(&self, error: StreamError) {
581 let mut inner = self.shared.lock().expect("merge hub poisoned");
582 if inner.failed.is_none() {
583 inner.failed = Some(error);
584 self.failed.store(true, Ordering::SeqCst);
585 }
586 self.condvar.notify_all();
587 }
588
589 fn failed_error(&self) -> Option<StreamError> {
590 if !self.failed.load(Ordering::SeqCst) {
591 return None;
592 }
593 self.shared
594 .lock()
595 .expect("merge hub poisoned")
596 .failed
597 .clone()
598 }
599
600 fn finish_if_draining(&self) {
601 let flags = self.state.lock();
602 if !flags.draining {
603 return;
604 }
605 drop(flags);
606
607 let mut inner = self.shared.lock().expect("merge hub poisoned");
608 prune_finished_producer_completions(&mut inner.producer_completions);
609 if inner.active_producers == 0 {
610 inner.completed = true;
611 self.condvar.notify_all();
612 }
613 }
614
615 fn close_source(&self) {
616 let mut inner = self.shared.lock().expect("merge hub poisoned");
617 inner.source_closed = true;
618 self.condvar.notify_all();
619 }
620
621 fn terminal_status(
622 &self,
623 materializer: &Materializer,
624 cancelled: &Arc<AtomicBool>,
625 ) -> StreamResult<()> {
626 if materializer.is_shutdown() {
627 self.close_source();
628 Err(StreamError::AbruptTermination)
629 } else if cancelled.load(Ordering::SeqCst) {
630 self.close_source();
631 Err(StreamError::Cancelled)
632 } else {
633 Ok(())
634 }
635 }
636
637 fn drain_terminal_batch(
638 &self,
639 materializer: &Materializer,
640 cancelled: &Arc<AtomicBool>,
641 prefer_direct: &AtomicBool,
642 batch: &mut Vec<T>,
643 ) -> StreamResult<TerminalSourceStatus> {
644 batch.clear();
645 loop {
646 self.terminal_status(materializer, cancelled)?;
647 let mut inner = self.shared.lock().expect("merge hub poisoned");
648 loop {
649 if let Some(error) = inner.failed.clone() {
650 inner.source_closed = true;
651 return Err(error);
652 }
653
654 let should_drain_direct = !inner.direct_producers.is_empty()
655 && (prefer_direct.load(Ordering::Relaxed) || inner.queue.is_empty());
656 if should_drain_direct {
657 let Some(mut producer) = self.pop_direct_producer(&mut inner) else {
658 continue;
659 };
660 let drain_limit = inner
661 .per_producer_buffer_size
662 .clamp(1, MERGE_HUB_BATCH_LIMIT);
663 drop(inner);
664
665 batch.reserve(drain_limit.saturating_sub(batch.capacity()));
666 let mut terminal = MergeHubProducerTerminal::Active;
667 for _ in 0..drain_limit {
668 match producer.input.next() {
669 Some(Ok(item)) => batch.push(item),
670 Some(Err(error)) => {
671 terminal = MergeHubProducerTerminal::Failed(error);
672 break;
673 }
674 None => {
675 terminal = MergeHubProducerTerminal::Completed;
676 break;
677 }
678 }
679 }
680
681 let restore_result = self.restore_direct_producer(producer, terminal);
682 match restore_result {
683 Ok(()) => {
684 prefer_direct.store(false, Ordering::Relaxed);
685 }
686 Err(error) => {
687 self.close_source();
688 return Err(error);
689 }
690 }
691
692 let completed = self.terminal_completed_after_batch();
693 if !batch.is_empty() {
694 return Ok(if completed {
695 TerminalSourceStatus::Completed
696 } else {
697 TerminalSourceStatus::Active
698 });
699 }
700 if completed {
701 return Ok(TerminalSourceStatus::Completed);
702 }
703 break;
704 }
705
706 if !inner.queue.is_empty() {
707 let drain_n = inner.queue.len().min(MERGE_HUB_BATCH_LIMIT);
708 batch.reserve(drain_n.saturating_sub(batch.capacity()));
709 for _ in 0..drain_n {
710 if let Some((producer_id, item)) = inner.queue.pop_front() {
711 if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
712 *queued = queued.saturating_sub(1);
713 }
714 batch.push(item);
715 }
716 }
717 let completed = inner.completed
718 && inner.queue.is_empty()
719 && inner.direct_producers.is_empty();
720 self.condvar.notify_all();
721 drop(inner);
722
723 prefer_direct.store(true, Ordering::Relaxed);
724 return Ok(if completed {
725 TerminalSourceStatus::Completed
726 } else {
727 TerminalSourceStatus::Active
728 });
729 }
730
731 if inner.completed {
732 inner.source_closed = true;
733 return Ok(TerminalSourceStatus::Completed);
734 }
735
736 let (next_inner, _) = self
737 .condvar
738 .wait_timeout(inner, fan_out_wait_timeout())
739 .expect("merge hub poisoned while waiting");
740 inner = next_inner;
741 if materializer.is_shutdown() {
742 inner.source_closed = true;
743 self.condvar.notify_all();
744 return Err(StreamError::AbruptTermination);
745 }
746 if cancelled.load(Ordering::SeqCst) {
747 inner.source_closed = true;
748 self.condvar.notify_all();
749 return Err(StreamError::Cancelled);
750 }
751 }
752 }
753 }
754
755 fn terminal_completed_after_batch(&self) -> bool {
756 let inner = self.shared.lock().expect("merge hub poisoned");
757 inner.completed && inner.queue.is_empty() && inner.direct_producers.is_empty()
758 }
759}
760
761fn prune_finished_producer_completions(completions: &mut Vec<StreamCompletion<NotUsed>>) {
762 let mut index = 0;
763 while index < completions.len() {
764 if completions[index].try_wait().is_some() {
765 drop(completions.swap_remove(index));
766 } else {
767 index += 1;
768 }
769 }
770}
771
772struct MergeHubSourceStream<T> {
773 state: Arc<MergeHubShared<T>>,
774 local: VecDeque<T>,
775 prefer_direct: bool,
776}
777
778impl<T> Iterator for MergeHubSourceStream<T> {
779 type Item = StreamResult<T>;
780
781 fn next(&mut self) -> Option<Self::Item> {
782 if let Some(error) = self.state.failed_error() {
783 self.local.clear();
784 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
785 inner.source_closed = true;
786 return Some(Err(error));
787 }
788 if let Some(item) = self.local.pop_front() {
789 return Some(Ok(item));
790 }
791
792 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
793 loop {
794 if let Some(error) = inner.failed.clone() {
795 inner.source_closed = true;
796 return Some(Err(error));
797 }
798 let should_drain_direct = !inner.direct_producers.is_empty()
799 && (self.prefer_direct || inner.queue.is_empty());
800 if should_drain_direct {
801 let Some(mut producer) = self.state.pop_direct_producer(&mut inner) else {
802 continue;
803 };
804 let drain_limit = inner
805 .per_producer_buffer_size
806 .clamp(1, MERGE_HUB_BATCH_LIMIT);
807 drop(inner);
808
809 let mut batch = std::mem::take(&mut self.local);
810 batch.clear();
811 batch.reserve(drain_limit.saturating_sub(batch.capacity()));
812 let mut terminal = MergeHubProducerTerminal::Active;
813 for _ in 0..drain_limit {
814 match producer.input.next() {
815 Some(Ok(item)) => batch.push_back(item),
816 Some(Err(error)) => {
817 terminal = MergeHubProducerTerminal::Failed(error);
818 break;
819 }
820 None => {
821 terminal = MergeHubProducerTerminal::Completed;
822 break;
823 }
824 }
825 }
826
827 let restore_result = self.state.restore_direct_producer(producer, terminal);
828 match restore_result {
829 Ok(()) => {
830 self.prefer_direct = false;
831 if let Some(first) = batch.pop_front() {
832 self.local = batch;
833 return Some(Ok(first));
834 }
835 inner = self.state.shared.lock().expect("merge hub poisoned");
836 continue;
837 }
838 Err(error) => {
839 self.local.clear();
840 return Some(Err(error));
841 }
842 }
843 }
844 if !inner.queue.is_empty() {
845 let drain_n = inner.queue.len().min(MERGE_HUB_BATCH_LIMIT);
846 let mut batch = std::mem::take(&mut self.local);
847 batch.clear();
848 batch.reserve(drain_n.saturating_sub(batch.capacity()));
849 for _ in 0..drain_n {
850 if let Some((producer_id, item)) = inner.queue.pop_front() {
851 if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
852 *queued = queued.saturating_sub(1);
853 }
854 batch.push_back(item);
855 }
856 }
857 self.state.condvar.notify_all();
858 drop(inner);
859 let first = batch
860 .pop_front()
861 .expect("merge hub drained non-empty batch");
862 self.local = batch;
863 self.prefer_direct = true;
864 return Some(Ok(first));
865 }
866 if inner.completed {
867 inner.source_closed = true;
868 return None;
869 }
870 inner = self
871 .state
872 .condvar
873 .wait(inner)
874 .expect("merge hub poisoned while waiting");
875 }
876 }
877}
878
879impl<T> Drop for MergeHubSourceStream<T> {
880 fn drop(&mut self) {
881 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
882 inner.source_closed = true;
883 self.state.condvar.notify_all();
884 }
885}
886
887#[derive(Clone, Copy)]
888enum FanOutMode {
889 Broadcast,
890 Partition,
891}
892
893struct FanOutHubShared<T> {
894 registry: Mutex<FanOutRegistry<T>>,
895 snapshot: ArcSwap<FanOutSnapshot<T>>,
896 producer_wait: Mutex<()>,
897 producer_condvar: Condvar,
898 producer_epoch: AtomicU64,
899 topology_epoch: AtomicU64,
900 mode: FanOutMode,
901 start_after_nr_of_consumers: usize,
902 buffer_size: usize,
903 partitioner: Option<Partitioner<T>>,
904}
905
906struct FanOutRegistry<T> {
907 consumers: BTreeMap<u64, Arc<FanOutConsumerLane<T>>>,
908 next_consumer_id: u64,
909 terminal: Option<FanOutTerminal>,
910}
911
912struct FanOutSnapshot<T> {
913 consumers: Vec<Arc<FanOutConsumerLane<T>>>,
914 terminal: Option<FanOutTerminal>,
915}
916
917#[derive(Clone, Debug)]
918enum FanOutTerminal {
919 Completed,
920 Failed(StreamError),
921}
922
923impl FanOutTerminal {
924 fn producer_error(&self) -> StreamError {
925 match self {
926 Self::Completed => StreamError::Cancelled,
927 Self::Failed(error) => error.clone(),
928 }
929 }
930}
931
932struct FanOutConsumerLane<T> {
933 id: u64,
934 state: Mutex<FanOutLaneState<T>>,
935 condvar: Condvar,
936 queued: AtomicUsize,
937 active: AtomicBool,
938 failed: AtomicBool,
939}
940
941struct FanOutLaneState<T> {
942 chunks: VecDeque<Vec<T>>,
943 queued: usize,
944 terminal: Option<FanOutTerminal>,
945}
946
947impl<T> FanOutConsumerLane<T> {
948 fn new(id: u64, buffer_size: usize) -> Self {
949 Self {
950 id,
951 state: Mutex::new(FanOutLaneState {
952 chunks: VecDeque::with_capacity((buffer_size / FAN_OUT_CONSUMER_BATCH_LIMIT) + 1),
953 queued: 0,
954 terminal: None,
955 }),
956 condvar: Condvar::new(),
957 queued: AtomicUsize::new(0),
958 active: AtomicBool::new(true),
959 failed: AtomicBool::new(false),
960 }
961 }
962
963 fn terminal(id: u64, terminal: FanOutTerminal) -> Self {
964 let failed = matches!(terminal, FanOutTerminal::Failed(_));
965 Self {
966 id,
967 state: Mutex::new(FanOutLaneState {
968 chunks: VecDeque::new(),
969 queued: 0,
970 terminal: Some(terminal),
971 }),
972 condvar: Condvar::new(),
973 queued: AtomicUsize::new(0),
974 active: AtomicBool::new(false),
975 failed: AtomicBool::new(failed),
976 }
977 }
978
979 fn id(&self) -> u64 {
980 self.id
981 }
982
983 fn queued_len(&self) -> usize {
984 self.queued.load(Ordering::Acquire)
985 }
986
987 fn is_active(&self) -> bool {
988 self.active.load(Ordering::Acquire)
989 }
990
991 fn deactivate(&self) {
992 self.active.store(false, Ordering::Release);
993 self.condvar.notify_all();
994 }
995
996 fn set_terminal(&self, terminal: FanOutTerminal) {
997 let mut state = self.state.lock().expect("fan-out lane poisoned");
998 if matches!(terminal, FanOutTerminal::Failed(_)) {
999 state.chunks.clear();
1000 state.queued = 0;
1001 self.queued.store(0, Ordering::Release);
1002 self.failed.store(true, Ordering::Release);
1003 }
1004 state.terminal = Some(terminal);
1005 drop(state);
1006 self.condvar.notify_all();
1007 }
1008}
1009
1010impl PartitionConsumerInfo {
1011 fn from_cached_lanes<T>(consumer_ids: &[u64], lanes: &[Arc<FanOutConsumerLane<T>>]) -> Self {
1012 Self {
1013 consumer_ids: consumer_ids.iter().copied().collect(),
1014 queue_sizes: lanes
1015 .iter()
1016 .map(|lane| (lane.id(), lane.queued_len()))
1017 .collect(),
1018 }
1019 }
1020}
1021
1022struct PartitionTopologyCache<T> {
1023 epoch: u64,
1024 consumer_ids: SmallVec<[u64; 16]>,
1025 lanes: SmallVec<[Arc<FanOutConsumerLane<T>>; 16]>,
1026}
1027
1028type PartitionRoutedBatches<T> = SmallVec<[(Arc<FanOutConsumerLane<T>>, VecDeque<T>); 16]>;
1029
1030impl<T> PartitionTopologyCache<T> {
1031 fn new() -> Self {
1032 Self {
1033 epoch: u64::MAX,
1034 consumer_ids: SmallVec::new(),
1035 lanes: SmallVec::new(),
1036 }
1037 }
1038
1039 fn clear(&mut self) {
1040 self.epoch = u64::MAX;
1041 self.consumer_ids.clear();
1042 self.lanes.clear();
1043 }
1044}
1045
1046impl<T> FanOutHubShared<T> {
1047 fn new(
1048 mode: FanOutMode,
1049 start_after_nr_of_consumers: usize,
1050 buffer_size: usize,
1051 partitioner: Option<Partitioner<T>>,
1052 ) -> Self {
1053 Self {
1054 registry: Mutex::new(FanOutRegistry {
1055 consumers: BTreeMap::new(),
1056 next_consumer_id: 0,
1057 terminal: None,
1058 }),
1059 snapshot: ArcSwap::from_pointee(FanOutSnapshot {
1060 consumers: Vec::new(),
1061 terminal: None,
1062 }),
1063 producer_wait: Mutex::new(()),
1064 producer_condvar: Condvar::new(),
1065 producer_epoch: AtomicU64::new(0),
1066 topology_epoch: AtomicU64::new(0),
1067 mode,
1068 start_after_nr_of_consumers,
1069 buffer_size,
1070 partitioner,
1071 }
1072 }
1073
1074 fn register_consumer(&self) -> Arc<FanOutConsumerLane<T>> {
1075 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1076 let id = registry.next_consumer_id;
1077 registry.next_consumer_id += 1;
1078
1079 if let Some(terminal) = registry.terminal.clone() {
1080 return Arc::new(FanOutConsumerLane::terminal(id, terminal));
1081 }
1082
1083 let lane = Arc::new(FanOutConsumerLane::new(id, self.buffer_size));
1084 registry.consumers.insert(id, Arc::clone(&lane));
1085 self.publish_snapshot_locked(®istry);
1086 drop(registry);
1087 self.notify_topology_transition();
1088 lane
1089 }
1090
1091 fn remove_consumer(&self, consumer_id: u64) {
1092 let lane = {
1093 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1094 let lane = registry.consumers.remove(&consumer_id);
1095 if let Some(lane) = &lane {
1096 lane.deactivate();
1097 self.publish_snapshot_locked(®istry);
1098 }
1099 lane
1100 };
1101
1102 if let Some(lane) = lane {
1103 lane.condvar.notify_all();
1104 self.notify_topology_transition();
1105 }
1106 }
1107
1108 fn wait_for_broadcast_capacity(&self, max_items: usize) -> StreamResult<usize> {
1109 loop {
1110 let observed = self.producer_epoch.load(Ordering::Acquire);
1111 let snapshot = self.snapshot.load();
1112 if let Some(terminal) = &snapshot.terminal {
1113 return Err(terminal.producer_error());
1114 }
1115
1116 let lanes = self.active_lanes(&snapshot);
1117 if lanes.len() < self.start_after_nr_of_consumers || lanes.is_empty() {
1118 self.wait_for_producer_transition(observed);
1119 continue;
1120 }
1121
1122 let free = lanes
1123 .iter()
1124 .map(|lane| self.buffer_size.saturating_sub(lane.queued_len()))
1125 .min()
1126 .unwrap_or(0);
1127 if free > 0 {
1128 let max_items = if lanes.len() == 1 {
1129 max_items.min(BROADCAST_HUB_SINGLE_CONSUMER_BATCH_LIMIT)
1130 } else {
1131 max_items
1132 };
1133 return Ok(free.min(max_items).max(1));
1134 }
1135 self.wait_for_producer_transition(observed);
1136 }
1137 }
1138
1139 fn push_broadcast_item(&self, item: T) -> StreamResult<()>
1140 where
1141 T: Clone,
1142 {
1143 let mut batch = VecDeque::new();
1144 batch.push_back(item);
1145 self.push_broadcast_batch(&mut batch)
1146 }
1147
1148 fn push_broadcast_batch(&self, batch: &mut VecDeque<T>) -> StreamResult<()>
1149 where
1150 T: Clone,
1151 {
1152 if batch.is_empty() {
1153 return Ok(());
1154 }
1155
1156 while !batch.is_empty() {
1157 let observed = self.producer_epoch.load(Ordering::Acquire);
1158 let snapshot = self.snapshot.load();
1159 if let Some(terminal) = &snapshot.terminal {
1160 batch.clear();
1161 return Err(terminal.producer_error());
1162 }
1163
1164 let lanes = self.active_lanes(&snapshot);
1165 if lanes.len() < self.start_after_nr_of_consumers || lanes.is_empty() {
1166 self.wait_for_producer_transition(observed);
1167 continue;
1168 }
1169
1170 let free = lanes
1171 .iter()
1172 .map(|lane| self.buffer_size.saturating_sub(lane.queued_len()))
1173 .min()
1174 .unwrap_or(0);
1175 if free == 0 {
1176 self.wait_for_producer_transition(observed);
1177 continue;
1178 }
1179
1180 let take_n = free.min(batch.len());
1181 if !self.try_push_broadcast_batch(&lanes, batch, take_n)? {
1182 self.wait_for_producer_transition(observed);
1183 }
1184 }
1185 Ok(())
1186 }
1187
1188 fn try_push_broadcast_batch(
1189 &self,
1190 lanes: &[Arc<FanOutConsumerLane<T>>],
1191 batch: &mut VecDeque<T>,
1192 take_n: usize,
1193 ) -> StreamResult<bool>
1194 where
1195 T: Clone,
1196 {
1197 if take_n == 0 || lanes.is_empty() {
1198 return Ok(false);
1199 }
1200
1201 let mut guards = SmallVec::<
1202 [(
1203 Arc<FanOutConsumerLane<T>>,
1204 MutexGuard<'_, FanOutLaneState<T>>,
1205 ); 16],
1206 >::new();
1207 for lane in lanes {
1208 if !lane.is_active() {
1209 return Ok(false);
1210 }
1211 let guard = lane.state.lock().expect("fan-out lane poisoned");
1212 if !lane.is_active() {
1213 return Ok(false);
1214 }
1215 if let Some(terminal) = &guard.terminal {
1216 return Err(terminal.producer_error());
1217 }
1218 if self.buffer_size.saturating_sub(guard.queued) < take_n {
1219 return Ok(false);
1220 }
1221 guards.push((Arc::clone(lane), guard));
1222 }
1223
1224 let lane_count = guards.len();
1225 if lane_count == 0 {
1226 return Ok(false);
1227 }
1228
1229 let mut notify_lanes = SmallVec::<[Arc<FanOutConsumerLane<T>>; 16]>::new();
1230 for (index, (lane, guard)) in guards.iter_mut().enumerate() {
1231 let chunk = if index + 1 == lane_count {
1232 batch.drain(..take_n).collect()
1233 } else {
1234 batch.iter().take(take_n).cloned().collect()
1235 };
1236 guard.chunks.push_back(chunk);
1237 guard.queued += take_n;
1238 lane.queued.store(guard.queued, Ordering::Release);
1239 notify_lanes.push(Arc::clone(lane));
1240 }
1241 drop(guards);
1242
1243 for lane in notify_lanes {
1244 lane.condvar.notify_one();
1245 }
1246 Ok(true)
1247 }
1248
1249 fn select_partition(
1250 &self,
1251 item: &T,
1252 cache: &mut PartitionTopologyCache<T>,
1253 ) -> StreamResult<Option<Arc<FanOutConsumerLane<T>>>> {
1254 let Some(partitioner) = &self.partitioner else {
1255 return Err(StreamError::Failed(
1256 "partition hub partitioner missing".to_owned(),
1257 ));
1258 };
1259 let topology_epoch = self.topology_epoch.load(Ordering::Acquire);
1260 if cache.epoch != topology_epoch || cache.lanes.is_empty() {
1261 self.refresh_partition_topology(cache)?;
1262 }
1263
1264 let info = PartitionConsumerInfo::from_cached_lanes(&cache.consumer_ids, &cache.lanes);
1265 let selected = partitioner(&info, item);
1266 if selected < 0 {
1267 return Ok(None);
1268 }
1269 let selected = selected as u64;
1270 let selected_idx = selected as usize;
1271 if cache.consumer_ids.get(selected_idx).copied() == Some(selected) {
1272 return Ok(Some(Arc::clone(&cache.lanes[selected_idx])));
1273 }
1274 cache
1275 .consumer_ids
1276 .iter()
1277 .position(|id| *id == selected)
1278 .map(|idx| Some(Arc::clone(&cache.lanes[idx])))
1279 .ok_or_else(|| {
1280 StreamError::Failed("partition hub selected unknown consumer".to_owned())
1281 })
1282 }
1283
1284 fn refresh_partition_topology(
1285 &self,
1286 cache: &mut PartitionTopologyCache<T>,
1287 ) -> StreamResult<()> {
1288 loop {
1289 let observed = self.producer_epoch.load(Ordering::Acquire);
1290 let topology_epoch = self.topology_epoch.load(Ordering::Acquire);
1291 let snapshot = self.snapshot.load();
1292 if let Some(terminal) = &snapshot.terminal {
1293 cache.clear();
1294 return Err(terminal.producer_error());
1295 }
1296
1297 let lanes = self.active_lanes(&snapshot);
1298 if lanes.len() >= self.start_after_nr_of_consumers && !lanes.is_empty() {
1299 cache.epoch = topology_epoch;
1300 cache.consumer_ids.clear();
1301 cache
1302 .consumer_ids
1303 .extend(lanes.iter().map(|lane| lane.id()));
1304 cache.lanes = lanes;
1305 return Ok(());
1306 }
1307 self.wait_for_producer_transition(observed);
1308 }
1309 }
1310
1311 fn enqueue_partition(&self, selected: Arc<FanOutConsumerLane<T>>, item: T) -> StreamResult<()> {
1312 let mut batch = VecDeque::new();
1313 batch.push_back(item);
1314 self.enqueue_partition_batch(selected, &mut batch)
1315 }
1316
1317 fn enqueue_partition_batch(
1318 &self,
1319 selected: Arc<FanOutConsumerLane<T>>,
1320 batch: &mut VecDeque<T>,
1321 ) -> StreamResult<()> {
1322 if batch.is_empty() {
1323 return Ok(());
1324 }
1325
1326 let mut state = selected.state.lock().expect("fan-out lane poisoned");
1327 loop {
1328 if !selected.is_active() {
1329 batch.clear();
1330 return Err(StreamError::Failed(
1331 "partition hub selected unknown consumer".to_owned(),
1332 ));
1333 }
1334 if let Some(terminal) = &state.terminal {
1335 batch.clear();
1336 return Err(terminal.producer_error());
1337 }
1338 let free = self.buffer_size.saturating_sub(state.queued);
1339 if free > 0 {
1340 let take_n = free.min(batch.len());
1341 let chunk = batch.drain(..take_n).collect();
1342 state.chunks.push_back(chunk);
1343 state.queued += take_n;
1344 selected.queued.store(state.queued, Ordering::Release);
1345 selected.condvar.notify_one();
1346 if batch.is_empty() {
1347 return Ok(());
1348 }
1349 }
1350 let (guard, _) = selected
1351 .condvar
1352 .wait_timeout(state, fan_out_wait_timeout())
1353 .expect("fan-out lane poisoned while waiting");
1354 state = guard;
1355 }
1356 }
1357
1358 fn complete(&self) {
1359 let lanes = {
1360 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1361 if registry.terminal.is_none() {
1362 registry.terminal = Some(FanOutTerminal::Completed);
1363 self.publish_snapshot_locked(®istry);
1364 }
1365 registry.consumers.values().cloned().collect::<Vec<_>>()
1366 };
1367
1368 for lane in lanes {
1369 lane.set_terminal(FanOutTerminal::Completed);
1370 }
1371 self.notify_topology_transition();
1372 }
1373
1374 fn fail(&self, error: StreamError) {
1375 let terminal = FanOutTerminal::Failed(error);
1376 let lanes = {
1377 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1378 if registry.terminal.is_none() {
1379 registry.terminal = Some(terminal.clone());
1380 self.publish_snapshot_locked(®istry);
1381 }
1382 registry.consumers.values().cloned().collect::<Vec<_>>()
1383 };
1384
1385 for lane in lanes {
1386 lane.set_terminal(terminal.clone());
1387 }
1388 self.notify_topology_transition();
1389 }
1390
1391 fn publish_snapshot_locked(&self, registry: &FanOutRegistry<T>) {
1392 self.snapshot.store(Arc::new(FanOutSnapshot {
1393 consumers: registry.consumers.values().cloned().collect(),
1394 terminal: registry.terminal.clone(),
1395 }));
1396 }
1397
1398 fn active_lanes(
1399 &self,
1400 snapshot: &FanOutSnapshot<T>,
1401 ) -> SmallVec<[Arc<FanOutConsumerLane<T>>; 16]> {
1402 snapshot
1403 .consumers
1404 .iter()
1405 .filter(|lane| lane.is_active())
1406 .cloned()
1407 .collect()
1408 }
1409
1410 fn partition_batch_limit(&self) -> usize {
1411 let snapshot = self.snapshot.load();
1412 let active = snapshot
1413 .consumers
1414 .iter()
1415 .filter(|lane| lane.is_active())
1416 .count();
1417 match active {
1418 0 | 1 => PARTITION_HUB_SINGLE_CONSUMER_BATCH_LIMIT,
1419 2..=7 => PARTITION_HUB_BATCH_LIMIT,
1420 _ => PARTITION_HUB_WIDE_BATCH_LIMIT,
1421 }
1422 }
1423
1424 fn notify_producer_transition(&self) {
1425 self.producer_epoch.fetch_add(1, Ordering::Release);
1426 self.producer_condvar.notify_one();
1427 }
1428
1429 fn notify_topology_transition(&self) {
1430 self.topology_epoch.fetch_add(1, Ordering::Release);
1431 self.notify_producer_transition();
1432 }
1433
1434 fn wait_for_producer_transition(&self, observed_epoch: u64) {
1435 let guard = self
1436 .producer_wait
1437 .lock()
1438 .expect("fan-out producer wait poisoned");
1439 if self.producer_epoch.load(Ordering::Acquire) == observed_epoch {
1440 let (_guard, _) = self
1441 .producer_condvar
1442 .wait_timeout(guard, fan_out_wait_timeout())
1443 .expect("fan-out producer wait poisoned while waiting");
1444 }
1445 }
1446}
1447
1448fn push_partition_routed<T>(
1449 routed: &mut PartitionRoutedBatches<T>,
1450 lane: Arc<FanOutConsumerLane<T>>,
1451 item: T,
1452) {
1453 for (existing, batch) in routed.iter_mut() {
1454 if existing.id() == lane.id() {
1455 batch.push_back(item);
1456 return;
1457 }
1458 }
1459
1460 let mut batch = VecDeque::new();
1461 batch.push_back(item);
1462 routed.push((lane, batch));
1463}
1464
1465struct FanOutProducer<T> {
1466 input: BoxStream<T>,
1467 state: Arc<FanOutHubShared<T>>,
1468}
1469
1470impl<T> FanOutProducer<T> {
1471 fn new(input: BoxStream<T>, state: Arc<FanOutHubShared<T>>) -> Self {
1472 Self { input, state }
1473 }
1474}
1475
1476impl<T: Send + 'static + Clone> FanOutProducer<T> {
1477 fn run(
1478 mut self,
1479 cancelled: Arc<std::sync::atomic::AtomicBool>,
1480 hints: SourceRuntimeHints,
1481 ) -> StreamResult<NotUsed> {
1482 struct ProducerDropGuard<T> {
1483 state: Arc<FanOutHubShared<T>>,
1484 disarmed: bool,
1485 }
1486
1487 impl<T> ProducerDropGuard<T> {
1488 fn new(state: Arc<FanOutHubShared<T>>) -> Self {
1489 Self {
1490 state,
1491 disarmed: false,
1492 }
1493 }
1494
1495 fn disarm(&mut self) {
1496 self.disarmed = true;
1497 }
1498 }
1499
1500 impl<T> Drop for ProducerDropGuard<T> {
1501 fn drop(&mut self) {
1502 if !self.disarmed && std::thread::panicking() {
1503 self.state.fail(StreamError::Failed(
1504 "fan-out hub producer panicked".to_owned(),
1505 ));
1506 }
1507 }
1508 }
1509
1510 let mut guard = ProducerDropGuard::new(Arc::clone(&self.state));
1511 match self.state.mode {
1512 FanOutMode::Broadcast => {
1513 let batch_producer = hints.inline_micro_max_success_items.is_some();
1514 let mut batch = VecDeque::new();
1515 loop {
1516 if cancelled.load(Ordering::SeqCst) {
1517 self.state.fail(StreamError::Cancelled);
1518 guard.disarm();
1519 return Err(StreamError::Cancelled);
1520 }
1521
1522 if batch_producer {
1523 let capacity = self
1524 .state
1525 .wait_for_broadcast_capacity(BROADCAST_HUB_BATCH_LIMIT)?;
1526 batch.clear();
1527 batch.reserve(capacity.saturating_sub(batch.capacity()));
1528 let mut terminal = None;
1529 for _ in 0..capacity {
1530 if cancelled.load(Ordering::SeqCst) {
1531 terminal = Some(Err(StreamError::Cancelled));
1532 break;
1533 }
1534 match self.input.next() {
1535 Some(Ok(item)) => batch.push_back(item),
1536 Some(Err(error)) => {
1537 terminal = Some(Err(error));
1538 break;
1539 }
1540 None => {
1541 terminal = Some(Ok(()));
1542 break;
1543 }
1544 }
1545 }
1546 if !batch.is_empty() {
1547 self.state.push_broadcast_batch(&mut batch)?;
1548 }
1549 match terminal {
1550 Some(Err(StreamError::Cancelled)) => {
1551 self.state.fail(StreamError::Cancelled);
1552 guard.disarm();
1553 return Err(StreamError::Cancelled);
1554 }
1555 Some(Err(error)) => {
1556 self.state.fail(error.clone());
1557 guard.disarm();
1558 return Err(error);
1559 }
1560 Some(Ok(())) => {
1561 self.state.complete();
1562 guard.disarm();
1563 return Ok(NotUsed);
1564 }
1565 None => continue,
1566 }
1567 }
1568
1569 match self.input.next() {
1570 Some(Ok(item)) => self.state.push_broadcast_item(item)?,
1571 Some(Err(error)) => {
1572 self.state.fail(error.clone());
1573 guard.disarm();
1574 return Err(error);
1575 }
1576 None => {
1577 self.state.complete();
1578 guard.disarm();
1579 return Ok(NotUsed);
1580 }
1581 }
1582 }
1583 }
1584 FanOutMode::Partition => {
1585 let batch_producer = hints.inline_micro_max_success_items.is_some();
1586 let mut topology_cache = PartitionTopologyCache::new();
1587 let mut routed = PartitionRoutedBatches::new();
1588 loop {
1589 if cancelled.load(Ordering::SeqCst) {
1590 self.state.fail(StreamError::Cancelled);
1591 guard.disarm();
1592 return Err(StreamError::Cancelled);
1593 }
1594
1595 if batch_producer {
1596 for (_, batch) in routed.iter_mut() {
1597 batch.clear();
1598 }
1599 let partition_batch_limit = self.state.partition_batch_limit();
1600 let mut terminal = None;
1601 for _ in 0..partition_batch_limit {
1602 if cancelled.load(Ordering::SeqCst) {
1603 terminal = Some(Err(StreamError::Cancelled));
1604 break;
1605 }
1606 match self.input.next() {
1607 Some(Ok(item)) => {
1608 if let Some(selected) =
1609 self.state.select_partition(&item, &mut topology_cache)?
1610 {
1611 push_partition_routed(&mut routed, selected, item);
1612 }
1613 }
1614 Some(Err(error)) => {
1615 terminal = Some(Err(error));
1616 break;
1617 }
1618 None => {
1619 terminal = Some(Ok(()));
1620 break;
1621 }
1622 }
1623 }
1624 for (selected, batch) in routed.iter_mut() {
1625 if !batch.is_empty() {
1626 self.state
1627 .enqueue_partition_batch(Arc::clone(selected), batch)?;
1628 }
1629 }
1630 match terminal {
1631 Some(Err(StreamError::Cancelled)) => {
1632 self.state.fail(StreamError::Cancelled);
1633 guard.disarm();
1634 return Err(StreamError::Cancelled);
1635 }
1636 Some(Err(error)) => {
1637 self.state.fail(error.clone());
1638 guard.disarm();
1639 return Err(error);
1640 }
1641 Some(Ok(())) => {
1642 self.state.complete();
1643 guard.disarm();
1644 return Ok(NotUsed);
1645 }
1646 None => continue,
1647 }
1648 }
1649
1650 match self.input.next() {
1651 Some(Ok(item)) => {
1652 if let Some(selected) =
1653 self.state.select_partition(&item, &mut topology_cache)?
1654 {
1655 self.state.enqueue_partition(selected, item)?;
1656 }
1657 }
1658 Some(Err(error)) => {
1659 self.state.fail(error.clone());
1660 guard.disarm();
1661 return Err(error);
1662 }
1663 None => {
1664 self.state.complete();
1665 guard.disarm();
1666 return Ok(NotUsed);
1667 }
1668 }
1669 }
1670 }
1671 }
1672 }
1673}
1674
1675struct FanOutConsumerStream<T> {
1676 state: Arc<FanOutHubShared<T>>,
1677 lane: Arc<FanOutConsumerLane<T>>,
1678 local: Option<std::vec::IntoIter<T>>,
1679 detached: bool,
1680}
1681
1682impl<T: Clone + Send + 'static> Iterator for FanOutConsumerStream<T> {
1683 type Item = StreamResult<T>;
1684
1685 fn next(&mut self) -> Option<Self::Item> {
1686 if self.lane.failed.load(Ordering::Acquire) {
1687 self.local = None;
1688 } else if let Some(local) = &mut self.local {
1689 if let Some(item) = local.next() {
1690 return Some(Ok(item));
1691 }
1692 self.local = None;
1693 }
1694
1695 let mut state = self.lane.state.lock().expect("fan-out lane poisoned");
1696 loop {
1697 if let Some(FanOutTerminal::Failed(error)) = state.terminal.clone() {
1698 return Some(Err(error));
1699 }
1700 if let Some(batch) = state.chunks.pop_front() {
1701 state.queued = state.queued.saturating_sub(batch.len());
1702 self.lane.queued.store(state.queued, Ordering::Release);
1703 if matches!(self.state.mode, FanOutMode::Partition) {
1704 self.lane.condvar.notify_one();
1705 }
1706 self.state.notify_producer_transition();
1707 drop(state);
1708 let mut batch = batch.into_iter();
1709 let first = batch.next().expect("fan-out drained non-empty lane batch");
1710 if batch.len() > 0 {
1711 self.local = Some(batch);
1712 }
1713 return Some(Ok(first));
1714 }
1715 if matches!(state.terminal, Some(FanOutTerminal::Completed)) {
1716 return None;
1717 }
1718 let (guard, _) = self
1719 .lane
1720 .condvar
1721 .wait_timeout(state, fan_out_wait_timeout())
1722 .expect("fan-out lane poisoned while waiting");
1723 state = guard;
1724 }
1725 }
1726}
1727
1728impl<T> Drop for FanOutConsumerStream<T> {
1729 fn drop(&mut self) {
1730 if !self.detached {
1731 self.state.remove_consumer(self.lane.id());
1732 self.detached = true;
1733 }
1734 }
1735}
1736
1737#[cfg(test)]
1738mod tests {
1739 use super::*;
1740 use crate::testkit::{TestSink, TestSource};
1741 use crate::{Keep, Materializer, Sink, Source};
1742 use std::{
1743 panic::{self, AssertUnwindSafe},
1744 sync::{
1745 Arc,
1746 atomic::{AtomicUsize, Ordering},
1747 },
1748 thread,
1749 time::{Duration, Instant},
1750 };
1751
1752 #[test]
1753 fn merge_hub_accepts_dynamic_producers_and_drains() {
1754 let materializer = Materializer::new();
1755 let ((hub_sink, control), completion) = MergeHub::source_with_draining::<i32>(4)
1756 .to_mat(Sink::collect(), Keep::both)
1757 .run_with_materializer(&materializer)
1758 .expect("merge hub materializes");
1759
1760 hub_sink
1761 .clone()
1762 .run_with(Source::from_iter([1, 2, 3]))
1763 .expect("first producer attaches");
1764 hub_sink
1765 .run_with(Source::from_iter([4, 5]))
1766 .expect("second producer attaches");
1767 control.drain_and_complete();
1768 let mut result = completion.wait().expect("merge hub completes");
1769 result.sort_unstable();
1770 assert_eq!(result, vec![1, 2, 3, 4, 5]);
1771 }
1772
1773 #[test]
1774 fn merge_hub_direct_terminal_fold_counts_single_finite_producer() {
1775 let materializer = Materializer::new();
1776 let ((hub_sink, control), completion) = MergeHub::source_with_draining::<u64>(16)
1777 .to_mat(Sink::fold(0_u64, |acc, _| acc + 1), Keep::both)
1778 .run_with_materializer(&materializer)
1779 .expect("merge hub materializes");
1780
1781 hub_sink
1782 .run_with(Source::from_iter(0_u64..1024))
1783 .expect("producer attaches");
1784 control.drain_and_complete();
1785
1786 assert_eq!(completion.wait().expect("merge hub completes"), 1024);
1787 }
1788
1789 #[test]
1790 fn merge_hub_direct_terminal_fold_result_error_closes_source() {
1791 let materializer = Materializer::new();
1792 let ((hub_sink, _control), completion) = MergeHub::source_with_draining::<i32>(8)
1793 .to_mat(
1794 Sink::fold_result(0_i32, |acc, item| {
1795 if item == 3 {
1796 Err(StreamError::Failed("terminal failed".to_owned()))
1797 } else {
1798 Ok(acc + item)
1799 }
1800 }),
1801 Keep::both,
1802 )
1803 .run_with_materializer(&materializer)
1804 .expect("merge hub materializes");
1805
1806 hub_sink
1807 .clone()
1808 .run_with(Source::from_iter([1, 2, 3, 4]))
1809 .expect("producer attaches");
1810
1811 assert_eq!(
1812 completion.wait(),
1813 Err(StreamError::Failed("terminal failed".to_owned()))
1814 );
1815 assert!(hub_sink.run_with(Source::single(5)).is_err());
1816 }
1817
1818 #[test]
1819 fn merge_hub_direct_terminal_panic_closes_source() {
1820 let materializer = Materializer::new();
1821 let ((hub_sink, _control), completion) = MergeHub::source_with_draining::<i32>(8)
1822 .to_mat(
1823 Sink::fold(0_i32, |acc, item| {
1824 if item == 3 {
1825 panic!("terminal failed");
1826 }
1827 acc + item
1828 }),
1829 Keep::both,
1830 )
1831 .run_with_materializer(&materializer)
1832 .expect("merge hub materializes");
1833
1834 hub_sink
1835 .clone()
1836 .run_with(Source::from_iter([1, 2, 3, 4]))
1837 .expect("producer attaches");
1838
1839 assert_eq!(completion.wait(), Err(StreamError::AbruptTermination));
1840 assert!(hub_sink.run_with(Source::single(5)).is_err());
1841 }
1842
1843 #[test]
1844 fn merge_hub_producer_error_fails_downstream_consumer() {
1845 let materializer = Materializer::new();
1846 let (hub_sink, sink) = MergeHub::source::<i32>(4)
1847 .to_mat(TestSink::probe(), Keep::both)
1848 .run_with_materializer(&materializer)
1849 .expect("merge hub materializes");
1850
1851 let producer_ok = TestSource::probe::<i32>()
1852 .to_mat(hub_sink.clone(), Keep::left)
1853 .run_with_materializer(&materializer)
1854 .expect("successful producer attaches");
1855 let producer_fail = TestSource::probe::<i32>()
1856 .to_mat(hub_sink, Keep::left)
1857 .run_with_materializer(&materializer)
1858 .expect("failing producer attaches");
1859
1860 sink.request(1);
1861 assert_eq!(producer_ok.expect_request(), 1);
1862 producer_ok.send_next(1);
1863 sink.assert_next(1);
1864
1865 producer_fail.send_error(StreamError::Failed("producer failed".to_owned()));
1866 sink.request(1);
1867 assert_eq!(
1868 sink.expect_error(),
1869 StreamError::Failed("producer failed".to_owned())
1870 );
1871 }
1872
1873 #[test]
1874 fn broadcast_hub_backpressures_slowest_consumer() {
1875 let materializer = Materializer::new();
1876 let (publisher, hub_source) = TestSource::probe::<i32>()
1877 .to_mat(BroadcastHub::sink(1), Keep::both)
1878 .run_with_materializer(&materializer)
1879 .expect("broadcast hub materializes");
1880
1881 let sink_a = hub_source
1882 .source()
1883 .run_with(TestSink::probe())
1884 .expect("first consumer materializes");
1885 let sink_b = hub_source
1886 .source()
1887 .run_with(TestSink::probe())
1888 .expect("second consumer materializes");
1889
1890 sink_a.request(1);
1891 sink_b.request(1);
1892 assert_eq!(publisher.expect_request(), 1);
1893 publisher.send_next(1);
1894 sink_a.assert_next(1);
1895 sink_b.assert_next(1);
1896
1897 sink_a.request(1);
1898 assert_eq!(publisher.expect_request(), 1);
1899 publisher.send_next(2);
1900 sink_a.assert_next(2);
1901 sink_b.expect_no_message(Duration::from_millis(250));
1902
1903 sink_a.request(1);
1904 sink_a.expect_no_message(Duration::from_millis(250));
1905
1906 sink_b.request(1);
1907 sink_b.assert_next(2);
1908 assert_eq!(publisher.expect_request(), 1);
1909 }
1910
1911 #[test]
1912 fn broadcast_hub_late_consumer_sees_only_late_elements() {
1913 let materializer = Materializer::new();
1914 let (publisher, hub_source) = TestSource::probe::<i32>()
1915 .to_mat(BroadcastHub::sink(2), Keep::both)
1916 .run_with_materializer(&materializer)
1917 .expect("broadcast hub materializes");
1918
1919 let sink_a = hub_source
1920 .source()
1921 .run_with(TestSink::probe())
1922 .expect("first consumer materializes");
1923 sink_a.request(1);
1924 assert_eq!(publisher.expect_request(), 1);
1925 publisher.send_next(1);
1926 sink_a.assert_next(1);
1927
1928 let sink_b = hub_source
1929 .source()
1930 .run_with(TestSink::probe())
1931 .expect("late consumer materializes");
1932 sink_a.request(1);
1933 sink_b.request(1);
1934 assert_eq!(publisher.expect_request(), 1);
1935 publisher.send_next(2);
1936 sink_a.assert_next(2);
1937 sink_b.assert_next(2);
1938
1939 publisher.send_complete();
1940 sink_a.request(1);
1941 sink_b.request(1);
1942 sink_a.expect_complete();
1943 sink_b.expect_complete();
1944 }
1945
1946 #[test]
1947 fn partition_hub_routes_elements_to_selected_consumers() {
1948 let materializer = Materializer::new();
1949 let hub = Source::from_iter([0, 1, 2, 3])
1950 .run_with_materializer(
1951 PartitionHub::sink(
1952 |info, item| {
1953 let idx = (*item as usize) % info.size();
1954 info.consumer_id_by_idx(idx) as isize
1955 },
1956 2,
1957 8,
1958 ),
1959 &materializer,
1960 )
1961 .expect("partition hub materializes");
1962
1963 let sink_a = hub
1964 .source()
1965 .run_with(TestSink::probe())
1966 .expect("first consumer materializes");
1967 let sink_b = hub
1968 .source()
1969 .run_with(TestSink::probe())
1970 .expect("second consumer materializes");
1971
1972 sink_a.request(2);
1973 sink_b.request(2);
1974 sink_a.assert_next_n([0, 2]);
1975 sink_b.assert_next_n([1, 3]);
1976 }
1977
1978 #[test]
1979 fn partition_hub_evaluates_stateful_partitioner_once_per_blocked_element() {
1980 let materializer = Materializer::new();
1981 let partition_calls = Arc::new(AtomicUsize::new(0));
1982 let partition_calls_for_hub = Arc::clone(&partition_calls);
1983
1984 let (publisher, hub) = TestSource::probe::<i32>()
1985 .to_mat(
1986 PartitionHub::sink(
1987 move |info, _item| {
1988 partition_calls_for_hub.fetch_add(1, Ordering::SeqCst);
1989 info.consumer_id_by_idx(0) as isize
1990 },
1991 1,
1992 1,
1993 ),
1994 Keep::both,
1995 )
1996 .run_with_materializer(&materializer)
1997 .expect("partition hub materializes");
1998
1999 let sink = hub
2000 .source()
2001 .run_with(TestSink::probe())
2002 .expect("consumer materializes");
2003
2004 assert_eq!(publisher.expect_request(), 1);
2005 publisher.send_next(1);
2006 wait_for_partition_calls(&partition_calls, 1);
2007
2008 assert_eq!(publisher.expect_request(), 1);
2009 publisher.send_next(2);
2010 sink.expect_no_message(Duration::from_millis(250));
2011 wait_for_partition_calls(&partition_calls, 2);
2012
2013 sink.request(1);
2014 sink.assert_next(1);
2015 sink.request(1);
2016 sink.assert_next(2);
2017 assert_eq!(partition_calls.load(Ordering::SeqCst), 2);
2018 }
2019
2020 #[test]
2021 fn partition_hub_invalidates_cached_topology_on_consumer_churn() {
2022 let materializer = Materializer::new();
2023 let (publisher, hub) = TestSource::probe::<i32>()
2024 .to_mat(
2025 PartitionHub::sink(
2026 |info, _item| info.consumer_id_by_idx(info.size() - 1) as isize,
2027 1,
2028 8,
2029 ),
2030 Keep::both,
2031 )
2032 .run_with_materializer(&materializer)
2033 .expect("partition hub materializes");
2034
2035 let sink_a = hub
2036 .source()
2037 .run_with(TestSink::probe())
2038 .expect("first consumer materializes");
2039 sink_a.request(1);
2040 assert_eq!(publisher.expect_request(), 1);
2041 publisher.send_next(1);
2042 sink_a.assert_next(1);
2043 drop(sink_a);
2044
2045 let sink_b = hub
2046 .source()
2047 .run_with(TestSink::probe())
2048 .expect("replacement consumer materializes");
2049 sink_b.request(1);
2050 assert_eq!(publisher.expect_request(), 1);
2051 publisher.send_next(2);
2052 sink_b.assert_next(2);
2053
2054 publisher.send_complete();
2055 sink_b.request(1);
2056 sink_b.expect_complete();
2057 }
2058
2059 #[test]
2060 fn broadcast_hub_drains_local_chunk_before_completion() {
2061 let materializer = Materializer::new();
2062 let hub = Source::from_iter([1, 2, 3])
2063 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
2064 .expect("broadcast hub materializes");
2065
2066 let sink = hub
2067 .source()
2068 .run_with(TestSink::probe())
2069 .expect("consumer materializes");
2070
2071 sink.request(1);
2072 sink.assert_next(1);
2073 sink.request(3);
2074 sink.assert_next_n([2, 3]);
2075 sink.expect_complete();
2076 }
2077
2078 #[test]
2079 fn broadcast_hub_panicking_upstream_fails_consumers() {
2080 let materializer = Materializer::new();
2081 let hub = Source::from_fn_iter(|| {
2082 let mut yielded = false;
2083 std::iter::from_fn(move || {
2084 if !yielded {
2085 yielded = true;
2086 Some(1)
2087 } else {
2088 panic!("boom");
2089 }
2090 })
2091 })
2092 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
2093 .expect("broadcast hub materializes");
2094
2095 let sink = hub
2096 .source()
2097 .run_with(TestSink::probe())
2098 .expect("consumer materializes");
2099
2100 sink.request(2);
2101 match panic::catch_unwind(AssertUnwindSafe(|| sink.expect_error())) {
2102 Ok(error) => assert_eq!(
2103 error,
2104 StreamError::Failed("fan-out hub producer panicked".to_owned())
2105 ),
2106 Err(payload) => {
2107 assert_eq!(
2108 panic_message(payload),
2109 "expected stream error, got next element"
2110 );
2111 sink.request(1);
2112 assert_eq!(
2113 sink.expect_error(),
2114 StreamError::Failed("fan-out hub producer panicked".to_owned())
2115 );
2116 }
2117 }
2118 }
2119
2120 fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
2121 match payload.downcast::<String>() {
2122 Ok(message) => *message,
2123 Err(payload) => match payload.downcast::<&'static str>() {
2124 Ok(message) => (*message).to_owned(),
2125 Err(_) => "<non-string panic payload>".to_owned(),
2126 },
2127 }
2128 }
2129
2130 fn wait_for_partition_calls(counter: &AtomicUsize, expected: usize) {
2131 let deadline = Instant::now() + Duration::from_secs(1);
2132 while Instant::now() < deadline {
2133 if counter.load(Ordering::SeqCst) == expected {
2134 return;
2135 }
2136 thread::sleep(Duration::from_millis(5));
2137 }
2138 assert_eq!(counter.load(Ordering::SeqCst), expected);
2139 }
2140}