1use crate::{common::*, config::BufSize, rt, utils};
2use dashmap::DashSet;
3use tokio::sync::Mutex;
4
5#[derive(Derivative)]
10#[derivative(Debug)]
11pub struct Tee<T>
12where
13 T: 'static,
14{
15 pub(super) buf_size: Option<usize>,
16 #[derivative(Debug = "ignore")]
17 pub(super) future: Arc<Mutex<Option<rt::JoinHandle<()>>>>,
18 pub(super) sender_set: Weak<DashSet<ByAddress<Arc<flume::Sender<T>>>>>,
19 #[derivative(Debug = "ignore")]
20 pub(super) stream: flume::r#async::RecvStream<'static, T>,
21}
22
23impl<T> Tee<T>
24where
25 T: Send + Clone,
26{
27 pub fn new<B, St>(stream: St, buf_size: B) -> Tee<T>
28 where
29 St: 'static + Send + Stream<Item = T>,
30 B: Into<BufSize>,
31 {
32 let buf_size = buf_size.into().get();
33 let (tx, rx) = utils::channel(buf_size);
34 let sender_set = Arc::new(DashSet::new());
35 sender_set.insert(ByAddress(Arc::new(tx)));
36
37 let future = {
38 let sender_set = sender_set.clone();
39 let mut stream = stream.boxed();
40
41 let future = rt::spawn(async move {
42 while let Some(item) = stream.next().await {
43 let futures: Vec<_> = sender_set
44 .iter()
45 .map(|tx| {
46 let tx = tx.clone();
47 let item = item.clone();
48 async move {
49 let result = tx.send_async(item).await;
50 (result, tx)
51 }
52 })
53 .collect();
54
55 let results = future::join_all(futures).await;
56 let success_count = results
57 .iter()
58 .filter(|(result, tx)| {
59 let ok = result.is_ok();
60 if !ok {
61 sender_set.remove(tx);
62 }
63 ok
64 })
65 .count();
66
67 if success_count == 0 {
68 break;
69 }
70 }
71 });
72
73 Arc::new(Mutex::new(Some(future)))
74 };
75
76 Tee {
77 future,
78 sender_set: Arc::downgrade(&sender_set),
79 stream: rx.into_stream(),
80 buf_size,
81 }
82 }
83}
84
85impl<T> Clone for Tee<T>
86where
87 T: 'static + Send,
88{
89 fn clone(&self) -> Self {
90 let buf_size = self.buf_size;
91 let (tx, rx) = utils::channel(buf_size);
92 let sender_set = self.sender_set.clone();
93
94 if let Some(sender_set) = sender_set.upgrade() {
95 sender_set.insert(ByAddress(Arc::new(tx)));
96 }
97
98 Self {
99 future: self.future.clone(),
100 sender_set,
101 stream: rx.into_stream(),
102 buf_size,
103 }
104 }
105}
106
107impl<T> Stream for Tee<T> {
108 type Item = T;
109
110 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
111 if let Ok(mut future_opt) = self.future.try_lock() {
112 if let Some(future) = &mut *future_opt {
113 if Pin::new(future).poll(cx).is_ready() {
114 *future_opt = None;
115 }
116 }
117 }
118
119 match Pin::new(&mut self.stream).poll_next(cx) {
120 Ready(Some(output)) => {
121 cx.waker().clone().wake();
122 Ready(Some(output))
123 }
124 Ready(None) => Ready(None),
125 Pending => {
126 cx.waker().clone().wake();
127 Pending
128 }
129 }
130 }
131}