amadeus_core/par_stream/
chain.rs

1use derive_new::new;
2use futures::Stream;
3use pin_project::pin_project;
4use serde::{Deserialize, Serialize};
5use std::{
6	pin::Pin, task::{Context, Poll}
7};
8
9use super::{ParallelStream, StreamTask};
10
11#[pin_project]
12#[derive(new)]
13#[must_use]
14pub struct Chain<A, B> {
15	#[pin]
16	a: A,
17	#[pin]
18	b: B,
19}
20
21impl_par_dist! {
22	impl<A: ParallelStream, B: ParallelStream<Item = A::Item>> ParallelStream for Chain<A, B> {
23		type Item = A::Item;
24		type Task = ChainTask<A::Task, B::Task>;
25
26		fn size_hint(&self) -> (usize, Option<usize>) {
27			let (a_lower, a_upper) = self.a.size_hint();
28			let (b_lower, b_upper) = self.b.size_hint();
29			(
30				a_lower + b_lower,
31				if let (Some(a), Some(b)) = (a_upper, b_upper) {
32					Some(a + b)
33				} else {
34					None
35				},
36			)
37		}
38		fn next_task(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Task>> {
39			let self_ = self.project();
40			match self_.a.next_task(cx) {
41				Poll::Ready(Some(a)) => Poll::Ready(Some(ChainTask::A(a))),
42				Poll::Ready(None) => self_.b.next_task(cx).map(|task| task.map(ChainTask::B)),
43				Poll::Pending => Poll::Pending,
44			}
45		}
46	}
47}
48
49#[pin_project(project = ChainTaskProj)]
50#[derive(Serialize, Deserialize)]
51pub enum ChainTask<A, B> {
52	A(#[pin] A),
53	B(#[pin] B),
54}
55impl<A: StreamTask, B: StreamTask<Item = A::Item>> StreamTask for ChainTask<A, B> {
56	type Item = A::Item;
57	type Async = ChainTask<A::Async, B::Async>;
58
59	fn into_async(self) -> Self::Async {
60		match self {
61			ChainTask::A(a) => ChainTask::A(a.into_async()),
62			ChainTask::B(b) => ChainTask::B(b.into_async()),
63		}
64	}
65}
66impl<A: Stream, B: Stream<Item = A::Item>> Stream for ChainTask<A, B> {
67	type Item = A::Item;
68
69	fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
70		match self.project() {
71			ChainTaskProj::A(a) => a.poll_next(cx),
72			ChainTaskProj::B(b) => b.poll_next(cx),
73		}
74	}
75}