amadeus_core/par_sink/
group_by.rs

1#![allow(clippy::type_complexity)]
2
3use derive_new::new;
4use educe::Educe;
5use futures::{pin_mut, ready, stream, Stream, StreamExt};
6use indexmap::IndexMap;
7use pin_project::pin_project;
8use serde::{Deserialize, Serialize};
9use std::{
10	hash::Hash, marker::PhantomData, mem, pin::Pin, task::{Context, Poll}
11};
12use sum::Sum2;
13
14use super::{
15	DistributedPipe, DistributedSink, ParallelPipe, ParallelSink, PipeTask, Reducer, ReducerProcessSend, ReducerSend
16};
17use crate::{
18	pipe::{Pipe, Sink, StreamExt as _}, pool::ProcessSend
19};
20
21#[derive(new)]
22#[must_use]
23pub struct GroupBy<A, B> {
24	a: A,
25	b: B,
26}
27
28impl<A: ParallelPipe<Item, Output = (T, U)>, B: ParallelSink<U>, Item, T, U> ParallelSink<Item>
29	for GroupBy<A, B>
30where
31	T: Eq + Hash + Send + 'static,
32	<B::Pipe as ParallelPipe<U>>::Task: Clone + Send + 'static,
33	B::ReduceA: Clone + Send + 'static,
34	B::ReduceC: Clone,
35	B::Done: Send + 'static,
36{
37	type Done = IndexMap<T, B::Done>;
38	type Pipe = A;
39	type ReduceA = GroupByReducerA<<B::Pipe as ParallelPipe<U>>::Task, B::ReduceA, T, U>;
40	type ReduceC = GroupByReducerB<
41		B::ReduceC,
42		T,
43		<B::ReduceA as ReducerSend<<B::Pipe as ParallelPipe<U>>::Output>>::Done,
44	>;
45
46	fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceC) {
47		let (a, b, c) = self.b.reducers();
48		(
49			self.a,
50			GroupByReducerA::new(a.task(), b),
51			GroupByReducerB::new(c),
52		)
53	}
54}
55
56impl<A: DistributedPipe<Item, Output = (T, U)>, B: DistributedSink<U>, Item, T, U>
57	DistributedSink<Item> for GroupBy<A, B>
58where
59	T: Eq + Hash + ProcessSend + 'static,
60	<B::Pipe as DistributedPipe<U>>::Task: Clone + ProcessSend + 'static,
61	B::ReduceA: Clone + ProcessSend + 'static,
62	B::ReduceB: Clone,
63	B::ReduceC: Clone,
64	B::Done: ProcessSend + 'static,
65{
66	type Done = IndexMap<T, B::Done>;
67	type Pipe = A;
68	type ReduceA = GroupByReducerA<<B::Pipe as DistributedPipe<U>>::Task, B::ReduceA, T, U>;
69	type ReduceB = GroupByReducerB<
70		B::ReduceB,
71		T,
72		<B::ReduceA as ReducerSend<<B::Pipe as DistributedPipe<U>>::Output>>::Done,
73	>;
74	type ReduceC = GroupByReducerB<
75		B::ReduceC,
76		T,
77		<B::ReduceB as ReducerProcessSend<
78			<B::ReduceA as Reducer<<B::Pipe as DistributedPipe<U>>::Output>>::Done,
79		>>::Done,
80	>;
81
82	fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceB, Self::ReduceC) {
83		let (a, b, c, d) = self.b.reducers();
84		(
85			self.a,
86			GroupByReducerA::new(a.task(), b),
87			GroupByReducerB::new(c),
88			GroupByReducerB::new(d),
89		)
90	}
91}
92
93#[derive(Educe, Serialize, Deserialize, new)]
94#[educe(Clone(bound = "P: Clone, R: Clone"))]
95#[serde(
96	bound(serialize = "P: Serialize, R: Serialize"),
97	bound(deserialize = "P: Deserialize<'de>, R: Deserialize<'de>")
98)]
99pub struct GroupByReducerA<P, R, T, U>(P, R, PhantomData<fn() -> (T, U)>);
100
101impl<P, R, T, U> Reducer<(T, U)> for GroupByReducerA<P, R, T, U>
102where
103	P: PipeTask<U>,
104	R: Reducer<P::Output> + Clone,
105	T: Eq + Hash,
106{
107	type Done = IndexMap<T, R::Done>;
108	type Async = GroupByReducerAAsync<P::Async, R, T, U>;
109
110	fn into_async(self) -> Self::Async {
111		GroupByReducerAAsync::new(self.0.into_async(), self.1)
112	}
113}
114impl<P, R, T, U> ReducerProcessSend<(T, U)> for GroupByReducerA<P, R, T, U>
115where
116	P: PipeTask<U>,
117	R: Reducer<P::Output> + Clone,
118	T: Eq + Hash + ProcessSend + 'static,
119	R::Done: ProcessSend + 'static,
120{
121	type Done = IndexMap<T, R::Done>;
122}
123impl<P, R, T, U> ReducerSend<(T, U)> for GroupByReducerA<P, R, T, U>
124where
125	P: PipeTask<U>,
126	R: Reducer<P::Output> + Clone,
127	T: Eq + Hash + Send + 'static,
128	R::Done: Send + 'static,
129{
130	type Done = IndexMap<T, R::Done>;
131}
132
133#[pin_project]
134#[derive(new)]
135pub struct GroupByReducerAAsync<P, R, T, U>
136where
137	P: Pipe<U>,
138	R: Reducer<P::Output>,
139{
140	#[pin]
141	pipe: P,
142	factory: R,
143	#[new(default)]
144	pending: Option<Sum2<(T, Option<U>, Option<Pin<Box<R::Async>>>), Vec<Option<R::Done>>>>,
145	#[new(default)]
146	map: IndexMap<T, Pin<Box<R::Async>>>,
147}
148
149impl<P, R, T, U> Sink<(T, U)> for GroupByReducerAAsync<P, R, T, U>
150where
151	P: Pipe<U>,
152	R: Reducer<P::Output> + Clone,
153	T: Eq + Hash,
154{
155	type Done = IndexMap<T, R::Done>;
156
157	#[inline(always)]
158	fn poll_forward(
159		self: Pin<&mut Self>, cx: &mut Context, mut stream: Pin<&mut impl Stream<Item = (T, U)>>,
160	) -> Poll<Self::Done> {
161		let mut self_ = self.project();
162		loop {
163			if !self_.pending.is_some() {
164				*self_.pending = Some(
165					ready!(stream.as_mut().poll_next(cx))
166						.map(|(k, u)| {
167							let r = if !self_.map.contains_key(&k) {
168								Some(Box::pin(self_.factory.clone().into_async()))
169							} else {
170								None
171							};
172							(k, Some(u), r)
173						})
174						.map_or_else(
175							|| Sum2::B((0..self_.map.len()).map(|_| None).collect()),
176							Sum2::A,
177						),
178				);
179			}
180			match self_.pending.as_mut().unwrap() {
181				Sum2::A((k, u, r)) => {
182					let waker = cx.waker();
183					let stream = stream::poll_fn(|cx| {
184						u.take().map_or_else(
185							|| {
186								let waker_ = cx.waker();
187								if !waker.will_wake(waker_) {
188									waker_.wake_by_ref();
189								}
190								Poll::Pending
191							},
192							|u| Poll::Ready(Some(u)),
193						)
194					})
195					.fuse()
196					.pipe(self_.pipe.as_mut());
197					pin_mut!(stream);
198					let map = &mut *self_.map;
199					let r_ = r.as_mut().unwrap_or_else(|| map.get_mut(k).unwrap());
200					if r_.as_mut().poll_forward(cx, stream).is_ready() {
201						let _ = u.take();
202					}
203					if u.is_some() {
204						return Poll::Pending;
205					}
206					let (k, _u, r) = self_.pending.take().unwrap().a().unwrap();
207					if let Some(r) = r {
208						let _ = self_.map.insert(k, r);
209					}
210				}
211				Sum2::B(done) => {
212					let mut done_ = true;
213					self_
214						.map
215						.values_mut()
216						.zip(done.iter_mut())
217						.for_each(|(r, done)| {
218							if done.is_none() {
219								let stream = stream::empty();
220								pin_mut!(stream);
221								if let Poll::Ready(done_) = r.as_mut().poll_forward(cx, stream) {
222									*done = Some(done_);
223								} else {
224									done_ = false;
225								}
226							}
227						});
228					if !done_ {
229						return Poll::Pending;
230					}
231					let ret = mem::take(self_.map)
232						.into_iter()
233						.zip(done.iter_mut())
234						.map(|((k, _), v)| (k, v.take().unwrap()))
235						.collect();
236					return Poll::Ready(ret);
237				}
238			}
239		}
240	}
241}
242
243#[derive(Educe, Serialize, Deserialize, new)]
244#[educe(Clone(bound = "R: Clone"))]
245#[serde(
246	bound(serialize = "R: Serialize"),
247	bound(deserialize = "R: Deserialize<'de>")
248)]
249pub struct GroupByReducerB<R, T, U>(R, PhantomData<fn() -> (T, U)>);
250
251impl<R, T, U> Reducer<IndexMap<T, U>> for GroupByReducerB<R, T, U>
252where
253	R: Reducer<U> + Clone,
254	T: Eq + Hash,
255{
256	type Done = IndexMap<T, R::Done>;
257	type Async = GroupByReducerBAsync<R, T, U>;
258
259	fn into_async(self) -> Self::Async {
260		GroupByReducerBAsync::new(self.0)
261	}
262}
263impl<R, T, U> ReducerProcessSend<IndexMap<T, U>> for GroupByReducerB<R, T, U>
264where
265	R: Reducer<U> + Clone,
266	T: Eq + Hash + ProcessSend + 'static,
267	R::Done: ProcessSend + 'static,
268{
269	type Done = IndexMap<T, R::Done>;
270}
271impl<R, T, U> ReducerSend<IndexMap<T, U>> for GroupByReducerB<R, T, U>
272where
273	R: Reducer<U> + Clone,
274	T: Eq + Hash + Send + 'static,
275	R::Done: Send + 'static,
276{
277	type Done = IndexMap<T, R::Done>;
278}
279
280#[pin_project]
281#[derive(new)]
282pub struct GroupByReducerBAsync<R, T, U>
283where
284	R: Reducer<U>,
285{
286	f: R,
287	#[new(default)]
288	pending: Option<Sum2<IndexMap<T, (U, Option<Pin<Box<R::Async>>>)>, Vec<Option<R::Done>>>>,
289	#[new(default)]
290	map: IndexMap<T, Pin<Box<R::Async>>>,
291}
292
293impl<R, T, U> Sink<IndexMap<T, U>> for GroupByReducerBAsync<R, T, U>
294where
295	R: Reducer<U> + Clone,
296	T: Eq + Hash,
297{
298	type Done = IndexMap<T, R::Done>;
299
300	#[inline(always)]
301	fn poll_forward(
302		self: Pin<&mut Self>, cx: &mut Context,
303		mut stream: Pin<&mut impl Stream<Item = IndexMap<T, U>>>,
304	) -> Poll<Self::Done> {
305		let self_ = self.project();
306		loop {
307			if self_.pending.is_none() {
308				*self_.pending = Some(
309					ready!(stream.as_mut().poll_next(cx))
310						.map(|item| {
311							item.into_iter()
312								.map(|(k, v)| {
313									let r = if !self_.map.contains_key(&k) {
314										Some(Box::pin(self_.f.clone().into_async()))
315									} else {
316										None
317									};
318									(k, (v, r))
319								})
320								.collect()
321						})
322						.map_or_else(
323							|| Sum2::B((0..self_.map.len()).map(|_| None).collect()),
324							Sum2::A,
325						),
326				);
327			}
328			match self_.pending.as_mut().unwrap() {
329				Sum2::A(pending) => {
330					while let Some((k, (v, mut r))) = pending.pop() {
331						let mut v = Some(v);
332						let waker = cx.waker();
333						let stream = stream::poll_fn(|cx| {
334							v.take().map_or_else(
335								|| {
336									let waker_ = cx.waker();
337									if !waker.will_wake(waker_) {
338										waker_.wake_by_ref();
339									}
340									Poll::Pending
341								},
342								|v| Poll::Ready(Some(v)),
343							)
344						})
345						.fuse();
346						pin_mut!(stream);
347						let map = &mut *self_.map;
348						let r_ = r.as_mut().unwrap_or_else(|| map.get_mut(&k).unwrap());
349						if r_.as_mut().poll_forward(cx, stream).is_ready() {
350							let _ = v.take();
351						}
352						if let Some(v) = v {
353							let _ = pending.insert(k, (v, r));
354							return Poll::Pending;
355						}
356						if let Some(r) = r {
357							let _ = self_.map.insert(k, r);
358						}
359					}
360					*self_.pending = None;
361				}
362				Sum2::B(done) => {
363					let mut done_ = true;
364					self_
365						.map
366						.values_mut()
367						.zip(done.iter_mut())
368						.for_each(|(r, done)| {
369							if done.is_none() {
370								let stream = stream::empty();
371								pin_mut!(stream);
372								if let Poll::Ready(done_) = r.as_mut().poll_forward(cx, stream) {
373									*done = Some(done_);
374								} else {
375									done_ = false;
376								}
377							}
378						});
379					if !done_ {
380						return Poll::Pending;
381					}
382					let ret = mem::take(self_.map)
383						.into_iter()
384						.zip(done.iter_mut())
385						.map(|((k, _), v)| (k, v.take().unwrap()))
386						.collect();
387					return Poll::Ready(ret);
388				}
389			}
390		}
391	}
392}