1#![allow(clippy::type_complexity)]
2
3use derive_new::new;
4use educe::Educe;
5use futures::{pin_mut, ready, stream, Stream, StreamExt};
6use indexmap::IndexMap;
7use pin_project::pin_project;
8use serde::{Deserialize, Serialize};
9use std::{
10 hash::Hash, marker::PhantomData, mem, pin::Pin, task::{Context, Poll}
11};
12use sum::Sum2;
13
14use super::{
15 DistributedPipe, DistributedSink, ParallelPipe, ParallelSink, PipeTask, Reducer, ReducerProcessSend, ReducerSend
16};
17use crate::{
18 pipe::{Pipe, Sink, StreamExt as _}, pool::ProcessSend
19};
20
21#[derive(new)]
22#[must_use]
23pub struct GroupBy<A, B> {
24 a: A,
25 b: B,
26}
27
28impl<A: ParallelPipe<Item, Output = (T, U)>, B: ParallelSink<U>, Item, T, U> ParallelSink<Item>
29 for GroupBy<A, B>
30where
31 T: Eq + Hash + Send + 'static,
32 <B::Pipe as ParallelPipe<U>>::Task: Clone + Send + 'static,
33 B::ReduceA: Clone + Send + 'static,
34 B::ReduceC: Clone,
35 B::Done: Send + 'static,
36{
37 type Done = IndexMap<T, B::Done>;
38 type Pipe = A;
39 type ReduceA = GroupByReducerA<<B::Pipe as ParallelPipe<U>>::Task, B::ReduceA, T, U>;
40 type ReduceC = GroupByReducerB<
41 B::ReduceC,
42 T,
43 <B::ReduceA as ReducerSend<<B::Pipe as ParallelPipe<U>>::Output>>::Done,
44 >;
45
46 fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceC) {
47 let (a, b, c) = self.b.reducers();
48 (
49 self.a,
50 GroupByReducerA::new(a.task(), b),
51 GroupByReducerB::new(c),
52 )
53 }
54}
55
56impl<A: DistributedPipe<Item, Output = (T, U)>, B: DistributedSink<U>, Item, T, U>
57 DistributedSink<Item> for GroupBy<A, B>
58where
59 T: Eq + Hash + ProcessSend + 'static,
60 <B::Pipe as DistributedPipe<U>>::Task: Clone + ProcessSend + 'static,
61 B::ReduceA: Clone + ProcessSend + 'static,
62 B::ReduceB: Clone,
63 B::ReduceC: Clone,
64 B::Done: ProcessSend + 'static,
65{
66 type Done = IndexMap<T, B::Done>;
67 type Pipe = A;
68 type ReduceA = GroupByReducerA<<B::Pipe as DistributedPipe<U>>::Task, B::ReduceA, T, U>;
69 type ReduceB = GroupByReducerB<
70 B::ReduceB,
71 T,
72 <B::ReduceA as ReducerSend<<B::Pipe as DistributedPipe<U>>::Output>>::Done,
73 >;
74 type ReduceC = GroupByReducerB<
75 B::ReduceC,
76 T,
77 <B::ReduceB as ReducerProcessSend<
78 <B::ReduceA as Reducer<<B::Pipe as DistributedPipe<U>>::Output>>::Done,
79 >>::Done,
80 >;
81
82 fn reducers(self) -> (Self::Pipe, Self::ReduceA, Self::ReduceB, Self::ReduceC) {
83 let (a, b, c, d) = self.b.reducers();
84 (
85 self.a,
86 GroupByReducerA::new(a.task(), b),
87 GroupByReducerB::new(c),
88 GroupByReducerB::new(d),
89 )
90 }
91}
92
93#[derive(Educe, Serialize, Deserialize, new)]
94#[educe(Clone(bound = "P: Clone, R: Clone"))]
95#[serde(
96 bound(serialize = "P: Serialize, R: Serialize"),
97 bound(deserialize = "P: Deserialize<'de>, R: Deserialize<'de>")
98)]
99pub struct GroupByReducerA<P, R, T, U>(P, R, PhantomData<fn() -> (T, U)>);
100
101impl<P, R, T, U> Reducer<(T, U)> for GroupByReducerA<P, R, T, U>
102where
103 P: PipeTask<U>,
104 R: Reducer<P::Output> + Clone,
105 T: Eq + Hash,
106{
107 type Done = IndexMap<T, R::Done>;
108 type Async = GroupByReducerAAsync<P::Async, R, T, U>;
109
110 fn into_async(self) -> Self::Async {
111 GroupByReducerAAsync::new(self.0.into_async(), self.1)
112 }
113}
114impl<P, R, T, U> ReducerProcessSend<(T, U)> for GroupByReducerA<P, R, T, U>
115where
116 P: PipeTask<U>,
117 R: Reducer<P::Output> + Clone,
118 T: Eq + Hash + ProcessSend + 'static,
119 R::Done: ProcessSend + 'static,
120{
121 type Done = IndexMap<T, R::Done>;
122}
123impl<P, R, T, U> ReducerSend<(T, U)> for GroupByReducerA<P, R, T, U>
124where
125 P: PipeTask<U>,
126 R: Reducer<P::Output> + Clone,
127 T: Eq + Hash + Send + 'static,
128 R::Done: Send + 'static,
129{
130 type Done = IndexMap<T, R::Done>;
131}
132
133#[pin_project]
134#[derive(new)]
135pub struct GroupByReducerAAsync<P, R, T, U>
136where
137 P: Pipe<U>,
138 R: Reducer<P::Output>,
139{
140 #[pin]
141 pipe: P,
142 factory: R,
143 #[new(default)]
144 pending: Option<Sum2<(T, Option<U>, Option<Pin<Box<R::Async>>>), Vec<Option<R::Done>>>>,
145 #[new(default)]
146 map: IndexMap<T, Pin<Box<R::Async>>>,
147}
148
149impl<P, R, T, U> Sink<(T, U)> for GroupByReducerAAsync<P, R, T, U>
150where
151 P: Pipe<U>,
152 R: Reducer<P::Output> + Clone,
153 T: Eq + Hash,
154{
155 type Done = IndexMap<T, R::Done>;
156
157 #[inline(always)]
158 fn poll_forward(
159 self: Pin<&mut Self>, cx: &mut Context, mut stream: Pin<&mut impl Stream<Item = (T, U)>>,
160 ) -> Poll<Self::Done> {
161 let mut self_ = self.project();
162 loop {
163 if !self_.pending.is_some() {
164 *self_.pending = Some(
165 ready!(stream.as_mut().poll_next(cx))
166 .map(|(k, u)| {
167 let r = if !self_.map.contains_key(&k) {
168 Some(Box::pin(self_.factory.clone().into_async()))
169 } else {
170 None
171 };
172 (k, Some(u), r)
173 })
174 .map_or_else(
175 || Sum2::B((0..self_.map.len()).map(|_| None).collect()),
176 Sum2::A,
177 ),
178 );
179 }
180 match self_.pending.as_mut().unwrap() {
181 Sum2::A((k, u, r)) => {
182 let waker = cx.waker();
183 let stream = stream::poll_fn(|cx| {
184 u.take().map_or_else(
185 || {
186 let waker_ = cx.waker();
187 if !waker.will_wake(waker_) {
188 waker_.wake_by_ref();
189 }
190 Poll::Pending
191 },
192 |u| Poll::Ready(Some(u)),
193 )
194 })
195 .fuse()
196 .pipe(self_.pipe.as_mut());
197 pin_mut!(stream);
198 let map = &mut *self_.map;
199 let r_ = r.as_mut().unwrap_or_else(|| map.get_mut(k).unwrap());
200 if r_.as_mut().poll_forward(cx, stream).is_ready() {
201 let _ = u.take();
202 }
203 if u.is_some() {
204 return Poll::Pending;
205 }
206 let (k, _u, r) = self_.pending.take().unwrap().a().unwrap();
207 if let Some(r) = r {
208 let _ = self_.map.insert(k, r);
209 }
210 }
211 Sum2::B(done) => {
212 let mut done_ = true;
213 self_
214 .map
215 .values_mut()
216 .zip(done.iter_mut())
217 .for_each(|(r, done)| {
218 if done.is_none() {
219 let stream = stream::empty();
220 pin_mut!(stream);
221 if let Poll::Ready(done_) = r.as_mut().poll_forward(cx, stream) {
222 *done = Some(done_);
223 } else {
224 done_ = false;
225 }
226 }
227 });
228 if !done_ {
229 return Poll::Pending;
230 }
231 let ret = mem::take(self_.map)
232 .into_iter()
233 .zip(done.iter_mut())
234 .map(|((k, _), v)| (k, v.take().unwrap()))
235 .collect();
236 return Poll::Ready(ret);
237 }
238 }
239 }
240 }
241}
242
243#[derive(Educe, Serialize, Deserialize, new)]
244#[educe(Clone(bound = "R: Clone"))]
245#[serde(
246 bound(serialize = "R: Serialize"),
247 bound(deserialize = "R: Deserialize<'de>")
248)]
249pub struct GroupByReducerB<R, T, U>(R, PhantomData<fn() -> (T, U)>);
250
251impl<R, T, U> Reducer<IndexMap<T, U>> for GroupByReducerB<R, T, U>
252where
253 R: Reducer<U> + Clone,
254 T: Eq + Hash,
255{
256 type Done = IndexMap<T, R::Done>;
257 type Async = GroupByReducerBAsync<R, T, U>;
258
259 fn into_async(self) -> Self::Async {
260 GroupByReducerBAsync::new(self.0)
261 }
262}
263impl<R, T, U> ReducerProcessSend<IndexMap<T, U>> for GroupByReducerB<R, T, U>
264where
265 R: Reducer<U> + Clone,
266 T: Eq + Hash + ProcessSend + 'static,
267 R::Done: ProcessSend + 'static,
268{
269 type Done = IndexMap<T, R::Done>;
270}
271impl<R, T, U> ReducerSend<IndexMap<T, U>> for GroupByReducerB<R, T, U>
272where
273 R: Reducer<U> + Clone,
274 T: Eq + Hash + Send + 'static,
275 R::Done: Send + 'static,
276{
277 type Done = IndexMap<T, R::Done>;
278}
279
280#[pin_project]
281#[derive(new)]
282pub struct GroupByReducerBAsync<R, T, U>
283where
284 R: Reducer<U>,
285{
286 f: R,
287 #[new(default)]
288 pending: Option<Sum2<IndexMap<T, (U, Option<Pin<Box<R::Async>>>)>, Vec<Option<R::Done>>>>,
289 #[new(default)]
290 map: IndexMap<T, Pin<Box<R::Async>>>,
291}
292
293impl<R, T, U> Sink<IndexMap<T, U>> for GroupByReducerBAsync<R, T, U>
294where
295 R: Reducer<U> + Clone,
296 T: Eq + Hash,
297{
298 type Done = IndexMap<T, R::Done>;
299
300 #[inline(always)]
301 fn poll_forward(
302 self: Pin<&mut Self>, cx: &mut Context,
303 mut stream: Pin<&mut impl Stream<Item = IndexMap<T, U>>>,
304 ) -> Poll<Self::Done> {
305 let self_ = self.project();
306 loop {
307 if self_.pending.is_none() {
308 *self_.pending = Some(
309 ready!(stream.as_mut().poll_next(cx))
310 .map(|item| {
311 item.into_iter()
312 .map(|(k, v)| {
313 let r = if !self_.map.contains_key(&k) {
314 Some(Box::pin(self_.f.clone().into_async()))
315 } else {
316 None
317 };
318 (k, (v, r))
319 })
320 .collect()
321 })
322 .map_or_else(
323 || Sum2::B((0..self_.map.len()).map(|_| None).collect()),
324 Sum2::A,
325 ),
326 );
327 }
328 match self_.pending.as_mut().unwrap() {
329 Sum2::A(pending) => {
330 while let Some((k, (v, mut r))) = pending.pop() {
331 let mut v = Some(v);
332 let waker = cx.waker();
333 let stream = stream::poll_fn(|cx| {
334 v.take().map_or_else(
335 || {
336 let waker_ = cx.waker();
337 if !waker.will_wake(waker_) {
338 waker_.wake_by_ref();
339 }
340 Poll::Pending
341 },
342 |v| Poll::Ready(Some(v)),
343 )
344 })
345 .fuse();
346 pin_mut!(stream);
347 let map = &mut *self_.map;
348 let r_ = r.as_mut().unwrap_or_else(|| map.get_mut(&k).unwrap());
349 if r_.as_mut().poll_forward(cx, stream).is_ready() {
350 let _ = v.take();
351 }
352 if let Some(v) = v {
353 let _ = pending.insert(k, (v, r));
354 return Poll::Pending;
355 }
356 if let Some(r) = r {
357 let _ = self_.map.insert(k, r);
358 }
359 }
360 *self_.pending = None;
361 }
362 Sum2::B(done) => {
363 let mut done_ = true;
364 self_
365 .map
366 .values_mut()
367 .zip(done.iter_mut())
368 .for_each(|(r, done)| {
369 if done.is_none() {
370 let stream = stream::empty();
371 pin_mut!(stream);
372 if let Poll::Ready(done_) = r.as_mut().poll_forward(cx, stream) {
373 *done = Some(done_);
374 } else {
375 done_ = false;
376 }
377 }
378 });
379 if !done_ {
380 return Poll::Pending;
381 }
382 let ret = mem::take(self_.map)
383 .into_iter()
384 .zip(done.iter_mut())
385 .map(|((k, _), v)| (k, v.take().unwrap()))
386 .collect();
387 return Poll::Ready(ret);
388 }
389 }
390 }
391 }
392}