amadeus_core/par_sink/
fork.rs

1// TODO: document why this is sound
2#![allow(unsafe_code)]
3#![allow(clippy::type_complexity, clippy::too_many_lines)]
4
5use derive_new::new;
6use futures::{pin_mut, ready, stream, Stream, StreamExt as _};
7use pin_project::pin_project;
8use serde::{Deserialize, Serialize};
9use std::{
10	marker::PhantomData, pin::Pin, task::{Context, Poll}
11};
12use sum::Sum2;
13
14use super::{
15	DistributedPipe, DistributedSink, ParallelPipe, ParallelSink, PipeTask, ReduceA2, ReduceC2
16};
17use crate::{
18	par_stream::{ParallelStream, StreamTask}, pipe::Pipe, util::transmute
19};
20
21#[pin_project]
22#[derive(new)]
23#[must_use]
24pub struct Fork<A, B, C, RefAItem> {
25	#[pin]
26	a: A,
27	b: B,
28	c: C,
29	marker: PhantomData<fn() -> RefAItem>,
30}
31
32impl_par_dist! {
33	impl<A, B, C, RefAItem> ParallelStream for Fork<A, B, C, RefAItem>
34	where
35		A: ParallelStream,
36		B: ParallelPipe<A::Item>,
37		C: ParallelPipe<RefAItem>,
38		RefAItem: 'static,
39	{
40		type Item = Sum2<B::Output, C::Output>;
41		type Task = JoinTask<A::Task, B::Task, C::Task, RefAItem>;
42
43		#[inline(always)]
44		fn size_hint(&self) -> (usize, Option<usize>) {
45			self.a.size_hint()
46		}
47		#[inline(always)]
48		fn next_task(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Task>> {
49			let self_ = self.project();
50			let b = self_.b;
51			let c = self_.c;
52			self_.a.next_task(cx).map(|task| {
53				task.map(|task| JoinTask {
54					stream: task,
55					pipe: b.task(),
56					pipe_ref: c.task(),
57					marker: PhantomData,
58				})
59			})
60		}
61	}
62	impl<A, B, C, Input, RefAItem> ParallelPipe<Input> for Fork<A, B, C, RefAItem>
63	where
64		A: ParallelPipe<Input>,
65		B: ParallelPipe<A::Output>,
66		C: ParallelPipe<RefAItem>,
67		RefAItem: 'static,
68	{
69		type Output = Sum2<B::Output, C::Output>;
70		type Task = JoinTask<A::Task, B::Task, C::Task, RefAItem>;
71
72		#[inline(always)]
73		fn task(&self) -> Self::Task {
74			let stream = self.a.task();
75			let pipe = self.b.task();
76			let pipe_ref = self.c.task();
77			JoinTask {
78				stream,
79				pipe,
80				pipe_ref,
81				marker: PhantomData,
82			}
83		}
84	}
85}
86
87impl<A, B, C, Item, RefAItem> ParallelSink<Item> for Fork<A, B, C, RefAItem>
88where
89	A: ParallelPipe<Item>,
90	B: ParallelSink<A::Output>,
91	C: ParallelSink<RefAItem>,
92	RefAItem: 'static,
93{
94	type Done = (B::Done, C::Done);
95	type Pipe = Fork<A, B::Pipe, C::Pipe, RefAItem>;
96	type ReduceA = ReduceA2<B::ReduceA, C::ReduceA>;
97	type ReduceC = ReduceC2<B::ReduceC, C::ReduceC>;
98
99	#[inline(always)]
100	fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceC) {
101		let (iterator_a, reducer_a_a, reducer_a_c) = self.b.reducers();
102		let (iterator_b, reducer_b_a, reducer_b_c) = self.c.reducers();
103		(
104			Fork::new(self.a, iterator_a, iterator_b),
105			ReduceA2::new(reducer_a_a, reducer_b_a),
106			ReduceC2::new(reducer_a_c, reducer_b_c),
107		)
108	}
109}
110impl<A, B, C, Item, RefAItem> DistributedSink<Item> for Fork<A, B, C, RefAItem>
111where
112	A: DistributedPipe<Item>,
113	B: DistributedSink<A::Output>,
114	C: DistributedSink<RefAItem>,
115	RefAItem: 'static,
116{
117	type Done = (B::Done, C::Done);
118	type Pipe = Fork<A, B::Pipe, C::Pipe, RefAItem>;
119	type ReduceA = ReduceA2<B::ReduceA, C::ReduceA>;
120	type ReduceB = ReduceC2<B::ReduceB, C::ReduceB>;
121	type ReduceC = ReduceC2<B::ReduceC, C::ReduceC>;
122
123	#[inline(always)]
124	fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceB, Self::ReduceC) {
125		let (iterator_a, reducer_a_a, reducer_a_b, reducer_a_c) = self.b.reducers();
126		let (iterator_b, reducer_b_a, reducer_b_b, reducer_b_c) = self.c.reducers();
127		(
128			Fork::new(self.a, iterator_a, iterator_b),
129			ReduceA2::new(reducer_a_a, reducer_b_a),
130			ReduceC2::new(reducer_a_b, reducer_b_b),
131			ReduceC2::new(reducer_a_c, reducer_b_c),
132		)
133	}
134}
135
136#[derive(Serialize, Deserialize)]
137#[serde(
138	bound(serialize = "A: Serialize, B: Serialize, C: Serialize"),
139	bound(deserialize = "A: Deserialize<'de>, B: Deserialize<'de>, C: Deserialize<'de>")
140)]
141pub struct JoinTask<A, B, C, RefAItem> {
142	stream: A,
143	pipe: B,
144	pipe_ref: C,
145	marker: PhantomData<fn() -> RefAItem>,
146}
147impl<A, B, C, RefAItem> StreamTask for JoinTask<A, B, C, RefAItem>
148where
149	A: StreamTask,
150	B: PipeTask<A::Item>,
151	C: PipeTask<RefAItem>,
152{
153	type Item = Sum2<B::Output, C::Output>;
154	type Async = JoinStreamTaskAsync<A::Async, B::Async, C::Async, RefAItem, A::Item>;
155
156	fn into_async(self) -> Self::Async {
157		JoinStreamTaskAsync {
158			stream: self.stream.into_async(),
159			pipe: Some(self.pipe.into_async()),
160			pipe_ref: Some(self.pipe_ref.into_async()),
161			ref_given: false,
162			pending: None,
163			marker: PhantomData,
164		}
165	}
166}
167impl<A, B, C, Input, RefAItem> PipeTask<Input> for JoinTask<A, B, C, RefAItem>
168where
169	A: PipeTask<Input>,
170	B: PipeTask<A::Output>,
171	C: PipeTask<RefAItem>,
172{
173	type Output = Sum2<B::Output, C::Output>;
174	type Async = JoinStreamTaskAsync<A::Async, B::Async, C::Async, RefAItem, A::Output>;
175
176	fn into_async(self) -> Self::Async {
177		JoinStreamTaskAsync {
178			stream: self.stream.into_async(),
179			pipe: Some(self.pipe.into_async()),
180			pipe_ref: Some(self.pipe_ref.into_async()),
181			ref_given: false,
182			pending: None,
183			marker: PhantomData,
184		}
185	}
186}
187
188#[pin_project(project = JoinStreamTaskAsyncProj)]
189pub struct JoinStreamTaskAsync<A, B, C, RefAItem, T> {
190	#[pin]
191	stream: A,
192	#[pin]
193	pipe: Option<B>,
194	#[pin]
195	pipe_ref: Option<C>,
196	ref_given: bool,
197	pending: Option<Option<T>>,
198	marker: PhantomData<fn() -> RefAItem>,
199}
200
201impl<'a, A, B, C, AItem, RefAItem> JoinStreamTaskAsyncProj<'a, A, B, C, RefAItem, AItem>
202where
203	B: Pipe<AItem>,
204	C: Pipe<RefAItem>,
205{
206	// TODO: fairness
207	fn poll(&mut self, cx: &mut Context) -> Option<Poll<Option<Sum2<B::Output, C::Output>>>> {
208		if let pending @ Some(_) = self.pending.as_mut().unwrap() {
209			let ref_given = &mut *self.ref_given;
210			{
211				let waker = cx.waker();
212				let stream = stream::poll_fn(|cx| {
213					if !*ref_given {
214						*ref_given = true;
215						return Poll::Ready(Some(unsafe { transmute(pending.as_ref()) }));
216					}
217					let waker_ = cx.waker();
218					if !waker.will_wake(waker_) {
219						waker_.wake_by_ref();
220					}
221					Poll::Pending
222				})
223				.fuse();
224				pin_mut!(stream);
225				match self
226					.pipe_ref
227					.as_mut()
228					.as_pin_mut()
229					.map(|pipe_ref| pipe_ref.poll_next(cx, stream))
230				{
231					Some(Poll::Ready(Some(item))) => {
232						return Some(Poll::Ready(Some(Sum2::B(unsafe { transmute(item) }))))
233					}
234					Some(Poll::Ready(None)) | None => {
235						self.pipe_ref.set(None);
236						*ref_given = true
237					}
238					Some(Poll::Pending) => (),
239				}
240			}
241			let mut item_given = false;
242			{
243				let waker = cx.waker();
244				let stream = stream::poll_fn(|cx| {
245					if *ref_given {
246						item_given = true;
247						return Poll::Ready(Some(pending.take().unwrap()));
248					}
249					let waker_ = cx.waker();
250					if !waker.will_wake(waker_) {
251						waker_.wake_by_ref();
252					}
253					Poll::Pending
254				})
255				.fuse();
256				pin_mut!(stream);
257				let res = self
258					.pipe
259					.as_mut()
260					.as_pin_mut()
261					.map(|pipe| pipe.poll_next(cx, stream));
262				if item_given || matches!(res, Some(Poll::Ready(None)) | None) {
263					*ref_given = false;
264					*self.pending = None;
265				}
266				match res {
267					Some(Poll::Ready(Some(item))) => {
268						return Some(Poll::Ready(Some(Sum2::A(item))));
269					}
270					Some(Poll::Ready(None)) | None => self.pipe.set(None),
271					Some(Poll::Pending) => (),
272				}
273			}
274			if self.pending.is_some() {
275				return Some(Poll::Pending);
276			}
277			None
278		} else {
279			let stream = stream::empty();
280			pin_mut!(stream);
281			match self
282				.pipe_ref
283				.as_mut()
284				.as_pin_mut()
285				.map(|pipe_ref| pipe_ref.poll_next(cx, stream))
286			{
287				Some(Poll::Ready(Some(item))) => {
288					return Some(Poll::Ready(Some(Sum2::B(unsafe { transmute(item) }))))
289				}
290				Some(Poll::Ready(None)) => self.pipe_ref.set(None),
291				Some(Poll::Pending) | None => (),
292			}
293			let stream = stream::empty();
294			pin_mut!(stream);
295			match self
296				.pipe
297				.as_mut()
298				.as_pin_mut()
299				.map(|pipe| pipe.poll_next(cx, stream))
300			{
301				Some(Poll::Ready(Some(item))) => return Some(Poll::Ready(Some(Sum2::A(item)))),
302				Some(Poll::Ready(None)) => self.pipe.set(None),
303				Some(Poll::Pending) | None => (),
304			}
305			Some(if self.pipe_ref.is_none() && self.pipe.is_none() {
306				Poll::Ready(None)
307			} else {
308				Poll::Pending
309			})
310		}
311	}
312}
313
314impl<A, B, C, RefAItem> Stream for JoinStreamTaskAsync<A, B, C, RefAItem, A::Item>
315where
316	A: Stream,
317	B: Pipe<A::Item>,
318	C: Pipe<RefAItem>,
319{
320	type Item = Sum2<B::Output, C::Output>;
321
322	fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
323		let mut self_ = self.project();
324		loop {
325			if self_.pending.is_none() {
326				*self_.pending = Some(ready!(self_.stream.as_mut().poll_next(cx)));
327			}
328			if let Some(ret) = self_.poll(cx) {
329				break ret;
330			}
331		}
332	}
333}
334
335impl<A, B, C, Input, RefAItem> Pipe<Input> for JoinStreamTaskAsync<A, B, C, RefAItem, A::Output>
336where
337	A: Pipe<Input>,
338	B: Pipe<A::Output>,
339	C: Pipe<RefAItem>,
340{
341	type Output = Sum2<B::Output, C::Output>;
342
343	fn poll_next(
344		self: Pin<&mut Self>, cx: &mut Context, mut stream: Pin<&mut impl Stream<Item = Input>>,
345	) -> Poll<Option<Self::Output>> {
346		let mut self_ = self.project();
347		loop {
348			if self_.pending.is_none() {
349				*self_.pending = Some(ready!(self_.stream.as_mut().poll_next(cx, stream.as_mut())));
350			}
351			if let Some(ret) = self_.poll(cx) {
352				break ret;
353			}
354		}
355	}
356}