use futures::future::FutureExt;
use std::future::Future;
use crate::stream::{Flow, NotUsed, RunnableGraph, Sink, Source, StreamResult};
#[derive(Clone)]
pub struct SourceWithContext<Out, Ctx, Mat = NotUsed> {
pub(crate) delegate: Source<(Out, Ctx), Mat>,
}
#[derive(Clone)]
pub struct FlowWithContext<In, CtxIn, Out, CtxOut, Mat = NotUsed> {
pub(crate) delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
}
impl<Out: Send + 'static, Ctx: Send + 'static, Mat: Send + 'static>
SourceWithContext<Out, Ctx, Mat>
{
pub(crate) fn from_source(delegate: Source<(Out, Ctx), Mat>) -> Self {
Self { delegate }
}
pub fn as_source(self) -> Source<(Out, Ctx), Mat> {
self.delegate
}
pub fn run_collect(self) -> StreamResult<Vec<(Out, Ctx)>> {
self.delegate.run_collect()
}
pub fn map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> Next + Send + Sync + 'static,
{
SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
}
pub fn filter<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
where
F: Fn(&Out) -> bool + Send + Sync + 'static,
{
SourceWithContext::from_source(self.delegate.filter_map(move |(out, ctx)| {
if predicate(&out) {
Some((out, ctx))
} else {
None
}
}))
}
pub fn filter_not<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
where
F: Fn(&Out) -> bool + Send + Sync + 'static,
{
self.filter(move |out| !predicate(out))
}
pub fn filter_map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
{
SourceWithContext::from_source(
self.delegate
.filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
)
}
pub fn map_concat<Next, F, I>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> I + Send + Sync + 'static,
I: IntoIterator<Item = Next>,
I::IntoIter: Send + 'static,
Ctx: Clone,
{
SourceWithContext::from_source(self.delegate.map_concat(move |(out, ctx)| {
let ctx = ctx.clone();
f(out).into_iter().map(move |next| (next, ctx.clone()))
}))
}
pub fn map_async<Next, F, Fut>(
self,
parallelism: usize,
f: F,
) -> SourceWithContext<Next, Ctx, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> Fut + Send + Sync + 'static,
Fut: Future<Output = StreamResult<Next>> + Send + 'static,
{
SourceWithContext::from_source(self.delegate.map_async(parallelism, move |(out, ctx)| {
f(out).map(|next| next.map(|next| (next, ctx)))
}))
}
pub fn map_context<CtxOut, F>(self, f: F) -> SourceWithContext<Out, CtxOut, Mat>
where
CtxOut: Send + 'static,
F: Fn(Ctx) -> CtxOut + Send + Sync + 'static,
{
SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
}
pub fn grouped(self, size: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat> {
SourceWithContext::from_source(self.delegate.grouped(size).map(unzip_pairs))
}
pub fn sliding(self, size: usize, step: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat>
where
Out: Clone,
Ctx: Clone,
{
SourceWithContext::from_source(self.delegate.sliding(size, step).map(unzip_pairs))
}
pub fn via<Out2, Ctx2, FlowMat>(
self,
flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
) -> SourceWithContext<Out2, Ctx2, Mat>
where
Out2: Send + 'static,
Ctx2: Send + 'static,
FlowMat: Send + 'static,
{
SourceWithContext::from_source(self.delegate.via(flow.delegate))
}
pub fn via_mat<Out2, Ctx2, FlowMat, Combined, F>(
self,
flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
combine: F,
) -> SourceWithContext<Out2, Ctx2, Combined>
where
Out2: Send + 'static,
Ctx2: Send + 'static,
FlowMat: Send + 'static,
Combined: Send + 'static,
F: Fn(Mat, FlowMat) -> Combined + Send + Sync + 'static,
{
SourceWithContext::from_source(self.delegate.via_mat(flow.delegate, combine))
}
pub fn to<SinkMat>(self, sink: Sink<(Out, Ctx), SinkMat>) -> RunnableGraph<Mat>
where
SinkMat: Send + 'static,
{
self.delegate.to(sink)
}
pub fn to_mat<SinkMat, Combined, F>(
self,
sink: Sink<(Out, Ctx), SinkMat>,
combine: F,
) -> RunnableGraph<Combined>
where
SinkMat: Send + 'static,
Combined: Send + 'static,
F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
{
self.delegate.to_mat(sink, combine)
}
}
impl<In: Send + 'static, CtxIn: Send + 'static> FlowWithContext<In, CtxIn, In, CtxIn, NotUsed> {
pub fn identity() -> Self {
FlowWithContext::from_flow(Flow::identity())
}
}
impl<
In: Send + 'static,
CtxIn: Send + 'static,
Out: Send + 'static,
CtxOut: Send + 'static,
Mat: Send + 'static,
> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
{
pub(crate) fn from_flow(
delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat> {
FlowWithContext { delegate }
}
pub fn as_flow(self) -> Flow<(In, CtxIn), (Out, CtxOut), Mat> {
self.delegate
}
pub fn map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> Next + Send + Sync + 'static,
{
FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
}
pub fn filter<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
where
F: Fn(&Out) -> bool + Send + Sync + 'static,
{
FlowWithContext::from_flow(self.delegate.filter(move |(out, _)| predicate(out)))
}
pub fn filter_not<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
where
F: Fn(&Out) -> bool + Send + Sync + 'static,
{
self.filter(move |out| !predicate(out))
}
pub fn filter_map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
{
FlowWithContext::from_flow(
self.delegate
.filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
)
}
pub fn map_concat<Next, F, I>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> I + Send + Sync + 'static,
I: IntoIterator<Item = Next>,
I::IntoIter: Send + 'static,
CtxOut: Clone,
{
FlowWithContext::from_flow(self.delegate.map_concat(move |(out, ctx)| {
let ctx = ctx.clone();
f(out).into_iter().map(move |next| (next, ctx.clone()))
}))
}
pub fn map_async<Next, F, Fut>(
self,
parallelism: usize,
f: F,
) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
where
Next: Send + 'static,
F: Fn(Out) -> Fut + Send + Sync + 'static,
Fut: Future<Output = StreamResult<Next>> + Send + 'static,
{
FlowWithContext::from_flow(self.delegate.map_async(parallelism, move |(out, ctx)| {
f(out).map(|next| next.map(|next| (next, ctx)))
}))
}
pub fn map_context<CtxOut2, F>(self, f: F) -> FlowWithContext<In, CtxIn, Out, CtxOut2, Mat>
where
CtxOut2: Send + 'static,
F: Fn(CtxOut) -> CtxOut2 + Send + Sync + 'static,
{
FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
}
pub fn grouped(self, n: usize) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat> {
FlowWithContext::from_flow(self.delegate.grouped(n).map(unzip_pairs))
}
pub fn sliding(
self,
n: usize,
step: usize,
) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat>
where
Out: Clone,
CtxOut: Clone,
{
FlowWithContext::from_flow(self.delegate.sliding(n, step).map(unzip_pairs))
}
pub fn via<Out2, Ctx2, FlowMat>(
self,
flow: FlowWithContext<Out, CtxOut, Out2, Ctx2, FlowMat>,
) -> FlowWithContext<In, CtxIn, Out2, Ctx2, Mat>
where
Out2: Send + 'static,
Ctx2: Send + 'static,
FlowMat: Send + 'static,
{
FlowWithContext::from_flow(self.delegate.via(flow.delegate))
}
pub fn to<SinkMat>(self, sink: Sink<(Out, CtxOut), SinkMat>) -> Sink<(In, CtxIn), Mat>
where
SinkMat: Send + 'static,
{
self.delegate.to(sink)
}
pub fn to_mat<SinkMat, Combined, F>(
self,
sink: Sink<(Out, CtxOut), SinkMat>,
combine: F,
) -> Sink<(In, CtxIn), Combined>
where
SinkMat: Send + 'static,
Combined: Send + 'static,
F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
{
self.delegate.to_mat(sink, combine)
}
}
fn unzip_pairs<Out, Ctx>(pairs: Vec<(Out, Ctx)>) -> (Vec<Out>, Vec<Ctx>) {
let mut outs = Vec::with_capacity(pairs.len());
let mut ctxs = Vec::with_capacity(pairs.len());
for (out, ctx) in pairs {
outs.push(out);
ctxs.push(ctx);
}
(outs, ctxs)
}
#[cfg(test)]
mod tests {
use super::*;
use std::{thread, time::Duration};
#[test]
fn source_with_context_preserves_context_for_map_and_filter() {
let values = Source::from_iter(0_i32..6)
.as_source_with_context(|item| item + 100)
.map(|item| item + 1)
.filter(|item| item % 2 == 0)
.filter_not(|item| *item == 4)
.run_collect()
.unwrap();
assert_eq!(values, vec![(2, 101), (6, 105)]);
}
#[test]
fn source_with_context_filters_context_with_map_filter() {
let values = Source::from_iter(1_i32..5)
.as_source_with_context(|item| item * 10)
.filter(|item| item % 2 == 1)
.run_collect()
.unwrap();
assert_eq!(values, vec![(1, 10), (3, 30)]);
}
#[test]
fn source_with_context_map_context_transform_is_supported() {
let values = Source::from_iter(1_i32..4)
.as_source_with_context(|item| *item)
.map_context(|ctx| ctx + 10)
.run_collect()
.unwrap();
assert_eq!(values, vec![(1, 11), (2, 12), (3, 13)]);
}
#[test]
fn source_with_context_map_concat_duplicates_context() {
let values = Source::from_iter([1_i32, 2])
.as_source_with_context(|item| item + 10)
.map_concat(|item| vec![item + 1, item + 2])
.run_collect()
.unwrap();
assert_eq!(values, vec![(2, 11), (3, 11), (3, 12), (4, 12)]);
}
#[test]
fn source_with_context_groups_context_vectors_with_grouped_and_sliding() {
let grouped = Source::from_iter(1_i32..4)
.as_source_with_context(|item| item + 10)
.grouped(2)
.run_collect()
.unwrap();
assert_eq!(
grouped,
vec![(vec![1, 2], vec![11, 12]), (vec![3], vec![13])]
);
let sliding = Source::from_iter([1_i32, 2, 3, 4])
.as_source_with_context(|item| item + 10)
.sliding(3, 2)
.run_collect()
.unwrap();
assert_eq!(
sliding,
vec![
(vec![1, 2, 3], vec![11, 12, 13]),
(vec![3, 4], vec![13, 14])
]
);
}
#[test]
fn source_with_context_map_async_keeps_context_with_out_of_order_completions() {
let values = Source::from_iter([3_i32, 1, 2, 0])
.as_source_with_context(|item| item + 100)
.map_async(2, |item| async move {
if item % 2 == 0 {
thread::sleep(Duration::from_millis(20));
} else {
thread::sleep(Duration::from_millis(2));
}
Ok(item * 2)
})
.run_collect()
.unwrap();
assert_eq!(values, vec![(6, 103), (2, 101), (4, 102), (0, 100)]);
}
#[test]
fn source_with_context_filter_and_map_with_string_elements() {
let input = ["a".to_string(), "bravo".to_string(), "charlie".to_string()];
let source_values = Source::from_iter(input.to_vec())
.as_source_with_context(|item| format!("ctx-{item}"))
.filter(|item| item.len() >= 5)
.map(|item| format!("mapped:{item}"))
.run_collect()
.unwrap();
assert_eq!(
source_values,
vec![
("mapped:bravo".to_string(), "ctx-bravo".to_string()),
("mapped:charlie".to_string(), "ctx-charlie".to_string()),
]
);
let flow_values = Source::from_iter(input)
.as_source_with_context(|item| format!("ctx-{item}"))
.via(
FlowWithContext::<String, String, String, String, NotUsed>::identity()
.filter(|item| item.len() >= 5)
.map(|item| format!("mapped:{item}")),
)
.run_collect()
.unwrap();
assert_eq!(flow_values, source_values);
}
#[test]
fn source_with_context_via_context_flow_and_to_mat() {
let flow = FlowWithContext::<i32, i32, i32, i32, NotUsed>::from_flow(
Flow::identity().map(|(value, ctx)| (value * 2, ctx * 2)),
);
let sink = Sink::collect();
let completion = Source::from_iter([1_i32, 2, 3])
.as_source_with_context(|item| item + 10)
.via(flow)
.to_mat(sink, |_, mat| mat)
.run()
.unwrap();
assert_eq!(completion.wait().unwrap(), vec![(2, 22), (4, 24), (6, 26)]);
}
#[test]
fn source_with_context_as_source_to_runs_and_to_mat_work() {
let source = Source::from_iter([1_i32, 2, 3]).as_source_with_context(|item| item + 10);
assert_eq!(source.clone().to(Sink::ignore()).run(), Ok(NotUsed));
let pair_sink = source
.to_mat(Sink::collect(), |_, pairs| pairs)
.run()
.unwrap();
assert_eq!(pair_sink.wait().unwrap(), vec![(1, 11), (2, 12), (3, 13)]);
}
}