use std::{
collections::{BTreeMap, VecDeque},
fmt,
sync::{Arc, Condvar, Mutex, MutexGuard},
};
use crate::stream::{BoxStream, NotUsed, Sink, Source, StreamCompletion};
use crate::{StreamError, StreamResult};
type Partitioner<T> = Arc<dyn Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync>;
#[derive(Clone)]
pub struct MergeHubDrainingControl {
state: Arc<MergeHubState>,
on_drain: Arc<dyn Fn() + Send + Sync>,
}
impl fmt::Debug for MergeHubDrainingControl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MergeHubDrainingControl").finish()
}
}
impl MergeHubDrainingControl {
pub fn drain_and_complete(&self) {
let mut state = self.state.lock();
state.draining = true;
self.state.condvar.notify_all();
drop(state);
(self.on_drain)();
}
}
pub struct MergeHub;
impl MergeHub {
#[must_use]
pub fn source<T: Send + 'static>(
per_producer_buffer_size: usize,
) -> Source<T, Sink<T, NotUsed>> {
Self::source_with_draining(per_producer_buffer_size)
.map_materialized_value(|(sink, _)| sink)
}
#[must_use]
pub fn source_with_draining<T: Send + 'static>(
per_producer_buffer_size: usize,
) -> Source<T, (Sink<T, NotUsed>, MergeHubDrainingControl)> {
assert!(
per_producer_buffer_size > 0,
"MergeHub per_producer_buffer_size must be greater than zero"
);
Source::from_materialized_factory(move |_| {
let state = Arc::new(MergeHubShared::<T>::new(per_producer_buffer_size));
let source = Box::new(MergeHubSourceStream {
state: Arc::clone(&state),
}) as BoxStream<T>;
let sink = merge_hub_sink(Arc::clone(&state));
let control = MergeHubDrainingControl {
state: Arc::clone(&state.state),
on_drain: Arc::new({
let state = Arc::clone(&state);
move || state.finish_if_draining()
}),
};
Ok((source, (sink, control)))
})
}
}
pub struct BroadcastHub;
impl BroadcastHub {
#[must_use]
pub fn sink<T: Clone + Send + 'static>(
buffer_size: usize,
) -> Sink<T, BroadcastHubConsumerSource<T>> {
Self::sink_starting_after(0, buffer_size)
}
#[must_use]
pub fn sink_starting_after<T: Clone + Send + 'static>(
start_after_nr_of_consumers: usize,
buffer_size: usize,
) -> Sink<T, BroadcastHubConsumerSource<T>> {
assert!(
buffer_size > 0,
"BroadcastHub buffer_size must be greater than zero"
);
Sink::from_runner(move |input, materializer| {
let state = Arc::new(FanOutHubShared::new(
FanOutMode::Broadcast,
start_after_nr_of_consumers,
buffer_size,
None::<Partitioner<T>>,
));
let source = BroadcastHubConsumerSource {
state: Arc::clone(&state),
completion: Arc::new(Mutex::new(None)),
};
let completion = materializer
.spawn_stream(move |cancelled| FanOutProducer::new(input, state).run(cancelled));
source.attach_completion(completion);
Ok(source)
})
}
}
pub struct PartitionHub;
impl PartitionHub {
#[must_use]
pub fn sink<T: Clone + Send + 'static, F>(
partitioner: F,
start_after_nr_of_consumers: usize,
buffer_size: usize,
) -> Sink<T, PartitionHubConsumerSource<T>>
where
F: Fn(&PartitionConsumerInfo, &T) -> isize + Send + Sync + 'static,
{
assert!(
buffer_size > 0,
"PartitionHub buffer_size must be greater than zero"
);
let partitioner = Arc::new(partitioner);
Sink::from_runner(move |input, materializer| {
let partitioner = Arc::clone(&partitioner);
let state = Arc::new(FanOutHubShared::new(
FanOutMode::Partition,
start_after_nr_of_consumers,
buffer_size,
Some(partitioner),
));
let source = PartitionHubConsumerSource {
state: Arc::clone(&state),
completion: Arc::new(Mutex::new(None)),
};
let completion = materializer
.spawn_stream(move |cancelled| FanOutProducer::new(input, state).run(cancelled));
source.attach_completion(completion);
Ok(source)
})
}
}
#[derive(Clone)]
pub struct BroadcastHubConsumerSource<T> {
state: Arc<FanOutHubShared<T>>,
completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
}
impl<T: Clone + Send + 'static> BroadcastHubConsumerSource<T> {
fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
*self
.completion
.lock()
.expect("broadcast hub completion poisoned") = Some(completion);
}
#[must_use]
pub fn source(&self) -> Source<T, NotUsed> {
let state = Arc::clone(&self.state);
Source::from_materialized_factory(move |_| {
let consumer_id = state.register_consumer();
let stream = Box::new(FanOutConsumerStream {
state: Arc::clone(&state),
consumer_id,
detached: false,
}) as BoxStream<T>;
Ok((stream, NotUsed))
})
}
}
impl<T: Clone + Send + 'static> fmt::Debug for BroadcastHubConsumerSource<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BroadcastHubConsumerSource").finish()
}
}
#[derive(Clone)]
pub struct PartitionHubConsumerSource<T> {
state: Arc<FanOutHubShared<T>>,
completion: Arc<Mutex<Option<StreamCompletion<NotUsed>>>>,
}
impl<T: Clone + Send + 'static> PartitionHubConsumerSource<T> {
fn attach_completion(&self, completion: StreamCompletion<NotUsed>) {
*self
.completion
.lock()
.expect("partition hub completion poisoned") = Some(completion);
}
#[must_use]
pub fn source(&self) -> Source<T, NotUsed> {
let state = Arc::clone(&self.state);
Source::from_materialized_factory(move |_| {
let consumer_id = state.register_consumer();
let stream = Box::new(FanOutConsumerStream {
state: Arc::clone(&state),
consumer_id,
detached: false,
}) as BoxStream<T>;
Ok((stream, NotUsed))
})
}
}
impl<T: Clone + Send + 'static> fmt::Debug for PartitionHubConsumerSource<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PartitionHubConsumerSource").finish()
}
}
#[derive(Clone, Debug)]
pub struct PartitionConsumerInfo {
consumer_ids: Vec<u64>,
queue_sizes: BTreeMap<u64, usize>,
}
impl PartitionConsumerInfo {
#[must_use]
pub fn size(&self) -> usize {
self.consumer_ids.len()
}
#[must_use]
pub fn consumer_ids(&self) -> &[u64] {
&self.consumer_ids
}
#[must_use]
pub fn consumer_id_by_idx(&self, idx: usize) -> u64 {
self.consumer_ids[idx]
}
#[must_use]
pub fn queue_size(&self, consumer_id: u64) -> usize {
self.queue_sizes.get(&consumer_id).copied().unwrap_or(0)
}
}
fn merge_hub_sink<T: Send + 'static>(state: Arc<MergeHubShared<T>>) -> Sink<T, NotUsed> {
Sink::from_runner(move |input, materializer| {
let producer_id = state.register_producer()?;
let hub = Arc::clone(&state);
let completion = materializer.spawn_stream(move |cancelled| {
let mut input = input;
loop {
if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
hub.fail(StreamError::Cancelled);
hub.deregister_producer(producer_id);
return Err(StreamError::Cancelled);
}
match input.next() {
Some(Ok(item)) => hub.push_item(producer_id, item)?,
Some(Err(error)) => {
hub.fail(error.clone());
hub.deregister_producer(producer_id);
return Err(error);
}
None => {
hub.deregister_producer(producer_id);
return Ok(NotUsed);
}
}
}
});
state.store_producer_completion(completion);
Ok(NotUsed)
})
}
struct MergeHubShared<T> {
state: Arc<MergeHubState>,
shared: Mutex<MergeHubInner<T>>,
condvar: Condvar,
}
#[derive(Debug)]
struct MergeHubState {
inner: Mutex<MergeHubFlags>,
condvar: Condvar,
}
#[derive(Debug, Default)]
struct MergeHubFlags {
draining: bool,
}
impl MergeHubState {
fn lock(&self) -> MutexGuard<'_, MergeHubFlags> {
self.inner.lock().expect("merge hub flags poisoned")
}
}
struct MergeHubInner<T> {
queue: VecDeque<(u64, T)>,
queued_per_producer: BTreeMap<u64, usize>,
producer_completions: Vec<StreamCompletion<NotUsed>>,
active_producers: usize,
next_producer_id: u64,
source_closed: bool,
completed: bool,
failed: Option<StreamError>,
per_producer_buffer_size: usize,
}
impl<T> MergeHubShared<T> {
fn new(per_producer_buffer_size: usize) -> Self {
Self {
state: Arc::new(MergeHubState {
inner: Mutex::new(MergeHubFlags::default()),
condvar: Condvar::new(),
}),
shared: Mutex::new(MergeHubInner {
queue: VecDeque::new(),
queued_per_producer: BTreeMap::new(),
producer_completions: Vec::new(),
active_producers: 0,
next_producer_id: 0,
source_closed: false,
completed: false,
failed: None,
per_producer_buffer_size,
}),
condvar: Condvar::new(),
}
}
fn register_producer(&self) -> StreamResult<u64> {
let mut inner = self.shared.lock().expect("merge hub poisoned");
prune_finished_producer_completions(&mut inner.producer_completions);
let flags = self.state.lock();
if flags.draining || inner.source_closed || inner.completed {
return Err(StreamError::Failed(
"merge hub is draining or closed to new producers".to_owned(),
));
}
if let Some(error) = inner.failed.clone() {
return Err(error);
}
let id = inner.next_producer_id;
inner.next_producer_id += 1;
inner.active_producers += 1;
inner.queued_per_producer.insert(id, 0);
Ok(id)
}
fn store_producer_completion(&self, completion: StreamCompletion<NotUsed>) {
let mut inner = self.shared.lock().expect("merge hub poisoned");
prune_finished_producer_completions(&mut inner.producer_completions);
inner.producer_completions.push(completion);
}
fn push_item(&self, producer_id: u64, item: T) -> StreamResult<()> {
let mut inner = self.shared.lock().expect("merge hub poisoned");
prune_finished_producer_completions(&mut inner.producer_completions);
loop {
if let Some(error) = inner.failed.clone() {
inner.queued_per_producer.remove(&producer_id);
return Err(error);
}
if inner.source_closed {
inner.queued_per_producer.remove(&producer_id);
return Err(StreamError::Cancelled);
}
let queued = inner
.queued_per_producer
.get(&producer_id)
.copied()
.unwrap_or(0);
if queued < inner.per_producer_buffer_size {
inner.queue.push_back((producer_id, item));
inner.queued_per_producer.insert(producer_id, queued + 1);
self.condvar.notify_all();
return Ok(());
}
inner = self
.condvar
.wait(inner)
.expect("merge hub poisoned while waiting");
}
}
fn deregister_producer(&self, producer_id: u64) {
let mut inner = self.shared.lock().expect("merge hub poisoned");
prune_finished_producer_completions(&mut inner.producer_completions);
inner.queued_per_producer.remove(&producer_id);
inner.active_producers = inner.active_producers.saturating_sub(1);
if inner.active_producers == 0 {
let flags = self.state.lock();
if flags.draining {
inner.completed = true;
}
}
self.condvar.notify_all();
}
fn fail(&self, error: StreamError) {
let mut inner = self.shared.lock().expect("merge hub poisoned");
if inner.failed.is_none() {
inner.failed = Some(error);
}
self.condvar.notify_all();
}
fn finish_if_draining(&self) {
let flags = self.state.lock();
if !flags.draining {
return;
}
drop(flags);
let mut inner = self.shared.lock().expect("merge hub poisoned");
prune_finished_producer_completions(&mut inner.producer_completions);
if inner.active_producers == 0 {
inner.completed = true;
self.condvar.notify_all();
}
}
}
fn prune_finished_producer_completions(completions: &mut Vec<StreamCompletion<NotUsed>>) {
let mut index = 0;
while index < completions.len() {
if completions[index].try_wait().is_some() {
drop(completions.swap_remove(index));
} else {
index += 1;
}
}
}
struct MergeHubSourceStream<T> {
state: Arc<MergeHubShared<T>>,
}
impl<T> Iterator for MergeHubSourceStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
let mut inner = self.state.shared.lock().expect("merge hub poisoned");
loop {
if let Some(error) = inner.failed.clone() {
inner.source_closed = true;
return Some(Err(error));
}
if let Some((producer_id, item)) = inner.queue.pop_front() {
if let Some(queued) = inner.queued_per_producer.get_mut(&producer_id) {
*queued = queued.saturating_sub(1);
}
self.state.condvar.notify_all();
return Some(Ok(item));
}
if inner.completed {
inner.source_closed = true;
return None;
}
inner = self
.state
.condvar
.wait(inner)
.expect("merge hub poisoned while waiting");
}
}
}
impl<T> Drop for MergeHubSourceStream<T> {
fn drop(&mut self) {
let mut inner = self.state.shared.lock().expect("merge hub poisoned");
inner.source_closed = true;
self.state.condvar.notify_all();
}
}
#[derive(Clone, Copy)]
enum FanOutMode {
Broadcast,
Partition,
}
struct FanOutHubShared<T> {
state: Mutex<FanOutState<T>>,
condvar: Condvar,
mode: FanOutMode,
start_after_nr_of_consumers: usize,
buffer_size: usize,
partitioner: Option<Partitioner<T>>,
}
struct FanOutState<T> {
consumers: BTreeMap<u64, VecDeque<T>>,
next_consumer_id: u64,
completed: bool,
failed: Option<StreamError>,
}
impl<T> FanOutHubShared<T> {
fn new(
mode: FanOutMode,
start_after_nr_of_consumers: usize,
buffer_size: usize,
partitioner: Option<Partitioner<T>>,
) -> Self {
Self {
state: Mutex::new(FanOutState {
consumers: BTreeMap::new(),
next_consumer_id: 0,
completed: false,
failed: None,
}),
condvar: Condvar::new(),
mode,
start_after_nr_of_consumers,
buffer_size,
partitioner,
}
}
fn register_consumer(&self) -> u64 {
let mut state = self.state.lock().expect("fan-out hub poisoned");
let id = state.next_consumer_id;
state.next_consumer_id += 1;
state.consumers.insert(id, VecDeque::new());
self.condvar.notify_all();
id
}
fn remove_consumer(&self, consumer_id: u64) {
let mut state = self.state.lock().expect("fan-out hub poisoned");
state.consumers.remove(&consumer_id);
self.condvar.notify_all();
}
fn push(&self, item: T) -> StreamResult<()>
where
T: Clone,
{
let mut state = self.state.lock().expect("fan-out hub poisoned");
loop {
if state.failed.is_some() || state.completed {
return Err(StreamError::Cancelled);
}
if state.consumers.len() < self.start_after_nr_of_consumers
|| state.consumers.is_empty()
{
state = self
.condvar
.wait(state)
.expect("fan-out hub poisoned while waiting");
continue;
}
match self.mode {
FanOutMode::Broadcast => {
if state
.consumers
.values()
.any(|queue| queue.len() >= self.buffer_size)
{
state = self
.condvar
.wait(state)
.expect("fan-out hub poisoned while waiting");
continue;
}
for queue in state.consumers.values_mut() {
queue.push_back(item.clone());
}
self.condvar.notify_all();
return Ok(());
}
FanOutMode::Partition => {
let Some(selected) = self.select_partition(&state, &item)? else {
return Ok(());
};
loop {
if state.failed.is_some() || state.completed {
return Err(StreamError::Cancelled);
}
let Some(queue) = state.consumers.get_mut(&selected) else {
return Err(StreamError::Failed(
"partition hub selected unknown consumer".to_owned(),
));
};
if queue.len() < self.buffer_size {
queue.push_back(item);
self.condvar.notify_all();
return Ok(());
}
state = self
.condvar
.wait(state)
.expect("fan-out hub poisoned while waiting");
}
}
}
}
}
fn select_partition(&self, state: &FanOutState<T>, item: &T) -> StreamResult<Option<u64>> {
let info = PartitionConsumerInfo {
consumer_ids: state.consumers.keys().copied().collect(),
queue_sizes: state
.consumers
.iter()
.map(|(id, queue)| (*id, queue.len()))
.collect(),
};
let Some(partitioner) = &self.partitioner else {
return Err(StreamError::Failed(
"partition hub partitioner missing".to_owned(),
));
};
let selected = partitioner(&info, item);
if selected < 0 {
return Ok(None);
}
Ok(Some(selected as u64))
}
fn complete(&self) {
let mut state = self.state.lock().expect("fan-out hub poisoned");
state.completed = true;
self.condvar.notify_all();
}
fn fail(&self, error: StreamError) {
let mut state = self.state.lock().expect("fan-out hub poisoned");
state.failed = Some(error);
self.condvar.notify_all();
}
}
struct FanOutProducer<T> {
input: BoxStream<T>,
state: Arc<FanOutHubShared<T>>,
}
impl<T> FanOutProducer<T> {
fn new(input: BoxStream<T>, state: Arc<FanOutHubShared<T>>) -> Self {
Self { input, state }
}
}
impl<T: Send + 'static + Clone> FanOutProducer<T> {
fn run(mut self, cancelled: Arc<std::sync::atomic::AtomicBool>) -> StreamResult<NotUsed> {
struct ProducerDropGuard<T> {
state: Arc<FanOutHubShared<T>>,
disarmed: bool,
}
impl<T> ProducerDropGuard<T> {
fn new(state: Arc<FanOutHubShared<T>>) -> Self {
Self {
state,
disarmed: false,
}
}
fn disarm(&mut self) {
self.disarmed = true;
}
}
impl<T> Drop for ProducerDropGuard<T> {
fn drop(&mut self) {
if !self.disarmed && std::thread::panicking() {
self.state.fail(StreamError::Failed(
"fan-out hub producer panicked".to_owned(),
));
}
}
}
let mut guard = ProducerDropGuard::new(Arc::clone(&self.state));
loop {
if cancelled.load(std::sync::atomic::Ordering::SeqCst) {
self.state.fail(StreamError::Cancelled);
guard.disarm();
return Err(StreamError::Cancelled);
}
match self.input.next() {
Some(Ok(item)) => self.state.push(item)?,
Some(Err(error)) => {
self.state.fail(error.clone());
guard.disarm();
return Err(error);
}
None => {
self.state.complete();
guard.disarm();
return Ok(NotUsed);
}
}
}
}
}
struct FanOutConsumerStream<T> {
state: Arc<FanOutHubShared<T>>,
consumer_id: u64,
detached: bool,
}
impl<T: Clone + Send + 'static> Iterator for FanOutConsumerStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
let mut state = self.state.state.lock().expect("fan-out hub poisoned");
loop {
if let Some(error) = state.failed.clone() {
return Some(Err(error));
}
if let Some(queue) = state.consumers.get_mut(&self.consumer_id)
&& let Some(item) = queue.pop_front()
{
self.state.condvar.notify_all();
return Some(Ok(item));
}
if state.completed {
return None;
}
state = self
.state
.condvar
.wait(state)
.expect("fan-out hub poisoned while waiting");
}
}
}
impl<T> Drop for FanOutConsumerStream<T> {
fn drop(&mut self) {
if !self.detached {
self.state.remove_consumer(self.consumer_id);
self.detached = true;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testkit::{TestSink, TestSource};
use crate::{Keep, Materializer, Sink, Source};
use std::{
panic::{self, AssertUnwindSafe},
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
},
thread,
time::{Duration, Instant},
};
#[test]
fn merge_hub_accepts_dynamic_producers_and_drains() {
let materializer = Materializer::new();
let ((hub_sink, control), completion) = MergeHub::source_with_draining::<i32>(4)
.to_mat(Sink::collect(), Keep::both)
.run_with_materializer(&materializer)
.expect("merge hub materializes");
hub_sink
.clone()
.run_with(Source::from_iter([1, 2, 3]))
.expect("first producer attaches");
hub_sink
.run_with(Source::from_iter([4, 5]))
.expect("second producer attaches");
control.drain_and_complete();
let mut result = completion.wait().expect("merge hub completes");
result.sort_unstable();
assert_eq!(result, vec![1, 2, 3, 4, 5]);
}
#[test]
fn merge_hub_producer_error_fails_downstream_consumer() {
let materializer = Materializer::new();
let (hub_sink, sink) = MergeHub::source::<i32>(4)
.to_mat(TestSink::probe(), Keep::both)
.run_with_materializer(&materializer)
.expect("merge hub materializes");
let producer_ok = TestSource::probe::<i32>()
.to_mat(hub_sink.clone(), Keep::left)
.run_with_materializer(&materializer)
.expect("successful producer attaches");
let producer_fail = TestSource::probe::<i32>()
.to_mat(hub_sink, Keep::left)
.run_with_materializer(&materializer)
.expect("failing producer attaches");
sink.request(1);
assert_eq!(producer_ok.expect_request(), 1);
producer_ok.send_next(1);
sink.assert_next(1);
producer_fail.send_error(StreamError::Failed("producer failed".to_owned()));
sink.request(1);
assert_eq!(
sink.expect_error(),
StreamError::Failed("producer failed".to_owned())
);
}
#[test]
fn broadcast_hub_backpressures_slowest_consumer() {
let materializer = Materializer::new();
let (publisher, hub_source) = TestSource::probe::<i32>()
.to_mat(BroadcastHub::sink(1), Keep::both)
.run_with_materializer(&materializer)
.expect("broadcast hub materializes");
let sink_a = hub_source
.source()
.run_with(TestSink::probe())
.expect("first consumer materializes");
let sink_b = hub_source
.source()
.run_with(TestSink::probe())
.expect("second consumer materializes");
sink_a.request(1);
sink_b.request(1);
assert_eq!(publisher.expect_request(), 1);
publisher.send_next(1);
sink_a.assert_next(1);
sink_b.assert_next(1);
sink_a.request(1);
assert_eq!(publisher.expect_request(), 1);
publisher.send_next(2);
sink_a.assert_next(2);
sink_b.expect_no_message(Duration::from_millis(250));
sink_a.request(1);
sink_a.expect_no_message(Duration::from_millis(250));
sink_b.request(1);
sink_b.assert_next(2);
assert_eq!(publisher.expect_request(), 1);
}
#[test]
fn broadcast_hub_late_consumer_sees_only_late_elements() {
let materializer = Materializer::new();
let (publisher, hub_source) = TestSource::probe::<i32>()
.to_mat(BroadcastHub::sink(2), Keep::both)
.run_with_materializer(&materializer)
.expect("broadcast hub materializes");
let sink_a = hub_source
.source()
.run_with(TestSink::probe())
.expect("first consumer materializes");
sink_a.request(1);
assert_eq!(publisher.expect_request(), 1);
publisher.send_next(1);
sink_a.assert_next(1);
let sink_b = hub_source
.source()
.run_with(TestSink::probe())
.expect("late consumer materializes");
sink_a.request(1);
sink_b.request(1);
assert_eq!(publisher.expect_request(), 1);
publisher.send_next(2);
sink_a.assert_next(2);
sink_b.assert_next(2);
publisher.send_complete();
sink_a.request(1);
sink_b.request(1);
sink_a.expect_complete();
sink_b.expect_complete();
}
#[test]
fn partition_hub_routes_elements_to_selected_consumers() {
let materializer = Materializer::new();
let hub = Source::from_iter([0, 1, 2, 3])
.run_with_materializer(
PartitionHub::sink(
|info, item| {
let idx = (*item as usize) % info.size();
info.consumer_id_by_idx(idx) as isize
},
2,
8,
),
&materializer,
)
.expect("partition hub materializes");
let sink_a = hub
.source()
.run_with(TestSink::probe())
.expect("first consumer materializes");
let sink_b = hub
.source()
.run_with(TestSink::probe())
.expect("second consumer materializes");
sink_a.request(2);
sink_b.request(2);
sink_a.assert_next_n([0, 2]);
sink_b.assert_next_n([1, 3]);
}
#[test]
fn partition_hub_evaluates_stateful_partitioner_once_per_blocked_element() {
let materializer = Materializer::new();
let partition_calls = Arc::new(AtomicUsize::new(0));
let partition_calls_for_hub = Arc::clone(&partition_calls);
let (publisher, hub) = TestSource::probe::<i32>()
.to_mat(
PartitionHub::sink(
move |info, _item| {
partition_calls_for_hub.fetch_add(1, Ordering::SeqCst);
info.consumer_id_by_idx(0) as isize
},
1,
1,
),
Keep::both,
)
.run_with_materializer(&materializer)
.expect("partition hub materializes");
let sink = hub
.source()
.run_with(TestSink::probe())
.expect("consumer materializes");
assert_eq!(publisher.expect_request(), 1);
publisher.send_next(1);
wait_for_partition_calls(&partition_calls, 1);
assert_eq!(publisher.expect_request(), 1);
publisher.send_next(2);
sink.expect_no_message(Duration::from_millis(250));
wait_for_partition_calls(&partition_calls, 2);
sink.request(1);
sink.assert_next(1);
sink.request(1);
sink.assert_next(2);
assert_eq!(partition_calls.load(Ordering::SeqCst), 2);
}
#[test]
fn broadcast_hub_panicking_upstream_fails_consumers() {
let materializer = Materializer::new();
let hub = Source::from_fn_iter(|| {
let mut yielded = false;
std::iter::from_fn(move || {
if !yielded {
yielded = true;
Some(1)
} else {
panic!("boom");
}
})
})
.run_with_materializer(BroadcastHub::sink_starting_after(1, 8), &materializer)
.expect("broadcast hub materializes");
let sink = hub
.source()
.run_with(TestSink::probe())
.expect("consumer materializes");
sink.request(2);
match panic::catch_unwind(AssertUnwindSafe(|| sink.expect_error())) {
Ok(error) => assert_eq!(
error,
StreamError::Failed("fan-out hub producer panicked".to_owned())
),
Err(payload) => {
assert_eq!(
panic_message(payload),
"expected stream error, got next element"
);
sink.request(1);
assert_eq!(
sink.expect_error(),
StreamError::Failed("fan-out hub producer panicked".to_owned())
);
}
}
}
fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
match payload.downcast::<String>() {
Ok(message) => *message,
Err(payload) => match payload.downcast::<&'static str>() {
Ok(message) => (*message).to_owned(),
Err(_) => "<non-string panic payload>".to_owned(),
},
}
}
fn wait_for_partition_calls(counter: &AtomicUsize, expected: usize) {
let deadline = Instant::now() + Duration::from_secs(1);
while Instant::now() < deadline {
if counter.load(Ordering::SeqCst) == expected {
return;
}
thread::sleep(Duration::from_millis(5));
}
assert_eq!(counter.load(Ordering::SeqCst), expected);
}
}