1mod fn_factory;
4mod future_factory;
5
6pub use fn_factory::*;
7pub use future_factory::*;
8
9use crate::{
10 common::*,
11 config::ParParams,
12 index_stream::{IndexStreamExt as _, ReorderEnumerated},
13 par_stream::ParStreamExt as _,
14 rt, utils,
15};
16use flume::r#async::RecvStream;
17use tokio::sync::broadcast;
18
19pub type UnorderedStream<T> = RecvStream<'static, T>;
20pub type OrderedStream<T> = ReorderEnumerated<RecvStream<'static, (usize, T)>, T>;
21
22pub struct ParBuilder<St>
24where
25 St: ?Sized + Stream,
26{
27 stream: St,
28}
29
30pub struct ParAsyncBuilder<St, Fac>
32where
33 St: ?Sized + Stream,
34 St::Item: 'static + Send,
35 Fac: FutureFactory<St::Item>,
36 Fac::Fut: 'static + Send + Future,
37 <Fac::Fut as Future>::Output: Send,
38{
39 fac: Fac,
40 stream: St,
41}
42
43pub struct ParAsyncTailBlockBuilder<St, FutFac, FnFac, Out>
45where
46 St: ?Sized + Stream,
47 St::Item: 'static + Send,
48 FutFac: FutureFactory<St::Item>,
49 FutFac::Fut: 'static + Send + Future,
50 <FutFac::Fut as Future>::Output: Send,
51 FnFac: FnFactory<<FutFac::Fut as Future>::Output, Out>,
52 FnFac::Fn: 'static + Send + FnOnce() -> Out,
53 Out: 'static + Send,
54{
55 fut_fac: FutFac,
56 fn_fac: FnFac,
57 _phantom: PhantomData<Out>,
58 stream: St,
59}
60
61pub struct ParBlockingBuilder<St, Fac, Out>
63where
64 St: ?Sized + Stream,
65 St::Item: 'static + Send,
66 Fac: FnFactory<St::Item, Out>,
67 Fac::Fn: 'static + Send + FnOnce() -> Out,
68 Out: 'static + Send,
69{
70 fac: Fac,
71 _phantom: PhantomData<Out>,
72 stream: St,
73}
74
75impl<St> ParBuilder<St>
76where
77 St: Stream,
78{
79 pub fn new(stream: St) -> Self {
81 Self { stream }
82 }
83
84 pub fn map_async<Fut, Fac>(self, fac: Fac) -> ParAsyncBuilder<St, Fac>
86 where
87 St::Item: Send,
88 Fac: 'static + Send + FnMut(St::Item) -> Fut,
89 Fut: 'static + Send + Future,
90 Fut::Output: Send,
91 {
92 let Self { stream } = self;
93
94 ParAsyncBuilder { fac, stream }
95 }
96
97 pub fn map_blocking<Fac, Func, Out>(self, fac: Fac) -> ParBlockingBuilder<St, Fac, Out>
99 where
100 St::Item: 'static + Send,
101 Fac: 'static + Send + FnMut(St::Item) -> Func,
102 Func: 'static + Send + FnOnce() -> Out,
103 Out: Send,
104 {
105 let Self { stream } = self;
106
107 ParBlockingBuilder {
108 fac,
109 stream,
110 _phantom: PhantomData,
111 }
112 }
113}
114
115impl<St, Fac> ParAsyncBuilder<St, Fac>
116where
117 St: Stream,
118 St::Item: 'static + Send,
119 Fac: 'static + Send + FutureFactory<St::Item>,
120 Fac::Fut: 'static + Send + Future,
121 <Fac::Fut as Future>::Output: 'static + Send,
122{
123 pub fn map_async<NewFac, NewFut>(
125 self,
126 new_fac: NewFac,
127 ) -> ParAsyncBuilder<St, ComposeFutureFactory<St::Item, Fac, NewFac>>
128 where
129 NewFac: 'static + Send + Clone + FnMut(<Fac::Fut as Future>::Output) -> NewFut,
130 NewFut: 'static + Send + Future,
131 NewFut::Output: 'static + Send,
132 {
133 let Self {
134 fac: orig_fac,
135 stream,
136 ..
137 } = self;
138
139 ParAsyncBuilder {
140 fac: orig_fac.compose(new_fac),
141 stream,
142 }
143 }
144
145 pub fn map_blocking<NewOut, NewFac, NewFunc>(
150 self,
151 new_fac: NewFac,
152 ) -> ParAsyncTailBlockBuilder<St, Fac, NewFac, NewOut>
153 where
154 NewFac: 'static + Send + Clone + FnMut(<Fac::Fut as Future>::Output) -> NewFunc,
155 NewFunc: 'static + Send + FnOnce() -> NewOut,
156 NewFunc::Output: 'static + Send,
157 {
158 let Self {
159 fac: orig_fac,
160 stream,
161 ..
162 } = self;
163
164 ParAsyncTailBlockBuilder {
165 fut_fac: orig_fac,
166 fn_fac: new_fac,
167 _phantom: PhantomData,
168 stream,
169 }
170 }
171
172 pub fn build_unordered_stream<P>(
174 self,
175 params: P,
176 ) -> UnorderedStream<<Fac::Fut as Future>::Output>
177 where
178 St: 'static + Send,
179 P: Into<ParParams>,
180 {
181 let Self {
182 mut fac, stream, ..
183 } = self;
184 let ParParams {
185 num_workers,
186 buf_size,
187 } = params.into();
188
189 let stream = stream.map(move |item| fac.generate(item)).spawned(buf_size);
190 let (output_tx, output_rx) = utils::channel(buf_size);
191
192 (0..num_workers).for_each(move |_| {
193 let stream = stream.clone();
194 let output_tx = output_tx.clone();
195
196 rt::spawn(async move {
197 let _ = stream
198 .then(|fut| fut)
199 .map(Ok)
200 .forward(output_tx.into_sink())
201 .await;
202 });
203 });
204
205 output_rx.into_stream()
206 }
207
208 pub fn build_ordered_stream<P>(self, params: P) -> OrderedStream<<Fac::Fut as Future>::Output>
210 where
211 St: 'static + Send,
212 P: Into<ParParams>,
213 {
214 let Self {
215 mut fac, stream, ..
216 } = self;
217 let ParParams {
218 num_workers,
219 buf_size,
220 } = params.into();
221
222 let stream = stream
223 .map(move |item| fac.generate(item))
224 .enumerate()
225 .spawned(buf_size);
226 let (output_tx, output_rx) = utils::channel(buf_size);
227
228 (0..num_workers).for_each(move |_| {
229 let stream = stream.clone();
230 let output_tx = output_tx.clone();
231
232 rt::spawn(async move {
233 let _ = stream
234 .then(|(index, fut)| async move { (index, fut.await) })
235 .map(Ok)
236 .forward(output_tx.into_sink())
237 .await;
238 });
239 });
240
241 output_rx.into_stream().reorder_enumerated()
242 }
243}
244
245impl<St, Fac> ParAsyncBuilder<St, Fac>
246where
247 St: 'static + Send + Stream,
248 St::Item: 'static + Send,
249 Fac: 'static + Send + FutureFactory<St::Item>,
250 Fac::Fut: 'static + Send + Future<Output = ()>,
251{
252 pub async fn for_each<P>(self, params: P)
254 where
255 P: Into<ParParams>,
256 {
257 let ParParams {
258 num_workers,
259 buf_size,
260 } = params.into();
261 let Self {
262 mut fac, stream, ..
263 } = self;
264 let stream = stream.map(move |item| fac.generate(item)).spawned(buf_size);
265
266 let worker_futures = (0..num_workers).map(move |_| {
267 let stream = stream.clone();
268
269 rt::spawn(async move {
270 stream.for_each(|fut| fut).await;
271 })
272 });
273
274 future::join_all(worker_futures).await;
275 }
276}
277
278impl<St, Fac, Error> ParAsyncBuilder<St, Fac>
279where
280 St: 'static + Send + Stream,
281 St::Item: 'static + Send,
282 Fac: 'static + Send + FutureFactory<St::Item>,
283 Fac::Fut: 'static + Send + Future<Output = Result<(), Error>>,
284 Error: 'static + Send,
285{
286 pub async fn try_for_each<P>(self, params: P) -> Result<(), Error>
288 where
289 P: Into<ParParams>,
290 {
291 let ParParams {
292 num_workers,
293 buf_size,
294 } = params.into();
295 let Self {
296 mut fac, stream, ..
297 } = self;
298 let (terminate_tx, mut terminate_rx) = broadcast::channel(1);
299 let stream = stream
300 .take_until(async move {
301 let _ = terminate_rx.recv().await;
302 })
303 .map(move |item| fac.generate(item))
304 .spawned(buf_size);
305
306 let worker_futures = (0..num_workers).map(move |_| {
307 let stream = stream.clone();
308 let terminate_tx = terminate_tx.clone();
309
310 rt::spawn(async move {
311 let result = stream.map(Ok).try_for_each(|fut| fut).await;
312
313 if result.is_err() {
314 let _ = terminate_tx.send(());
315 }
316
317 result
318 })
319 });
320
321 future::try_join_all(worker_futures).await?;
322 Ok(())
323 }
324}
325
326impl<St, Fac, Out> ParBlockingBuilder<St, Fac, Out>
327where
328 St: Stream,
329 St::Item: 'static + Send,
330 Fac: 'static + Send + FnFactory<St::Item, Out>,
331 Fac::Fn: 'static + Send + FnOnce() -> Out,
332 Out: 'static + Send,
333{
334 pub fn map_async<NewFac, NewFut>(
336 self,
337 new_fac: NewFac,
338 ) -> ParAsyncBuilder<
339 St,
340 ComposeFutureFactory<St::Item, impl FnMut(St::Item) -> rt::JoinHandle<Out>, NewFac>,
341 >
342 where
343 NewFac: Send + Clone + FnMut(Out) -> NewFut,
344 NewFut: 'static + Send + Future,
345 NewFut::Output: 'static + Send,
346 {
347 let Self {
348 fac: mut orig_fac,
349 stream,
350 ..
351 } = self;
352
353 let orig_fac_async = move |input: St::Item| rt::spawn_blocking(orig_fac.generate(input));
354
355 ParAsyncBuilder {
356 fac: orig_fac_async.compose(new_fac),
357 stream,
358 }
359 }
360
361 pub fn map_blocking<NewOut, NewFac, NewFunc>(
363 self,
364 new_fac: NewFac,
365 ) -> ParBlockingBuilder<St, BoxFnFactory<St::Item, NewOut>, NewOut>
366 where
367 NewFac: 'static + Send + Clone + FnMut(Out) -> NewFunc,
368 NewFunc: 'static + Send + FnOnce() -> NewOut,
369 NewFunc::Output: 'static + Send,
370 {
371 let Self {
372 fac: orig_fac,
373 stream,
374 ..
375 } = self;
376
377 ParBlockingBuilder {
378 fac: orig_fac.chain(new_fac),
379 _phantom: PhantomData,
380 stream,
381 }
382 }
383
384 pub fn build_unordered_stream<P>(self, params: P) -> UnorderedStream<Out>
386 where
387 St: 'static + Send,
388 P: Into<ParParams>,
389 {
390 let Self {
391 mut fac, stream, ..
392 } = self;
393 let ParParams {
394 num_workers,
395 buf_size,
396 } = params.into();
397
398 let stream = stream.map(move |item| fac.generate(item)).spawned(buf_size);
399 let (output_tx, output_rx) = utils::channel(buf_size);
400
401 (0..num_workers).for_each(move |_| {
402 let mut stream = stream.clone();
403 let output_tx = output_tx.clone();
404
405 rt::spawn_blocking(move || {
406 while let Some(func) = rt::block_on(stream.next()) {
407 let output = func();
408 let result = output_tx.send(output);
409 if result.is_err() {
410 break;
411 }
412 }
413 });
414 });
415
416 output_rx.into_stream()
417 }
418
419 pub fn build_ordered_stream<P>(self, params: P) -> OrderedStream<Out>
421 where
422 St: 'static + Send,
423 P: Into<ParParams>,
424 {
425 let Self {
426 mut fac, stream, ..
427 } = self;
428 let ParParams {
429 num_workers,
430 buf_size,
431 } = params.into();
432
433 let stream = stream
434 .map(move |item| fac.generate(item))
435 .enumerate()
436 .spawned(buf_size);
437 let (output_tx, output_rx) = utils::channel(buf_size);
438
439 (0..num_workers).for_each(move |_| {
440 let mut stream = stream.clone();
441 let output_tx = output_tx.clone();
442
443 rt::spawn_blocking(move || {
444 while let Some((index, func)) = rt::block_on(stream.next()) {
445 let output = func();
446 let result = output_tx.send((index, output));
447 if result.is_err() {
448 break;
449 }
450 }
451 });
452 });
453
454 output_rx.into_stream().reorder_enumerated()
455 }
456}
457
458impl<St, Fac> ParBlockingBuilder<St, Fac, ()>
459where
460 St: 'static + Send + Stream,
461 St::Item: 'static + Send,
462 Fac: 'static + Send + FnFactory<St::Item, ()>,
463 Fac::Fn: 'static + Send + FnOnce(),
464{
465 pub async fn for_each<P>(self, params: P)
467 where
468 P: Into<ParParams>,
469 {
470 let Self {
471 mut fac, stream, ..
472 } = self;
473 let ParParams {
474 num_workers,
475 buf_size,
476 } = params.into();
477 let stream = stream.map(move |item| fac.generate(item)).spawned(buf_size);
478
479 let worker_futures = (0..num_workers).map(move |_| {
480 let mut stream = stream.clone();
481
482 rt::spawn_blocking(move || {
483 while let Some(func) = rt::block_on(stream.next()) {
484 func();
485 }
486 })
487 });
488
489 future::join_all(worker_futures).await;
490 }
491}
492
493impl<St, FutFac, FnFac, Out> ParAsyncTailBlockBuilder<St, FutFac, FnFac, Out>
539where
540 St: Stream,
541 St::Item: 'static + Send,
542 FutFac: 'static + Send + FutureFactory<St::Item>,
543 FutFac::Fut: 'static + Send + Future,
544 <FutFac::Fut as Future>::Output: 'static + Send,
545 FnFac: 'static + Send + Clone + FnFactory<<FutFac::Fut as Future>::Output, Out>,
546 FnFac::Fn: 'static + Send + FnOnce() -> Out,
547 Out: 'static + Send,
548{
549 pub fn map_async<NewFac, NewFut>(
551 self,
552 new_fac: NewFac,
553 ) -> ParAsyncBuilder<St, BoxFutureFactory<'static, St::Item, NewFut::Output>>
554 where
555 NewFac: 'static + Send + Clone + FnMut(Out) -> NewFut,
556 NewFut: 'static + Send + Future,
557 NewFut::Output: 'static + Send,
558 {
559 let Self {
560 fut_fac,
561 mut fn_fac,
562 stream,
563 ..
564 } = self;
565
566 let fn_fac_async = move |input: <FutFac::Fut as Future>::Output| {
567 rt::spawn_blocking(fn_fac.generate(input))
568 };
569
570 ParAsyncBuilder {
571 fac: fut_fac.compose(fn_fac_async).compose(new_fac).boxed(),
572 stream,
573 }
574 }
575
576 pub fn map_blocking<NewOut, NewFac, NewFunc>(
578 self,
579 new_fac: NewFac,
580 ) -> ParAsyncTailBlockBuilder<
581 St,
582 FutFac,
583 BoxFnFactory<<FutFac::Fut as Future>::Output, NewOut>,
584 NewOut,
585 >
586 where
587 NewFac: 'static + Send + Clone + FnMut(Out) -> NewFunc,
588 NewFunc: 'static + Send + FnOnce() -> NewOut,
589 NewFunc::Output: 'static + Send,
590 {
591 let Self {
592 fut_fac,
593 fn_fac,
594 stream,
595 ..
596 } = self;
597
598 ParAsyncTailBlockBuilder {
599 fut_fac,
600 fn_fac: fn_fac.chain(new_fac),
601 _phantom: PhantomData,
602 stream,
603 }
604 }
605
606 pub fn build_unordered_stream<P>(self, params: P) -> UnorderedStream<Out>
608 where
609 St: 'static + Send,
610 P: Into<ParParams>,
611 {
612 self.into_async_builder().build_unordered_stream(params)
613 }
614
615 pub fn build_ordered_stream<P>(self, params: P) -> OrderedStream<Out>
617 where
618 St: 'static + Send,
619 P: Into<ParParams>,
620 {
621 self.into_async_builder().build_ordered_stream(params)
622 }
623
624 fn into_async_builder(self) -> ParAsyncBuilder<St, BoxFutureFactory<'static, St::Item, Out>> {
625 let Self {
626 fut_fac,
627 mut fn_fac,
628 stream,
629 ..
630 } = self;
631
632 let fn_fac_async = move |input: <FutFac::Fut as Future>::Output| {
633 rt::spawn_blocking(fn_fac.generate(input))
634 };
635
636 ParAsyncBuilder {
637 fac: fut_fac.compose(fn_fac_async).boxed(),
638 stream,
639 }
640 }
641}
642
643impl<St, FutFac, FnFac> ParAsyncTailBlockBuilder<St, FutFac, FnFac, ()>
644where
645 St: 'static + Send + Stream,
646 St::Item: 'static + Send,
647 FutFac: 'static + Send + FutureFactory<St::Item>,
648 FutFac::Fut: 'static + Send + Future,
649 <FutFac::Fut as Future>::Output: 'static + Send,
650 FnFac: 'static + Send + Clone + FnFactory<<FutFac::Fut as Future>::Output, ()>,
651 FnFac::Fn: 'static + Send + FnOnce(),
652{
653 pub async fn for_each<P>(self, params: P)
655 where
656 P: Into<ParParams>,
657 {
658 self.into_async_builder().for_each(params).await;
659 }
660}
661
662#[cfg(test)]
682mod tests {
683 use super::*;
684 use crate::utils::async_test;
685
686 async_test! {
687 async fn par_builder_blocking_test() {
688 let vec: Vec<_> = stream::iter(1u64..=1000)
689 .par_builder()
690 .map_blocking(|val| move || val.pow(5))
691 .map_blocking(|val| move || val + 1)
692 .build_ordered_stream(None)
693 .collect()
694 .await;
695 let expect: Vec<_> = (1u64..=1000).map(|val| val.pow(5) + 1).collect();
696
697 assert_eq!(vec, expect);
698 }
699
700 async fn par_builder_async_test() {
701 let vec: Vec<_> = stream::iter(1u64..=1000)
702 .par_builder()
703 .map_async(|val| async move { val.pow(5) })
704 .map_async(|val| async move { val + 1 })
705 .build_ordered_stream(None)
706 .collect()
707 .await;
708 let expect: Vec<_> = (1u64..=1000).map(|val| val.pow(5) + 1).collect();
709
710 assert_eq!(vec, expect);
711 }
712
713 async fn par_builder_mixed_async_blocking_test() {
714 {
715 let vec: Vec<_> = stream::iter(1u64..=1000)
716 .par_builder()
717 .map_async(|val| async move { val.pow(5) })
718 .map_blocking(|val| move || val + 1)
719 .build_ordered_stream(None)
720 .collect()
721 .await;
722 let expect: Vec<_> = (1u64..=1000).map(|val| val.pow(5) + 1).collect();
723
724 assert_eq!(vec, expect);
725 }
726
727 {
728 let vec: Vec<_> = stream::iter(1u64..=1000)
729 .par_builder()
730 .map_blocking(|val| move || val.pow(5))
731 .map_async(|val| async move { val + 1 })
732 .build_ordered_stream(None)
733 .collect()
734 .await;
735 let expect: Vec<_> = (1u64..=1000).map(|val| val.pow(5) + 1).collect();
736
737 assert_eq!(vec, expect);
738 }
739
740 {
741 let vec: Vec<_> = stream::iter(1u64..=1000)
742 .par_builder()
743 .map_blocking(|val| move || val.pow(5))
744 .map_async(|val| async move { val + 1 })
745 .map_blocking(|val| move || val / 2)
746 .build_ordered_stream(None)
747 .collect()
748 .await;
749 let expect: Vec<_> = (1u64..=1000).map(|val| (val.pow(5) + 1) / 2).collect();
750
751 assert_eq!(vec, expect);
752 }
753
754 {
755 let vec: Vec<_> = stream::iter(1u64..=1000)
756 .par_builder()
757 .map_async(|val| async move { val.pow(5) })
758 .map_blocking(|val| move || val + 1)
759 .map_async(|val| async move { val / 2 })
760 .build_ordered_stream(None)
761 .collect()
762 .await;
763 let expect: Vec<_> = (1u64..=1000).map(|val| (val.pow(5) + 1) / 2).collect();
764
765 assert_eq!(vec, expect);
766 }
767 }
768
769 }
776}