1use futures::future::FutureExt;
17use std::future::Future;
18
19use crate::stream::{Flow, NotUsed, RunnableGraph, Sink, Source, StreamResult};
20
21#[derive(Clone)]
22pub struct SourceWithContext<Out, Ctx, Mat = NotUsed> {
23 pub(crate) delegate: Source<(Out, Ctx), Mat>,
24}
25
26#[derive(Clone)]
27pub struct FlowWithContext<In, CtxIn, Out, CtxOut, Mat = NotUsed> {
28 pub(crate) delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
29}
30
31impl<Out: Send + 'static, Ctx: Send + 'static, Mat: Send + 'static>
32 SourceWithContext<Out, Ctx, Mat>
33{
34 pub(crate) fn from_source(delegate: Source<(Out, Ctx), Mat>) -> Self {
35 Self { delegate }
36 }
37
38 pub fn as_source(self) -> Source<(Out, Ctx), Mat> {
39 self.delegate
40 }
41
42 pub fn run_collect(self) -> StreamResult<Vec<(Out, Ctx)>> {
43 self.delegate.run_collect()
44 }
45
46 pub fn map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
47 where
48 Next: Send + 'static,
49 F: Fn(Out) -> Next + Send + Sync + 'static,
50 {
51 SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
52 }
53
54 pub fn filter<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
55 where
56 F: Fn(&Out) -> bool + Send + Sync + 'static,
57 {
58 SourceWithContext::from_source(self.delegate.filter_map(move |(out, ctx)| {
59 if predicate(&out) {
60 Some((out, ctx))
61 } else {
62 None
63 }
64 }))
65 }
66
67 pub fn filter_not<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
68 where
69 F: Fn(&Out) -> bool + Send + Sync + 'static,
70 {
71 self.filter(move |out| !predicate(out))
72 }
73
74 pub fn filter_map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
75 where
76 Next: Send + 'static,
77 F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
78 {
79 SourceWithContext::from_source(
80 self.delegate
81 .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
82 )
83 }
84
85 pub fn map_concat<Next, F, I>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
86 where
87 Next: Send + 'static,
88 F: Fn(Out) -> I + Send + Sync + 'static,
89 I: IntoIterator<Item = Next>,
90 I::IntoIter: Send + 'static,
91 Ctx: Clone,
92 {
93 SourceWithContext::from_source(self.delegate.map_concat(move |(out, ctx)| {
94 let ctx = ctx.clone();
95 f(out).into_iter().map(move |next| (next, ctx.clone()))
96 }))
97 }
98
99 pub fn map_async<Next, F, Fut>(
100 self,
101 parallelism: usize,
102 f: F,
103 ) -> SourceWithContext<Next, Ctx, Mat>
104 where
105 Next: Send + 'static,
106 F: Fn(Out) -> Fut + Send + Sync + 'static,
107 Fut: Future<Output = StreamResult<Next>> + Send + 'static,
108 {
109 SourceWithContext::from_source(self.delegate.map_async(parallelism, move |(out, ctx)| {
110 f(out).map(|next| next.map(|next| (next, ctx)))
111 }))
112 }
113
114 pub fn map_context<CtxOut, F>(self, f: F) -> SourceWithContext<Out, CtxOut, Mat>
115 where
116 CtxOut: Send + 'static,
117 F: Fn(Ctx) -> CtxOut + Send + Sync + 'static,
118 {
119 SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
120 }
121
122 pub fn grouped(self, size: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat> {
123 SourceWithContext::from_source(self.delegate.grouped(size).map(unzip_pairs))
124 }
125
126 pub fn sliding(self, size: usize, step: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat>
127 where
128 Out: Clone,
129 Ctx: Clone,
130 {
131 SourceWithContext::from_source(self.delegate.sliding(size, step).map(unzip_pairs))
132 }
133
134 pub fn via<Out2, Ctx2, FlowMat>(
135 self,
136 flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
137 ) -> SourceWithContext<Out2, Ctx2, Mat>
138 where
139 Out2: Send + 'static,
140 Ctx2: Send + 'static,
141 FlowMat: Send + 'static,
142 {
143 SourceWithContext::from_source(self.delegate.via(flow.delegate))
144 }
145
146 pub fn via_mat<Out2, Ctx2, FlowMat, Combined, F>(
147 self,
148 flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
149 combine: F,
150 ) -> SourceWithContext<Out2, Ctx2, Combined>
151 where
152 Out2: Send + 'static,
153 Ctx2: Send + 'static,
154 FlowMat: Send + 'static,
155 Combined: Send + 'static,
156 F: Fn(Mat, FlowMat) -> Combined + Send + Sync + 'static,
157 {
158 SourceWithContext::from_source(self.delegate.via_mat(flow.delegate, combine))
159 }
160
161 pub fn to<SinkMat>(self, sink: Sink<(Out, Ctx), SinkMat>) -> RunnableGraph<Mat>
162 where
163 SinkMat: Send + 'static,
164 {
165 self.delegate.to(sink)
166 }
167
168 pub fn to_mat<SinkMat, Combined, F>(
169 self,
170 sink: Sink<(Out, Ctx), SinkMat>,
171 combine: F,
172 ) -> RunnableGraph<Combined>
173 where
174 SinkMat: Send + 'static,
175 Combined: Send + 'static,
176 F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
177 {
178 self.delegate.to_mat(sink, combine)
179 }
180}
181
182impl<In: Send + 'static, CtxIn: Send + 'static> FlowWithContext<In, CtxIn, In, CtxIn, NotUsed> {
183 pub fn identity() -> Self {
184 FlowWithContext::from_flow(Flow::identity())
185 }
186}
187
188impl<
189 In: Send + 'static,
190 CtxIn: Send + 'static,
191 Out: Send + 'static,
192 CtxOut: Send + 'static,
193 Mat: Send + 'static,
194> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
195{
196 pub(crate) fn from_flow(
197 delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
198 ) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat> {
199 FlowWithContext { delegate }
200 }
201
202 pub fn as_flow(self) -> Flow<(In, CtxIn), (Out, CtxOut), Mat> {
203 self.delegate
204 }
205
206 pub fn map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
207 where
208 Next: Send + 'static,
209 F: Fn(Out) -> Next + Send + Sync + 'static,
210 {
211 FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
212 }
213
214 pub fn filter<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
215 where
216 F: Fn(&Out) -> bool + Send + Sync + 'static,
217 {
218 FlowWithContext::from_flow(self.delegate.filter(move |(out, _)| predicate(out)))
219 }
220
221 pub fn filter_not<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
222 where
223 F: Fn(&Out) -> bool + Send + Sync + 'static,
224 {
225 self.filter(move |out| !predicate(out))
226 }
227
228 pub fn filter_map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
229 where
230 Next: Send + 'static,
231 F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
232 {
233 FlowWithContext::from_flow(
234 self.delegate
235 .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
236 )
237 }
238
239 pub fn map_concat<Next, F, I>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
240 where
241 Next: Send + 'static,
242 F: Fn(Out) -> I + Send + Sync + 'static,
243 I: IntoIterator<Item = Next>,
244 I::IntoIter: Send + 'static,
245 CtxOut: Clone,
246 {
247 FlowWithContext::from_flow(self.delegate.map_concat(move |(out, ctx)| {
248 let ctx = ctx.clone();
249 f(out).into_iter().map(move |next| (next, ctx.clone()))
250 }))
251 }
252
253 pub fn map_async<Next, F, Fut>(
254 self,
255 parallelism: usize,
256 f: F,
257 ) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
258 where
259 Next: Send + 'static,
260 F: Fn(Out) -> Fut + Send + Sync + 'static,
261 Fut: Future<Output = StreamResult<Next>> + Send + 'static,
262 {
263 FlowWithContext::from_flow(self.delegate.map_async(parallelism, move |(out, ctx)| {
264 f(out).map(|next| next.map(|next| (next, ctx)))
265 }))
266 }
267
268 pub fn map_context<CtxOut2, F>(self, f: F) -> FlowWithContext<In, CtxIn, Out, CtxOut2, Mat>
269 where
270 CtxOut2: Send + 'static,
271 F: Fn(CtxOut) -> CtxOut2 + Send + Sync + 'static,
272 {
273 FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
274 }
275
276 pub fn grouped(self, n: usize) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat> {
277 FlowWithContext::from_flow(self.delegate.grouped(n).map(unzip_pairs))
278 }
279
280 pub fn sliding(
281 self,
282 n: usize,
283 step: usize,
284 ) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat>
285 where
286 Out: Clone,
287 CtxOut: Clone,
288 {
289 FlowWithContext::from_flow(self.delegate.sliding(n, step).map(unzip_pairs))
290 }
291
292 pub fn via<Out2, Ctx2, FlowMat>(
293 self,
294 flow: FlowWithContext<Out, CtxOut, Out2, Ctx2, FlowMat>,
295 ) -> FlowWithContext<In, CtxIn, Out2, Ctx2, Mat>
296 where
297 Out2: Send + 'static,
298 Ctx2: Send + 'static,
299 FlowMat: Send + 'static,
300 {
301 FlowWithContext::from_flow(self.delegate.via(flow.delegate))
302 }
303
304 pub fn to<SinkMat>(self, sink: Sink<(Out, CtxOut), SinkMat>) -> Sink<(In, CtxIn), Mat>
305 where
306 SinkMat: Send + 'static,
307 {
308 self.delegate.to(sink)
309 }
310
311 pub fn to_mat<SinkMat, Combined, F>(
312 self,
313 sink: Sink<(Out, CtxOut), SinkMat>,
314 combine: F,
315 ) -> Sink<(In, CtxIn), Combined>
316 where
317 SinkMat: Send + 'static,
318 Combined: Send + 'static,
319 F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
320 {
321 self.delegate.to_mat(sink, combine)
322 }
323}
324
325fn unzip_pairs<Out, Ctx>(pairs: Vec<(Out, Ctx)>) -> (Vec<Out>, Vec<Ctx>) {
326 let mut outs = Vec::with_capacity(pairs.len());
327 let mut ctxs = Vec::with_capacity(pairs.len());
328
329 for (out, ctx) in pairs {
330 outs.push(out);
331 ctxs.push(ctx);
332 }
333
334 (outs, ctxs)
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use std::{thread, time::Duration};
341
342 #[test]
343 fn source_with_context_preserves_context_for_map_and_filter() {
344 let values = Source::from_iter(0_i32..6)
345 .as_source_with_context(|item| item + 100)
346 .map(|item| item + 1)
347 .filter(|item| item % 2 == 0)
348 .filter_not(|item| *item == 4)
349 .run_collect()
350 .unwrap();
351
352 assert_eq!(values, vec![(2, 101), (6, 105)]);
353 }
354
355 #[test]
356 fn source_with_context_filters_context_with_map_filter() {
357 let values = Source::from_iter(1_i32..5)
358 .as_source_with_context(|item| item * 10)
359 .filter(|item| item % 2 == 1)
360 .run_collect()
361 .unwrap();
362
363 assert_eq!(values, vec![(1, 10), (3, 30)]);
364 }
365
366 #[test]
367 fn source_with_context_map_context_transform_is_supported() {
368 let values = Source::from_iter(1_i32..4)
369 .as_source_with_context(|item| *item)
370 .map_context(|ctx| ctx + 10)
371 .run_collect()
372 .unwrap();
373
374 assert_eq!(values, vec![(1, 11), (2, 12), (3, 13)]);
375 }
376
377 #[test]
378 fn source_with_context_map_concat_duplicates_context() {
379 let values = Source::from_iter([1_i32, 2])
380 .as_source_with_context(|item| item + 10)
381 .map_concat(|item| vec![item + 1, item + 2])
382 .run_collect()
383 .unwrap();
384
385 assert_eq!(values, vec![(2, 11), (3, 11), (3, 12), (4, 12)]);
386 }
387
388 #[test]
389 fn source_with_context_groups_context_vectors_with_grouped_and_sliding() {
390 let grouped = Source::from_iter(1_i32..4)
391 .as_source_with_context(|item| item + 10)
392 .grouped(2)
393 .run_collect()
394 .unwrap();
395
396 assert_eq!(
397 grouped,
398 vec![(vec![1, 2], vec![11, 12]), (vec![3], vec![13])]
399 );
400
401 let sliding = Source::from_iter([1_i32, 2, 3, 4])
402 .as_source_with_context(|item| item + 10)
403 .sliding(3, 2)
404 .run_collect()
405 .unwrap();
406
407 assert_eq!(
408 sliding,
409 vec![
410 (vec![1, 2, 3], vec![11, 12, 13]),
411 (vec![3, 4], vec![13, 14])
412 ]
413 );
414 }
415
416 #[test]
417 fn source_with_context_map_async_keeps_context_with_out_of_order_completions() {
418 let values = Source::from_iter([3_i32, 1, 2, 0])
419 .as_source_with_context(|item| item + 100)
420 .map_async(2, |item| async move {
421 if item % 2 == 0 {
422 thread::sleep(Duration::from_millis(20));
423 } else {
424 thread::sleep(Duration::from_millis(2));
425 }
426 Ok(item * 2)
427 })
428 .run_collect()
429 .unwrap();
430
431 assert_eq!(values, vec![(6, 103), (2, 101), (4, 102), (0, 100)]);
432 }
433
434 #[test]
435 fn source_with_context_filter_and_map_with_string_elements() {
436 let input = ["a".to_string(), "bravo".to_string(), "charlie".to_string()];
437
438 let source_values = Source::from_iter(input.to_vec())
439 .as_source_with_context(|item| format!("ctx-{item}"))
440 .filter(|item| item.len() >= 5)
441 .map(|item| format!("mapped:{item}"))
442 .run_collect()
443 .unwrap();
444
445 assert_eq!(
446 source_values,
447 vec![
448 ("mapped:bravo".to_string(), "ctx-bravo".to_string()),
449 ("mapped:charlie".to_string(), "ctx-charlie".to_string()),
450 ]
451 );
452
453 let flow_values = Source::from_iter(input)
454 .as_source_with_context(|item| format!("ctx-{item}"))
455 .via(
456 FlowWithContext::<String, String, String, String, NotUsed>::identity()
457 .filter(|item| item.len() >= 5)
458 .map(|item| format!("mapped:{item}")),
459 )
460 .run_collect()
461 .unwrap();
462
463 assert_eq!(flow_values, source_values);
464 }
465
466 #[test]
467 fn source_with_context_via_context_flow_and_to_mat() {
468 let flow = FlowWithContext::<i32, i32, i32, i32, NotUsed>::from_flow(
469 Flow::identity().map(|(value, ctx)| (value * 2, ctx * 2)),
470 );
471
472 let sink = Sink::collect();
473 let completion = Source::from_iter([1_i32, 2, 3])
474 .as_source_with_context(|item| item + 10)
475 .via(flow)
476 .to_mat(sink, |_, mat| mat)
477 .run()
478 .unwrap();
479
480 assert_eq!(completion.wait().unwrap(), vec![(2, 22), (4, 24), (6, 26)]);
481 }
482
483 #[test]
484 fn source_with_context_as_source_to_runs_and_to_mat_work() {
485 let source = Source::from_iter([1_i32, 2, 3]).as_source_with_context(|item| item + 10);
486
487 assert_eq!(source.clone().to(Sink::ignore()).run(), Ok(NotUsed));
488
489 let pair_sink = source
490 .to_mat(Sink::collect(), |_, pairs| pairs)
491 .run()
492 .unwrap();
493 assert_eq!(pair_sink.wait().unwrap(), vec![(1, 11), (2, 12), (3, 13)]);
494 }
495}