par_stream/
tee.rs

1use crate::{common::*, config::BufSize, rt, utils};
2use dashmap::DashSet;
3use tokio::sync::Mutex;
4
5/// Stream for the [tee()](crate::par_stream::ParStreamExt::tee) method.
6///
7/// Cloning this stream allocates a new channel for the new receiver, so that
8/// future copies of stream items are forwarded to the channel.
9#[derive(Derivative)]
10#[derivative(Debug)]
11pub struct Tee<T>
12where
13    T: 'static,
14{
15    pub(super) buf_size: Option<usize>,
16    #[derivative(Debug = "ignore")]
17    pub(super) future: Arc<Mutex<Option<rt::JoinHandle<()>>>>,
18    pub(super) sender_set: Weak<DashSet<ByAddress<Arc<flume::Sender<T>>>>>,
19    #[derivative(Debug = "ignore")]
20    pub(super) stream: flume::r#async::RecvStream<'static, T>,
21}
22
23impl<T> Tee<T>
24where
25    T: Send + Clone,
26{
27    pub fn new<B, St>(stream: St, buf_size: B) -> Tee<T>
28    where
29        St: 'static + Send + Stream<Item = T>,
30        B: Into<BufSize>,
31    {
32        let buf_size = buf_size.into().get();
33        let (tx, rx) = utils::channel(buf_size);
34        let sender_set = Arc::new(DashSet::new());
35        sender_set.insert(ByAddress(Arc::new(tx)));
36
37        let future = {
38            let sender_set = sender_set.clone();
39            let mut stream = stream.boxed();
40
41            let future = rt::spawn(async move {
42                while let Some(item) = stream.next().await {
43                    let futures: Vec<_> = sender_set
44                        .iter()
45                        .map(|tx| {
46                            let tx = tx.clone();
47                            let item = item.clone();
48                            async move {
49                                let result = tx.send_async(item).await;
50                                (result, tx)
51                            }
52                        })
53                        .collect();
54
55                    let results = future::join_all(futures).await;
56                    let success_count = results
57                        .iter()
58                        .filter(|(result, tx)| {
59                            let ok = result.is_ok();
60                            if !ok {
61                                sender_set.remove(tx);
62                            }
63                            ok
64                        })
65                        .count();
66
67                    if success_count == 0 {
68                        break;
69                    }
70                }
71            });
72
73            Arc::new(Mutex::new(Some(future)))
74        };
75
76        Tee {
77            future,
78            sender_set: Arc::downgrade(&sender_set),
79            stream: rx.into_stream(),
80            buf_size,
81        }
82    }
83}
84
85impl<T> Clone for Tee<T>
86where
87    T: 'static + Send,
88{
89    fn clone(&self) -> Self {
90        let buf_size = self.buf_size;
91        let (tx, rx) = utils::channel(buf_size);
92        let sender_set = self.sender_set.clone();
93
94        if let Some(sender_set) = sender_set.upgrade() {
95            sender_set.insert(ByAddress(Arc::new(tx)));
96        }
97
98        Self {
99            future: self.future.clone(),
100            sender_set,
101            stream: rx.into_stream(),
102            buf_size,
103        }
104    }
105}
106
107impl<T> Stream for Tee<T> {
108    type Item = T;
109
110    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111        if let Ok(mut future_opt) = self.future.try_lock() {
112            if let Some(future) = &mut *future_opt {
113                if Pin::new(future).poll(cx).is_ready() {
114                    *future_opt = None;
115                }
116            }
117        }
118
119        match Pin::new(&mut self.stream).poll_next(cx) {
120            Ready(Some(output)) => {
121                cx.waker().clone().wake();
122                Ready(Some(output))
123            }
124            Ready(None) => Ready(None),
125            Pending => {
126                cx.waker().clone().wake();
127                Pending
128            }
129        }
130    }
131}