1use std::{
2 collections::{BTreeMap, VecDeque},
3 fmt,
4 sync::{
5 Arc, Condvar, Mutex, MutexGuard,
6 atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering},
7 },
8 time::Duration,
9};
10
11use arc_swap::ArcSwap;
12use smallvec::SmallVec;
13
14use crate::stream::{BoxStream, NotUsed, Sink, Source, SourceRuntimeHints, StreamCompletion};
15use crate::{StreamError, StreamResult};
16
17type Partitioner<T> = Arc<dyn Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync>;
18
19const MERGE_HUB_BATCH_LIMIT: usize = 256;
20const BROADCAST_HUB_BATCH_LIMIT: usize = 256;
21const BROADCAST_HUB_SINGLE_CONSUMER_BATCH_LIMIT: usize = 64;
22const PARTITION_HUB_BATCH_LIMIT: usize = 1024;
23const PARTITION_HUB_WIDE_BATCH_LIMIT: usize = 2048;
24const PARTITION_HUB_SINGLE_CONSUMER_BATCH_LIMIT: usize = 256;
25const FAN_OUT_CONSUMER_BATCH_LIMIT: usize = 256;
26
27fn fan_out_wait_timeout() -> Duration {
28 Duration::from_millis(1)
29}
30
31#[derive(Clone)]
32pub struct MergeHubDrainingControl {
33 state: Arc<MergeHubState>,
34 on_drain: Arc<dyn Fn() + Send + Sync>,
35}
36
37impl fmt::Debug for MergeHubDrainingControl {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 f.debug_struct("MergeHubDrainingControl").finish()
40 }
41}
42
43impl MergeHubDrainingControl {
44 pub fn drain_and_complete(&self) {
45 let mut state = self.state.lock();
46 state.draining = true;
47 self.state.condvar.notify_all();
48 drop(state);
49 (self.on_drain)();
50 }
51}
52
53pub struct MergeHub;
54
55impl MergeHub {
56 #[must_use]
58 pub fn source<T: Send + 'static>(
59 per_producer_buffer_size: usize,
60 ) -> Source<T, Sink<T, NotUsed>> {
61 Self::source_with_draining(per_producer_buffer_size)
62 .map_materialized_value(|(sink, _)| sink)
63 }
64
65 #[must_use]
68 pub fn source_with_draining<T: Send + 'static>(
69 per_producer_buffer_size: usize,
70 ) -> Source<T, (Sink<T, NotUsed>, MergeHubDrainingControl)> {
71 assert!(
72 per_producer_buffer_size > 0,
73 "MergeHub per_producer_buffer_size must be greater than zero"
74 );
75 Source::from_terminal_batch_materialized_factory(move |_| {
76 let state = Arc::new(MergeHubShared::<T>::new(per_producer_buffer_size));
77 let source = Box::new(MergeHubSourceStream {
78 state: Arc::clone(&state),
79 local: VecDeque::new(),
80 prefer_direct: false,
81 }) as BoxStream<T>;
82 let sink = merge_hub_sink(Arc::clone(&state));
83 let control = MergeHubDrainingControl {
84 state: Arc::clone(&state.state),
85 on_drain: Arc::new({
86 let state = Arc::clone(&state);
87 move || state.finish_if_draining()
88 }),
89 };
90 Ok((source, (sink, control)))
91 })
92 }
93}
94
95pub struct BroadcastHub;
96
97impl BroadcastHub {
98 #[must_use]
106 pub fn sink<T: Clone + Send + 'static>(
107 buffer_size: usize,
108 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
109 Self::sink_starting_after(0, buffer_size)
110 }
111
112 #[must_use]
114 pub fn sink_starting_after<T: Clone + Send + 'static>(
115 start_after_nr_of_consumers: usize,
116 buffer_size: usize,
117 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
118 assert!(
119 buffer_size > 0,
120 "BroadcastHub buffer_size must be greater than zero"
121 );
122 Sink::from_hinted_runner(move |input, materializer, hints| {
123 let state = Arc::new(FanOutHubShared::new(
124 FanOutMode::Broadcast,
125 start_after_nr_of_consumers,
126 buffer_size,
127 None::<Partitioner<T>>,
128 ));
129 let source = BroadcastHubConsumerSource {
130 state: Arc::clone(&state),
131 completion: Arc::new(Mutex::new(None)),
132 };
133 let completion = materializer.spawn_stream(move |cancelled| {
134 FanOutProducer::new(input, state).run(cancelled, hints)
135 });
136 source.attach_completion(completion);
137 Ok(source)
138 })
139 }
140}
141
142pub struct PartitionHub;
143
144impl PartitionHub {
145 #[must_use]
153 pub fn sink<T: Clone + Send + 'static, F>(
154 partitioner: F,
155 start_after_nr_of_consumers: usize,
156 buffer_size: usize,
157 ) -> Sink<T, PartitionHubConsumerSource<T>>
158 where
159 F: Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync + 'static,
160 {
161 assert!(
162 buffer_size > 0,
163 "PartitionHub buffer_size must be greater than zero"
164 );
165 let partitioner = Arc::new(partitioner);
166 Sink::from_hinted_runner(move |input, materializer, hints| {
167 let partitioner = Arc::clone(&partitioner);
168 let state = Arc::new(FanOutHubShared::new(
169 FanOutMode::Partition,
170 start_after_nr_of_consumers,
171 buffer_size,
172 Some(partitioner),
173 ));
174 let source = PartitionHubConsumerSource {
175 state: Arc::clone(&state),
176 completion: Arc::new(Mutex::new(None)),
177 };
178 let completion = materializer.spawn_stream(move |cancelled| {
179 FanOutProducer::new(input, state).run(cancelled, hints)
180 });
181 source.attach_completion(completion);
182 Ok(source)
183 })
184 }
185}
186
187#[derive(Clone)]
188pub struct BroadcastHubConsumerSource<T> {
189 state: Arc<FanOutHubShared<T>>,
190 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
191}
192
193impl<T: Clone + Send + 'static> BroadcastHubConsumerSource<T> {
194 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
195 *self
196 .completion
197 .lock()
198 .expect("broadcast hub completion poisoned") = Some(completion);
199 }
200
201 #[must_use]
202 pub fn source(&self) -> Source<T, NotUsed> {
203 let state = Arc::clone(&self.state);
204 Source::from_materialized_factory(move |_| {
205 let lane = state.register_consumer();
206 let stream = Box::new(FanOutConsumerStream {
207 state: Arc::clone(&state),
208 lane,
209 local: None,
210 detached: false,
211 }) as BoxStream<T>;
212 Ok((stream, NotUsed))
213 })
214 }
215}
216
217impl<T: Clone + Send + 'static> fmt::Debug for BroadcastHubConsumerSource<T> {
218 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
219 f.debug_struct("BroadcastHubConsumerSource").finish()
220 }
221}
222
223#[derive(Clone)]
224pub struct PartitionHubConsumerSource<T> {
225 state: Arc<FanOutHubShared<T>>,
226 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
227}
228
229impl<T: Clone + Send + 'static> PartitionHubConsumerSource<T> {
230 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
231 *self
232 .completion
233 .lock()
234 .expect("partition hub completion poisoned") = Some(completion);
235 }
236
237 #[must_use]
238 pub fn source(&self) -> Source<T, NotUsed> {
239 let state = Arc::clone(&self.state);
240 Source::from_materialized_factory(move |_| {
241 let lane = state.register_consumer();
242 let stream = Box::new(FanOutConsumerStream {
243 state: Arc::clone(&state),
244 lane,
245 local: None,
246 detached: false,
247 }) as BoxStream<T>;
248 Ok((stream, NotUsed))
249 })
250 }
251}
252
253impl<T: Clone + Send + 'static> fmt::Debug for PartitionHubConsumerSource<T> {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("PartitionHubConsumerSource").finish()
256 }
257}
258
259#[derive(Clone, Debug)]
260pub struct PartitionConsumerInfo {
261 consumer_ids: SmallVec<[u64; 16]>,
262 queue_sizes: SmallVec<[(u64, usize); 16]>,
263}
264
265impl PartitionConsumerInfo {
266 #[must_use]
267 pub fn size(&self) -> usize {
268 self.consumer_ids.len()
269 }
270
271 #[must_use]
272 pub fn consumer_ids(&self) -> &[u64] {
273 &self.consumer_ids
274 }
275
276 #[must_use]
277 pub fn consumer_id_by_idx(&self, idx: usize) -> u64 {
278 self.consumer_ids[idx]
279 }
280
281 #[must_use]
282 pub fn queue_size(&self, consumer_id: u64) -> usize {
283 self.queue_sizes
284 .iter()
285 .find_map(|(id, size)| (*id == consumer_id).then_some(*size))
286 .unwrap_or(0)
287 }
288}
289
290fn merge_hub_sink<T: Send + 'static>(state: Arc<MergeHubShared<T>>) -> Sink<T, NotUsed> {
291 Sink::from_raw_hinted_runner(move |input, materializer, hints| {
292 if hints.inline_micro_max_success_items.is_some() {
293 state.register_direct_producer(input)?;
294 return Ok(NotUsed);
295 }
296
297 let input = materializer.checked_stream(input, None);
298 let producer_id = state.register_producer()?;
299 let hub = Arc::clone(&state);
300 let completion = materializer.spawn_stream(move |cancelled| {
301 let mut input = input;
302 loop {
303 if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
304 hub.fail(StreamError::Cancelled);
305 hub.deregister_producer(producer_id);
306 return Err(StreamError::Cancelled);
307 }
308
309 match input.next() {
310 Some(Ok(item)) => hub.push_item(producer_id, item)?,
311 Some(Err(error)) => {
312 hub.fail(error.clone());
313 hub.deregister_producer(producer_id);
314 return Err(error);
315 }
316 None => {
317 hub.deregister_producer(producer_id);
318 return Ok(NotUsed);
319 }
320 }
321 }
322 });
323 state.store_producer_completion(completion);
324 Ok(NotUsed)
325 })
326}
327
328struct MergeHubShared<T> {
329 state: Arc<MergeHubState>,
330 shared: Mutex<MergeHubInner<T>>,
331 condvar: Condvar,
332 failed: AtomicBool,
333}
334
335#[derive(Debug)]
336struct MergeHubState {
337 inner: Mutex<MergeHubFlags>,
338 condvar: Condvar,
339}
340
341#[derive(Debug, Default)]
342struct MergeHubFlags {
343 draining: bool,
344}
345
346impl MergeHubState {
347 fn lock(&self) -> MutexGuard<'_, MergeHubFlags> {
348 self.inner.lock().expect("merge hub flags poisoned")
349 }
350}
351
352struct MergeHubInner<T> {
353 queue: VecDeque<(u64, T)>,
354 direct_producers: VecDeque<MergeHubDirectProducer<T>>,
355 queued_per_producer: BTreeMap<u64, usize>,
356 producer_completions: Vec<StreamCompletion<NotUsed>>,
357 active_producers: usize,
358 next_producer_id: u64,
359 source_closed: bool,
360 completed: bool,
361 failed: Option<StreamError>,
362 per_producer_buffer_size: usize,
363}
364
365struct MergeHubDirectProducer<T> {
366 id: u64,
367 input: BoxStream<T>,
368}
369
370enum MergeHubProducerTerminal {
371 Active,
372 Completed,
373 Failed(StreamError),
374}
375
376impl<T> MergeHubShared<T> {
377 fn new(per_producer_buffer_size: usize) -> Self {
378 Self {
379 state: Arc::new(MergeHubState {
380 inner: Mutex::new(MergeHubFlags::default()),
381 condvar: Condvar::new(),
382 }),
383 shared: Mutex::new(MergeHubInner {
384 queue: VecDeque::new(),
385 direct_producers: VecDeque::new(),
386 queued_per_producer: BTreeMap::new(),
387 producer_completions: Vec::new(),
388 active_producers: 0,
389 next_producer_id: 0,
390 source_closed: false,
391 completed: false,
392 failed: None,
393 per_producer_buffer_size,
394 }),
395 condvar: Condvar::new(),
396 failed: AtomicBool::new(false),
397 }
398 }
399
400 fn register_direct_producer(&self, input: BoxStream<T>) -> StreamResult<()> {
401 let mut inner = self.shared.lock().expect("merge hub poisoned");
402 prune_finished_producer_completions(&mut inner.producer_completions);
403 let flags = self.state.lock();
404 if flags.draining || inner.source_closed || inner.completed {
405 return Err(StreamError::Failed(
406 "merge hub is draining or closed to new producers".to_owned(),
407 ));
408 }
409 if let Some(error) = inner.failed.clone() {
410 return Err(error);
411 }
412 let id = inner.next_producer_id;
413 inner.next_producer_id += 1;
414 inner.active_producers += 1;
415 inner
416 .direct_producers
417 .push_back(MergeHubDirectProducer { id, input });
418 drop(flags);
419 drop(inner);
420 self.condvar.notify_all();
421 Ok(())
422 }
423
424 fn register_producer(&self) -> StreamResult<u64> {
425 let mut inner = self.shared.lock().expect("merge hub poisoned");
426 prune_finished_producer_completions(&mut inner.producer_completions);
427 let flags = self.state.lock();
428 if flags.draining || inner.source_closed || inner.completed {
429 return Err(StreamError::Failed(
430 "merge hub is draining or closed to new producers".to_owned(),
431 ));
432 }
433 if let Some(error) = inner.failed.clone() {
434 return Err(error);
435 }
436 let id = inner.next_producer_id;
437 inner.next_producer_id += 1;
438 inner.active_producers += 1;
439 inner.queued_per_producer.insert(id, 0);
440 Ok(id)
441 }
442
443 fn store_producer_completion(&self, completion: StreamCompletion<NotUsed>) {
444 let mut inner = self.shared.lock().expect("merge hub poisoned");
445 prune_finished_producer_completions(&mut inner.producer_completions);
446 inner.producer_completions.push(completion);
447 }
448
449 fn push_item(&self, producer_id: u64, item: T) -> StreamResult<()> {
450 let mut inner = self.shared.lock().expect("merge hub poisoned");
451 loop {
452 if let Some(error) = inner.failed.clone() {
453 inner.queued_per_producer.remove(&producer_id);
454 return Err(error);
455 }
456 if inner.source_closed {
457 inner.queued_per_producer.remove(&producer_id);
458 return Err(StreamError::Cancelled);
459 }
460 let queued = inner
461 .queued_per_producer
462 .get(&producer_id)
463 .copied()
464 .unwrap_or(0);
465 if queued < inner.per_producer_buffer_size {
466 inner.queue.push_back((producer_id, item));
467 inner.queued_per_producer.insert(producer_id, queued + 1);
468 self.condvar.notify_all();
469 return Ok(());
470 }
471 inner = self
472 .condvar
473 .wait(inner)
474 .expect("merge hub poisoned while waiting");
475 }
476 }
477
478 fn deregister_producer(&self, producer_id: u64) {
479 let mut inner = self.shared.lock().expect("merge hub poisoned");
480 prune_finished_producer_completions(&mut inner.producer_completions);
481 inner.queued_per_producer.remove(&producer_id);
482 inner
483 .direct_producers
484 .retain(|producer| producer.id != producer_id);
485 inner.active_producers = inner.active_producers.saturating_sub(1);
486 if inner.active_producers == 0 {
487 let flags = self.state.lock();
488 if flags.draining {
489 inner.completed = true;
490 }
491 }
492 self.condvar.notify_all();
493 }
494
495 fn pop_direct_producer(
496 &self,
497 inner: &mut MergeHubInner<T>,
498 ) -> Option<MergeHubDirectProducer<T>> {
499 inner.direct_producers.pop_front()
500 }
501
502 fn restore_direct_producer(
503 &self,
504 producer: MergeHubDirectProducer<T>,
505 terminal: MergeHubProducerTerminal,
506 ) -> StreamResult<()> {
507 match terminal {
508 MergeHubProducerTerminal::Active => {
509 let mut inner = self.shared.lock().expect("merge hub poisoned");
510 if let Some(error) = inner.failed.clone() {
511 inner.active_producers = inner.active_producers.saturating_sub(1);
512 return Err(error);
513 }
514 if inner.source_closed {
515 inner.active_producers = inner.active_producers.saturating_sub(1);
516 return Err(StreamError::Cancelled);
517 }
518 inner.direct_producers.push_back(producer);
519 Ok(())
520 }
521 MergeHubProducerTerminal::Completed => {
522 self.deregister_producer(producer.id);
523 Ok(())
524 }
525 MergeHubProducerTerminal::Failed(error) => {
526 self.fail(error.clone());
527 self.deregister_producer(producer.id);
528 Err(error)
529 }
530 }
531 }
532
533 fn fail(&self, error: StreamError) {
534 let mut inner = self.shared.lock().expect("merge hub poisoned");
535 if inner.failed.is_none() {
536 inner.failed = Some(error);
537 self.failed.store(true, Ordering::SeqCst);
538 }
539 self.condvar.notify_all();
540 }
541
542 fn failed_error(&self) -> Option<StreamError> {
543 if !self.failed.load(Ordering::SeqCst) {
544 return None;
545 }
546 self.shared
547 .lock()
548 .expect("merge hub poisoned")
549 .failed
550 .clone()
551 }
552
553 fn finish_if_draining(&self) {
554 let flags = self.state.lock();
555 if !flags.draining {
556 return;
557 }
558 drop(flags);
559
560 let mut inner = self.shared.lock().expect("merge hub poisoned");
561 prune_finished_producer_completions(&mut inner.producer_completions);
562 if inner.active_producers == 0 {
563 inner.completed = true;
564 self.condvar.notify_all();
565 }
566 }
567}
568
569fn prune_finished_producer_completions(completions: &mut Vec<StreamCompletion<NotUsed>>) {
570 let mut index = 0;
571 while index < completions.len() {
572 if completions[index].try_wait().is_some() {
573 drop(completions.swap_remove(index));
574 } else {
575 index += 1;
576 }
577 }
578}
579
580struct MergeHubSourceStream<T> {
581 state: Arc<MergeHubShared<T>>,
582 local: VecDeque<T>,
583 prefer_direct: bool,
584}
585
586impl<T> Iterator for MergeHubSourceStream<T> {
587 type Item = StreamResult<T>;
588
589 fn next(&mut self) -> Option<Self::Item> {
590 if let Some(error) = self.state.failed_error() {
591 self.local.clear();
592 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
593 inner.source_closed = true;
594 return Some(Err(error));
595 }
596 if let Some(item) = self.local.pop_front() {
597 return Some(Ok(item));
598 }
599
600 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
601 loop {
602 if let Some(error) = inner.failed.clone() {
603 inner.source_closed = true;
604 return Some(Err(error));
605 }
606 let should_drain_direct = !inner.direct_producers.is_empty()
607 && (self.prefer_direct || inner.queue.is_empty());
608 if should_drain_direct {
609 let Some(mut producer) = self.state.pop_direct_producer(&mut inner) else {
610 continue;
611 };
612 let drain_limit = inner
613 .per_producer_buffer_size
614 .clamp(1, MERGE_HUB_BATCH_LIMIT);
615 drop(inner);
616
617 let mut batch = std::mem::take(&mut self.local);
618 batch.clear();
619 batch.reserve(drain_limit.saturating_sub(batch.capacity()));
620 let mut terminal = MergeHubProducerTerminal::Active;
621 for _ in 0..drain_limit {
622 match producer.input.next() {
623 Some(Ok(item)) => batch.push_back(item),
624 Some(Err(error)) => {
625 terminal = MergeHubProducerTerminal::Failed(error);
626 break;
627 }
628 None => {
629 terminal = MergeHubProducerTerminal::Completed;
630 break;
631 }
632 }
633 }
634
635 let restore_result = self.state.restore_direct_producer(producer, terminal);
636 match restore_result {
637 Ok(()) => {
638 self.prefer_direct = false;
639 if let Some(first) = batch.pop_front() {
640 self.local = batch;
641 return Some(Ok(first));
642 }
643 inner = self.state.shared.lock().expect("merge hub poisoned");
644 continue;
645 }
646 Err(error) => {
647 self.local.clear();
648 return Some(Err(error));
649 }
650 }
651 }
652 if !inner.queue.is_empty() {
653 let drain_n = inner.queue.len().min(MERGE_HUB_BATCH_LIMIT);
654 let mut batch = std::mem::take(&mut self.local);
655 batch.clear();
656 batch.reserve(drain_n.saturating_sub(batch.capacity()));
657 for _ in 0..drain_n {
658 if let Some((producer_id, item)) = inner.queue.pop_front() {
659 if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
660 *queued = queued.saturating_sub(1);
661 }
662 batch.push_back(item);
663 }
664 }
665 self.state.condvar.notify_all();
666 drop(inner);
667 let first = batch
668 .pop_front()
669 .expect("merge hub drained non-empty batch");
670 self.local = batch;
671 self.prefer_direct = true;
672 return Some(Ok(first));
673 }
674 if inner.completed {
675 inner.source_closed = true;
676 return None;
677 }
678 inner = self
679 .state
680 .condvar
681 .wait(inner)
682 .expect("merge hub poisoned while waiting");
683 }
684 }
685}
686
687impl<T> Drop for MergeHubSourceStream<T> {
688 fn drop(&mut self) {
689 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
690 inner.source_closed = true;
691 self.state.condvar.notify_all();
692 }
693}
694
695#[derive(Clone, Copy)]
696enum FanOutMode {
697 Broadcast,
698 Partition,
699}
700
701struct FanOutHubShared<T> {
702 registry: Mutex<FanOutRegistry<T>>,
703 snapshot: ArcSwap<FanOutSnapshot<T>>,
704 producer_wait: Mutex<()>,
705 producer_condvar: Condvar,
706 producer_epoch: AtomicU64,
707 topology_epoch: AtomicU64,
708 mode: FanOutMode,
709 start_after_nr_of_consumers: usize,
710 buffer_size: usize,
711 partitioner: Option<Partitioner<T>>,
712}
713
714struct FanOutRegistry<T> {
715 consumers: BTreeMap<u64, Arc<FanOutConsumerLane<T>>>,
716 next_consumer_id: u64,
717 terminal: Option<FanOutTerminal>,
718}
719
720struct FanOutSnapshot<T> {
721 consumers: Vec<Arc<FanOutConsumerLane<T>>>,
722 terminal: Option<FanOutTerminal>,
723}
724
725#[derive(Clone, Debug)]
726enum FanOutTerminal {
727 Completed,
728 Failed(StreamError),
729}
730
731impl FanOutTerminal {
732 fn producer_error(&self) -> StreamError {
733 match self {
734 Self::Completed => StreamError::Cancelled,
735 Self::Failed(error) => error.clone(),
736 }
737 }
738}
739
740struct FanOutConsumerLane<T> {
741 id: u64,
742 state: Mutex<FanOutLaneState<T>>,
743 condvar: Condvar,
744 queued: AtomicUsize,
745 active: AtomicBool,
746 failed: AtomicBool,
747}
748
749struct FanOutLaneState<T> {
750 chunks: VecDeque<Vec<T>>,
751 queued: usize,
752 terminal: Option<FanOutTerminal>,
753}
754
755impl<T> FanOutConsumerLane<T> {
756 fn new(id: u64, buffer_size: usize) -> Self {
757 Self {
758 id,
759 state: Mutex::new(FanOutLaneState {
760 chunks: VecDeque::with_capacity((buffer_size / FAN_OUT_CONSUMER_BATCH_LIMIT) + 1),
761 queued: 0,
762 terminal: None,
763 }),
764 condvar: Condvar::new(),
765 queued: AtomicUsize::new(0),
766 active: AtomicBool::new(true),
767 failed: AtomicBool::new(false),
768 }
769 }
770
771 fn terminal(id: u64, terminal: FanOutTerminal) -> Self {
772 let failed = matches!(terminal, FanOutTerminal::Failed(_));
773 Self {
774 id,
775 state: Mutex::new(FanOutLaneState {
776 chunks: VecDeque::new(),
777 queued: 0,
778 terminal: Some(terminal),
779 }),
780 condvar: Condvar::new(),
781 queued: AtomicUsize::new(0),
782 active: AtomicBool::new(false),
783 failed: AtomicBool::new(failed),
784 }
785 }
786
787 fn id(&self) -> u64 {
788 self.id
789 }
790
791 fn queued_len(&self) -> usize {
792 self.queued.load(Ordering::Acquire)
793 }
794
795 fn is_active(&self) -> bool {
796 self.active.load(Ordering::Acquire)
797 }
798
799 fn deactivate(&self) {
800 self.active.store(false, Ordering::Release);
801 self.condvar.notify_all();
802 }
803
804 fn set_terminal(&self, terminal: FanOutTerminal) {
805 let mut state = self.state.lock().expect("fan-out lane poisoned");
806 if matches!(terminal, FanOutTerminal::Failed(_)) {
807 state.chunks.clear();
808 state.queued = 0;
809 self.queued.store(0, Ordering::Release);
810 self.failed.store(true, Ordering::Release);
811 }
812 state.terminal = Some(terminal);
813 drop(state);
814 self.condvar.notify_all();
815 }
816}
817
818impl PartitionConsumerInfo {
819 fn from_cached_lanes<T>(consumer_ids: &[u64], lanes: &[Arc<FanOutConsumerLane<T>>]) -> Self {
820 Self {
821 consumer_ids: consumer_ids.iter().copied().collect(),
822 queue_sizes: lanes
823 .iter()
824 .map(|lane| (lane.id(), lane.queued_len()))
825 .collect(),
826 }
827 }
828}
829
830struct PartitionTopologyCache<T> {
831 epoch: u64,
832 consumer_ids: SmallVec<[u64; 16]>,
833 lanes: SmallVec<[Arc<FanOutConsumerLane<T>>; 16]>,
834}
835
836type PartitionRoutedBatches<T> = SmallVec<[(Arc<FanOutConsumerLane<T>>, VecDeque<T>); 16]>;
837
838impl<T> PartitionTopologyCache<T> {
839 fn new() -> Self {
840 Self {
841 epoch: u64::MAX,
842 consumer_ids: SmallVec::new(),
843 lanes: SmallVec::new(),
844 }
845 }
846
847 fn clear(&mut self) {
848 self.epoch = u64::MAX;
849 self.consumer_ids.clear();
850 self.lanes.clear();
851 }
852}
853
854impl<T> FanOutHubShared<T> {
855 fn new(
856 mode: FanOutMode,
857 start_after_nr_of_consumers: usize,
858 buffer_size: usize,
859 partitioner: Option<Partitioner<T>>,
860 ) -> Self {
861 Self {
862 registry: Mutex::new(FanOutRegistry {
863 consumers: BTreeMap::new(),
864 next_consumer_id: 0,
865 terminal: None,
866 }),
867 snapshot: ArcSwap::from_pointee(FanOutSnapshot {
868 consumers: Vec::new(),
869 terminal: None,
870 }),
871 producer_wait: Mutex::new(()),
872 producer_condvar: Condvar::new(),
873 producer_epoch: AtomicU64::new(0),
874 topology_epoch: AtomicU64::new(0),
875 mode,
876 start_after_nr_of_consumers,
877 buffer_size,
878 partitioner,
879 }
880 }
881
882 fn register_consumer(&self) -> Arc<FanOutConsumerLane<T>> {
883 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
884 let id = registry.next_consumer_id;
885 registry.next_consumer_id += 1;
886
887 if let Some(terminal) = registry.terminal.clone() {
888 return Arc::new(FanOutConsumerLane::terminal(id, terminal));
889 }
890
891 let lane = Arc::new(FanOutConsumerLane::new(id, self.buffer_size));
892 registry.consumers.insert(id, Arc::clone(&lane));
893 self.publish_snapshot_locked(®istry);
894 drop(registry);
895 self.notify_topology_transition();
896 lane
897 }
898
899 fn remove_consumer(&self, consumer_id: u64) {
900 let lane = {
901 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
902 let lane = registry.consumers.remove(&consumer_id);
903 if let Some(lane) = &lane {
904 lane.deactivate();
905 self.publish_snapshot_locked(®istry);
906 }
907 lane
908 };
909
910 if let Some(lane) = lane {
911 lane.condvar.notify_all();
912 self.notify_topology_transition();
913 }
914 }
915
916 fn wait_for_broadcast_capacity(&self, max_items: usize) -> StreamResult<usize> {
917 loop {
918 let observed = self.producer_epoch.load(Ordering::Acquire);
919 let snapshot = self.snapshot.load();
920 if let Some(terminal) = &snapshot.terminal {
921 return Err(terminal.producer_error());
922 }
923
924 let lanes = self.active_lanes(&snapshot);
925 if lanes.len() < self.start_after_nr_of_consumers || lanes.is_empty() {
926 self.wait_for_producer_transition(observed);
927 continue;
928 }
929
930 let free = lanes
931 .iter()
932 .map(|lane| self.buffer_size.saturating_sub(lane.queued_len()))
933 .min()
934 .unwrap_or(0);
935 if free > 0 {
936 let max_items = if lanes.len() == 1 {
937 max_items.min(BROADCAST_HUB_SINGLE_CONSUMER_BATCH_LIMIT)
938 } else {
939 max_items
940 };
941 return Ok(free.min(max_items).max(1));
942 }
943 self.wait_for_producer_transition(observed);
944 }
945 }
946
947 fn push_broadcast_item(&self, item: T) -> StreamResult<()>
948 where
949 T: Clone,
950 {
951 let mut batch = VecDeque::new();
952 batch.push_back(item);
953 self.push_broadcast_batch(&mut batch)
954 }
955
956 fn push_broadcast_batch(&self, batch: &mut VecDeque<T>) -> StreamResult<()>
957 where
958 T: Clone,
959 {
960 if batch.is_empty() {
961 return Ok(());
962 }
963
964 while !batch.is_empty() {
965 let observed = self.producer_epoch.load(Ordering::Acquire);
966 let snapshot = self.snapshot.load();
967 if let Some(terminal) = &snapshot.terminal {
968 batch.clear();
969 return Err(terminal.producer_error());
970 }
971
972 let lanes = self.active_lanes(&snapshot);
973 if lanes.len() < self.start_after_nr_of_consumers || lanes.is_empty() {
974 self.wait_for_producer_transition(observed);
975 continue;
976 }
977
978 let free = lanes
979 .iter()
980 .map(|lane| self.buffer_size.saturating_sub(lane.queued_len()))
981 .min()
982 .unwrap_or(0);
983 if free == 0 {
984 self.wait_for_producer_transition(observed);
985 continue;
986 }
987
988 let take_n = free.min(batch.len());
989 if !self.try_push_broadcast_batch(&lanes, batch, take_n)? {
990 self.wait_for_producer_transition(observed);
991 }
992 }
993 Ok(())
994 }
995
996 fn try_push_broadcast_batch(
997 &self,
998 lanes: &[Arc<FanOutConsumerLane<T>>],
999 batch: &mut VecDeque<T>,
1000 take_n: usize,
1001 ) -> StreamResult<bool>
1002 where
1003 T: Clone,
1004 {
1005 if take_n == 0 || lanes.is_empty() {
1006 return Ok(false);
1007 }
1008
1009 let mut guards = SmallVec::<
1010 [(
1011 Arc<FanOutConsumerLane<T>>,
1012 MutexGuard<'_, FanOutLaneState<T>>,
1013 ); 16],
1014 >::new();
1015 for lane in lanes {
1016 if !lane.is_active() {
1017 return Ok(false);
1018 }
1019 let guard = lane.state.lock().expect("fan-out lane poisoned");
1020 if !lane.is_active() {
1021 return Ok(false);
1022 }
1023 if let Some(terminal) = &guard.terminal {
1024 return Err(terminal.producer_error());
1025 }
1026 if self.buffer_size.saturating_sub(guard.queued) < take_n {
1027 return Ok(false);
1028 }
1029 guards.push((Arc::clone(lane), guard));
1030 }
1031
1032 let lane_count = guards.len();
1033 if lane_count == 0 {
1034 return Ok(false);
1035 }
1036
1037 let mut notify_lanes = SmallVec::<[Arc<FanOutConsumerLane<T>>; 16]>::new();
1038 for (index, (lane, guard)) in guards.iter_mut().enumerate() {
1039 let chunk = if index + 1 == lane_count {
1040 batch.drain(..take_n).collect()
1041 } else {
1042 batch.iter().take(take_n).cloned().collect()
1043 };
1044 guard.chunks.push_back(chunk);
1045 guard.queued += take_n;
1046 lane.queued.store(guard.queued, Ordering::Release);
1047 notify_lanes.push(Arc::clone(lane));
1048 }
1049 drop(guards);
1050
1051 for lane in notify_lanes {
1052 lane.condvar.notify_one();
1053 }
1054 Ok(true)
1055 }
1056
1057 fn select_partition(
1058 &self,
1059 item: &T,
1060 cache: &mut PartitionTopologyCache<T>,
1061 ) -> StreamResult<Option<Arc<FanOutConsumerLane<T>>>> {
1062 let Some(partitioner) = &self.partitioner else {
1063 return Err(StreamError::Failed(
1064 "partition hub partitioner missing".to_owned(),
1065 ));
1066 };
1067 let topology_epoch = self.topology_epoch.load(Ordering::Acquire);
1068 if cache.epoch != topology_epoch || cache.lanes.is_empty() {
1069 self.refresh_partition_topology(cache)?;
1070 }
1071
1072 let info = PartitionConsumerInfo::from_cached_lanes(&cache.consumer_ids, &cache.lanes);
1073 let selected = partitioner(&info, item);
1074 if selected < 0 {
1075 return Ok(None);
1076 }
1077 let selected = selected as u64;
1078 let selected_idx = selected as usize;
1079 if cache.consumer_ids.get(selected_idx).copied() == Some(selected) {
1080 return Ok(Some(Arc::clone(&cache.lanes[selected_idx])));
1081 }
1082 cache
1083 .consumer_ids
1084 .iter()
1085 .position(|id| *id == selected)
1086 .map(|idx| Some(Arc::clone(&cache.lanes[idx])))
1087 .ok_or_else(|| {
1088 StreamError::Failed("partition hub selected unknown consumer".to_owned())
1089 })
1090 }
1091
1092 fn refresh_partition_topology(
1093 &self,
1094 cache: &mut PartitionTopologyCache<T>,
1095 ) -> StreamResult<()> {
1096 loop {
1097 let observed = self.producer_epoch.load(Ordering::Acquire);
1098 let topology_epoch = self.topology_epoch.load(Ordering::Acquire);
1099 let snapshot = self.snapshot.load();
1100 if let Some(terminal) = &snapshot.terminal {
1101 cache.clear();
1102 return Err(terminal.producer_error());
1103 }
1104
1105 let lanes = self.active_lanes(&snapshot);
1106 if lanes.len() >= self.start_after_nr_of_consumers && !lanes.is_empty() {
1107 cache.epoch = topology_epoch;
1108 cache.consumer_ids.clear();
1109 cache
1110 .consumer_ids
1111 .extend(lanes.iter().map(|lane| lane.id()));
1112 cache.lanes = lanes;
1113 return Ok(());
1114 }
1115 self.wait_for_producer_transition(observed);
1116 }
1117 }
1118
1119 fn enqueue_partition(&self, selected: Arc<FanOutConsumerLane<T>>, item: T) -> StreamResult<()> {
1120 let mut batch = VecDeque::new();
1121 batch.push_back(item);
1122 self.enqueue_partition_batch(selected, &mut batch)
1123 }
1124
1125 fn enqueue_partition_batch(
1126 &self,
1127 selected: Arc<FanOutConsumerLane<T>>,
1128 batch: &mut VecDeque<T>,
1129 ) -> StreamResult<()> {
1130 if batch.is_empty() {
1131 return Ok(());
1132 }
1133
1134 let mut state = selected.state.lock().expect("fan-out lane poisoned");
1135 loop {
1136 if !selected.is_active() {
1137 batch.clear();
1138 return Err(StreamError::Failed(
1139 "partition hub selected unknown consumer".to_owned(),
1140 ));
1141 }
1142 if let Some(terminal) = &state.terminal {
1143 batch.clear();
1144 return Err(terminal.producer_error());
1145 }
1146 let free = self.buffer_size.saturating_sub(state.queued);
1147 if free > 0 {
1148 let take_n = free.min(batch.len());
1149 let chunk = batch.drain(..take_n).collect();
1150 state.chunks.push_back(chunk);
1151 state.queued += take_n;
1152 selected.queued.store(state.queued, Ordering::Release);
1153 selected.condvar.notify_one();
1154 if batch.is_empty() {
1155 return Ok(());
1156 }
1157 }
1158 let (guard, _) = selected
1159 .condvar
1160 .wait_timeout(state, fan_out_wait_timeout())
1161 .expect("fan-out lane poisoned while waiting");
1162 state = guard;
1163 }
1164 }
1165
1166 fn complete(&self) {
1167 let lanes = {
1168 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1169 if registry.terminal.is_none() {
1170 registry.terminal = Some(FanOutTerminal::Completed);
1171 self.publish_snapshot_locked(®istry);
1172 }
1173 registry.consumers.values().cloned().collect::<Vec<_>>()
1174 };
1175
1176 for lane in lanes {
1177 lane.set_terminal(FanOutTerminal::Completed);
1178 }
1179 self.notify_topology_transition();
1180 }
1181
1182 fn fail(&self, error: StreamError) {
1183 let terminal = FanOutTerminal::Failed(error);
1184 let lanes = {
1185 let mut registry = self.registry.lock().expect("fan-out hub poisoned");
1186 if registry.terminal.is_none() {
1187 registry.terminal = Some(terminal.clone());
1188 self.publish_snapshot_locked(®istry);
1189 }
1190 registry.consumers.values().cloned().collect::<Vec<_>>()
1191 };
1192
1193 for lane in lanes {
1194 lane.set_terminal(terminal.clone());
1195 }
1196 self.notify_topology_transition();
1197 }
1198
1199 fn publish_snapshot_locked(&self, registry: &FanOutRegistry<T>) {
1200 self.snapshot.store(Arc::new(FanOutSnapshot {
1201 consumers: registry.consumers.values().cloned().collect(),
1202 terminal: registry.terminal.clone(),
1203 }));
1204 }
1205
1206 fn active_lanes(
1207 &self,
1208 snapshot: &FanOutSnapshot<T>,
1209 ) -> SmallVec<[Arc<FanOutConsumerLane<T>>; 16]> {
1210 snapshot
1211 .consumers
1212 .iter()
1213 .filter(|lane| lane.is_active())
1214 .cloned()
1215 .collect()
1216 }
1217
1218 fn partition_batch_limit(&self) -> usize {
1219 let snapshot = self.snapshot.load();
1220 let active = snapshot
1221 .consumers
1222 .iter()
1223 .filter(|lane| lane.is_active())
1224 .count();
1225 match active {
1226 0 | 1 => PARTITION_HUB_SINGLE_CONSUMER_BATCH_LIMIT,
1227 2..=7 => PARTITION_HUB_BATCH_LIMIT,
1228 _ => PARTITION_HUB_WIDE_BATCH_LIMIT,
1229 }
1230 }
1231
1232 fn notify_producer_transition(&self) {
1233 self.producer_epoch.fetch_add(1, Ordering::Release);
1234 self.producer_condvar.notify_one();
1235 }
1236
1237 fn notify_topology_transition(&self) {
1238 self.topology_epoch.fetch_add(1, Ordering::Release);
1239 self.notify_producer_transition();
1240 }
1241
1242 fn wait_for_producer_transition(&self, observed_epoch: u64) {
1243 let guard = self
1244 .producer_wait
1245 .lock()
1246 .expect("fan-out producer wait poisoned");
1247 if self.producer_epoch.load(Ordering::Acquire) == observed_epoch {
1248 let (_guard, _) = self
1249 .producer_condvar
1250 .wait_timeout(guard, fan_out_wait_timeout())
1251 .expect("fan-out producer wait poisoned while waiting");
1252 }
1253 }
1254}
1255
1256fn push_partition_routed<T>(
1257 routed: &mut PartitionRoutedBatches<T>,
1258 lane: Arc<FanOutConsumerLane<T>>,
1259 item: T,
1260) {
1261 for (existing, batch) in routed.iter_mut() {
1262 if existing.id() == lane.id() {
1263 batch.push_back(item);
1264 return;
1265 }
1266 }
1267
1268 let mut batch = VecDeque::new();
1269 batch.push_back(item);
1270 routed.push((lane, batch));
1271}
1272
1273struct FanOutProducer<T> {
1274 input: BoxStream<T>,
1275 state: Arc<FanOutHubShared<T>>,
1276}
1277
1278impl<T> FanOutProducer<T> {
1279 fn new(input: BoxStream<T>, state: Arc<FanOutHubShared<T>>) -> Self {
1280 Self { input, state }
1281 }
1282}
1283
1284impl<T: Send + 'static + Clone> FanOutProducer<T> {
1285 fn run(
1286 mut self,
1287 cancelled: Arc<std::sync::atomic::AtomicBool>,
1288 hints: SourceRuntimeHints,
1289 ) -> StreamResult<NotUsed> {
1290 struct ProducerDropGuard<T> {
1291 state: Arc<FanOutHubShared<T>>,
1292 disarmed: bool,
1293 }
1294
1295 impl<T> ProducerDropGuard<T> {
1296 fn new(state: Arc<FanOutHubShared<T>>) -> Self {
1297 Self {
1298 state,
1299 disarmed: false,
1300 }
1301 }
1302
1303 fn disarm(&mut self) {
1304 self.disarmed = true;
1305 }
1306 }
1307
1308 impl<T> Drop for ProducerDropGuard<T> {
1309 fn drop(&mut self) {
1310 if !self.disarmed && std::thread::panicking() {
1311 self.state.fail(StreamError::Failed(
1312 "fan-out hub producer panicked".to_owned(),
1313 ));
1314 }
1315 }
1316 }
1317
1318 let mut guard = ProducerDropGuard::new(Arc::clone(&self.state));
1319 match self.state.mode {
1320 FanOutMode::Broadcast => {
1321 let batch_producer = hints.inline_micro_max_success_items.is_some();
1322 let mut batch = VecDeque::new();
1323 loop {
1324 if cancelled.load(Ordering::SeqCst) {
1325 self.state.fail(StreamError::Cancelled);
1326 guard.disarm();
1327 return Err(StreamError::Cancelled);
1328 }
1329
1330 if batch_producer {
1331 let capacity = self
1332 .state
1333 .wait_for_broadcast_capacity(BROADCAST_HUB_BATCH_LIMIT)?;
1334 batch.clear();
1335 batch.reserve(capacity.saturating_sub(batch.capacity()));
1336 let mut terminal = None;
1337 for _ in 0..capacity {
1338 if cancelled.load(Ordering::SeqCst) {
1339 terminal = Some(Err(StreamError::Cancelled));
1340 break;
1341 }
1342 match self.input.next() {
1343 Some(Ok(item)) => batch.push_back(item),
1344 Some(Err(error)) => {
1345 terminal = Some(Err(error));
1346 break;
1347 }
1348 None => {
1349 terminal = Some(Ok(()));
1350 break;
1351 }
1352 }
1353 }
1354 if !batch.is_empty() {
1355 self.state.push_broadcast_batch(&mut batch)?;
1356 }
1357 match terminal {
1358 Some(Err(StreamError::Cancelled)) => {
1359 self.state.fail(StreamError::Cancelled);
1360 guard.disarm();
1361 return Err(StreamError::Cancelled);
1362 }
1363 Some(Err(error)) => {
1364 self.state.fail(error.clone());
1365 guard.disarm();
1366 return Err(error);
1367 }
1368 Some(Ok(())) => {
1369 self.state.complete();
1370 guard.disarm();
1371 return Ok(NotUsed);
1372 }
1373 None => continue,
1374 }
1375 }
1376
1377 match self.input.next() {
1378 Some(Ok(item)) => self.state.push_broadcast_item(item)?,
1379 Some(Err(error)) => {
1380 self.state.fail(error.clone());
1381 guard.disarm();
1382 return Err(error);
1383 }
1384 None => {
1385 self.state.complete();
1386 guard.disarm();
1387 return Ok(NotUsed);
1388 }
1389 }
1390 }
1391 }
1392 FanOutMode::Partition => {
1393 let batch_producer = hints.inline_micro_max_success_items.is_some();
1394 let mut topology_cache = PartitionTopologyCache::new();
1395 let mut routed = PartitionRoutedBatches::new();
1396 loop {
1397 if cancelled.load(Ordering::SeqCst) {
1398 self.state.fail(StreamError::Cancelled);
1399 guard.disarm();
1400 return Err(StreamError::Cancelled);
1401 }
1402
1403 if batch_producer {
1404 for (_, batch) in routed.iter_mut() {
1405 batch.clear();
1406 }
1407 let partition_batch_limit = self.state.partition_batch_limit();
1408 let mut terminal = None;
1409 for _ in 0..partition_batch_limit {
1410 if cancelled.load(Ordering::SeqCst) {
1411 terminal = Some(Err(StreamError::Cancelled));
1412 break;
1413 }
1414 match self.input.next() {
1415 Some(Ok(item)) => {
1416 if let Some(selected) =
1417 self.state.select_partition(&item, &mut topology_cache)?
1418 {
1419 push_partition_routed(&mut routed, selected, item);
1420 }
1421 }
1422 Some(Err(error)) => {
1423 terminal = Some(Err(error));
1424 break;
1425 }
1426 None => {
1427 terminal = Some(Ok(()));
1428 break;
1429 }
1430 }
1431 }
1432 for (selected, batch) in routed.iter_mut() {
1433 if !batch.is_empty() {
1434 self.state
1435 .enqueue_partition_batch(Arc::clone(selected), batch)?;
1436 }
1437 }
1438 match terminal {
1439 Some(Err(StreamError::Cancelled)) => {
1440 self.state.fail(StreamError::Cancelled);
1441 guard.disarm();
1442 return Err(StreamError::Cancelled);
1443 }
1444 Some(Err(error)) => {
1445 self.state.fail(error.clone());
1446 guard.disarm();
1447 return Err(error);
1448 }
1449 Some(Ok(())) => {
1450 self.state.complete();
1451 guard.disarm();
1452 return Ok(NotUsed);
1453 }
1454 None => continue,
1455 }
1456 }
1457
1458 match self.input.next() {
1459 Some(Ok(item)) => {
1460 if let Some(selected) =
1461 self.state.select_partition(&item, &mut topology_cache)?
1462 {
1463 self.state.enqueue_partition(selected, item)?;
1464 }
1465 }
1466 Some(Err(error)) => {
1467 self.state.fail(error.clone());
1468 guard.disarm();
1469 return Err(error);
1470 }
1471 None => {
1472 self.state.complete();
1473 guard.disarm();
1474 return Ok(NotUsed);
1475 }
1476 }
1477 }
1478 }
1479 }
1480 }
1481}
1482
1483struct FanOutConsumerStream<T> {
1484 state: Arc<FanOutHubShared<T>>,
1485 lane: Arc<FanOutConsumerLane<T>>,
1486 local: Option<std::vec::IntoIter<T>>,
1487 detached: bool,
1488}
1489
1490impl<T: Clone + Send + 'static> Iterator for FanOutConsumerStream<T> {
1491 type Item = StreamResult<T>;
1492
1493 fn next(&mut self) -> Option<Self::Item> {
1494 if self.lane.failed.load(Ordering::Acquire) {
1495 self.local = None;
1496 } else if let Some(local) = &mut self.local {
1497 if let Some(item) = local.next() {
1498 return Some(Ok(item));
1499 }
1500 self.local = None;
1501 }
1502
1503 let mut state = self.lane.state.lock().expect("fan-out lane poisoned");
1504 loop {
1505 if let Some(FanOutTerminal::Failed(error)) = state.terminal.clone() {
1506 return Some(Err(error));
1507 }
1508 if let Some(batch) = state.chunks.pop_front() {
1509 state.queued = state.queued.saturating_sub(batch.len());
1510 self.lane.queued.store(state.queued, Ordering::Release);
1511 if matches!(self.state.mode, FanOutMode::Partition) {
1512 self.lane.condvar.notify_one();
1513 }
1514 self.state.notify_producer_transition();
1515 drop(state);
1516 let mut batch = batch.into_iter();
1517 let first = batch.next().expect("fan-out drained non-empty lane batch");
1518 if batch.len() > 0 {
1519 self.local = Some(batch);
1520 }
1521 return Some(Ok(first));
1522 }
1523 if matches!(state.terminal, Some(FanOutTerminal::Completed)) {
1524 return None;
1525 }
1526 let (guard, _) = self
1527 .lane
1528 .condvar
1529 .wait_timeout(state, fan_out_wait_timeout())
1530 .expect("fan-out lane poisoned while waiting");
1531 state = guard;
1532 }
1533 }
1534}
1535
1536impl<T> Drop for FanOutConsumerStream<T> {
1537 fn drop(&mut self) {
1538 if !self.detached {
1539 self.state.remove_consumer(self.lane.id());
1540 self.detached = true;
1541 }
1542 }
1543}
1544
1545#[cfg(test)]
1546mod tests {
1547 use super::*;
1548 use crate::testkit::{TestSink, TestSource};
1549 use crate::{Keep, Materializer, Sink, Source};
1550 use std::{
1551 panic::{self, AssertUnwindSafe},
1552 sync::{
1553 Arc,
1554 atomic::{AtomicUsize, Ordering},
1555 },
1556 thread,
1557 time::{Duration, Instant},
1558 };
1559
1560 #[test]
1561 fn merge_hub_accepts_dynamic_producers_and_drains() {
1562 let materializer = Materializer::new();
1563 let ((hub_sink, control), completion) = MergeHub::source_with_draining::<i32>(4)
1564 .to_mat(Sink::collect(), Keep::both)
1565 .run_with_materializer(&materializer)
1566 .expect("merge hub materializes");
1567
1568 hub_sink
1569 .clone()
1570 .run_with(Source::from_iter([1, 2, 3]))
1571 .expect("first producer attaches");
1572 hub_sink
1573 .run_with(Source::from_iter([4, 5]))
1574 .expect("second producer attaches");
1575 control.drain_and_complete();
1576 let mut result = completion.wait().expect("merge hub completes");
1577 result.sort_unstable();
1578 assert_eq!(result, vec![1, 2, 3, 4, 5]);
1579 }
1580
1581 #[test]
1582 fn merge_hub_producer_error_fails_downstream_consumer() {
1583 let materializer = Materializer::new();
1584 let (hub_sink, sink) = MergeHub::source::<i32>(4)
1585 .to_mat(TestSink::probe(), Keep::both)
1586 .run_with_materializer(&materializer)
1587 .expect("merge hub materializes");
1588
1589 let producer_ok = TestSource::probe::<i32>()
1590 .to_mat(hub_sink.clone(), Keep::left)
1591 .run_with_materializer(&materializer)
1592 .expect("successful producer attaches");
1593 let producer_fail = TestSource::probe::<i32>()
1594 .to_mat(hub_sink, Keep::left)
1595 .run_with_materializer(&materializer)
1596 .expect("failing producer attaches");
1597
1598 sink.request(1);
1599 assert_eq!(producer_ok.expect_request(), 1);
1600 producer_ok.send_next(1);
1601 sink.assert_next(1);
1602
1603 producer_fail.send_error(StreamError::Failed("producer failed".to_owned()));
1604 sink.request(1);
1605 assert_eq!(
1606 sink.expect_error(),
1607 StreamError::Failed("producer failed".to_owned())
1608 );
1609 }
1610
1611 #[test]
1612 fn broadcast_hub_backpressures_slowest_consumer() {
1613 let materializer = Materializer::new();
1614 let (publisher, hub_source) = TestSource::probe::<i32>()
1615 .to_mat(BroadcastHub::sink(1), Keep::both)
1616 .run_with_materializer(&materializer)
1617 .expect("broadcast hub materializes");
1618
1619 let sink_a = hub_source
1620 .source()
1621 .run_with(TestSink::probe())
1622 .expect("first consumer materializes");
1623 let sink_b = hub_source
1624 .source()
1625 .run_with(TestSink::probe())
1626 .expect("second consumer materializes");
1627
1628 sink_a.request(1);
1629 sink_b.request(1);
1630 assert_eq!(publisher.expect_request(), 1);
1631 publisher.send_next(1);
1632 sink_a.assert_next(1);
1633 sink_b.assert_next(1);
1634
1635 sink_a.request(1);
1636 assert_eq!(publisher.expect_request(), 1);
1637 publisher.send_next(2);
1638 sink_a.assert_next(2);
1639 sink_b.expect_no_message(Duration::from_millis(250));
1640
1641 sink_a.request(1);
1642 sink_a.expect_no_message(Duration::from_millis(250));
1643
1644 sink_b.request(1);
1645 sink_b.assert_next(2);
1646 assert_eq!(publisher.expect_request(), 1);
1647 }
1648
1649 #[test]
1650 fn broadcast_hub_late_consumer_sees_only_late_elements() {
1651 let materializer = Materializer::new();
1652 let (publisher, hub_source) = TestSource::probe::<i32>()
1653 .to_mat(BroadcastHub::sink(2), Keep::both)
1654 .run_with_materializer(&materializer)
1655 .expect("broadcast hub materializes");
1656
1657 let sink_a = hub_source
1658 .source()
1659 .run_with(TestSink::probe())
1660 .expect("first consumer materializes");
1661 sink_a.request(1);
1662 assert_eq!(publisher.expect_request(), 1);
1663 publisher.send_next(1);
1664 sink_a.assert_next(1);
1665
1666 let sink_b = hub_source
1667 .source()
1668 .run_with(TestSink::probe())
1669 .expect("late consumer materializes");
1670 sink_a.request(1);
1671 sink_b.request(1);
1672 assert_eq!(publisher.expect_request(), 1);
1673 publisher.send_next(2);
1674 sink_a.assert_next(2);
1675 sink_b.assert_next(2);
1676
1677 publisher.send_complete();
1678 sink_a.request(1);
1679 sink_b.request(1);
1680 sink_a.expect_complete();
1681 sink_b.expect_complete();
1682 }
1683
1684 #[test]
1685 fn partition_hub_routes_elements_to_selected_consumers() {
1686 let materializer = Materializer::new();
1687 let hub = Source::from_iter([0, 1, 2, 3])
1688 .run_with_materializer(
1689 PartitionHub::sink(
1690 |info, item| {
1691 let idx = (*item as usize) % info.size();
1692 info.consumer_id_by_idx(idx) as isize
1693 },
1694 2,
1695 8,
1696 ),
1697 &materializer,
1698 )
1699 .expect("partition hub materializes");
1700
1701 let sink_a = hub
1702 .source()
1703 .run_with(TestSink::probe())
1704 .expect("first consumer materializes");
1705 let sink_b = hub
1706 .source()
1707 .run_with(TestSink::probe())
1708 .expect("second consumer materializes");
1709
1710 sink_a.request(2);
1711 sink_b.request(2);
1712 sink_a.assert_next_n([0, 2]);
1713 sink_b.assert_next_n([1, 3]);
1714 }
1715
1716 #[test]
1717 fn partition_hub_evaluates_stateful_partitioner_once_per_blocked_element() {
1718 let materializer = Materializer::new();
1719 let partition_calls = Arc::new(AtomicUsize::new(0));
1720 let partition_calls_for_hub = Arc::clone(&partition_calls);
1721
1722 let (publisher, hub) = TestSource::probe::<i32>()
1723 .to_mat(
1724 PartitionHub::sink(
1725 move |info, _item| {
1726 partition_calls_for_hub.fetch_add(1, Ordering::SeqCst);
1727 info.consumer_id_by_idx(0) as isize
1728 },
1729 1,
1730 1,
1731 ),
1732 Keep::both,
1733 )
1734 .run_with_materializer(&materializer)
1735 .expect("partition hub materializes");
1736
1737 let sink = hub
1738 .source()
1739 .run_with(TestSink::probe())
1740 .expect("consumer materializes");
1741
1742 assert_eq!(publisher.expect_request(), 1);
1743 publisher.send_next(1);
1744 wait_for_partition_calls(&partition_calls, 1);
1745
1746 assert_eq!(publisher.expect_request(), 1);
1747 publisher.send_next(2);
1748 sink.expect_no_message(Duration::from_millis(250));
1749 wait_for_partition_calls(&partition_calls, 2);
1750
1751 sink.request(1);
1752 sink.assert_next(1);
1753 sink.request(1);
1754 sink.assert_next(2);
1755 assert_eq!(partition_calls.load(Ordering::SeqCst), 2);
1756 }
1757
1758 #[test]
1759 fn partition_hub_invalidates_cached_topology_on_consumer_churn() {
1760 let materializer = Materializer::new();
1761 let (publisher, hub) = TestSource::probe::<i32>()
1762 .to_mat(
1763 PartitionHub::sink(
1764 |info, _item| info.consumer_id_by_idx(info.size() - 1) as isize,
1765 1,
1766 8,
1767 ),
1768 Keep::both,
1769 )
1770 .run_with_materializer(&materializer)
1771 .expect("partition hub materializes");
1772
1773 let sink_a = hub
1774 .source()
1775 .run_with(TestSink::probe())
1776 .expect("first consumer materializes");
1777 sink_a.request(1);
1778 assert_eq!(publisher.expect_request(), 1);
1779 publisher.send_next(1);
1780 sink_a.assert_next(1);
1781 drop(sink_a);
1782
1783 let sink_b = hub
1784 .source()
1785 .run_with(TestSink::probe())
1786 .expect("replacement consumer materializes");
1787 sink_b.request(1);
1788 assert_eq!(publisher.expect_request(), 1);
1789 publisher.send_next(2);
1790 sink_b.assert_next(2);
1791
1792 publisher.send_complete();
1793 sink_b.request(1);
1794 sink_b.expect_complete();
1795 }
1796
1797 #[test]
1798 fn broadcast_hub_drains_local_chunk_before_completion() {
1799 let materializer = Materializer::new();
1800 let hub = Source::from_iter([1, 2, 3])
1801 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
1802 .expect("broadcast hub materializes");
1803
1804 let sink = hub
1805 .source()
1806 .run_with(TestSink::probe())
1807 .expect("consumer materializes");
1808
1809 sink.request(1);
1810 sink.assert_next(1);
1811 sink.request(3);
1812 sink.assert_next_n([2, 3]);
1813 sink.expect_complete();
1814 }
1815
1816 #[test]
1817 fn broadcast_hub_panicking_upstream_fails_consumers() {
1818 let materializer = Materializer::new();
1819 let hub = Source::from_fn_iter(|| {
1820 let mut yielded = false;
1821 std::iter::from_fn(move || {
1822 if !yielded {
1823 yielded = true;
1824 Some(1)
1825 } else {
1826 panic!("boom");
1827 }
1828 })
1829 })
1830 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
1831 .expect("broadcast hub materializes");
1832
1833 let sink = hub
1834 .source()
1835 .run_with(TestSink::probe())
1836 .expect("consumer materializes");
1837
1838 sink.request(2);
1839 match panic::catch_unwind(AssertUnwindSafe(|| sink.expect_error())) {
1840 Ok(error) => assert_eq!(
1841 error,
1842 StreamError::Failed("fan-out hub producer panicked".to_owned())
1843 ),
1844 Err(payload) => {
1845 assert_eq!(
1846 panic_message(payload),
1847 "expected stream error, got next element"
1848 );
1849 sink.request(1);
1850 assert_eq!(
1851 sink.expect_error(),
1852 StreamError::Failed("fan-out hub producer panicked".to_owned())
1853 );
1854 }
1855 }
1856 }
1857
1858 fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
1859 match payload.downcast::<String>() {
1860 Ok(message) => *message,
1861 Err(payload) => match payload.downcast::<&'static str>() {
1862 Ok(message) => (*message).to_owned(),
1863 Err(_) => "<non-string panic payload>".to_owned(),
1864 },
1865 }
1866 }
1867
1868 fn wait_for_partition_calls(counter: &AtomicUsize, expected: usize) {
1869 let deadline = Instant::now() + Duration::from_secs(1);
1870 while Instant::now() < deadline {
1871 if counter.load(Ordering::SeqCst) == expected {
1872 return;
1873 }
1874 thread::sleep(Duration::from_millis(5));
1875 }
1876 assert_eq!(counter.load(Ordering::SeqCst), expected);
1877 }
1878}