Skip to main content

datum/
context.rs

1//! Context-aware stream wrappers.
2//!
3//! `SourceWithContext` and `FlowWithContext` are thin wrappers around regular pair
4//! streams that keep an explicit context element adjacent to each data element.
5//!
6//! The operator set mirrors Akka's `SourceWithContext` / `FlowWithContext`: only
7//! one-to-one operators are exposed by design so `(data, context)` alignment cannot be
8//! broken by this API.
9//!
10//! This is intentional. Reordering, arbitrarily dropping, or expanding to
11//! variable-length sequences can detach context from its intended element, which can
12//! silently break downstream sinks that rely on commit/trace/offset semantics. Keeping
13//! the API wrapper-centric preserves the one-to-one contract by construction while
14//! still enabling downstream composition.
15
16use futures::future::FutureExt;
17use std::future::Future;
18
19use crate::stream::{Flow, NotUsed, RunnableGraph, Sink, Source, StreamResult};
20
21#[derive(Clone)]
22pub struct SourceWithContext<Out, Ctx, Mat = NotUsed> {
23    pub(crate) delegate: Source<(Out, Ctx), Mat>,
24}
25
26#[derive(Clone)]
27pub struct FlowWithContext<In, CtxIn, Out, CtxOut, Mat = NotUsed> {
28    pub(crate) delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
29}
30
31impl<Out: Send + 'static, Ctx: Send + 'static, Mat: Send + 'static>
32    SourceWithContext<Out, Ctx, Mat>
33{
34    pub(crate) fn from_source(delegate: Source<(Out, Ctx), Mat>) -> Self {
35        Self { delegate }
36    }
37
38    pub fn as_source(self) -> Source<(Out, Ctx), Mat> {
39        self.delegate
40    }
41
42    pub fn run_collect(self) -> StreamResult<Vec<(Out, Ctx)>> {
43        self.delegate.run_collect()
44    }
45
46    pub fn map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
47    where
48        Next: Send + 'static,
49        F: Fn(Out) -> Next + Send + Sync + 'static,
50    {
51        SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
52    }
53
54    pub fn filter<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
55    where
56        F: Fn(&Out) -> bool + Send + Sync + 'static,
57    {
58        SourceWithContext::from_source(self.delegate.filter_map(move |(out, ctx)| {
59            if predicate(&out) {
60                Some((out, ctx))
61            } else {
62                None
63            }
64        }))
65    }
66
67    pub fn filter_not<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
68    where
69        F: Fn(&Out) -> bool + Send + Sync + 'static,
70    {
71        self.filter(move |out| !predicate(out))
72    }
73
74    pub fn filter_map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
75    where
76        Next: Send + 'static,
77        F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
78    {
79        SourceWithContext::from_source(
80            self.delegate
81                .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
82        )
83    }
84
85    pub fn map_concat<Next, F, I>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
86    where
87        Next: Send + 'static,
88        F: Fn(Out) -> I + Send + Sync + 'static,
89        I: IntoIterator<Item = Next>,
90        I::IntoIter: Send + 'static,
91        Ctx: Clone,
92    {
93        SourceWithContext::from_source(self.delegate.map_concat(move |(out, ctx)| {
94            let ctx = ctx.clone();
95            f(out).into_iter().map(move |next| (next, ctx.clone()))
96        }))
97    }
98
99    pub fn map_async<Next, F, Fut>(
100        self,
101        parallelism: usize,
102        f: F,
103    ) -> SourceWithContext<Next, Ctx, Mat>
104    where
105        Next: Send + 'static,
106        F: Fn(Out) -> Fut + Send + Sync + 'static,
107        Fut: Future<Output = StreamResult<Next>> + Send + 'static,
108    {
109        SourceWithContext::from_source(self.delegate.map_async(parallelism, move |(out, ctx)| {
110            f(out).map(|next| next.map(|next| (next, ctx)))
111        }))
112    }
113
114    pub fn map_context<CtxOut, F>(self, f: F) -> SourceWithContext<Out, CtxOut, Mat>
115    where
116        CtxOut: Send + 'static,
117        F: Fn(Ctx) -> CtxOut + Send + Sync + 'static,
118    {
119        SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
120    }
121
122    pub fn grouped(self, size: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat> {
123        SourceWithContext::from_source(self.delegate.grouped(size).map(unzip_pairs))
124    }
125
126    pub fn sliding(self, size: usize, step: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat>
127    where
128        Out: Clone,
129        Ctx: Clone,
130    {
131        SourceWithContext::from_source(self.delegate.sliding(size, step).map(unzip_pairs))
132    }
133
134    pub fn via<Out2, Ctx2, FlowMat>(
135        self,
136        flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
137    ) -> SourceWithContext<Out2, Ctx2, Mat>
138    where
139        Out2: Send + 'static,
140        Ctx2: Send + 'static,
141        FlowMat: Send + 'static,
142    {
143        SourceWithContext::from_source(self.delegate.via(flow.delegate))
144    }
145
146    pub fn via_mat<Out2, Ctx2, FlowMat, Combined, F>(
147        self,
148        flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
149        combine: F,
150    ) -> SourceWithContext<Out2, Ctx2, Combined>
151    where
152        Out2: Send + 'static,
153        Ctx2: Send + 'static,
154        FlowMat: Send + 'static,
155        Combined: Send + 'static,
156        F: Fn(Mat, FlowMat) -> Combined + Send + Sync + 'static,
157    {
158        SourceWithContext::from_source(self.delegate.via_mat(flow.delegate, combine))
159    }
160
161    pub fn to<SinkMat>(self, sink: Sink<(Out, Ctx), SinkMat>) -> RunnableGraph<Mat>
162    where
163        SinkMat: Send + 'static,
164    {
165        self.delegate.to(sink)
166    }
167
168    pub fn to_mat<SinkMat, Combined, F>(
169        self,
170        sink: Sink<(Out, Ctx), SinkMat>,
171        combine: F,
172    ) -> RunnableGraph<Combined>
173    where
174        SinkMat: Send + 'static,
175        Combined: Send + 'static,
176        F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
177    {
178        self.delegate.to_mat(sink, combine)
179    }
180}
181
182impl<In: Send + 'static, CtxIn: Send + 'static> FlowWithContext<In, CtxIn, In, CtxIn, NotUsed> {
183    pub fn identity() -> Self {
184        FlowWithContext::from_flow(Flow::identity())
185    }
186}
187
188impl<
189    In: Send + 'static,
190    CtxIn: Send + 'static,
191    Out: Send + 'static,
192    CtxOut: Send + 'static,
193    Mat: Send + 'static,
194> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
195{
196    pub(crate) fn from_flow(
197        delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
198    ) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat> {
199        FlowWithContext { delegate }
200    }
201
202    pub fn as_flow(self) -> Flow<(In, CtxIn), (Out, CtxOut), Mat> {
203        self.delegate
204    }
205
206    pub fn map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
207    where
208        Next: Send + 'static,
209        F: Fn(Out) -> Next + Send + Sync + 'static,
210    {
211        FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
212    }
213
214    pub fn filter<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
215    where
216        F: Fn(&Out) -> bool + Send + Sync + 'static,
217    {
218        FlowWithContext::from_flow(self.delegate.filter(move |(out, _)| predicate(out)))
219    }
220
221    pub fn filter_not<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
222    where
223        F: Fn(&Out) -> bool + Send + Sync + 'static,
224    {
225        self.filter(move |out| !predicate(out))
226    }
227
228    pub fn filter_map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
229    where
230        Next: Send + 'static,
231        F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
232    {
233        FlowWithContext::from_flow(
234            self.delegate
235                .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
236        )
237    }
238
239    pub fn map_concat<Next, F, I>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
240    where
241        Next: Send + 'static,
242        F: Fn(Out) -> I + Send + Sync + 'static,
243        I: IntoIterator<Item = Next>,
244        I::IntoIter: Send + 'static,
245        CtxOut: Clone,
246    {
247        FlowWithContext::from_flow(self.delegate.map_concat(move |(out, ctx)| {
248            let ctx = ctx.clone();
249            f(out).into_iter().map(move |next| (next, ctx.clone()))
250        }))
251    }
252
253    pub fn map_async<Next, F, Fut>(
254        self,
255        parallelism: usize,
256        f: F,
257    ) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
258    where
259        Next: Send + 'static,
260        F: Fn(Out) -> Fut + Send + Sync + 'static,
261        Fut: Future<Output = StreamResult<Next>> + Send + 'static,
262    {
263        FlowWithContext::from_flow(self.delegate.map_async(parallelism, move |(out, ctx)| {
264            f(out).map(|next| next.map(|next| (next, ctx)))
265        }))
266    }
267
268    pub fn map_context<CtxOut2, F>(self, f: F) -> FlowWithContext<In, CtxIn, Out, CtxOut2, Mat>
269    where
270        CtxOut2: Send + 'static,
271        F: Fn(CtxOut) -> CtxOut2 + Send + Sync + 'static,
272    {
273        FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
274    }
275
276    pub fn grouped(self, n: usize) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat> {
277        FlowWithContext::from_flow(self.delegate.grouped(n).map(unzip_pairs))
278    }
279
280    pub fn sliding(
281        self,
282        n: usize,
283        step: usize,
284    ) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat>
285    where
286        Out: Clone,
287        CtxOut: Clone,
288    {
289        FlowWithContext::from_flow(self.delegate.sliding(n, step).map(unzip_pairs))
290    }
291
292    pub fn via<Out2, Ctx2, FlowMat>(
293        self,
294        flow: FlowWithContext<Out, CtxOut, Out2, Ctx2, FlowMat>,
295    ) -> FlowWithContext<In, CtxIn, Out2, Ctx2, Mat>
296    where
297        Out2: Send + 'static,
298        Ctx2: Send + 'static,
299        FlowMat: Send + 'static,
300    {
301        FlowWithContext::from_flow(self.delegate.via(flow.delegate))
302    }
303
304    pub fn to<SinkMat>(self, sink: Sink<(Out, CtxOut), SinkMat>) -> Sink<(In, CtxIn), Mat>
305    where
306        SinkMat: Send + 'static,
307    {
308        self.delegate.to(sink)
309    }
310
311    pub fn to_mat<SinkMat, Combined, F>(
312        self,
313        sink: Sink<(Out, CtxOut), SinkMat>,
314        combine: F,
315    ) -> Sink<(In, CtxIn), Combined>
316    where
317        SinkMat: Send + 'static,
318        Combined: Send + 'static,
319        F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
320    {
321        self.delegate.to_mat(sink, combine)
322    }
323}
324
325fn unzip_pairs<Out, Ctx>(pairs: Vec<(Out, Ctx)>) -> (Vec<Out>, Vec<Ctx>) {
326    let mut outs = Vec::with_capacity(pairs.len());
327    let mut ctxs = Vec::with_capacity(pairs.len());
328
329    for (out, ctx) in pairs {
330        outs.push(out);
331        ctxs.push(ctx);
332    }
333
334    (outs, ctxs)
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use std::{thread, time::Duration};
341
342    #[test]
343    fn source_with_context_preserves_context_for_map_and_filter() {
344        let values = Source::from_iter(0_i32..6)
345            .as_source_with_context(|item| item + 100)
346            .map(|item| item + 1)
347            .filter(|item| item % 2 == 0)
348            .filter_not(|item| *item == 4)
349            .run_collect()
350            .unwrap();
351
352        assert_eq!(values, vec![(2, 101), (6, 105)]);
353    }
354
355    #[test]
356    fn source_with_context_filters_context_with_map_filter() {
357        let values = Source::from_iter(1_i32..5)
358            .as_source_with_context(|item| item * 10)
359            .filter(|item| item % 2 == 1)
360            .run_collect()
361            .unwrap();
362
363        assert_eq!(values, vec![(1, 10), (3, 30)]);
364    }
365
366    #[test]
367    fn source_with_context_map_context_transform_is_supported() {
368        let values = Source::from_iter(1_i32..4)
369            .as_source_with_context(|item| *item)
370            .map_context(|ctx| ctx + 10)
371            .run_collect()
372            .unwrap();
373
374        assert_eq!(values, vec![(1, 11), (2, 12), (3, 13)]);
375    }
376
377    #[test]
378    fn source_with_context_map_concat_duplicates_context() {
379        let values = Source::from_iter([1_i32, 2])
380            .as_source_with_context(|item| item + 10)
381            .map_concat(|item| vec![item + 1, item + 2])
382            .run_collect()
383            .unwrap();
384
385        assert_eq!(values, vec![(2, 11), (3, 11), (3, 12), (4, 12)]);
386    }
387
388    #[test]
389    fn source_with_context_groups_context_vectors_with_grouped_and_sliding() {
390        let grouped = Source::from_iter(1_i32..4)
391            .as_source_with_context(|item| item + 10)
392            .grouped(2)
393            .run_collect()
394            .unwrap();
395
396        assert_eq!(
397            grouped,
398            vec![(vec![1, 2], vec![11, 12]), (vec![3], vec![13])]
399        );
400
401        let sliding = Source::from_iter([1_i32, 2, 3, 4])
402            .as_source_with_context(|item| item + 10)
403            .sliding(3, 2)
404            .run_collect()
405            .unwrap();
406
407        assert_eq!(
408            sliding,
409            vec![
410                (vec![1, 2, 3], vec![11, 12, 13]),
411                (vec![3, 4], vec![13, 14])
412            ]
413        );
414    }
415
416    #[test]
417    fn source_with_context_map_async_keeps_context_with_out_of_order_completions() {
418        let values = Source::from_iter([3_i32, 1, 2, 0])
419            .as_source_with_context(|item| item + 100)
420            .map_async(2, |item| async move {
421                if item % 2 == 0 {
422                    thread::sleep(Duration::from_millis(20));
423                } else {
424                    thread::sleep(Duration::from_millis(2));
425                }
426                Ok(item * 2)
427            })
428            .run_collect()
429            .unwrap();
430
431        assert_eq!(values, vec![(6, 103), (2, 101), (4, 102), (0, 100)]);
432    }
433
434    #[test]
435    fn source_with_context_filter_and_map_with_string_elements() {
436        let input = ["a".to_string(), "bravo".to_string(), "charlie".to_string()];
437
438        let source_values = Source::from_iter(input.to_vec())
439            .as_source_with_context(|item| format!("ctx-{item}"))
440            .filter(|item| item.len() >= 5)
441            .map(|item| format!("mapped:{item}"))
442            .run_collect()
443            .unwrap();
444
445        assert_eq!(
446            source_values,
447            vec![
448                ("mapped:bravo".to_string(), "ctx-bravo".to_string()),
449                ("mapped:charlie".to_string(), "ctx-charlie".to_string()),
450            ]
451        );
452
453        let flow_values = Source::from_iter(input)
454            .as_source_with_context(|item| format!("ctx-{item}"))
455            .via(
456                FlowWithContext::<String, String, String, String, NotUsed>::identity()
457                    .filter(|item| item.len() >= 5)
458                    .map(|item| format!("mapped:{item}")),
459            )
460            .run_collect()
461            .unwrap();
462
463        assert_eq!(flow_values, source_values);
464    }
465
466    #[test]
467    fn source_with_context_via_context_flow_and_to_mat() {
468        let flow = FlowWithContext::<i32, i32, i32, i32, NotUsed>::from_flow(
469            Flow::identity().map(|(value, ctx)| (value * 2, ctx * 2)),
470        );
471
472        let sink = Sink::collect();
473        let completion = Source::from_iter([1_i32, 2, 3])
474            .as_source_with_context(|item| item + 10)
475            .via(flow)
476            .to_mat(sink, |_, mat| mat)
477            .run()
478            .unwrap();
479
480        assert_eq!(completion.wait().unwrap(), vec![(2, 22), (4, 24), (6, 26)]);
481    }
482
483    #[test]
484    fn source_with_context_as_source_to_runs_and_to_mat_work() {
485        let source = Source::from_iter([1_i32, 2, 3]).as_source_with_context(|item| item + 10);
486
487        assert_eq!(source.clone().to(Sink::ignore()).run(), Ok(NotUsed));
488
489        let pair_sink = source
490            .to_mat(Sink::collect(), |_, pairs| pairs)
491            .run()
492            .unwrap();
493        assert_eq!(pair_sink.wait().unwrap(), vec![(1, 11), (2, 12), (3, 13)]);
494    }
495}