1use std::{
2 collections::{BTreeMap, VecDeque},
3 fmt,
4 sync::{Arc, Condvar, Mutex, MutexGuard},
5};
6
7use crate::stream::{BoxStream, NotUsed, Sink, Source, StreamCompletion};
8use crate::{StreamError, StreamResult};
9
10type Partitioner<T> = Arc<dyn Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync>;
11
12#[derive(Clone)]
13pub struct MergeHubDrainingControl {
14 state: Arc<MergeHubState>,
15 on_drain: Arc<dyn Fn() + Send + Sync>,
16}
17
18impl fmt::Debug for MergeHubDrainingControl {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 f.debug_struct("MergeHubDrainingControl").finish()
21 }
22}
23
24impl MergeHubDrainingControl {
25 pub fn drain_and_complete(&self) {
26 let mut state = self.state.lock();
27 state.draining = true;
28 self.state.condvar.notify_all();
29 drop(state);
30 (self.on_drain)();
31 }
32}
33
34pub struct MergeHub;
35
36impl MergeHub {
37 #[must_use]
39 pub fn source<T: Send + 'static>(
40 per_producer_buffer_size: usize,
41 ) -> Source<T, Sink<T, NotUsed>> {
42 Self::source_with_draining(per_producer_buffer_size)
43 .map_materialized_value(|(sink, _)| sink)
44 }
45
46 #[must_use]
49 pub fn source_with_draining<T: Send + 'static>(
50 per_producer_buffer_size: usize,
51 ) -> Source<T, (Sink<T, NotUsed>, MergeHubDrainingControl)> {
52 assert!(
53 per_producer_buffer_size > 0,
54 "MergeHub per_producer_buffer_size must be greater than zero"
55 );
56 Source::from_materialized_factory(move |_| {
57 let state = Arc::new(MergeHubShared::<T>::new(per_producer_buffer_size));
58 let source = Box::new(MergeHubSourceStream {
59 state: Arc::clone(&state),
60 }) as BoxStream<T>;
61 let sink = merge_hub_sink(Arc::clone(&state));
62 let control = MergeHubDrainingControl {
63 state: Arc::clone(&state.state),
64 on_drain: Arc::new({
65 let state = Arc::clone(&state);
66 move || state.finish_if_draining()
67 }),
68 };
69 Ok((source, (sink, control)))
70 })
71 }
72}
73
74pub struct BroadcastHub;
75
76impl BroadcastHub {
77 #[must_use]
85 pub fn sink<T: Clone + Send + 'static>(
86 buffer_size: usize,
87 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
88 Self::sink_starting_after(0, buffer_size)
89 }
90
91 #[must_use]
93 pub fn sink_starting_after<T: Clone + Send + 'static>(
94 start_after_nr_of_consumers: usize,
95 buffer_size: usize,
96 ) -> Sink<T, BroadcastHubConsumerSource<T>> {
97 assert!(
98 buffer_size > 0,
99 "BroadcastHub buffer_size must be greater than zero"
100 );
101 Sink::from_runner(move |input, materializer| {
102 let state = Arc::new(FanOutHubShared::new(
103 FanOutMode::Broadcast,
104 start_after_nr_of_consumers,
105 buffer_size,
106 None::<Partitioner<T>>,
107 ));
108 let source = BroadcastHubConsumerSource {
109 state: Arc::clone(&state),
110 completion: Arc::new(Mutex::new(None)),
111 };
112 let completion = materializer
113 .spawn_stream(move |cancelled| FanOutProducer::new(input, state).run(cancelled));
114 source.attach_completion(completion);
115 Ok(source)
116 })
117 }
118}
119
120pub struct PartitionHub;
121
122impl PartitionHub {
123 #[must_use]
131 pub fn sink<T: Clone + Send + 'static, F>(
132 partitioner: F,
133 start_after_nr_of_consumers: usize,
134 buffer_size: usize,
135 ) -> Sink<T, PartitionHubConsumerSource<T>>
136 where
137 F: Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync + 'static,
138 {
139 assert!(
140 buffer_size > 0,
141 "PartitionHub buffer_size must be greater than zero"
142 );
143 let partitioner = Arc::new(partitioner);
144 Sink::from_runner(move |input, materializer| {
145 let partitioner = Arc::clone(&partitioner);
146 let state = Arc::new(FanOutHubShared::new(
147 FanOutMode::Partition,
148 start_after_nr_of_consumers,
149 buffer_size,
150 Some(partitioner),
151 ));
152 let source = PartitionHubConsumerSource {
153 state: Arc::clone(&state),
154 completion: Arc::new(Mutex::new(None)),
155 };
156 let completion = materializer
157 .spawn_stream(move |cancelled| FanOutProducer::new(input, state).run(cancelled));
158 source.attach_completion(completion);
159 Ok(source)
160 })
161 }
162}
163
164#[derive(Clone)]
165pub struct BroadcastHubConsumerSource<T> {
166 state: Arc<FanOutHubShared<T>>,
167 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
168}
169
170impl<T: Clone + Send + 'static> BroadcastHubConsumerSource<T> {
171 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
172 *self
173 .completion
174 .lock()
175 .expect("broadcast hub completion poisoned") = Some(completion);
176 }
177
178 #[must_use]
179 pub fn source(&self) -> Source<T, NotUsed> {
180 let state = Arc::clone(&self.state);
181 Source::from_materialized_factory(move |_| {
182 let consumer_id = state.register_consumer();
183 let stream = Box::new(FanOutConsumerStream {
184 state: Arc::clone(&state),
185 consumer_id,
186 detached: false,
187 }) as BoxStream<T>;
188 Ok((stream, NotUsed))
189 })
190 }
191}
192
193impl<T: Clone + Send + 'static> fmt::Debug for BroadcastHubConsumerSource<T> {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("BroadcastHubConsumerSource").finish()
196 }
197}
198
199#[derive(Clone)]
200pub struct PartitionHubConsumerSource<T> {
201 state: Arc<FanOutHubShared<T>>,
202 completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
203}
204
205impl<T: Clone + Send + 'static> PartitionHubConsumerSource<T> {
206 fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
207 *self
208 .completion
209 .lock()
210 .expect("partition hub completion poisoned") = Some(completion);
211 }
212
213 #[must_use]
214 pub fn source(&self) -> Source<T, NotUsed> {
215 let state = Arc::clone(&self.state);
216 Source::from_materialized_factory(move |_| {
217 let consumer_id = state.register_consumer();
218 let stream = Box::new(FanOutConsumerStream {
219 state: Arc::clone(&state),
220 consumer_id,
221 detached: false,
222 }) as BoxStream<T>;
223 Ok((stream, NotUsed))
224 })
225 }
226}
227
228impl<T: Clone + Send + 'static> fmt::Debug for PartitionHubConsumerSource<T> {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f.debug_struct("PartitionHubConsumerSource").finish()
231 }
232}
233
234#[derive(Clone, Debug)]
235pub struct PartitionConsumerInfo {
236 consumer_ids: Vec<u64>,
237 queue_sizes: BTreeMap<u64, usize>,
238}
239
240impl PartitionConsumerInfo {
241 #[must_use]
242 pub fn size(&self) -> usize {
243 self.consumer_ids.len()
244 }
245
246 #[must_use]
247 pub fn consumer_ids(&self) -> &[u64] {
248 &self.consumer_ids
249 }
250
251 #[must_use]
252 pub fn consumer_id_by_idx(&self, idx: usize) -> u64 {
253 self.consumer_ids[idx]
254 }
255
256 #[must_use]
257 pub fn queue_size(&self, consumer_id: u64) -> usize {
258 self.queue_sizes.get(&consumer_id).copied().unwrap_or(0)
259 }
260}
261
262fn merge_hub_sink<T: Send + 'static>(state: Arc<MergeHubShared<T>>) -> Sink<T, NotUsed> {
263 Sink::from_runner(move |input, materializer| {
264 let producer_id = state.register_producer()?;
265 let hub = Arc::clone(&state);
266 let completion = materializer.spawn_stream(move |cancelled| {
267 let mut input = input;
268 loop {
269 if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
270 hub.fail(StreamError::Cancelled);
271 hub.deregister_producer(producer_id);
272 return Err(StreamError::Cancelled);
273 }
274 match input.next() {
275 Some(Ok(item)) => hub.push_item(producer_id, item)?,
276 Some(Err(error)) => {
277 hub.fail(error.clone());
278 hub.deregister_producer(producer_id);
279 return Err(error);
280 }
281 None => {
282 hub.deregister_producer(producer_id);
283 return Ok(NotUsed);
284 }
285 }
286 }
287 });
288 state.store_producer_completion(completion);
289 Ok(NotUsed)
290 })
291}
292
293struct MergeHubShared<T> {
294 state: Arc<MergeHubState>,
295 shared: Mutex<MergeHubInner<T>>,
296 condvar: Condvar,
297}
298
299#[derive(Debug)]
300struct MergeHubState {
301 inner: Mutex<MergeHubFlags>,
302 condvar: Condvar,
303}
304
305#[derive(Debug, Default)]
306struct MergeHubFlags {
307 draining: bool,
308}
309
310impl MergeHubState {
311 fn lock(&self) -> MutexGuard<'_, MergeHubFlags> {
312 self.inner.lock().expect("merge hub flags poisoned")
313 }
314}
315
316struct MergeHubInner<T> {
317 queue: VecDeque<(u64, T)>,
318 queued_per_producer: BTreeMap<u64, usize>,
319 producer_completions: Vec<StreamCompletion<NotUsed>>,
320 active_producers: usize,
321 next_producer_id: u64,
322 source_closed: bool,
323 completed: bool,
324 failed: Option<StreamError>,
325 per_producer_buffer_size: usize,
326}
327
328impl<T> MergeHubShared<T> {
329 fn new(per_producer_buffer_size: usize) -> Self {
330 Self {
331 state: Arc::new(MergeHubState {
332 inner: Mutex::new(MergeHubFlags::default()),
333 condvar: Condvar::new(),
334 }),
335 shared: Mutex::new(MergeHubInner {
336 queue: VecDeque::new(),
337 queued_per_producer: BTreeMap::new(),
338 producer_completions: Vec::new(),
339 active_producers: 0,
340 next_producer_id: 0,
341 source_closed: false,
342 completed: false,
343 failed: None,
344 per_producer_buffer_size,
345 }),
346 condvar: Condvar::new(),
347 }
348 }
349
350 fn register_producer(&self) -> StreamResult<u64> {
351 let mut inner = self.shared.lock().expect("merge hub poisoned");
352 prune_finished_producer_completions(&mut inner.producer_completions);
353 let flags = self.state.lock();
354 if flags.draining || inner.source_closed || inner.completed {
355 return Err(StreamError::Failed(
356 "merge hub is draining or closed to new producers".to_owned(),
357 ));
358 }
359 if let Some(error) = inner.failed.clone() {
360 return Err(error);
361 }
362 let id = inner.next_producer_id;
363 inner.next_producer_id += 1;
364 inner.active_producers += 1;
365 inner.queued_per_producer.insert(id, 0);
366 Ok(id)
367 }
368
369 fn store_producer_completion(&self, completion: StreamCompletion<NotUsed>) {
370 let mut inner = self.shared.lock().expect("merge hub poisoned");
371 prune_finished_producer_completions(&mut inner.producer_completions);
372 inner.producer_completions.push(completion);
373 }
374
375 fn push_item(&self, producer_id: u64, item: T) -> StreamResult<()> {
376 let mut inner = self.shared.lock().expect("merge hub poisoned");
377 prune_finished_producer_completions(&mut inner.producer_completions);
378 loop {
379 if let Some(error) = inner.failed.clone() {
380 inner.queued_per_producer.remove(&producer_id);
381 return Err(error);
382 }
383 if inner.source_closed {
384 inner.queued_per_producer.remove(&producer_id);
385 return Err(StreamError::Cancelled);
386 }
387 let queued = inner
388 .queued_per_producer
389 .get(&producer_id)
390 .copied()
391 .unwrap_or(0);
392 if queued < inner.per_producer_buffer_size {
393 inner.queue.push_back((producer_id, item));
394 inner.queued_per_producer.insert(producer_id, queued + 1);
395 self.condvar.notify_all();
396 return Ok(());
397 }
398 inner = self
399 .condvar
400 .wait(inner)
401 .expect("merge hub poisoned while waiting");
402 }
403 }
404
405 fn deregister_producer(&self, producer_id: u64) {
406 let mut inner = self.shared.lock().expect("merge hub poisoned");
407 prune_finished_producer_completions(&mut inner.producer_completions);
408 inner.queued_per_producer.remove(&producer_id);
409 inner.active_producers = inner.active_producers.saturating_sub(1);
410 if inner.active_producers == 0 {
411 let flags = self.state.lock();
412 if flags.draining {
413 inner.completed = true;
414 }
415 }
416 self.condvar.notify_all();
417 }
418
419 fn fail(&self, error: StreamError) {
420 let mut inner = self.shared.lock().expect("merge hub poisoned");
421 if inner.failed.is_none() {
422 inner.failed = Some(error);
423 }
424 self.condvar.notify_all();
425 }
426
427 fn finish_if_draining(&self) {
428 let flags = self.state.lock();
429 if !flags.draining {
430 return;
431 }
432 drop(flags);
433
434 let mut inner = self.shared.lock().expect("merge hub poisoned");
435 prune_finished_producer_completions(&mut inner.producer_completions);
436 if inner.active_producers == 0 {
437 inner.completed = true;
438 self.condvar.notify_all();
439 }
440 }
441}
442
443fn prune_finished_producer_completions(completions: &mut Vec<StreamCompletion<NotUsed>>) {
444 let mut index = 0;
445 while index < completions.len() {
446 if completions[index].try_wait().is_some() {
447 drop(completions.swap_remove(index));
448 } else {
449 index += 1;
450 }
451 }
452}
453
454struct MergeHubSourceStream<T> {
455 state: Arc<MergeHubShared<T>>,
456}
457
458impl<T> Iterator for MergeHubSourceStream<T> {
459 type Item = StreamResult<T>;
460
461 fn next(&mut self) -> Option<Self::Item> {
462 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
463 loop {
464 if let Some(error) = inner.failed.clone() {
465 inner.source_closed = true;
466 return Some(Err(error));
467 }
468 if let Some((producer_id, item)) = inner.queue.pop_front() {
469 if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
470 *queued = queued.saturating_sub(1);
471 }
472 self.state.condvar.notify_all();
473 return Some(Ok(item));
474 }
475 if inner.completed {
476 inner.source_closed = true;
477 return None;
478 }
479 inner = self
480 .state
481 .condvar
482 .wait(inner)
483 .expect("merge hub poisoned while waiting");
484 }
485 }
486}
487
488impl<T> Drop for MergeHubSourceStream<T> {
489 fn drop(&mut self) {
490 let mut inner = self.state.shared.lock().expect("merge hub poisoned");
491 inner.source_closed = true;
492 self.state.condvar.notify_all();
493 }
494}
495
496#[derive(Clone, Copy)]
497enum FanOutMode {
498 Broadcast,
499 Partition,
500}
501
502struct FanOutHubShared<T> {
503 state: Mutex<FanOutState<T>>,
504 condvar: Condvar,
505 mode: FanOutMode,
506 start_after_nr_of_consumers: usize,
507 buffer_size: usize,
508 partitioner: Option<Partitioner<T>>,
509}
510
511struct FanOutState<T> {
512 consumers: BTreeMap<u64, VecDeque<T>>,
513 next_consumer_id: u64,
514 completed: bool,
515 failed: Option<StreamError>,
516}
517
518impl<T> FanOutHubShared<T> {
519 fn new(
520 mode: FanOutMode,
521 start_after_nr_of_consumers: usize,
522 buffer_size: usize,
523 partitioner: Option<Partitioner<T>>,
524 ) -> Self {
525 Self {
526 state: Mutex::new(FanOutState {
527 consumers: BTreeMap::new(),
528 next_consumer_id: 0,
529 completed: false,
530 failed: None,
531 }),
532 condvar: Condvar::new(),
533 mode,
534 start_after_nr_of_consumers,
535 buffer_size,
536 partitioner,
537 }
538 }
539
540 fn register_consumer(&self) -> u64 {
541 let mut state = self.state.lock().expect("fan-out hub poisoned");
542 let id = state.next_consumer_id;
543 state.next_consumer_id += 1;
544 state.consumers.insert(id, VecDeque::new());
545 self.condvar.notify_all();
546 id
547 }
548
549 fn remove_consumer(&self, consumer_id: u64) {
550 let mut state = self.state.lock().expect("fan-out hub poisoned");
551 state.consumers.remove(&consumer_id);
552 self.condvar.notify_all();
553 }
554
555 fn push(&self, item: T) -> StreamResult<()>
556 where
557 T: Clone,
558 {
559 let mut state = self.state.lock().expect("fan-out hub poisoned");
560 loop {
561 if state.failed.is_some() || state.completed {
562 return Err(StreamError::Cancelled);
563 }
564 if state.consumers.len() < self.start_after_nr_of_consumers
565 || state.consumers.is_empty()
566 {
567 state = self
568 .condvar
569 .wait(state)
570 .expect("fan-out hub poisoned while waiting");
571 continue;
572 }
573 match self.mode {
574 FanOutMode::Broadcast => {
575 if state
576 .consumers
577 .values()
578 .any(|queue| queue.len() >= self.buffer_size)
579 {
580 state = self
581 .condvar
582 .wait(state)
583 .expect("fan-out hub poisoned while waiting");
584 continue;
585 }
586 for queue in state.consumers.values_mut() {
587 queue.push_back(item.clone());
588 }
589 self.condvar.notify_all();
590 return Ok(());
591 }
592 FanOutMode::Partition => {
593 let Some(selected) = self.select_partition(&state, &item)? else {
594 return Ok(());
595 };
596 loop {
597 if state.failed.is_some() || state.completed {
598 return Err(StreamError::Cancelled);
599 }
600 let Some(queue) = state.consumers.get_mut(&selected) else {
601 return Err(StreamError::Failed(
602 "partition hub selected unknown consumer".to_owned(),
603 ));
604 };
605 if queue.len() < self.buffer_size {
606 queue.push_back(item);
607 self.condvar.notify_all();
608 return Ok(());
609 }
610 state = self
611 .condvar
612 .wait(state)
613 .expect("fan-out hub poisoned while waiting");
614 }
615 }
616 }
617 }
618 }
619
620 fn select_partition(&self, state: &FanOutState<T>, item: &T) -> StreamResult<Option<u64>> {
621 let info = PartitionConsumerInfo {
622 consumer_ids: state.consumers.keys().copied().collect(),
623 queue_sizes: state
624 .consumers
625 .iter()
626 .map(|(id, queue)| (*id, queue.len()))
627 .collect(),
628 };
629 let Some(partitioner) = &self.partitioner else {
630 return Err(StreamError::Failed(
631 "partition hub partitioner missing".to_owned(),
632 ));
633 };
634 let selected = partitioner(&info, item);
635 if selected < 0 {
636 return Ok(None);
637 }
638 Ok(Some(selected as u64))
639 }
640
641 fn complete(&self) {
642 let mut state = self.state.lock().expect("fan-out hub poisoned");
643 state.completed = true;
644 self.condvar.notify_all();
645 }
646
647 fn fail(&self, error: StreamError) {
648 let mut state = self.state.lock().expect("fan-out hub poisoned");
649 state.failed = Some(error);
650 self.condvar.notify_all();
651 }
652}
653
654struct FanOutProducer<T> {
655 input: BoxStream<T>,
656 state: Arc<FanOutHubShared<T>>,
657}
658
659impl<T> FanOutProducer<T> {
660 fn new(input: BoxStream<T>, state: Arc<FanOutHubShared<T>>) -> Self {
661 Self { input, state }
662 }
663}
664
665impl<T: Send + 'static + Clone> FanOutProducer<T> {
666 fn run(mut self, cancelled: Arc<std::sync::atomic::AtomicBool>) -> StreamResult<NotUsed> {
667 struct ProducerDropGuard<T> {
668 state: Arc<FanOutHubShared<T>>,
669 disarmed: bool,
670 }
671
672 impl<T> ProducerDropGuard<T> {
673 fn new(state: Arc<FanOutHubShared<T>>) -> Self {
674 Self {
675 state,
676 disarmed: false,
677 }
678 }
679
680 fn disarm(&mut self) {
681 self.disarmed = true;
682 }
683 }
684
685 impl<T> Drop for ProducerDropGuard<T> {
686 fn drop(&mut self) {
687 if !self.disarmed && std::thread::panicking() {
688 self.state.fail(StreamError::Failed(
689 "fan-out hub producer panicked".to_owned(),
690 ));
691 }
692 }
693 }
694
695 let mut guard = ProducerDropGuard::new(Arc::clone(&self.state));
696 loop {
697 if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
698 self.state.fail(StreamError::Cancelled);
699 guard.disarm();
700 return Err(StreamError::Cancelled);
701 }
702 match self.input.next() {
703 Some(Ok(item)) => self.state.push(item)?,
704 Some(Err(error)) => {
705 self.state.fail(error.clone());
706 guard.disarm();
707 return Err(error);
708 }
709 None => {
710 self.state.complete();
711 guard.disarm();
712 return Ok(NotUsed);
713 }
714 }
715 }
716 }
717}
718
719struct FanOutConsumerStream<T> {
720 state: Arc<FanOutHubShared<T>>,
721 consumer_id: u64,
722 detached: bool,
723}
724
725impl<T: Clone + Send + 'static> Iterator for FanOutConsumerStream<T> {
726 type Item = StreamResult<T>;
727
728 fn next(&mut self) -> Option<Self::Item> {
729 let mut state = self.state.state.lock().expect("fan-out hub poisoned");
730 loop {
731 if let Some(error) = state.failed.clone() {
732 return Some(Err(error));
733 }
734 if let Some(queue) = state.consumers.get_mut(&self.consumer_id)
735 && let Some(item) = queue.pop_front()
736 {
737 self.state.condvar.notify_all();
738 return Some(Ok(item));
739 }
740 if state.completed {
741 return None;
742 }
743 state = self
744 .state
745 .condvar
746 .wait(state)
747 .expect("fan-out hub poisoned while waiting");
748 }
749 }
750}
751
752impl<T> Drop for FanOutConsumerStream<T> {
753 fn drop(&mut self) {
754 if !self.detached {
755 self.state.remove_consumer(self.consumer_id);
756 self.detached = true;
757 }
758 }
759}
760
761#[cfg(test)]
762mod tests {
763 use super::*;
764 use crate::testkit::{TestSink, TestSource};
765 use crate::{Keep, Materializer, Sink, Source};
766 use std::{
767 panic::{self, AssertUnwindSafe},
768 sync::{
769 Arc,
770 atomic::{AtomicUsize, Ordering},
771 },
772 thread,
773 time::{Duration, Instant},
774 };
775
776 #[test]
777 fn merge_hub_accepts_dynamic_producers_and_drains() {
778 let materializer = Materializer::new();
779 let ((hub_sink, control), completion) = MergeHub::source_with_draining::<i32>(4)
780 .to_mat(Sink::collect(), Keep::both)
781 .run_with_materializer(&materializer)
782 .expect("merge hub materializes");
783
784 hub_sink
785 .clone()
786 .run_with(Source::from_iter([1, 2, 3]))
787 .expect("first producer attaches");
788 hub_sink
789 .run_with(Source::from_iter([4, 5]))
790 .expect("second producer attaches");
791 control.drain_and_complete();
792 let mut result = completion.wait().expect("merge hub completes");
793 result.sort_unstable();
794 assert_eq!(result, vec![1, 2, 3, 4, 5]);
795 }
796
797 #[test]
798 fn merge_hub_producer_error_fails_downstream_consumer() {
799 let materializer = Materializer::new();
800 let (hub_sink, sink) = MergeHub::source::<i32>(4)
801 .to_mat(TestSink::probe(), Keep::both)
802 .run_with_materializer(&materializer)
803 .expect("merge hub materializes");
804
805 let producer_ok = TestSource::probe::<i32>()
806 .to_mat(hub_sink.clone(), Keep::left)
807 .run_with_materializer(&materializer)
808 .expect("successful producer attaches");
809 let producer_fail = TestSource::probe::<i32>()
810 .to_mat(hub_sink, Keep::left)
811 .run_with_materializer(&materializer)
812 .expect("failing producer attaches");
813
814 sink.request(1);
815 assert_eq!(producer_ok.expect_request(), 1);
816 producer_ok.send_next(1);
817 sink.assert_next(1);
818
819 producer_fail.send_error(StreamError::Failed("producer failed".to_owned()));
820 sink.request(1);
821 assert_eq!(
822 sink.expect_error(),
823 StreamError::Failed("producer failed".to_owned())
824 );
825 }
826
827 #[test]
828 fn broadcast_hub_backpressures_slowest_consumer() {
829 let materializer = Materializer::new();
830 let (publisher, hub_source) = TestSource::probe::<i32>()
831 .to_mat(BroadcastHub::sink(1), Keep::both)
832 .run_with_materializer(&materializer)
833 .expect("broadcast hub materializes");
834
835 let sink_a = hub_source
836 .source()
837 .run_with(TestSink::probe())
838 .expect("first consumer materializes");
839 let sink_b = hub_source
840 .source()
841 .run_with(TestSink::probe())
842 .expect("second consumer materializes");
843
844 sink_a.request(1);
845 sink_b.request(1);
846 assert_eq!(publisher.expect_request(), 1);
847 publisher.send_next(1);
848 sink_a.assert_next(1);
849 sink_b.assert_next(1);
850
851 sink_a.request(1);
852 assert_eq!(publisher.expect_request(), 1);
853 publisher.send_next(2);
854 sink_a.assert_next(2);
855 sink_b.expect_no_message(Duration::from_millis(250));
856
857 sink_a.request(1);
858 sink_a.expect_no_message(Duration::from_millis(250));
859
860 sink_b.request(1);
861 sink_b.assert_next(2);
862 assert_eq!(publisher.expect_request(), 1);
863 }
864
865 #[test]
866 fn broadcast_hub_late_consumer_sees_only_late_elements() {
867 let materializer = Materializer::new();
868 let (publisher, hub_source) = TestSource::probe::<i32>()
869 .to_mat(BroadcastHub::sink(2), Keep::both)
870 .run_with_materializer(&materializer)
871 .expect("broadcast hub materializes");
872
873 let sink_a = hub_source
874 .source()
875 .run_with(TestSink::probe())
876 .expect("first consumer materializes");
877 sink_a.request(1);
878 assert_eq!(publisher.expect_request(), 1);
879 publisher.send_next(1);
880 sink_a.assert_next(1);
881
882 let sink_b = hub_source
883 .source()
884 .run_with(TestSink::probe())
885 .expect("late consumer materializes");
886 sink_a.request(1);
887 sink_b.request(1);
888 assert_eq!(publisher.expect_request(), 1);
889 publisher.send_next(2);
890 sink_a.assert_next(2);
891 sink_b.assert_next(2);
892
893 publisher.send_complete();
894 sink_a.request(1);
895 sink_b.request(1);
896 sink_a.expect_complete();
897 sink_b.expect_complete();
898 }
899
900 #[test]
901 fn partition_hub_routes_elements_to_selected_consumers() {
902 let materializer = Materializer::new();
903 let hub = Source::from_iter([0, 1, 2, 3])
904 .run_with_materializer(
905 PartitionHub::sink(
906 |info, item| {
907 let idx = (*item as usize) % info.size();
908 info.consumer_id_by_idx(idx) as isize
909 },
910 2,
911 8,
912 ),
913 &materializer,
914 )
915 .expect("partition hub materializes");
916
917 let sink_a = hub
918 .source()
919 .run_with(TestSink::probe())
920 .expect("first consumer materializes");
921 let sink_b = hub
922 .source()
923 .run_with(TestSink::probe())
924 .expect("second consumer materializes");
925
926 sink_a.request(2);
927 sink_b.request(2);
928 sink_a.assert_next_n([0, 2]);
929 sink_b.assert_next_n([1, 3]);
930 }
931
932 #[test]
933 fn partition_hub_evaluates_stateful_partitioner_once_per_blocked_element() {
934 let materializer = Materializer::new();
935 let partition_calls = Arc::new(AtomicUsize::new(0));
936 let partition_calls_for_hub = Arc::clone(&partition_calls);
937
938 let (publisher, hub) = TestSource::probe::<i32>()
939 .to_mat(
940 PartitionHub::sink(
941 move |info, _item| {
942 partition_calls_for_hub.fetch_add(1, Ordering::SeqCst);
943 info.consumer_id_by_idx(0) as isize
944 },
945 1,
946 1,
947 ),
948 Keep::both,
949 )
950 .run_with_materializer(&materializer)
951 .expect("partition hub materializes");
952
953 let sink = hub
954 .source()
955 .run_with(TestSink::probe())
956 .expect("consumer materializes");
957
958 assert_eq!(publisher.expect_request(), 1);
959 publisher.send_next(1);
960 wait_for_partition_calls(&partition_calls, 1);
961
962 assert_eq!(publisher.expect_request(), 1);
963 publisher.send_next(2);
964 sink.expect_no_message(Duration::from_millis(250));
965 wait_for_partition_calls(&partition_calls, 2);
966
967 sink.request(1);
968 sink.assert_next(1);
969 sink.request(1);
970 sink.assert_next(2);
971 assert_eq!(partition_calls.load(Ordering::SeqCst), 2);
972 }
973
974 #[test]
975 fn broadcast_hub_panicking_upstream_fails_consumers() {
976 let materializer = Materializer::new();
977 let hub = Source::from_fn_iter(|| {
978 let mut yielded = false;
979 std::iter::from_fn(move || {
980 if !yielded {
981 yielded = true;
982 Some(1)
983 } else {
984 panic!("boom");
985 }
986 })
987 })
988 .run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
989 .expect("broadcast hub materializes");
990
991 let sink = hub
992 .source()
993 .run_with(TestSink::probe())
994 .expect("consumer materializes");
995
996 sink.request(2);
997 match panic::catch_unwind(AssertUnwindSafe(|| sink.expect_error())) {
998 Ok(error) => assert_eq!(
999 error,
1000 StreamError::Failed("fan-out hub producer panicked".to_owned())
1001 ),
1002 Err(payload) => {
1003 assert_eq!(
1004 panic_message(payload),
1005 "expected stream error, got next element"
1006 );
1007 sink.request(1);
1008 assert_eq!(
1009 sink.expect_error(),
1010 StreamError::Failed("fan-out hub producer panicked".to_owned())
1011 );
1012 }
1013 }
1014 }
1015
1016 fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
1017 match payload.downcast::<String>() {
1018 Ok(message) => *message,
1019 Err(payload) => match payload.downcast::<&'static str>() {
1020 Ok(message) => (*message).to_owned(),
1021 Err(_) => "<non-string panic payload>".to_owned(),
1022 },
1023 }
1024 }
1025
1026 fn wait_for_partition_calls(counter: &AtomicUsize, expected: usize) {
1027 let deadline = Instant::now() + Duration::from_secs(1);
1028 while Instant::now() < deadline {
1029 if counter.load(Ordering::SeqCst) == expected {
1030 return;
1031 }
1032 thread::sleep(Duration::from_millis(5));
1033 }
1034 assert_eq!(counter.load(Ordering::SeqCst), expected);
1035 }
1036}