use super::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OverflowStrategy {
DropHead,
DropTail,
DropBuffer,
DropNew,
Backpressure,
Fail,
}
pub struct AggregateTimer<Agg> {
predicate: Arc<dyn Fn(&Agg) -> bool + Send + Sync>,
interval: Duration,
}
impl<Agg> Clone for AggregateTimer<Agg> {
fn clone(&self) -> Self {
Self {
predicate: Arc::clone(&self.predicate),
interval: self.interval,
}
}
}
impl<Agg> AggregateTimer<Agg> {
#[must_use]
pub fn new<F>(predicate: F, interval: Duration) -> Self
where
F: Fn(&Agg) -> bool + Send + Sync + 'static,
{
assert!(
interval > Duration::ZERO,
"aggregate_with_boundary timer interval must be greater than zero"
);
Self {
predicate: Arc::new(predicate),
interval,
}
}
}
#[derive(Clone)]
enum TerminalSignal {
Complete,
Error(StreamError),
}
struct QueueShared<T> {
state: Mutex<QueueState<T>>,
available: Condvar,
cancelled: Arc<AtomicBool>,
capacity: usize,
}
struct QueueState<T> {
queue: VecDeque<T>,
terminal: Option<TerminalSignal>,
}
impl<T> QueueShared<T> {
fn new(capacity: usize) -> Arc<Self> {
Arc::new(Self {
state: Mutex::new(QueueState {
queue: VecDeque::with_capacity(capacity),
terminal: None,
}),
available: Condvar::new(),
cancelled: Arc::new(AtomicBool::new(false)),
capacity,
})
}
}
struct QueueStream<T> {
shared: Arc<QueueShared<T>>,
completion: Option<StreamCompletion<NotUsed>>,
}
impl<T> Iterator for QueueStream<T> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
let mut state = self
.shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
loop {
if let Some(item) = state.queue.pop_front() {
self.shared.available.notify_all();
return Some(Ok(item));
}
if let Some(terminal) = state.terminal.clone() {
return match terminal {
TerminalSignal::Complete => None,
TerminalSignal::Error(error) => Some(Err(error)),
};
}
state = self
.shared
.available
.wait(state)
.unwrap_or_else(|poison| poison.into_inner());
}
}
}
impl<T> Drop for QueueStream<T> {
fn drop(&mut self) {
self.shared.cancelled.store(true, Ordering::SeqCst);
self.shared.available.notify_all();
let _ = self.completion.take();
}
}
struct SlotShared<T, Extra> {
state: Mutex<SlotState<T, Extra>>,
available: Condvar,
cancelled: Arc<AtomicBool>,
}
struct SlotState<T, Extra> {
slot: Option<T>,
terminal: Option<TerminalSignal>,
extra: Extra,
}
impl<T, Extra> SlotShared<T, Extra> {
fn new(extra: Extra) -> Arc<Self> {
Arc::new(Self {
state: Mutex::new(SlotState {
slot: None,
terminal: None,
extra,
}),
available: Condvar::new(),
cancelled: Arc::new(AtomicBool::new(false)),
})
}
}
struct SlotStream<T, Extra> {
shared: Arc<SlotShared<T, Extra>>,
completion: Option<StreamCompletion<NotUsed>>,
}
impl<T, Extra> Iterator for SlotStream<T, Extra> {
type Item = StreamResult<T>;
fn next(&mut self) -> Option<Self::Item> {
let mut state = self
.shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
loop {
if let Some(item) = state.slot.take() {
self.shared.available.notify_all();
return Some(Ok(item));
}
if let Some(terminal) = state.terminal.clone() {
return match terminal {
TerminalSignal::Complete => None,
TerminalSignal::Error(error) => Some(Err(error)),
};
}
state = self
.shared
.available
.wait(state)
.unwrap_or_else(|poison| poison.into_inner());
}
}
}
impl<T, Extra> Drop for SlotStream<T, Extra> {
fn drop(&mut self) {
self.shared.cancelled.store(true, Ordering::SeqCst);
self.shared.available.notify_all();
let _ = self.completion.take();
}
}
fn finish_queue<T>(shared: &QueueShared<T>, terminal: TerminalSignal) {
let mut state = shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
if state.terminal.is_none() {
state.terminal = Some(terminal);
}
drop(state);
shared.available.notify_all();
}
fn finish_slot<T, Extra>(shared: &SlotShared<T, Extra>, terminal: TerminalSignal) {
let mut state = shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
if state.terminal.is_none() {
state.terminal = Some(terminal);
}
drop(state);
shared.available.notify_all();
}
struct ProducerPanicGuard<T> {
shared: Arc<QueueShared<T>>,
armed: bool,
}
impl<T> ProducerPanicGuard<T> {
fn new(shared: Arc<QueueShared<T>>) -> Self {
Self {
shared,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl<T> Drop for ProducerPanicGuard<T> {
fn drop(&mut self) {
if self.armed {
finish_queue(
&self.shared,
TerminalSignal::Error(StreamError::AbruptTermination),
);
}
}
}
struct SlotProducerPanicGuard<T, Extra> {
shared: Arc<SlotShared<T, Extra>>,
armed: bool,
}
impl<T, Extra> SlotProducerPanicGuard<T, Extra> {
fn new(shared: Arc<SlotShared<T, Extra>>) -> Self {
Self {
shared,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl<T, Extra> Drop for SlotProducerPanicGuard<T, Extra> {
fn drop(&mut self) {
if self.armed {
finish_slot(
&self.shared,
TerminalSignal::Error(StreamError::AbruptTermination),
);
}
}
}
#[derive(Default)]
struct NoExtra;
struct BatchExtra<In> {
remaining: i128,
pending: Option<In>,
}
impl<In> BatchExtra<In> {
fn new(limit: u64) -> Self {
Self {
remaining: i128::from(limit),
pending: None,
}
}
}
struct BoundaryExtra {
ready: bool,
}
struct BoundaryStream<Agg, Emit> {
shared: Arc<SlotShared<Agg, BoundaryExtra>>,
completion: Option<StreamCompletion<NotUsed>>,
timer: Option<Cancellable>,
harvest: Arc<dyn Fn(Agg) -> Emit + Send + Sync>,
}
impl<Agg, Emit> Iterator for BoundaryStream<Agg, Emit> {
type Item = StreamResult<Emit>;
fn next(&mut self) -> Option<Self::Item> {
loop {
let (slot, terminal) = {
let mut state = self
.shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
loop {
if state.extra.ready {
let slot = state.slot.take();
state.extra.ready = false;
self.shared.available.notify_all();
break (slot, None);
}
if state.terminal.is_some() {
if let Some(slot) = state.slot.take() {
self.shared.available.notify_all();
break (Some(slot), None);
}
break (None, state.terminal.clone());
}
state = self
.shared
.available
.wait(state)
.unwrap_or_else(|poison| poison.into_inner());
}
};
if let Some(agg) = slot {
return Some(Ok((self.harvest)(agg)));
}
if let Some(terminal) = terminal {
return match terminal {
TerminalSignal::Complete => None,
TerminalSignal::Error(error) => Some(Err(error)),
};
}
}
}
}
impl<Agg, Emit> Drop for BoundaryStream<Agg, Emit> {
fn drop(&mut self) {
self.shared.cancelled.store(true, Ordering::SeqCst);
self.shared.available.notify_all();
if let Some(timer) = self.timer.take() {
timer.cancel();
}
let _ = self.completion.take();
}
}
fn buffer_stage<T: Send + 'static>(
input: BoxStream<T>,
capacity: usize,
strategy: OverflowStrategy,
materializer: &Materializer,
) -> StreamResult<BoxStream<T>> {
let shared = QueueShared::new(capacity);
let producer_shared = Arc::clone(&shared);
let cancelled = Arc::clone(&shared.cancelled);
let state = Arc::clone(&materializer.inner.state);
let completion = materializer.spawn_stream(move |_| {
let mut panic_guard = ProducerPanicGuard::new(Arc::clone(&producer_shared));
let mut input = runtime_checked_stream(input, state, Some(Arc::clone(&cancelled)));
loop {
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
match input.next() {
Some(Ok(item)) => {
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
match strategy {
OverflowStrategy::Backpressure => {
while guard.queue.len() == producer_shared.capacity
&& !cancelled.load(Ordering::SeqCst)
{
guard = producer_shared
.available
.wait(guard)
.unwrap_or_else(|poison| poison.into_inner());
}
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
guard.queue.push_back(item);
}
OverflowStrategy::DropHead => {
if guard.queue.len() == producer_shared.capacity {
let _ = guard.queue.pop_front();
}
guard.queue.push_back(item);
}
OverflowStrategy::DropTail => {
if guard.queue.len() == producer_shared.capacity {
let _ = guard.queue.pop_back();
}
guard.queue.push_back(item);
}
OverflowStrategy::DropBuffer => {
if guard.queue.len() == producer_shared.capacity {
guard.queue.clear();
}
guard.queue.push_back(item);
}
OverflowStrategy::DropNew => {
if guard.queue.len() < producer_shared.capacity {
guard.queue.push_back(item);
}
}
OverflowStrategy::Fail => {
if guard.queue.len() == producer_shared.capacity {
guard.queue.clear();
drop(guard);
panic_guard.disarm();
finish_queue(
&producer_shared,
TerminalSignal::Error(StreamError::Failed(format!(
"Buffer overflow (max capacity was: {capacity})!"
))),
);
return Ok(NotUsed);
}
guard.queue.push_back(item);
}
}
drop(guard);
producer_shared.available.notify_all();
}
Some(Err(error)) => {
panic_guard.disarm();
finish_queue(&producer_shared, TerminalSignal::Error(error));
return Ok(NotUsed);
}
None => {
panic_guard.disarm();
finish_queue(&producer_shared, TerminalSignal::Complete);
return Ok(NotUsed);
}
}
}
});
Ok(Box::new(QueueStream {
shared,
completion: Some(completion),
}))
}
fn batch_stage<In, Agg, Cost, Seed, Aggregate>(
input: BoxStream<In>,
limit: u64,
cost_fn: Arc<Cost>,
seed: Arc<Seed>,
aggregate: Arc<Aggregate>,
materializer: &Materializer,
) -> StreamResult<BoxStream<Agg>>
where
In: Send + 'static,
Agg: Send + 'static,
Cost: Fn(&In) -> u64 + Send + Sync + 'static,
Seed: Fn(In) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, In) -> Agg + Send + Sync + 'static,
{
let shared = SlotShared::new(BatchExtra::new(limit));
let producer_shared = Arc::clone(&shared);
let cancelled = Arc::clone(&shared.cancelled);
let state = Arc::clone(&materializer.inner.state);
let completion = materializer.spawn_stream(move |_| {
let mut panic_guard = SlotProducerPanicGuard::new(Arc::clone(&producer_shared));
let mut input = runtime_checked_stream(input, state, Some(Arc::clone(&cancelled)));
let mut carry = None::<In>;
loop {
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
let next = if let Some(item) = carry.take() {
Some(Ok(item))
} else {
input.next()
};
match next {
Some(Ok(item)) => {
let cost = i128::from(cost_fn(&item));
let current = {
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
if guard.slot.is_none() {
None
} else if guard.extra.remaining < cost {
guard.extra.pending = Some(item);
while guard.slot.is_some() && !cancelled.load(Ordering::SeqCst) {
guard = producer_shared
.available
.wait(guard)
.unwrap_or_else(|poison| poison.into_inner());
}
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
carry = guard.extra.pending.take();
guard.extra.remaining = i128::from(limit);
continue;
} else {
let current = guard.slot.take();
guard.extra.remaining -= cost;
current
}
};
match current {
None => {
let next_agg = seed(item);
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
guard.extra.remaining = i128::from(limit) - cost;
guard.slot = Some(next_agg);
drop(guard);
producer_shared.available.notify_all();
}
Some(current) => {
let next_agg = aggregate(current, item);
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
guard.slot = Some(next_agg);
drop(guard);
producer_shared.available.notify_all();
}
}
}
Some(Err(error)) => {
panic_guard.disarm();
finish_slot(&producer_shared, TerminalSignal::Error(error));
return Ok(NotUsed);
}
None => {
panic_guard.disarm();
finish_slot(&producer_shared, TerminalSignal::Complete);
return Ok(NotUsed);
}
}
}
});
Ok(Box::new(SlotStream {
shared,
completion: Some(completion),
}))
}
fn expand_stage<In, Out, Expand, Iter>(
input: BoxStream<In>,
expand: Arc<Expand>,
initial: Option<Box<dyn Iterator<Item = Out> + Send>>,
materializer: &Materializer,
) -> StreamResult<BoxStream<Out>>
where
In: Send + 'static,
Out: Send + 'static,
Expand: Fn(In) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Out> + Send + 'static,
{
let shared = SlotShared::new(NoExtra);
let producer_shared = Arc::clone(&shared);
let cancelled = Arc::clone(&shared.cancelled);
let state = Arc::clone(&materializer.inner.state);
let completion = materializer.spawn_stream(move |_| {
let mut panic_guard = SlotProducerPanicGuard::new(Arc::clone(&producer_shared));
let mut input = runtime_checked_stream(input, state, Some(Arc::clone(&cancelled)));
loop {
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
match input.next() {
Some(Ok(item)) => {
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
while guard.slot.is_some() && !cancelled.load(Ordering::SeqCst) {
guard = producer_shared
.available
.wait(guard)
.unwrap_or_else(|poison| poison.into_inner());
}
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
guard.slot = Some(item);
drop(guard);
producer_shared.available.notify_all();
}
Some(Err(error)) => {
panic_guard.disarm();
finish_slot(&producer_shared, TerminalSignal::Error(error));
return Ok(NotUsed);
}
None => {
panic_guard.disarm();
finish_slot(&producer_shared, TerminalSignal::Complete);
return Ok(NotUsed);
}
}
}
});
Ok(Box::new(ExpandStream {
shared,
completion: Some(completion),
current: initial,
expanded_once: false,
seeded_from_upstream: false,
expand,
}))
}
struct ExpandStream<In, Out, Expand, Iter>
where
Expand: Fn(In) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Out> + Send + 'static,
{
shared: Arc<SlotShared<In, NoExtra>>,
completion: Option<StreamCompletion<NotUsed>>,
current: Option<Box<dyn Iterator<Item = Out> + Send>>,
expanded_once: bool,
seeded_from_upstream: bool,
expand: Arc<Expand>,
}
impl<In, Out, Expand, Iter> Iterator for ExpandStream<In, Out, Expand, Iter>
where
In: Send + 'static,
Out: Send + 'static,
Expand: Fn(In) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Out> + Send + 'static,
{
type Item = StreamResult<Out>;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(current) = &mut self.current
&& !self.expanded_once
&& self.seeded_from_upstream
{
if let Some(item) = current.next() {
self.expanded_once = true;
return Some(Ok(item));
}
self.current = None;
self.expanded_once = false;
}
enum Decision<In> {
NewElement(In),
Extrapolate,
EmitInitial,
Terminal(TerminalSignal),
}
let decision = {
let mut state = self
.shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
loop {
if let Some(item) = state.slot.take() {
self.shared.available.notify_all();
break Decision::NewElement(item);
}
if let Some(terminal) = state.terminal.clone() {
break Decision::Terminal(terminal);
}
if self.current.is_some() {
break if self.expanded_once {
Decision::Extrapolate
} else {
Decision::EmitInitial
};
}
state = self
.shared
.available
.wait(state)
.unwrap_or_else(|poison| poison.into_inner());
}
};
match decision {
Decision::NewElement(item) => {
self.current = Some(Box::new((self.expand)(item)));
self.expanded_once = false;
self.seeded_from_upstream = true;
}
Decision::EmitInitial => {
if let Some(current) = &mut self.current
&& let Some(item) = current.next()
{
self.expanded_once = true;
return Some(Ok(item));
}
self.current = None;
self.expanded_once = false;
}
Decision::Extrapolate => {
if let Some(current) = &mut self.current
&& let Some(item) = current.next()
{
return Some(Ok(item));
}
self.current = None;
self.expanded_once = false;
}
Decision::Terminal(terminal) => {
return match terminal {
TerminalSignal::Complete => None,
TerminalSignal::Error(error) => Some(Err(error)),
};
}
}
}
}
}
impl<In, Out, Expand, Iter> Drop for ExpandStream<In, Out, Expand, Iter>
where
Expand: Fn(In) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Out> + Send + 'static,
{
fn drop(&mut self) {
self.shared.cancelled.store(true, Ordering::SeqCst);
self.shared.available.notify_all();
let _ = self.completion.take();
}
}
fn aggregate_with_boundary_stage<In, Agg, Emit, Allocate, Aggregate, Harvest>(
input: BoxStream<In>,
allocate: Arc<Allocate>,
aggregate: Arc<Aggregate>,
harvest: Arc<Harvest>,
emit_on_timer: Option<AggregateTimer<Agg>>,
materializer: &Materializer,
) -> StreamResult<BoxStream<Emit>>
where
In: Send + 'static,
Agg: Send + 'static,
Emit: Send + 'static,
Allocate: Fn() -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, In) -> (Agg, bool) + Send + Sync + 'static,
Harvest: Fn(Agg) -> Emit + Send + Sync + 'static,
{
let shared = SlotShared::new(BoundaryExtra { ready: false });
let producer_shared = Arc::clone(&shared);
let cancelled = Arc::clone(&shared.cancelled);
let state = Arc::clone(&materializer.inner.state);
let completion = materializer.spawn_stream(move |_| {
let mut panic_guard = SlotProducerPanicGuard::new(Arc::clone(&producer_shared));
let mut input = runtime_checked_stream(input, state, Some(Arc::clone(&cancelled)));
loop {
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
match input.next() {
Some(Ok(item)) => {
let current = {
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
while guard.extra.ready && !cancelled.load(Ordering::SeqCst) {
guard = producer_shared
.available
.wait(guard)
.unwrap_or_else(|poison| poison.into_inner());
}
if cancelled.load(Ordering::SeqCst) {
panic_guard.disarm();
return Ok(NotUsed);
}
guard.slot.take()
};
let current = current.unwrap_or_else(|| allocate());
let (updated, ready) = aggregate(current, item);
let mut guard = producer_shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
guard.slot = Some(updated);
guard.extra.ready = ready;
drop(guard);
producer_shared.available.notify_all();
}
Some(Err(error)) => {
panic_guard.disarm();
finish_slot(&producer_shared, TerminalSignal::Error(error));
return Ok(NotUsed);
}
None => {
panic_guard.disarm();
finish_slot(&producer_shared, TerminalSignal::Complete);
return Ok(NotUsed);
}
}
}
});
let timer = emit_on_timer.map(|timer| {
let shared = Arc::clone(&shared);
let cancelled = Arc::clone(&shared.cancelled);
materializer.schedule_with_fixed_delay(timer.interval, timer.interval, move || {
if cancelled.load(Ordering::SeqCst) {
return;
}
let should_emit = {
let state = shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
if state.slot.is_none() || state.extra.ready || state.terminal.is_some() {
None
} else {
Some(std::panic::catch_unwind(std::panic::AssertUnwindSafe(
|| (timer.predicate)(state.slot.as_ref().expect("aggregate present")),
)))
}
};
match should_emit {
Some(Ok(true)) => {
let mut state = shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
if state.slot.is_some() && !state.extra.ready && state.terminal.is_none() {
state.extra.ready = true;
}
drop(state);
shared.available.notify_all();
}
Some(Ok(false)) | None => {}
Some(Err(_)) => {
finish_slot(
&shared,
TerminalSignal::Error(StreamError::AbruptTermination),
);
}
}
})
});
Ok(Box::new(BoundaryStream {
shared,
completion: Some(completion),
timer,
harvest,
}))
}
impl<In: Send + 'static, Out: Send + 'static, Mat: Send + 'static> Flow<In, Out, Mat> {
pub fn buffer(self, size: usize, strategy: OverflowStrategy) -> Flow<In, Out, Mat> {
assert!(size > 0, "buffer size must be greater than zero");
self.via(Flow::from_runtime_transform(move |input, materializer| {
buffer_stage(input, size, strategy, materializer)
}))
}
pub fn conflate_with_seed<Agg, Seed, Aggregate>(
self,
seed: Seed,
aggregate: Aggregate,
) -> Flow<In, Agg, Mat>
where
Agg: Send + 'static,
Seed: Fn(Out) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> Agg + Send + Sync + 'static,
{
let seed = Arc::new(seed);
let aggregate = Arc::new(aggregate);
self.via(Flow::from_runtime_transform(move |input, materializer| {
batch_stage(
input,
1,
Arc::new(|_: &Out| 0),
Arc::clone(&seed),
Arc::clone(&aggregate),
materializer,
)
}))
}
pub fn conflate(self, aggregate: impl Fn(Out, Out) -> Out + Send + Sync + 'static) -> Self {
self.conflate_with_seed(|item| item, aggregate)
}
pub fn batch<Agg, Seed, Aggregate>(
self,
max: u64,
seed: Seed,
aggregate: Aggregate,
) -> Flow<In, Agg, Mat>
where
Agg: Send + 'static,
Seed: Fn(Out) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> Agg + Send + Sync + 'static,
{
assert!(max > 0, "batch max must be greater than zero");
let seed = Arc::new(seed);
let aggregate = Arc::new(aggregate);
self.via(Flow::from_runtime_transform(move |input, materializer| {
batch_stage(
input,
max,
Arc::new(|_: &Out| 1),
Arc::clone(&seed),
Arc::clone(&aggregate),
materializer,
)
}))
}
pub fn batch_weighted<Agg, Cost, Seed, Aggregate>(
self,
max: u64,
cost_fn: Cost,
seed: Seed,
aggregate: Aggregate,
) -> Flow<In, Agg, Mat>
where
Agg: Send + 'static,
Cost: Fn(&Out) -> u64 + Send + Sync + 'static,
Seed: Fn(Out) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> Agg + Send + Sync + 'static,
{
assert!(max > 0, "batch_weighted max must be greater than zero");
let cost_fn = Arc::new(cost_fn);
let seed = Arc::new(seed);
let aggregate = Arc::new(aggregate);
self.via(Flow::from_runtime_transform(move |input, materializer| {
batch_stage(
input,
max,
Arc::clone(&cost_fn),
Arc::clone(&seed),
Arc::clone(&aggregate),
materializer,
)
}))
}
pub fn expand<Next, Expand, Iter>(self, expand: Expand) -> Flow<In, Next, Mat>
where
Next: Send + 'static,
Expand: Fn(Out) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Next> + Send + 'static,
{
let expand = Arc::new(expand);
self.via(Flow::from_runtime_transform(move |input, materializer| {
expand_stage(input, Arc::clone(&expand), None, materializer)
}))
}
pub fn extrapolate<Expand, Iter>(
self,
extrapolator: Expand,
initial: Option<Out>,
) -> Flow<In, Out, Mat>
where
Out: Clone + Sync,
Expand: Fn(Out) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Out> + Send + 'static,
{
let extrapolator = Arc::new(extrapolator);
self.via(Flow::from_runtime_transform(move |input, materializer| {
let extrapolator = Arc::clone(&extrapolator);
let initial = initial.clone().map(|item| {
Box::new(std::iter::once(item)) as Box<dyn Iterator<Item = Out> + Send>
});
expand_stage(
input,
Arc::new(move |item: Out| {
std::iter::once(item.clone()).chain((extrapolator)(item))
}),
initial,
materializer,
)
}))
}
pub fn aggregate_with_boundary<Agg, Emit, Allocate, Aggregate, Harvest>(
self,
allocate: Allocate,
aggregate: Aggregate,
harvest: Harvest,
emit_on_timer: Option<AggregateTimer<Agg>>,
) -> Flow<In, Emit, Mat>
where
Agg: Send + 'static,
Emit: Send + 'static,
Allocate: Fn() -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> (Agg, bool) + Send + Sync + 'static,
Harvest: Fn(Agg) -> Emit + Send + Sync + 'static,
{
let allocate = Arc::new(allocate);
let aggregate = Arc::new(aggregate);
let harvest = Arc::new(harvest);
self.via(Flow::from_runtime_transform(move |input, materializer| {
aggregate_with_boundary_stage(
input,
Arc::clone(&allocate),
Arc::clone(&aggregate),
Arc::clone(&harvest),
emit_on_timer.clone(),
materializer,
)
}))
}
pub fn detach(self) -> Flow<In, Out, Mat> {
self.buffer(1, OverflowStrategy::Backpressure)
}
}
impl<Out: Send + 'static, Mat: Send + 'static> Source<Out, Mat> {
pub fn buffer(self, size: usize, strategy: OverflowStrategy) -> Self {
self.via(Flow::identity().buffer(size, strategy))
}
pub fn conflate_with_seed<Agg, Seed, Aggregate>(
self,
seed: Seed,
aggregate: Aggregate,
) -> Source<Agg, Mat>
where
Agg: Send + 'static,
Seed: Fn(Out) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> Agg + Send + Sync + 'static,
{
self.via(Flow::identity().conflate_with_seed(seed, aggregate))
}
pub fn conflate(self, aggregate: impl Fn(Out, Out) -> Out + Send + Sync + 'static) -> Self {
self.via(Flow::identity().conflate(aggregate))
}
pub fn batch<Agg, Seed, Aggregate>(
self,
max: u64,
seed: Seed,
aggregate: Aggregate,
) -> Source<Agg, Mat>
where
Agg: Send + 'static,
Seed: Fn(Out) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> Agg + Send + Sync + 'static,
{
self.via(Flow::identity().batch(max, seed, aggregate))
}
pub fn batch_weighted<Agg, Cost, Seed, Aggregate>(
self,
max: u64,
cost_fn: Cost,
seed: Seed,
aggregate: Aggregate,
) -> Source<Agg, Mat>
where
Agg: Send + 'static,
Cost: Fn(&Out) -> u64 + Send + Sync + 'static,
Seed: Fn(Out) -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> Agg + Send + Sync + 'static,
{
self.via(Flow::identity().batch_weighted(max, cost_fn, seed, aggregate))
}
pub fn expand<Next, Expand, Iter>(self, expand: Expand) -> Source<Next, Mat>
where
Next: Send + 'static,
Expand: Fn(Out) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Next> + Send + 'static,
{
self.via(Flow::identity().expand(expand))
}
pub fn extrapolate<Expand, Iter>(self, extrapolator: Expand, initial: Option<Out>) -> Self
where
Out: Clone + Sync,
Expand: Fn(Out) -> Iter + Send + Sync + 'static,
Iter: Iterator<Item = Out> + Send + 'static,
{
self.via(Flow::identity().extrapolate(extrapolator, initial))
}
pub fn aggregate_with_boundary<Agg, Emit, Allocate, Aggregate, Harvest>(
self,
allocate: Allocate,
aggregate: Aggregate,
harvest: Harvest,
emit_on_timer: Option<AggregateTimer<Agg>>,
) -> Source<Emit, Mat>
where
Agg: Send + 'static,
Emit: Send + 'static,
Allocate: Fn() -> Agg + Send + Sync + 'static,
Aggregate: Fn(Agg, Out) -> (Agg, bool) + Send + Sync + 'static,
Harvest: Fn(Agg) -> Emit + Send + Sync + 'static,
{
self.via(Flow::identity().aggregate_with_boundary(
allocate,
aggregate,
harvest,
emit_on_timer,
))
}
pub fn detach(self) -> Self {
self.via(Flow::identity().detach())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testkit::{TestSink, TestSource};
use std::sync::mpsc;
use std::time::Duration;
fn materialize_buffered_stream<T: Send + 'static>(
source: Source<T>,
) -> (BoxStream<T>, Materializer) {
let materializer = Materializer::new();
let (stream, _) = Arc::clone(&source.factory)
.create(&materializer)
.expect("buffer source materializes");
(stream, materializer)
}
fn buffered_probe(
strategy: OverflowStrategy,
) -> (
crate::testkit::TestPublisherProbe<i32>,
crate::testkit::TestSubscriberProbe<i32>,
) {
TestSource::probe::<i32>()
.buffer(2, strategy)
.to_mat(TestSink::probe(), Keep::both)
.run()
.expect("buffer probe materializes")
}
fn rate_probe<T, U>(
flow: impl FnOnce(
Source<T, crate::testkit::TestPublisherProbe<T>>,
) -> Source<U, crate::testkit::TestPublisherProbe<T>>,
) -> (
crate::testkit::TestPublisherProbe<T>,
crate::testkit::TestSubscriberProbe<U>,
)
where
T: Send + 'static,
U: Send + 'static,
{
flow(TestSource::probe::<T>())
.to_mat(TestSink::probe(), Keep::both)
.run()
.expect("rate probe materializes")
}
#[test]
fn buffer_drop_head_drops_oldest_buffered_elements() {
let (publisher, subscriber) = buffered_probe(OverflowStrategy::DropHead);
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.expect_request();
publisher.send_next(3);
publisher.expect_request();
publisher.send_next(4);
publisher.expect_request();
publisher.send_complete();
subscriber.request(3);
assert_eq!(subscriber.expect_next(), 3);
assert_eq!(subscriber.expect_next(), 4);
subscriber.expect_complete();
}
#[test]
fn buffer_pulls_eagerly_before_downstream_demand() {
let (mut publisher, mut subscriber) = buffered_probe(OverflowStrategy::Backpressure);
publisher.set_timeout(Duration::from_millis(250));
subscriber.set_timeout(Duration::from_millis(250));
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
subscriber.expect_no_message(Duration::from_millis(250));
subscriber.request(2);
assert_eq!(subscriber.expect_next(), 1);
assert_eq!(subscriber.expect_next(), 2);
}
#[test]
fn buffer_drop_tail_drops_newest_buffered_element() {
let (publisher, subscriber) = buffered_probe(OverflowStrategy::DropTail);
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.expect_request();
publisher.send_next(3);
publisher.expect_request();
publisher.send_next(4);
publisher.expect_request();
publisher.send_complete();
subscriber.request(3);
assert_eq!(subscriber.expect_next(), 1);
assert_eq!(subscriber.expect_next(), 4);
subscriber.expect_complete();
}
#[test]
fn buffer_drop_buffer_drops_all_buffered_elements() {
let (publisher, subscriber) = buffered_probe(OverflowStrategy::DropBuffer);
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.expect_request();
publisher.send_next(3);
publisher.expect_request();
publisher.send_next(4);
publisher.expect_request();
publisher.send_complete();
subscriber.request(3);
assert_eq!(subscriber.expect_next(), 3);
assert_eq!(subscriber.expect_next(), 4);
subscriber.expect_complete();
}
#[test]
fn buffer_drop_new_drops_incoming_elements() {
let (publisher, subscriber) = buffered_probe(OverflowStrategy::DropNew);
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.expect_request();
publisher.send_next(3);
publisher.expect_request();
publisher.send_next(4);
publisher.expect_request();
publisher.send_complete();
subscriber.request(3);
assert_eq!(subscriber.expect_next(), 1);
assert_eq!(subscriber.expect_next(), 2);
subscriber.expect_complete();
}
#[test]
fn buffer_backpressure_stops_pulling_when_full() {
let (mut publisher, mut subscriber) = buffered_probe(OverflowStrategy::Backpressure);
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.set_timeout(Duration::from_millis(250));
subscriber.set_timeout(Duration::from_millis(250));
subscriber.expect_no_message(Duration::from_millis(250));
subscriber.request(1);
assert_eq!(subscriber.expect_next(), 1);
publisher.expect_request();
publisher.send_next(3);
publisher.send_complete();
subscriber.request(3);
assert_eq!(subscriber.expect_next(), 2);
assert_eq!(subscriber.expect_next(), 3);
subscriber.expect_complete();
}
#[test]
fn buffer_fail_surfaces_overflow_error() {
let (publisher, subscriber) = buffered_probe(OverflowStrategy::Fail);
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.expect_request();
publisher.send_next(3);
publisher.expect_cancellation();
subscriber.request(1);
assert_eq!(
subscriber.expect_error(),
StreamError::Failed("Buffer overflow (max capacity was: 2)!".to_owned())
);
}
#[test]
fn buffer_terminal_completion_is_sticky_across_repolls() {
let (mut stream, _materializer) = materialize_buffered_stream(
Source::from_iter([1, 2]).buffer(2, OverflowStrategy::Backpressure),
);
assert_eq!(stream.next(), Some(Ok(1)));
assert_eq!(stream.next(), Some(Ok(2)));
assert_eq!(stream.next(), None);
assert_eq!(stream.next(), None);
}
#[test]
fn buffer_terminal_failure_is_sticky_across_repolls() {
let (mut stream, _materializer) = materialize_buffered_stream(
Source::<i32>::failed(StreamError::Failed("boom".to_owned()))
.buffer(2, OverflowStrategy::Backpressure),
);
assert_eq!(
stream.next(),
Some(Err(StreamError::Failed("boom".to_owned())))
);
assert_eq!(
stream.next(),
Some(Err(StreamError::Failed("boom".to_owned())))
);
}
#[test]
fn buffer_surfaces_producer_panics_as_stream_failure() {
let (sender, receiver) = mpsc::channel();
std::thread::spawn(move || {
let (mut stream, _materializer) = materialize_buffered_stream(
Source::from_iter([1, 2, 3])
.map(|item| {
if item == 2 {
panic!("boom");
}
item
})
.buffer(2, OverflowStrategy::Backpressure),
);
sender
.send((stream.next(), stream.next(), stream.next()))
.expect("test thread sends buffered panic results");
});
let (first, second, third) = receiver
.recv_timeout(Duration::from_secs(1))
.expect("buffer panic path should not hang");
assert_eq!(first, Some(Ok(1)));
assert_eq!(second, Some(Err(StreamError::AbruptTermination)));
assert_eq!(third, Some(Err(StreamError::AbruptTermination)));
}
#[test]
fn conflate_passes_through_without_rate_difference() {
let (publisher, subscriber) =
rate_probe(|source| source.conflate(|left, right| left + right));
for value in 1..=4 {
subscriber.request(1);
publisher.expect_request();
publisher.send_next(value);
assert_eq!(subscriber.expect_next(), value);
}
}
#[test]
fn conflate_aggregates_while_downstream_is_silent() {
let (publisher, subscriber) = rate_probe(|source| {
source.conflate_with_seed(
|item| vec![item],
|mut items, item| {
items.push(item);
items
},
)
});
for value in 1..=4 {
publisher.expect_request();
publisher.send_next(value);
}
publisher.expect_request();
publisher.send_complete();
subscriber.request(2);
assert_eq!(subscriber.expect_next(), vec![1, 2, 3, 4]);
subscriber.expect_complete();
}
#[test]
fn batch_passes_through_without_rate_difference() {
let (publisher, subscriber) =
rate_probe(|source| source.batch(2, |item| item, |left, right| left + right));
for value in 1..=4 {
subscriber.request(1);
publisher.expect_request();
publisher.send_next(value);
assert_eq!(subscriber.expect_next(), value);
}
}
#[test]
fn batch_aggregates_while_downstream_is_silent() {
let (publisher, subscriber) = rate_probe(|source| {
source.batch(
u64::MAX,
|item| vec![item],
|mut items, item| {
items.insert(0, item);
items
},
)
});
for value in 1..=10 {
publisher.expect_request();
publisher.send_next(value);
}
publisher.expect_request();
publisher.send_complete();
subscriber.request(1);
assert_eq!(
subscriber.expect_next(),
vec![10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
);
}
#[test]
fn batch_weighted_keeps_heavy_elements_separate() {
let (publisher, subscriber) = rate_probe(|source| {
source.batch_weighted(3, |_| 4, |item| item, |left, right| left + right)
});
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
subscriber.request(1);
assert_eq!(subscriber.expect_next(), 1);
publisher.send_next(3);
publisher.send_complete();
subscriber.request(2);
assert_eq!(subscriber.expect_next(), 2);
assert_eq!(subscriber.expect_next(), 3);
}
#[test]
fn batch_backpressures_at_max_aggregate() {
let (publisher, subscriber) =
rate_probe(|source| source.batch(2, |item| item, |left, right| left + right));
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
publisher.expect_request();
publisher.send_next(3);
subscriber.expect_no_message(Duration::from_millis(200));
subscriber.request(1);
let first = subscriber.expect_next();
publisher.expect_request();
publisher.send_next(4);
publisher.send_complete();
let mut all_values = vec![first];
all_values.extend(subscriber.drain_until_complete());
let total: i32 = all_values.iter().sum();
assert_eq!(total, 10, "all elements should be accounted for");
for &v in &all_values {
assert!((1..=7).contains(&v), "batch value {v} out of range");
}
}
#[test]
fn expand_passes_through_without_rate_difference() {
let (publisher, subscriber) = rate_probe(|source| source.expand(std::iter::once::<i32>));
for value in 1..=4 {
publisher.expect_request();
publisher.send_next(value);
subscriber.request(1);
assert_eq!(subscriber.expect_next(), value);
}
}
#[test]
fn expand_elements_while_upstream_is_silent() {
let (publisher, mut subscriber) = rate_probe(|source| source.expand(std::iter::repeat));
subscriber.set_timeout(Duration::from_millis(250));
publisher.expect_request();
publisher.send_next(42);
subscriber.request(4);
assert_eq!(subscriber.expect_next(), 42);
assert_eq!(subscriber.expect_next(), 42);
assert_eq!(subscriber.expect_next(), 42);
assert_eq!(subscriber.expect_next(), 42);
publisher.expect_request();
publisher.send_next(-42);
subscriber.expect_no_message(Duration::from_millis(250));
subscriber.request(1);
assert_eq!(subscriber.expect_next(), -42);
}
#[test]
fn expand_does_not_drop_last_element() {
let (mut stream, _materializer) =
materialize_buffered_stream(Source::from_iter([1, 2]).expand(std::iter::once::<i32>));
assert_eq!(stream.next(), Some(Ok(1)));
assert_eq!(stream.next(), Some(Ok(2)));
assert_eq!(stream.next(), None);
}
#[test]
fn expand_handles_finite_extrapolations() {
let (mut stream, _materializer) = materialize_buffered_stream(
Source::from_iter([1, 2]).expand(|item| (0..3).map(move |index| (item, index))),
);
let mut output = Vec::new();
for item in stream.by_ref() {
output.push(item.expect("expand should not fail"));
}
assert!(
!output.is_empty(),
"expand must emit at least the first real element"
);
assert_eq!(
output
.iter()
.filter_map(|(item, index)| (*index == 0).then_some(*item))
.collect::<Vec<_>>(),
vec![1, 2],
"each real upstream element must emit its first expanded value exactly once"
);
let mut current_item = output[0].0;
let mut expected_index = 0;
for &(item, index) in &output {
assert!(
item == current_item || item == current_item + 1,
"expanded output must stay on the current item or advance to the next real item"
);
if item != current_item {
assert_eq!(
item,
current_item + 1,
"real items must remain in upstream order"
);
assert_eq!(
index, 0,
"a new real item must begin with its first expanded value"
);
current_item = item;
expected_index = 0;
}
assert_eq!(
index, expected_index,
"each real item may be followed only by its own ordered extrapolations"
);
expected_index += 1;
}
assert_eq!(stream.next(), None);
assert_eq!(stream.next(), None);
}
#[test]
fn expand_emits_first_value_even_if_terminal_is_already_visible() {
let shared = SlotShared::new(NoExtra);
{
let mut state = shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
state.slot = Some(42);
state.terminal = Some(TerminalSignal::Complete);
}
let mut stream = ExpandStream {
shared,
completion: None,
current: None,
expanded_once: false,
seeded_from_upstream: false,
expand: Arc::new(|item| (0..3).map(move |index| (item, index))),
};
assert_eq!(stream.next(), Some(Ok((42, 0))));
assert_eq!(stream.next(), None);
}
#[test]
fn extrapolate_initial_yields_to_real_element_already_in_slot() {
let shared = SlotShared::new(NoExtra);
{
let mut state = shared
.state
.lock()
.unwrap_or_else(|poison| poison.into_inner());
state.slot = Some(7);
}
let mut stream = ExpandStream {
shared: Arc::clone(&shared),
completion: None,
current: Some(Box::new(std::iter::repeat(99))),
expanded_once: false,
seeded_from_upstream: false,
expand: Arc::new(std::iter::repeat),
};
assert_eq!(stream.next(), Some(Ok(7)));
finish_slot(&shared, TerminalSignal::Complete);
assert_eq!(stream.next(), None);
assert_eq!(stream.next(), None);
}
#[test]
fn extrapolate_preserves_original_before_filling_gaps() {
let (publisher, subscriber) =
rate_probe(|source| source.extrapolate(|item| std::iter::once(item + 100), None));
publisher.expect_request();
publisher.send_next(1);
subscriber.request(2);
assert_eq!(subscriber.expect_next(), 1);
assert_eq!(subscriber.expect_next(), 101);
}
#[test]
fn extrapolate_emits_initial_element_before_upstream_arrives() {
let (publisher, subscriber) =
rate_probe(|source| source.extrapolate(std::iter::repeat, Some(0)));
subscriber.request(1);
assert_eq!(subscriber.expect_next(), 0);
publisher.expect_request();
publisher.send_next(42);
subscriber.request(3);
assert_eq!(subscriber.expect_next(), 42);
assert_eq!(subscriber.expect_next(), 42);
assert_eq!(subscriber.expect_next(), 42);
}
#[test]
fn aggregate_with_boundary_splits_by_size() {
let result = Source::from_iter(1..=7)
.aggregate_with_boundary(
Vec::<i32>::new,
|mut buffer, item| {
buffer.push(item);
let ready = buffer.len() >= 3;
(buffer, ready)
},
|buffer| buffer,
None,
)
.run_collect()
.unwrap();
assert_eq!(result, vec![vec![1, 2, 3], vec![4, 5, 6], vec![7]]);
}
#[test]
fn aggregate_with_boundary_honors_timer_trigger() {
let (publisher, mut subscriber) = rate_probe(|source| {
source.aggregate_with_boundary(
Vec::<i32>::new,
|mut buffer, item| {
buffer.push(item);
(buffer, false)
},
|buffer| buffer,
Some(AggregateTimer::new(
|buffer: &Vec<i32>| !buffer.is_empty(),
Duration::from_millis(10),
)),
)
});
subscriber.set_timeout(Duration::from_millis(200));
publisher.expect_request();
publisher.send_next(1);
publisher.expect_request();
publisher.send_next(2);
subscriber.request(1);
assert_eq!(subscriber.expect_next(), vec![1, 2]);
}
#[test]
fn detach_passes_through_all_elements() {
assert_eq!(
Source::from_iter(1..=100).detach().run_collect().unwrap(),
(1..=100).collect::<Vec<_>>()
);
}
#[test]
fn detach_passes_through_failure() {
let result = Source::<i32>::failed(StreamError::Failed("boom".to_owned()))
.detach()
.run_collect();
assert_eq!(result, Err(StreamError::Failed("boom".to_owned())));
}
#[test]
fn detach_emits_last_element_when_completed_without_demand() {
let (mut stream, _materializer) = materialize_buffered_stream(Source::single(42).detach());
assert_eq!(stream.next(), Some(Ok(42)));
assert_eq!(stream.next(), None);
}
}