Skip to main content

datum/
context.rs

1//! Context-aware stream wrappers.
2//!
3//! `SourceWithContext` and `FlowWithContext` are thin wrappers around regular pair
4//! streams that keep an explicit context element adjacent to each data element.
5//!
6//! The operator set mirrors Akka's `SourceWithContext` / `FlowWithContext`: only
7//! one-to-one operators are exposed by design so `(data, context)` alignment cannot be
8//! broken by this API.
9//!
10//! This is intentional. Reordering, arbitrarily dropping, or expanding to
11//! variable-length sequences can detach context from its intended element, which can
12//! silently break downstream sinks that rely on commit/trace/offset semantics. Keeping
13//! the API wrapper-centric preserves the one-to-one contract by construction while
14//! still enabling downstream composition.
15
16use futures::future::FutureExt;
17use std::future::Future;
18
19use crate::stream::{Flow, NotUsed, RunnableGraph, Sink, Source, StreamResult};
20
21/// A `Source` whose elements each carry an adjacent context value (mirrors Akka's
22/// `SourceWithContext`). Exposes only context-preserving operators; convert via
23/// [`SourceWithContext::as_source`].
24#[derive(Clone)]
25pub struct SourceWithContext<Out, Ctx, Mat = NotUsed> {
26    pub(crate) delegate: Source<(Out, Ctx), Mat>,
27}
28
29/// A `Flow` whose input and output elements each carry an adjacent context value (mirrors Akka's
30/// `FlowWithContext`). Exposes only context-preserving operators; convert via
31/// [`FlowWithContext::as_flow`].
32#[derive(Clone)]
33pub struct FlowWithContext<In, CtxIn, Out, CtxOut, Mat = NotUsed> {
34    pub(crate) delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
35}
36
37impl<Out: Send + 'static, Ctx: Send + 'static, Mat: Send + 'static>
38    SourceWithContext<Out, Ctx, Mat>
39{
40    pub(crate) fn from_source(delegate: Source<(Out, Ctx), Mat>) -> Self {
41        Self { delegate }
42    }
43
44    pub fn as_source(self) -> Source<(Out, Ctx), Mat> {
45        self.delegate
46    }
47
48    pub fn run_collect(self) -> StreamResult<Vec<(Out, Ctx)>> {
49        self.delegate.run_collect()
50    }
51
52    pub fn map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
53    where
54        Next: Send + 'static,
55        F: Fn(Out) -> Next + Send + Sync + 'static,
56    {
57        SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
58    }
59
60    pub fn filter<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
61    where
62        F: Fn(&Out) -> bool + Send + Sync + 'static,
63    {
64        SourceWithContext::from_source(self.delegate.filter_map(move |(out, ctx)| {
65            if predicate(&out) {
66                Some((out, ctx))
67            } else {
68                None
69            }
70        }))
71    }
72
73    pub fn filter_not<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
74    where
75        F: Fn(&Out) -> bool + Send + Sync + 'static,
76    {
77        self.filter(move |out| !predicate(out))
78    }
79
80    pub fn filter_map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
81    where
82        Next: Send + 'static,
83        F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
84    {
85        SourceWithContext::from_source(
86            self.delegate
87                .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
88        )
89    }
90
91    pub fn map_concat<Next, F, I>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
92    where
93        Next: Send + 'static,
94        F: Fn(Out) -> I + Send + Sync + 'static,
95        I: IntoIterator<Item = Next>,
96        I::IntoIter: Send + 'static,
97        Ctx: Clone,
98    {
99        SourceWithContext::from_source(self.delegate.map_concat(move |(out, ctx)| {
100            let ctx = ctx.clone();
101            f(out).into_iter().map(move |next| (next, ctx.clone()))
102        }))
103    }
104
105    pub fn map_async<Next, F, Fut>(
106        self,
107        parallelism: usize,
108        f: F,
109    ) -> SourceWithContext<Next, Ctx, Mat>
110    where
111        Next: Send + 'static,
112        F: Fn(Out) -> Fut + Send + Sync + 'static,
113        Fut: Future<Output = StreamResult<Next>> + Send + 'static,
114    {
115        SourceWithContext::from_source(self.delegate.map_async(parallelism, move |(out, ctx)| {
116            f(out).map(|next| next.map(|next| (next, ctx)))
117        }))
118    }
119
120    pub fn map_context<CtxOut, F>(self, f: F) -> SourceWithContext<Out, CtxOut, Mat>
121    where
122        CtxOut: Send + 'static,
123        F: Fn(Ctx) -> CtxOut + Send + Sync + 'static,
124    {
125        SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
126    }
127
128    pub fn grouped(self, size: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat> {
129        SourceWithContext::from_source(self.delegate.grouped(size).map(unzip_pairs))
130    }
131
132    pub fn sliding(self, size: usize, step: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat>
133    where
134        Out: Clone,
135        Ctx: Clone,
136    {
137        SourceWithContext::from_source(self.delegate.sliding(size, step).map(unzip_pairs))
138    }
139
140    pub fn via<Out2, Ctx2, FlowMat>(
141        self,
142        flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
143    ) -> SourceWithContext<Out2, Ctx2, Mat>
144    where
145        Out2: Send + 'static,
146        Ctx2: Send + 'static,
147        FlowMat: Send + 'static,
148    {
149        SourceWithContext::from_source(self.delegate.via(flow.delegate))
150    }
151
152    pub fn via_mat<Out2, Ctx2, FlowMat, Combined, F>(
153        self,
154        flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
155        combine: F,
156    ) -> SourceWithContext<Out2, Ctx2, Combined>
157    where
158        Out2: Send + 'static,
159        Ctx2: Send + 'static,
160        FlowMat: Send + 'static,
161        Combined: Send + 'static,
162        F: Fn(Mat, FlowMat) -> Combined + Send + Sync + 'static,
163    {
164        SourceWithContext::from_source(self.delegate.via_mat(flow.delegate, combine))
165    }
166
167    pub fn to<SinkMat>(self, sink: Sink<(Out, Ctx), SinkMat>) -> RunnableGraph<Mat>
168    where
169        SinkMat: Send + 'static,
170    {
171        self.delegate.to(sink)
172    }
173
174    pub fn to_mat<SinkMat, Combined, F>(
175        self,
176        sink: Sink<(Out, Ctx), SinkMat>,
177        combine: F,
178    ) -> RunnableGraph<Combined>
179    where
180        SinkMat: Send + 'static,
181        Combined: Send + 'static,
182        F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
183    {
184        self.delegate.to_mat(sink, combine)
185    }
186}
187
188impl<In: Send + 'static, CtxIn: Send + 'static> FlowWithContext<In, CtxIn, In, CtxIn, NotUsed> {
189    pub fn identity() -> Self {
190        FlowWithContext::from_flow(Flow::identity())
191    }
192}
193
194impl<
195    In: Send + 'static,
196    CtxIn: Send + 'static,
197    Out: Send + 'static,
198    CtxOut: Send + 'static,
199    Mat: Send + 'static,
200> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
201{
202    pub(crate) fn from_flow(
203        delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
204    ) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat> {
205        FlowWithContext { delegate }
206    }
207
208    pub fn as_flow(self) -> Flow<(In, CtxIn), (Out, CtxOut), Mat> {
209        self.delegate
210    }
211
212    pub fn map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
213    where
214        Next: Send + 'static,
215        F: Fn(Out) -> Next + Send + Sync + 'static,
216    {
217        FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
218    }
219
220    pub fn filter<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
221    where
222        F: Fn(&Out) -> bool + Send + Sync + 'static,
223    {
224        FlowWithContext::from_flow(self.delegate.filter(move |(out, _)| predicate(out)))
225    }
226
227    pub fn filter_not<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
228    where
229        F: Fn(&Out) -> bool + Send + Sync + 'static,
230    {
231        self.filter(move |out| !predicate(out))
232    }
233
234    pub fn filter_map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
235    where
236        Next: Send + 'static,
237        F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
238    {
239        FlowWithContext::from_flow(
240            self.delegate
241                .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
242        )
243    }
244
245    pub fn map_concat<Next, F, I>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
246    where
247        Next: Send + 'static,
248        F: Fn(Out) -> I + Send + Sync + 'static,
249        I: IntoIterator<Item = Next>,
250        I::IntoIter: Send + 'static,
251        CtxOut: Clone,
252    {
253        FlowWithContext::from_flow(self.delegate.map_concat(move |(out, ctx)| {
254            let ctx = ctx.clone();
255            f(out).into_iter().map(move |next| (next, ctx.clone()))
256        }))
257    }
258
259    pub fn map_async<Next, F, Fut>(
260        self,
261        parallelism: usize,
262        f: F,
263    ) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
264    where
265        Next: Send + 'static,
266        F: Fn(Out) -> Fut + Send + Sync + 'static,
267        Fut: Future<Output = StreamResult<Next>> + Send + 'static,
268    {
269        FlowWithContext::from_flow(self.delegate.map_async(parallelism, move |(out, ctx)| {
270            f(out).map(|next| next.map(|next| (next, ctx)))
271        }))
272    }
273
274    pub fn map_context<CtxOut2, F>(self, f: F) -> FlowWithContext<In, CtxIn, Out, CtxOut2, Mat>
275    where
276        CtxOut2: Send + 'static,
277        F: Fn(CtxOut) -> CtxOut2 + Send + Sync + 'static,
278    {
279        FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
280    }
281
282    pub fn grouped(self, n: usize) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat> {
283        FlowWithContext::from_flow(self.delegate.grouped(n).map(unzip_pairs))
284    }
285
286    pub fn sliding(
287        self,
288        n: usize,
289        step: usize,
290    ) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat>
291    where
292        Out: Clone,
293        CtxOut: Clone,
294    {
295        FlowWithContext::from_flow(self.delegate.sliding(n, step).map(unzip_pairs))
296    }
297
298    pub fn via<Out2, Ctx2, FlowMat>(
299        self,
300        flow: FlowWithContext<Out, CtxOut, Out2, Ctx2, FlowMat>,
301    ) -> FlowWithContext<In, CtxIn, Out2, Ctx2, Mat>
302    where
303        Out2: Send + 'static,
304        Ctx2: Send + 'static,
305        FlowMat: Send + 'static,
306    {
307        FlowWithContext::from_flow(self.delegate.via(flow.delegate))
308    }
309
310    pub fn to<SinkMat>(self, sink: Sink<(Out, CtxOut), SinkMat>) -> Sink<(In, CtxIn), Mat>
311    where
312        SinkMat: Send + 'static,
313    {
314        self.delegate.to(sink)
315    }
316
317    pub fn to_mat<SinkMat, Combined, F>(
318        self,
319        sink: Sink<(Out, CtxOut), SinkMat>,
320        combine: F,
321    ) -> Sink<(In, CtxIn), Combined>
322    where
323        SinkMat: Send + 'static,
324        Combined: Send + 'static,
325        F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
326    {
327        self.delegate.to_mat(sink, combine)
328    }
329}
330
331fn unzip_pairs<Out, Ctx>(pairs: Vec<(Out, Ctx)>) -> (Vec<Out>, Vec<Ctx>) {
332    let mut outs = Vec::with_capacity(pairs.len());
333    let mut ctxs = Vec::with_capacity(pairs.len());
334
335    for (out, ctx) in pairs {
336        outs.push(out);
337        ctxs.push(ctx);
338    }
339
340    (outs, ctxs)
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use std::{thread, time::Duration};
347
348    #[test]
349    fn source_with_context_preserves_context_for_map_and_filter() {
350        let values = Source::from_iter(0_i32..6)
351            .as_source_with_context(|item| item + 100)
352            .map(|item| item + 1)
353            .filter(|item| item % 2 == 0)
354            .filter_not(|item| *item == 4)
355            .run_collect()
356            .unwrap();
357
358        assert_eq!(values, vec![(2, 101), (6, 105)]);
359    }
360
361    #[test]
362    fn source_with_context_filters_context_with_map_filter() {
363        let values = Source::from_iter(1_i32..5)
364            .as_source_with_context(|item| item * 10)
365            .filter(|item| item % 2 == 1)
366            .run_collect()
367            .unwrap();
368
369        assert_eq!(values, vec![(1, 10), (3, 30)]);
370    }
371
372    #[test]
373    fn source_with_context_map_context_transform_is_supported() {
374        let values = Source::from_iter(1_i32..4)
375            .as_source_with_context(|item| *item)
376            .map_context(|ctx| ctx + 10)
377            .run_collect()
378            .unwrap();
379
380        assert_eq!(values, vec![(1, 11), (2, 12), (3, 13)]);
381    }
382
383    #[test]
384    fn source_with_context_map_concat_duplicates_context() {
385        let values = Source::from_iter([1_i32, 2])
386            .as_source_with_context(|item| item + 10)
387            .map_concat(|item| vec![item + 1, item + 2])
388            .run_collect()
389            .unwrap();
390
391        assert_eq!(values, vec![(2, 11), (3, 11), (3, 12), (4, 12)]);
392    }
393
394    #[test]
395    fn source_with_context_groups_context_vectors_with_grouped_and_sliding() {
396        let grouped = Source::from_iter(1_i32..4)
397            .as_source_with_context(|item| item + 10)
398            .grouped(2)
399            .run_collect()
400            .unwrap();
401
402        assert_eq!(
403            grouped,
404            vec![(vec![1, 2], vec![11, 12]), (vec![3], vec![13])]
405        );
406
407        let sliding = Source::from_iter([1_i32, 2, 3, 4])
408            .as_source_with_context(|item| item + 10)
409            .sliding(3, 2)
410            .run_collect()
411            .unwrap();
412
413        assert_eq!(
414            sliding,
415            vec![
416                (vec![1, 2, 3], vec![11, 12, 13]),
417                (vec![3, 4], vec![13, 14])
418            ]
419        );
420    }
421
422    #[test]
423    fn source_with_context_map_async_keeps_context_with_out_of_order_completions() {
424        let values = Source::from_iter([3_i32, 1, 2, 0])
425            .as_source_with_context(|item| item + 100)
426            .map_async(2, |item| async move {
427                if item % 2 == 0 {
428                    thread::sleep(Duration::from_millis(20));
429                } else {
430                    thread::sleep(Duration::from_millis(2));
431                }
432                Ok(item * 2)
433            })
434            .run_collect()
435            .unwrap();
436
437        assert_eq!(values, vec![(6, 103), (2, 101), (4, 102), (0, 100)]);
438    }
439
440    #[test]
441    fn source_with_context_filter_and_map_with_string_elements() {
442        let input = ["a".to_string(), "bravo".to_string(), "charlie".to_string()];
443
444        let source_values = Source::from_iter(input.to_vec())
445            .as_source_with_context(|item| format!("ctx-{item}"))
446            .filter(|item| item.len() >= 5)
447            .map(|item| format!("mapped:{item}"))
448            .run_collect()
449            .unwrap();
450
451        assert_eq!(
452            source_values,
453            vec![
454                ("mapped:bravo".to_string(), "ctx-bravo".to_string()),
455                ("mapped:charlie".to_string(), "ctx-charlie".to_string()),
456            ]
457        );
458
459        let flow_values = Source::from_iter(input)
460            .as_source_with_context(|item| format!("ctx-{item}"))
461            .via(
462                FlowWithContext::<String, String, String, String, NotUsed>::identity()
463                    .filter(|item| item.len() >= 5)
464                    .map(|item| format!("mapped:{item}")),
465            )
466            .run_collect()
467            .unwrap();
468
469        assert_eq!(flow_values, source_values);
470    }
471
472    #[test]
473    fn source_with_context_via_context_flow_and_to_mat() {
474        let flow = FlowWithContext::<i32, i32, i32, i32, NotUsed>::from_flow(
475            Flow::identity().map(|(value, ctx)| (value * 2, ctx * 2)),
476        );
477
478        let sink = Sink::collect();
479        let completion = Source::from_iter([1_i32, 2, 3])
480            .as_source_with_context(|item| item + 10)
481            .via(flow)
482            .to_mat(sink, |_, mat| mat)
483            .run()
484            .unwrap();
485
486        assert_eq!(completion.wait().unwrap(), vec![(2, 22), (4, 24), (6, 26)]);
487    }
488
489    #[test]
490    fn source_with_context_as_source_to_runs_and_to_mat_work() {
491        let source = Source::from_iter([1_i32, 2, 3]).as_source_with_context(|item| item + 10);
492
493        assert_eq!(source.clone().to(Sink::ignore()).run(), Ok(NotUsed));
494
495        let pair_sink = source
496            .to_mat(Sink::collect(), |_, pairs| pairs)
497            .run()
498            .unwrap();
499        assert_eq!(pair_sink.wait().unwrap(), vec![(1, 11), (2, 12), (3, 13)]);
500    }
501}