futures_rx/stream_ext/
share.rs1use core::pin::Pin;
2use core::task::Context;
3use core::task::Poll;
4use futures::stream::Fuse;
5use futures::stream::FusedStream;
6use futures::Stream;
7use futures::StreamExt;
8use pin_project_lite::pin_project;
9use std::cell::RefCell;
10use std::rc::Rc;
11
12use crate::subject::shareable_subject::ShareableSubject;
13use crate::subject::Subject;
14use crate::Event;
15use crate::Observable;
16
17pin_project! {
18 #[must_use = "streams do nothing unless polled"]
20 pub struct Shared<S: Stream, Sub: Subject<Item = S::Item>> {
21 inner: Rc<RefCell<ShareableSubject<S, Sub>>>,
22 #[pin]
23 stream: Fuse<Observable<S::Item>>,
24 }
25}
26
27impl<S: Stream, Sub: Subject<Item = S::Item>> Shared<S, Sub> {
28 pub(crate) fn new(stream: S, subject: Sub) -> Self {
29 let mut subject = ShareableSubject::new(stream, subject);
30 let stream = subject.subscribe().fuse();
31
32 Self {
33 inner: Rc::new(RefCell::new(subject)),
34 stream,
35 }
36 }
37}
38
39impl<S: Stream, Sub: Subject<Item = S::Item>> Clone for Shared<S, Sub> {
40 fn clone(&self) -> Self {
41 let stream = self.inner.borrow_mut().subscribe().fuse();
42
43 Self {
44 inner: Rc::clone(&self.inner),
45 stream,
46 }
47 }
48}
49
50impl<S: Stream, Sub: Subject<Item = S::Item>> Stream for Shared<S, Sub> {
51 type Item = Event<S::Item>;
52
53 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
54 self.inner.borrow_mut().poll_next(cx);
55 self.stream.poll_next_unpin(cx)
56 }
57
58 fn size_hint(&self) -> (usize, Option<usize>) {
59 self.stream.size_hint()
60 }
61}
62
63impl<S: Stream, Sub: Subject<Item = S::Item>> FusedStream for Shared<S, Sub> {
64 fn is_terminated(&self) -> bool {
65 self.stream.is_terminated()
66 }
67}
68
69#[cfg(test)]
70mod test {
71 use futures::{executor::block_on, future::join, stream, StreamExt};
72
73 use crate::RxExt;
74
75 #[test]
76 fn smoke() {
77 block_on(async {
78 let stream = stream::iter(1usize..=3usize);
79 let s1 = stream.share();
80 let s2 = s1.clone();
81 let (a, b) = join(s1.collect::<Vec<_>>(), s2.collect::<Vec<_>>()).await;
82
83 assert_eq!(a, [1.into(), 2.into(), 3.into()]);
84 assert_eq!(b, [1.into(), 2.into(), 3.into()]);
85 });
86 }
87}