1#![allow(unsafe_code)]
3#![allow(clippy::type_complexity, clippy::too_many_lines)]
4
5use derive_new::new;
6use futures::{pin_mut, ready, stream, Stream, StreamExt as _};
7use pin_project::pin_project;
8use serde::{Deserialize, Serialize};
9use std::{
10 marker::PhantomData, pin::Pin, task::{Context, Poll}
11};
12use sum::Sum2;
13
14use super::{
15 DistributedPipe, DistributedSink, ParallelPipe, ParallelSink, PipeTask, ReduceA2, ReduceC2
16};
17use crate::{
18 par_stream::{ParallelStream, StreamTask}, pipe::Pipe, util::transmute
19};
20
21#[pin_project]
22#[derive(new)]
23#[must_use]
24pub struct Fork<A, B, C, RefAItem> {
25 #[pin]
26 a: A,
27 b: B,
28 c: C,
29 marker: PhantomData<fn() -> RefAItem>,
30}
31
32impl_par_dist! {
33 impl<A, B, C, RefAItem> ParallelStream for Fork<A, B, C, RefAItem>
34 where
35 A: ParallelStream,
36 B: ParallelPipe<A::Item>,
37 C: ParallelPipe<RefAItem>,
38 RefAItem: 'static,
39 {
40 type Item = Sum2<B::Output, C::Output>;
41 type Task = JoinTask<A::Task, B::Task, C::Task, RefAItem>;
42
43 #[inline(always)]
44 fn size_hint(&self) -> (usize, Option<usize>) {
45 self.a.size_hint()
46 }
47 #[inline(always)]
48 fn next_task(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Task>> {
49 let self_ = self.project();
50 let b = self_.b;
51 let c = self_.c;
52 self_.a.next_task(cx).map(|task| {
53 task.map(|task| JoinTask {
54 stream: task,
55 pipe: b.task(),
56 pipe_ref: c.task(),
57 marker: PhantomData,
58 })
59 })
60 }
61 }
62 impl<A, B, C, Input, RefAItem> ParallelPipe<Input> for Fork<A, B, C, RefAItem>
63 where
64 A: ParallelPipe<Input>,
65 B: ParallelPipe<A::Output>,
66 C: ParallelPipe<RefAItem>,
67 RefAItem: 'static,
68 {
69 type Output = Sum2<B::Output, C::Output>;
70 type Task = JoinTask<A::Task, B::Task, C::Task, RefAItem>;
71
72 #[inline(always)]
73 fn task(&self) -> Self::Task {
74 let stream = self.a.task();
75 let pipe = self.b.task();
76 let pipe_ref = self.c.task();
77 JoinTask {
78 stream,
79 pipe,
80 pipe_ref,
81 marker: PhantomData,
82 }
83 }
84 }
85}
86
87impl<A, B, C, Item, RefAItem> ParallelSink<Item> for Fork<A, B, C, RefAItem>
88where
89 A: ParallelPipe<Item>,
90 B: ParallelSink<A::Output>,
91 C: ParallelSink<RefAItem>,
92 RefAItem: 'static,
93{
94 type Done = (B::Done, C::Done);
95 type Pipe = Fork<A, B::Pipe, C::Pipe, RefAItem>;
96 type ReduceA = ReduceA2<B::ReduceA, C::ReduceA>;
97 type ReduceC = ReduceC2<B::ReduceC, C::ReduceC>;
98
99 #[inline(always)]
100 fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceC) {
101 let (iterator_a, reducer_a_a, reducer_a_c) = self.b.reducers();
102 let (iterator_b, reducer_b_a, reducer_b_c) = self.c.reducers();
103 (
104 Fork::new(self.a, iterator_a, iterator_b),
105 ReduceA2::new(reducer_a_a, reducer_b_a),
106 ReduceC2::new(reducer_a_c, reducer_b_c),
107 )
108 }
109}
110impl<A, B, C, Item, RefAItem> DistributedSink<Item> for Fork<A, B, C, RefAItem>
111where
112 A: DistributedPipe<Item>,
113 B: DistributedSink<A::Output>,
114 C: DistributedSink<RefAItem>,
115 RefAItem: 'static,
116{
117 type Done = (B::Done, C::Done);
118 type Pipe = Fork<A, B::Pipe, C::Pipe, RefAItem>;
119 type ReduceA = ReduceA2<B::ReduceA, C::ReduceA>;
120 type ReduceB = ReduceC2<B::ReduceB, C::ReduceB>;
121 type ReduceC = ReduceC2<B::ReduceC, C::ReduceC>;
122
123 #[inline(always)]
124 fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceB, Self::ReduceC) {
125 let (iterator_a, reducer_a_a, reducer_a_b, reducer_a_c) = self.b.reducers();
126 let (iterator_b, reducer_b_a, reducer_b_b, reducer_b_c) = self.c.reducers();
127 (
128 Fork::new(self.a, iterator_a, iterator_b),
129 ReduceA2::new(reducer_a_a, reducer_b_a),
130 ReduceC2::new(reducer_a_b, reducer_b_b),
131 ReduceC2::new(reducer_a_c, reducer_b_c),
132 )
133 }
134}
135
136#[derive(Serialize, Deserialize)]
137#[serde(
138 bound(serialize = "A: Serialize, B: Serialize, C: Serialize"),
139 bound(deserialize = "A: Deserialize<'de>, B: Deserialize<'de>, C: Deserialize<'de>")
140)]
141pub struct JoinTask<A, B, C, RefAItem> {
142 stream: A,
143 pipe: B,
144 pipe_ref: C,
145 marker: PhantomData<fn() -> RefAItem>,
146}
147impl<A, B, C, RefAItem> StreamTask for JoinTask<A, B, C, RefAItem>
148where
149 A: StreamTask,
150 B: PipeTask<A::Item>,
151 C: PipeTask<RefAItem>,
152{
153 type Item = Sum2<B::Output, C::Output>;
154 type Async = JoinStreamTaskAsync<A::Async, B::Async, C::Async, RefAItem, A::Item>;
155
156 fn into_async(self) -> Self::Async {
157 JoinStreamTaskAsync {
158 stream: self.stream.into_async(),
159 pipe: Some(self.pipe.into_async()),
160 pipe_ref: Some(self.pipe_ref.into_async()),
161 ref_given: false,
162 pending: None,
163 marker: PhantomData,
164 }
165 }
166}
167impl<A, B, C, Input, RefAItem> PipeTask<Input> for JoinTask<A, B, C, RefAItem>
168where
169 A: PipeTask<Input>,
170 B: PipeTask<A::Output>,
171 C: PipeTask<RefAItem>,
172{
173 type Output = Sum2<B::Output, C::Output>;
174 type Async = JoinStreamTaskAsync<A::Async, B::Async, C::Async, RefAItem, A::Output>;
175
176 fn into_async(self) -> Self::Async {
177 JoinStreamTaskAsync {
178 stream: self.stream.into_async(),
179 pipe: Some(self.pipe.into_async()),
180 pipe_ref: Some(self.pipe_ref.into_async()),
181 ref_given: false,
182 pending: None,
183 marker: PhantomData,
184 }
185 }
186}
187
188#[pin_project(project = JoinStreamTaskAsyncProj)]
189pub struct JoinStreamTaskAsync<A, B, C, RefAItem, T> {
190 #[pin]
191 stream: A,
192 #[pin]
193 pipe: Option<B>,
194 #[pin]
195 pipe_ref: Option<C>,
196 ref_given: bool,
197 pending: Option<Option<T>>,
198 marker: PhantomData<fn() -> RefAItem>,
199}
200
201impl<'a, A, B, C, AItem, RefAItem> JoinStreamTaskAsyncProj<'a, A, B, C, RefAItem, AItem>
202where
203 B: Pipe<AItem>,
204 C: Pipe<RefAItem>,
205{
206 fn poll(&mut self, cx: &mut Context) -> Option<Poll<Option<Sum2<B::Output, C::Output>>>> {
208 if let pending @ Some(_) = self.pending.as_mut().unwrap() {
209 let ref_given = &mut *self.ref_given;
210 {
211 let waker = cx.waker();
212 let stream = stream::poll_fn(|cx| {
213 if !*ref_given {
214 *ref_given = true;
215 return Poll::Ready(Some(unsafe { transmute(pending.as_ref()) }));
216 }
217 let waker_ = cx.waker();
218 if !waker.will_wake(waker_) {
219 waker_.wake_by_ref();
220 }
221 Poll::Pending
222 })
223 .fuse();
224 pin_mut!(stream);
225 match self
226 .pipe_ref
227 .as_mut()
228 .as_pin_mut()
229 .map(|pipe_ref| pipe_ref.poll_next(cx, stream))
230 {
231 Some(Poll::Ready(Some(item))) => {
232 return Some(Poll::Ready(Some(Sum2::B(unsafe { transmute(item) }))))
233 }
234 Some(Poll::Ready(None)) | None => {
235 self.pipe_ref.set(None);
236 *ref_given = true
237 }
238 Some(Poll::Pending) => (),
239 }
240 }
241 let mut item_given = false;
242 {
243 let waker = cx.waker();
244 let stream = stream::poll_fn(|cx| {
245 if *ref_given {
246 item_given = true;
247 return Poll::Ready(Some(pending.take().unwrap()));
248 }
249 let waker_ = cx.waker();
250 if !waker.will_wake(waker_) {
251 waker_.wake_by_ref();
252 }
253 Poll::Pending
254 })
255 .fuse();
256 pin_mut!(stream);
257 let res = self
258 .pipe
259 .as_mut()
260 .as_pin_mut()
261 .map(|pipe| pipe.poll_next(cx, stream));
262 if item_given || matches!(res, Some(Poll::Ready(None)) | None) {
263 *ref_given = false;
264 *self.pending = None;
265 }
266 match res {
267 Some(Poll::Ready(Some(item))) => {
268 return Some(Poll::Ready(Some(Sum2::A(item))));
269 }
270 Some(Poll::Ready(None)) | None => self.pipe.set(None),
271 Some(Poll::Pending) => (),
272 }
273 }
274 if self.pending.is_some() {
275 return Some(Poll::Pending);
276 }
277 None
278 } else {
279 let stream = stream::empty();
280 pin_mut!(stream);
281 match self
282 .pipe_ref
283 .as_mut()
284 .as_pin_mut()
285 .map(|pipe_ref| pipe_ref.poll_next(cx, stream))
286 {
287 Some(Poll::Ready(Some(item))) => {
288 return Some(Poll::Ready(Some(Sum2::B(unsafe { transmute(item) }))))
289 }
290 Some(Poll::Ready(None)) => self.pipe_ref.set(None),
291 Some(Poll::Pending) | None => (),
292 }
293 let stream = stream::empty();
294 pin_mut!(stream);
295 match self
296 .pipe
297 .as_mut()
298 .as_pin_mut()
299 .map(|pipe| pipe.poll_next(cx, stream))
300 {
301 Some(Poll::Ready(Some(item))) => return Some(Poll::Ready(Some(Sum2::A(item)))),
302 Some(Poll::Ready(None)) => self.pipe.set(None),
303 Some(Poll::Pending) | None => (),
304 }
305 Some(if self.pipe_ref.is_none() && self.pipe.is_none() {
306 Poll::Ready(None)
307 } else {
308 Poll::Pending
309 })
310 }
311 }
312}
313
314impl<A, B, C, RefAItem> Stream for JoinStreamTaskAsync<A, B, C, RefAItem, A::Item>
315where
316 A: Stream,
317 B: Pipe<A::Item>,
318 C: Pipe<RefAItem>,
319{
320 type Item = Sum2<B::Output, C::Output>;
321
322 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
323 let mut self_ = self.project();
324 loop {
325 if self_.pending.is_none() {
326 *self_.pending = Some(ready!(self_.stream.as_mut().poll_next(cx)));
327 }
328 if let Some(ret) = self_.poll(cx) {
329 break ret;
330 }
331 }
332 }
333}
334
335impl<A, B, C, Input, RefAItem> Pipe<Input> for JoinStreamTaskAsync<A, B, C, RefAItem, A::Output>
336where
337 A: Pipe<Input>,
338 B: Pipe<A::Output>,
339 C: Pipe<RefAItem>,
340{
341 type Output = Sum2<B::Output, C::Output>;
342
343 fn poll_next(
344 self: Pin<&mut Self>, cx: &mut Context, mut stream: Pin<&mut impl Stream<Item = Input>>,
345 ) -> Poll<Option<Self::Output>> {
346 let mut self_ = self.project();
347 loop {
348 if self_.pending.is_none() {
349 *self_.pending = Some(ready!(self_.stream.as_mut().poll_next(cx, stream.as_mut())));
350 }
351 if let Some(ret) = self_.poll(cx) {
352 break ret;
353 }
354 }
355 }
356}