marigold_impl/
multi_consumer_stream.rs

1use core::marker::PhantomData;
2use core::pin::Pin;
3use futures::channel::mpsc::Receiver;
4use futures::channel::mpsc::Sender;
5use futures::future::Future;
6use futures::sink::SinkExt;
7use futures::stream::FuturesUnordered;
8use futures::stream::Stream;
9use futures::stream::StreamExt;
10use futures::task::Context;
11use futures::task::Poll;
12
13const BUFFER_SIZE: usize = 1;
14
15pub struct MultiConsumerStream<
16    T: std::marker::Send + 'static,
17    S: Stream<Item = T> + std::marker::Unpin + std::marker::Send + 'static,
18> {
19    inner_stream: S,
20    senders: Vec<Sender<T>>,
21}
22
23impl<
24        T: std::marker::Send + Copy + 'static,
25        S: Stream<Item = T> + std::marker::Unpin + std::marker::Send + 'static,
26    > MultiConsumerStream<T, S>
27{
28    pub fn new(s: S) -> Self {
29        MultiConsumerStream {
30            inner_stream: s,
31            senders: Vec::new(),
32        }
33    }
34
35    pub fn get(&mut self) -> Receiver<T> {
36        let (sender, receiver) = futures::channel::mpsc::channel(BUFFER_SIZE);
37        self.senders.push(sender);
38        receiver
39    }
40
41    pub async fn run(mut self) {
42        self.senders.shrink_to_fit();
43
44        #[cfg(any(feature = "async-std", feature = "tokio"))]
45        crate::async_runtime::spawn(async move {
46            while let Some(v) = self.inner_stream.next().await {
47                let mut futures = self
48                    .senders
49                    .iter_mut()
50                    .map(|sender| sender.feed(v))
51                    .collect::<FuturesUnordered<_>>();
52                while let Some(_result) = futures.next().await {}
53            }
54            self.senders.iter_mut().for_each(|s| s.disconnect());
55        });
56
57        #[cfg(not(any(feature = "async-std", feature = "tokio")))]
58        {
59            while let Some(v) = self.inner_stream.next().await {
60                let mut futures = self
61                    .senders
62                    .iter_mut()
63                    .map(|sender| sender.feed(v))
64                    .collect::<FuturesUnordered<_>>();
65                while let Some(_result) = futures.next().await {}
66            }
67            self.senders.iter_mut().for_each(|s| s.disconnect());
68        }
69    }
70}
71
72pub struct RunFutureAsStream<T: Unpin, O, F: Future<Output = O>> {
73    future: Pin<Box<F>>,
74    t: PhantomData<T>,
75}
76
77impl<T: Unpin, O, F: Future<Output = O>> RunFutureAsStream<T, O, F> {
78    pub fn new(f: Pin<Box<F>>) -> RunFutureAsStream<T, O, F> {
79        RunFutureAsStream {
80            future: f,
81            t: PhantomData,
82        }
83    }
84}
85
86impl<T: std::marker::Send + Unpin + 'static, O, F: Future<Output = O>> Stream
87    for RunFutureAsStream<T, O, F>
88{
89    type Item = T;
90
91    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        let future = &mut self.future;
93        match Pin::new(future).poll(cx) {
94            Poll::Pending => Poll::Pending,
95            Poll::Ready(_) => Poll::Ready(None),
96        }
97    }
98
99    fn size_hint(&self) -> (usize, Option<usize>) {
100        (0, None)
101    }
102}