1use futures::future::FutureExt;
17use std::future::Future;
18
19use crate::stream::{Flow, NotUsed, RunnableGraph, Sink, Source, StreamResult};
20
21#[derive(Clone)]
25pub struct SourceWithContext<Out, Ctx, Mat = NotUsed> {
26 pub(crate) delegate: Source<(Out, Ctx), Mat>,
27}
28
29#[derive(Clone)]
33pub struct FlowWithContext<In, CtxIn, Out, CtxOut, Mat = NotUsed> {
34 pub(crate) delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
35}
36
37impl<Out: Send + 'static, Ctx: Send + 'static, Mat: Send + 'static>
38 SourceWithContext<Out, Ctx, Mat>
39{
40 pub(crate) fn from_source(delegate: Source<(Out, Ctx), Mat>) -> Self {
41 Self { delegate }
42 }
43
44 pub fn as_source(self) -> Source<(Out, Ctx), Mat> {
45 self.delegate
46 }
47
48 pub fn run_collect(self) -> StreamResult<Vec<(Out, Ctx)>> {
49 self.delegate.run_collect()
50 }
51
52 pub fn map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
53 where
54 Next: Send + 'static,
55 F: Fn(Out) -> Next + Send + Sync + 'static,
56 {
57 SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
58 }
59
60 pub fn filter<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
61 where
62 F: Fn(&Out) -> bool + Send + Sync + 'static,
63 {
64 SourceWithContext::from_source(self.delegate.filter_map(move |(out, ctx)| {
65 if predicate(&out) {
66 Some((out, ctx))
67 } else {
68 None
69 }
70 }))
71 }
72
73 pub fn filter_not<F>(self, predicate: F) -> SourceWithContext<Out, Ctx, Mat>
74 where
75 F: Fn(&Out) -> bool + Send + Sync + 'static,
76 {
77 self.filter(move |out| !predicate(out))
78 }
79
80 pub fn filter_map<Next, F>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
81 where
82 Next: Send + 'static,
83 F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
84 {
85 SourceWithContext::from_source(
86 self.delegate
87 .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
88 )
89 }
90
91 pub fn map_concat<Next, F, I>(self, f: F) -> SourceWithContext<Next, Ctx, Mat>
92 where
93 Next: Send + 'static,
94 F: Fn(Out) -> I + Send + Sync + 'static,
95 I: IntoIterator<Item = Next>,
96 I::IntoIter: Send + 'static,
97 Ctx: Clone,
98 {
99 SourceWithContext::from_source(self.delegate.map_concat(move |(out, ctx)| {
100 let ctx = ctx.clone();
101 f(out).into_iter().map(move |next| (next, ctx.clone()))
102 }))
103 }
104
105 pub fn map_async<Next, F, Fut>(
106 self,
107 parallelism: usize,
108 f: F,
109 ) -> SourceWithContext<Next, Ctx, Mat>
110 where
111 Next: Send + 'static,
112 F: Fn(Out) -> Fut + Send + Sync + 'static,
113 Fut: Future<Output = StreamResult<Next>> + Send + 'static,
114 {
115 SourceWithContext::from_source(self.delegate.map_async(parallelism, move |(out, ctx)| {
116 f(out).map(|next| next.map(|next| (next, ctx)))
117 }))
118 }
119
120 pub fn map_context<CtxOut, F>(self, f: F) -> SourceWithContext<Out, CtxOut, Mat>
121 where
122 CtxOut: Send + 'static,
123 F: Fn(Ctx) -> CtxOut + Send + Sync + 'static,
124 {
125 SourceWithContext::from_source(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
126 }
127
128 pub fn grouped(self, size: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat> {
129 SourceWithContext::from_source(self.delegate.grouped(size).map(unzip_pairs))
130 }
131
132 pub fn sliding(self, size: usize, step: usize) -> SourceWithContext<Vec<Out>, Vec<Ctx>, Mat>
133 where
134 Out: Clone,
135 Ctx: Clone,
136 {
137 SourceWithContext::from_source(self.delegate.sliding(size, step).map(unzip_pairs))
138 }
139
140 pub fn via<Out2, Ctx2, FlowMat>(
141 self,
142 flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
143 ) -> SourceWithContext<Out2, Ctx2, Mat>
144 where
145 Out2: Send + 'static,
146 Ctx2: Send + 'static,
147 FlowMat: Send + 'static,
148 {
149 SourceWithContext::from_source(self.delegate.via(flow.delegate))
150 }
151
152 pub fn via_mat<Out2, Ctx2, FlowMat, Combined, F>(
153 self,
154 flow: FlowWithContext<Out, Ctx, Out2, Ctx2, FlowMat>,
155 combine: F,
156 ) -> SourceWithContext<Out2, Ctx2, Combined>
157 where
158 Out2: Send + 'static,
159 Ctx2: Send + 'static,
160 FlowMat: Send + 'static,
161 Combined: Send + 'static,
162 F: Fn(Mat, FlowMat) -> Combined + Send + Sync + 'static,
163 {
164 SourceWithContext::from_source(self.delegate.via_mat(flow.delegate, combine))
165 }
166
167 pub fn to<SinkMat>(self, sink: Sink<(Out, Ctx), SinkMat>) -> RunnableGraph<Mat>
168 where
169 SinkMat: Send + 'static,
170 {
171 self.delegate.to(sink)
172 }
173
174 pub fn to_mat<SinkMat, Combined, F>(
175 self,
176 sink: Sink<(Out, Ctx), SinkMat>,
177 combine: F,
178 ) -> RunnableGraph<Combined>
179 where
180 SinkMat: Send + 'static,
181 Combined: Send + 'static,
182 F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
183 {
184 self.delegate.to_mat(sink, combine)
185 }
186}
187
188impl<In: Send + 'static, CtxIn: Send + 'static> FlowWithContext<In, CtxIn, In, CtxIn, NotUsed> {
189 pub fn identity() -> Self {
190 FlowWithContext::from_flow(Flow::identity())
191 }
192}
193
194impl<
195 In: Send + 'static,
196 CtxIn: Send + 'static,
197 Out: Send + 'static,
198 CtxOut: Send + 'static,
199 Mat: Send + 'static,
200> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
201{
202 pub(crate) fn from_flow(
203 delegate: Flow<(In, CtxIn), (Out, CtxOut), Mat>,
204 ) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat> {
205 FlowWithContext { delegate }
206 }
207
208 pub fn as_flow(self) -> Flow<(In, CtxIn), (Out, CtxOut), Mat> {
209 self.delegate
210 }
211
212 pub fn map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
213 where
214 Next: Send + 'static,
215 F: Fn(Out) -> Next + Send + Sync + 'static,
216 {
217 FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (f(out), ctx)))
218 }
219
220 pub fn filter<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
221 where
222 F: Fn(&Out) -> bool + Send + Sync + 'static,
223 {
224 FlowWithContext::from_flow(self.delegate.filter(move |(out, _)| predicate(out)))
225 }
226
227 pub fn filter_not<F>(self, predicate: F) -> FlowWithContext<In, CtxIn, Out, CtxOut, Mat>
228 where
229 F: Fn(&Out) -> bool + Send + Sync + 'static,
230 {
231 self.filter(move |out| !predicate(out))
232 }
233
234 pub fn filter_map<Next, F>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
235 where
236 Next: Send + 'static,
237 F: Fn(Out) -> Option<Next> + Send + Sync + 'static,
238 {
239 FlowWithContext::from_flow(
240 self.delegate
241 .filter_map(move |(out, ctx)| f(out).map(|item| (item, ctx))),
242 )
243 }
244
245 pub fn map_concat<Next, F, I>(self, f: F) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
246 where
247 Next: Send + 'static,
248 F: Fn(Out) -> I + Send + Sync + 'static,
249 I: IntoIterator<Item = Next>,
250 I::IntoIter: Send + 'static,
251 CtxOut: Clone,
252 {
253 FlowWithContext::from_flow(self.delegate.map_concat(move |(out, ctx)| {
254 let ctx = ctx.clone();
255 f(out).into_iter().map(move |next| (next, ctx.clone()))
256 }))
257 }
258
259 pub fn map_async<Next, F, Fut>(
260 self,
261 parallelism: usize,
262 f: F,
263 ) -> FlowWithContext<In, CtxIn, Next, CtxOut, Mat>
264 where
265 Next: Send + 'static,
266 F: Fn(Out) -> Fut + Send + Sync + 'static,
267 Fut: Future<Output = StreamResult<Next>> + Send + 'static,
268 {
269 FlowWithContext::from_flow(self.delegate.map_async(parallelism, move |(out, ctx)| {
270 f(out).map(|next| next.map(|next| (next, ctx)))
271 }))
272 }
273
274 pub fn map_context<CtxOut2, F>(self, f: F) -> FlowWithContext<In, CtxIn, Out, CtxOut2, Mat>
275 where
276 CtxOut2: Send + 'static,
277 F: Fn(CtxOut) -> CtxOut2 + Send + Sync + 'static,
278 {
279 FlowWithContext::from_flow(self.delegate.map(move |(out, ctx)| (out, f(ctx))))
280 }
281
282 pub fn grouped(self, n: usize) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat> {
283 FlowWithContext::from_flow(self.delegate.grouped(n).map(unzip_pairs))
284 }
285
286 pub fn sliding(
287 self,
288 n: usize,
289 step: usize,
290 ) -> FlowWithContext<In, CtxIn, Vec<Out>, Vec<CtxOut>, Mat>
291 where
292 Out: Clone,
293 CtxOut: Clone,
294 {
295 FlowWithContext::from_flow(self.delegate.sliding(n, step).map(unzip_pairs))
296 }
297
298 pub fn via<Out2, Ctx2, FlowMat>(
299 self,
300 flow: FlowWithContext<Out, CtxOut, Out2, Ctx2, FlowMat>,
301 ) -> FlowWithContext<In, CtxIn, Out2, Ctx2, Mat>
302 where
303 Out2: Send + 'static,
304 Ctx2: Send + 'static,
305 FlowMat: Send + 'static,
306 {
307 FlowWithContext::from_flow(self.delegate.via(flow.delegate))
308 }
309
310 pub fn to<SinkMat>(self, sink: Sink<(Out, CtxOut), SinkMat>) -> Sink<(In, CtxIn), Mat>
311 where
312 SinkMat: Send + 'static,
313 {
314 self.delegate.to(sink)
315 }
316
317 pub fn to_mat<SinkMat, Combined, F>(
318 self,
319 sink: Sink<(Out, CtxOut), SinkMat>,
320 combine: F,
321 ) -> Sink<(In, CtxIn), Combined>
322 where
323 SinkMat: Send + 'static,
324 Combined: Send + 'static,
325 F: Fn(Mat, SinkMat) -> Combined + Send + Sync + 'static,
326 {
327 self.delegate.to_mat(sink, combine)
328 }
329}
330
331fn unzip_pairs<Out, Ctx>(pairs: Vec<(Out, Ctx)>) -> (Vec<Out>, Vec<Ctx>) {
332 let mut outs = Vec::with_capacity(pairs.len());
333 let mut ctxs = Vec::with_capacity(pairs.len());
334
335 for (out, ctx) in pairs {
336 outs.push(out);
337 ctxs.push(ctx);
338 }
339
340 (outs, ctxs)
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use std::{thread, time::Duration};
347
348 #[test]
349 fn source_with_context_preserves_context_for_map_and_filter() {
350 let values = Source::from_iter(0_i32..6)
351 .as_source_with_context(|item| item + 100)
352 .map(|item| item + 1)
353 .filter(|item| item % 2 == 0)
354 .filter_not(|item| *item == 4)
355 .run_collect()
356 .unwrap();
357
358 assert_eq!(values, vec![(2, 101), (6, 105)]);
359 }
360
361 #[test]
362 fn source_with_context_filters_context_with_map_filter() {
363 let values = Source::from_iter(1_i32..5)
364 .as_source_with_context(|item| item * 10)
365 .filter(|item| item % 2 == 1)
366 .run_collect()
367 .unwrap();
368
369 assert_eq!(values, vec![(1, 10), (3, 30)]);
370 }
371
372 #[test]
373 fn source_with_context_map_context_transform_is_supported() {
374 let values = Source::from_iter(1_i32..4)
375 .as_source_with_context(|item| *item)
376 .map_context(|ctx| ctx + 10)
377 .run_collect()
378 .unwrap();
379
380 assert_eq!(values, vec![(1, 11), (2, 12), (3, 13)]);
381 }
382
383 #[test]
384 fn source_with_context_map_concat_duplicates_context() {
385 let values = Source::from_iter([1_i32, 2])
386 .as_source_with_context(|item| item + 10)
387 .map_concat(|item| vec![item + 1, item + 2])
388 .run_collect()
389 .unwrap();
390
391 assert_eq!(values, vec![(2, 11), (3, 11), (3, 12), (4, 12)]);
392 }
393
394 #[test]
395 fn source_with_context_groups_context_vectors_with_grouped_and_sliding() {
396 let grouped = Source::from_iter(1_i32..4)
397 .as_source_with_context(|item| item + 10)
398 .grouped(2)
399 .run_collect()
400 .unwrap();
401
402 assert_eq!(
403 grouped,
404 vec![(vec![1, 2], vec![11, 12]), (vec![3], vec![13])]
405 );
406
407 let sliding = Source::from_iter([1_i32, 2, 3, 4])
408 .as_source_with_context(|item| item + 10)
409 .sliding(3, 2)
410 .run_collect()
411 .unwrap();
412
413 assert_eq!(
414 sliding,
415 vec![
416 (vec![1, 2, 3], vec![11, 12, 13]),
417 (vec![3, 4], vec![13, 14])
418 ]
419 );
420 }
421
422 #[test]
423 fn source_with_context_map_async_keeps_context_with_out_of_order_completions() {
424 let values = Source::from_iter([3_i32, 1, 2, 0])
425 .as_source_with_context(|item| item + 100)
426 .map_async(2, |item| async move {
427 if item % 2 == 0 {
428 thread::sleep(Duration::from_millis(20));
429 } else {
430 thread::sleep(Duration::from_millis(2));
431 }
432 Ok(item * 2)
433 })
434 .run_collect()
435 .unwrap();
436
437 assert_eq!(values, vec![(6, 103), (2, 101), (4, 102), (0, 100)]);
438 }
439
440 #[test]
441 fn source_with_context_filter_and_map_with_string_elements() {
442 let input = ["a".to_string(), "bravo".to_string(), "charlie".to_string()];
443
444 let source_values = Source::from_iter(input.to_vec())
445 .as_source_with_context(|item| format!("ctx-{item}"))
446 .filter(|item| item.len() >= 5)
447 .map(|item| format!("mapped:{item}"))
448 .run_collect()
449 .unwrap();
450
451 assert_eq!(
452 source_values,
453 vec![
454 ("mapped:bravo".to_string(), "ctx-bravo".to_string()),
455 ("mapped:charlie".to_string(), "ctx-charlie".to_string()),
456 ]
457 );
458
459 let flow_values = Source::from_iter(input)
460 .as_source_with_context(|item| format!("ctx-{item}"))
461 .via(
462 FlowWithContext::<String, String, String, String, NotUsed>::identity()
463 .filter(|item| item.len() >= 5)
464 .map(|item| format!("mapped:{item}")),
465 )
466 .run_collect()
467 .unwrap();
468
469 assert_eq!(flow_values, source_values);
470 }
471
472 #[test]
473 fn source_with_context_via_context_flow_and_to_mat() {
474 let flow = FlowWithContext::<i32, i32, i32, i32, NotUsed>::from_flow(
475 Flow::identity().map(|(value, ctx)| (value * 2, ctx * 2)),
476 );
477
478 let sink = Sink::collect();
479 let completion = Source::from_iter([1_i32, 2, 3])
480 .as_source_with_context(|item| item + 10)
481 .via(flow)
482 .to_mat(sink, |_, mat| mat)
483 .run()
484 .unwrap();
485
486 assert_eq!(completion.wait().unwrap(), vec![(2, 22), (4, 24), (6, 26)]);
487 }
488
489 #[test]
490 fn source_with_context_as_source_to_runs_and_to_mat_work() {
491 let source = Source::from_iter([1_i32, 2, 3]).as_source_with_context(|item| item + 10);
492
493 assert_eq!(source.clone().to(Sink::ignore()).run(), Ok(NotUsed));
494
495 let pair_sink = source
496 .to_mat(Sink::collect(), |_, pairs| pairs)
497 .run()
498 .unwrap();
499 assert_eq!(pair_sink.wait().unwrap(), vec![(1, 11), (2, 12), (3, 13)]);
500 }
501}