1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3
4use futures_core::{future::*, stream::*, task::Poll};
27use futures_sink::Sink;
28use futures_util::sink::SinkExt;
29use futures_util::stream::StreamExt;
30use futures_util::future::try_join_all;
31use slab::Slab;
32use std::fmt::{self, Debug};
33use std::sync::Arc;
34
35#[cfg(not(feature = "default-channels"))]
36use std::sync::RwLock;
37
38#[cfg(feature = "default-channels")]
39use parking_lot::RwLock;
40
41#[cfg(feature = "default-channels")]
42use futures_channel::mpsc::*;
43use futures_util::task::Context;
44use std::pin::Pin;
45
46pub struct BroadcastChannel<
49 T,
50 #[cfg(feature = "default-channels")] S = UnboundedSender<T>,
51 #[cfg(feature = "default-channels")] R = UnboundedReceiver<T>,
52 #[cfg(not(feature = "default-channels"))] S,
53 #[cfg(not(feature = "default-channels"))] R,
54> where
55 T: Send + Clone + 'static,
56 S: Send + Sync + Unpin + Clone + Sink<T>,
57 R: Unpin + Stream<Item = T>,
58{
59 senders: Arc<RwLock<Slab<S>>>,
60 sender_key: usize,
61 receiver: R,
62 ctor: Arc<dyn Fn() -> (S, R) + Send + Sync>,
63}
64
65#[cfg(feature = "default-channels")]
66impl<T: Send + Clone> BroadcastChannel<T> {
67 pub fn new() -> Self {
69 let (tx, rx) = unbounded();
70 let mut slab = Slab::new();
71 let sender_key = slab.insert(tx);
72 Self {
73 senders: Arc::new(RwLock::new(slab)),
74 sender_key,
75 receiver: rx,
76 ctor: Arc::new(unbounded),
77 }
78 }
79}
80
81#[cfg(feature = "default-channels")]
82impl<T: Send + Clone> BroadcastChannel<T, Sender<T>, Receiver<T>> {
83 pub fn with_cap(cap: usize) -> Self {
85 let (tx, rx) = channel(cap);
86 let mut slab = Slab::new();
87 let sender_key = slab.insert(tx);
88 Self {
89 senders: Arc::new(RwLock::new(slab)),
90 sender_key,
91 receiver: rx,
92 ctor: Arc::new(move || channel(cap)),
93 }
94 }
95
96 pub fn try_send(&self, item: &T) -> Result<(), TrySendError<T>> {
98 #[cfg(feature = "parking-lot")]
99 let mut senders: Slab<Sender<T>> = Slab::clone(&*self.senders.read());
100
101 #[cfg(not(feature = "parking-lot"))]
102 let mut senders: Slab<Sender<T>> = Slab::clone(&*self.senders.read().unwrap());
103
104 senders
105 .iter_mut()
106 .map(|(_, s)| s.try_send(item.clone()))
107 .collect()
108 }
109}
110
111impl<T, S, R> BroadcastChannel<T, S, R>
112where
113 T: Send + Clone + 'static,
114 S: Send + Sync + Unpin + Clone + Sink<T>,
115 R: Unpin + Stream<Item = T>,
116{
117 pub fn with_ctor(ctor: Arc<dyn Fn() -> (S, R) + Send + Sync>) -> Self {
120 let (tx, rx) = ctor();
121 let mut slab = Slab::new();
122 let sender_key = slab.insert(tx);
123 Self {
124 senders: Arc::new(RwLock::new(slab)),
125 sender_key,
126 receiver: rx,
127 ctor,
128 }
129 }
130
131 pub async fn send(&self, item: &T) -> Result<(), S::Error> {
135 let mut senders = self.senders();
136 try_join_all(senders.iter_mut().map(|(_, s)| s.send(item.clone()))).await?;
137 Ok(())
138 }
139
140 pub fn recv(&mut self) -> impl Future<Output = Option<T>> + '_ {
142 self.next()
143 }
144
145 fn senders(&self) -> Slab<S> {
147 #[cfg(feature = "parking-lot")]
149 let senders: Slab<S> = Slab::clone(&*self.senders.read());
150
151 #[cfg(not(feature = "parking-lot"))]
152 let senders: Slab<S> = Slab::clone(&*self.senders.read().unwrap());
153
154 senders
155 }
156}
157
158impl<T, S, R> Clone for BroadcastChannel<T, S, R>
159where
160 T: Send + Clone + 'static,
161 S: Send + Sync + Unpin + Clone + Sink<T>,
162 R: Unpin + Stream<Item = T>,
163{
164 fn clone(&self) -> Self {
165 let (tx, rx) = (self.ctor)();
166 #[cfg(feature = "parking-lot")]
167 let sender_key = self.senders.write().insert(tx);
168
169 #[cfg(not(feature = "parking-lot"))]
170 let sender_key = self.senders.write().unwrap().insert(tx);
171
172 Self {
173 senders: self.senders.clone(),
174 sender_key,
175 receiver: rx,
176 ctor: self.ctor.clone(),
177 }
178 }
179}
180
181impl<T, S, R> Drop for BroadcastChannel<T, S, R>
182where
183 T: Send + Clone + 'static,
184 S: Send + Sync + Unpin + Clone + Sink<T>,
185 R: Unpin + Stream<Item = T>,
186{
187 fn drop(&mut self) {
188 #[cfg(feature = "parking-lot")]
189 self.senders.write().remove(self.sender_key);
190
191 #[cfg(not(feature = "parking-lot"))]
192 self.senders.write().unwrap().remove(self.sender_key);
193 }
194}
195
196impl<T, S, R> Debug for BroadcastChannel<T, S, R>
197where
198 T: Send + Clone + 'static,
199 S: Send + Sync + Unpin + Clone + Debug + Sink<T>,
200 R: Unpin + Debug + Stream<Item = T>,
201{
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_struct("BroadcastChannel")
204 .field("senders", &self.senders)
205 .field("sender_key", &self.sender_key)
206 .field("receiver", &self.receiver)
207 .finish()
208 }
209}
210
211impl<T, S, R> Stream for BroadcastChannel<T, S, R>
212where
213 T: Send + Clone + 'static,
214 S: Send + Sync + Unpin + Clone + Sink<T>,
215 R: Unpin + Stream<Item = T>,
216{
217 type Item = T;
218
219 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
220 (&mut self.receiver).poll_next_unpin(cx)
221 }
222}
223
224impl<T, S, R> Sink<T> for &BroadcastChannel<T, S, R>
225where
226 T: Send + Clone + 'static,
227 S: Send + Sync + Unpin + Clone + Sink<T>,
228 R: Unpin + Stream<Item = T>,
229{
230 type Error = S::Error;
231
232 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233 (*self)
234 .senders()
235 .iter_mut()
236 .map(|(_, sender)| Pin::new(sender).poll_ready(cx))
237 .find_map(|poll| match poll {
238 Poll::Ready(Err(_)) | Poll::Pending => Some(poll),
239 _ => None,
240 })
241 .or_else(|| Some(Poll::Ready(Ok(()))))
242 .unwrap()
243 }
244
245 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
246 (*self)
247 .senders()
248 .iter_mut()
249 .map(|(_, sender)| Pin::new(sender).start_send(item.clone()))
250 .collect::<Result<_, _>>()
251 }
252
253 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
254 (*self)
255 .senders()
256 .iter_mut()
257 .map(|(_, sender)| Pin::new(sender).poll_flush(cx))
258 .find_map(|poll| match poll {
259 Poll::Ready(Err(_)) | Poll::Pending => Some(poll),
260 _ => None,
261 })
262 .or_else(|| Some(Poll::Ready(Ok(()))))
263 .unwrap()
264 }
265
266 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
267 (*self)
268 .senders()
269 .iter_mut()
270 .map(|(_, sender)| Pin::new(sender).poll_close(cx))
271 .find_map(|poll| match poll {
272 Poll::Ready(Err(_)) | Poll::Pending => Some(poll),
273 _ => None,
274 })
275 .or_else(|| Some(Poll::Ready(Ok(()))))
276 .unwrap()
277 }
278}
279
280impl<T, S, R> Sink<T> for BroadcastChannel<T, S, R>
281 where
282 T: Send + Clone + 'static,
283 S: Send + Sync + Unpin + Clone + Sink<T>,
284 R: Unpin + Stream<Item = T>,
285{
286 type Error = S::Error;
287
288 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
289 Sink::poll_ready(Pin::new(&mut &*self), cx)
290 }
291
292 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
293 Sink::start_send(Pin::new(&mut &*self), item)
294 }
295
296 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
297 Sink::poll_flush(Pin::new(&mut &*self), cx)
298 }
299
300 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
301 Sink::poll_close(Pin::new(&mut &*self), cx)
302 }
303}
304
305#[cfg(all(feature = "default-channels", test))]
306mod test {
307 use super::BroadcastChannel;
308 use futures_executor::block_on;
309 use futures_util::future::{FutureExt, ready};
310 use futures_core::future::Future;
311 use futures_util::{StreamExt, SinkExt};
312 use futures_channel::mpsc::SendError;
313
314 #[test]
315 fn send_next() {
316 let mut chan = BroadcastChannel::new();
317 block_on(chan.send(&5)).unwrap();
318 assert_eq!(block_on(chan.next()), Some(5));
319 }
320
321 #[test]
322 fn split() {
323 fn plus_1(num: usize) -> impl Future<Output = Result<usize, SendError>> {
325 ready(Ok(num + 1))
326 }
327
328 let chan = BroadcastChannel::new();
329 let chan_cloned = chan.clone();
330
331 let (sink, stream) = chan.split();
332 let mut sink = sink.with(plus_1);
333 block_on(sink.send(5)).unwrap();
334 block_on(chan_cloned.send(&10)).unwrap();
335
336 assert_eq!(block_on(stream.take(2).collect::<Vec<_>>()), vec![6, 10]);
337 }
338
339 #[test]
340 fn now_or_never() {
341 let fut = async {
342 let mut chan = BroadcastChannel::new();
343 chan.send(&5i32).await?;
344 assert_eq!(chan.next().await, Some(5));
345
346 let mut chan2 = chan.clone();
347 chan2.send(&6i32).await?;
348 assert_eq!(chan.next().await, Some(6));
349 assert_eq!(chan2.next().await, Some(6));
350 Ok::<(), futures_channel::mpsc::SendError>(())
351 };
352 fut.now_or_never().unwrap().unwrap();
353 }
354
355 #[test]
356 fn try_send() {
357 let fut = async {
358 let mut chan = BroadcastChannel::with_cap(2);
359 chan.try_send(&5i32)?;
360 assert_eq!(chan.next().await, Some(5));
361
362 let mut chan2 = chan.clone();
363 chan2.try_send(&6i32)?;
364 assert_eq!(chan.next().await, Some(6));
365 assert_eq!(chan2.next().await, Some(6));
366 Ok::<(), futures_channel::mpsc::TrySendError<i32>>(())
367 };
368 fut.now_or_never().unwrap().unwrap();
369 }
370
371 fn assert_impl_send<T: Send>() {}
372 fn assert_impl_sync<T: Sync>() {}
373 fn assert_val_impl_send<T: Send>(_val: &T) {}
374 fn assert_val_impl_sync<T: Sync>(_val: &T) {}
375
376 #[test]
377 fn recv_two() {
378 let fut = async {
379 let mut chan = BroadcastChannel::new();
380 chan.send(&5i32).await?;
381 assert_eq!(chan.next().await, Some(5));
382
383 let mut chan2 = chan.clone();
384 chan2.send(&6i32).await?;
385 assert_eq!(chan.next().await, Some(6));
386 assert_eq!(chan2.next().await, Some(6));
387 Ok::<(), futures_channel::mpsc::SendError>(())
388 };
389 assert_val_impl_send(&fut);
390 assert_val_impl_sync(&fut);
391 block_on(fut).unwrap();
392 }
393
394 #[test]
395 fn send_sync() {
396 assert_impl_send::<BroadcastChannel<i32>>();
397 assert_impl_sync::<BroadcastChannel<i32>>();
398 }
399}