1use std::{
5 collections::VecDeque,
6 sync::{Arc, Mutex},
7 task::Waker,
8};
9
10use futures::{stream::BoxStream, Stream, StreamExt};
11use pin_project::pin_project;
12use tokio::sync::Semaphore;
13use tokio_util::sync::PollSemaphore;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16enum Side {
17 Left,
18 Right,
19}
20
21#[derive(Clone, Copy, Debug, PartialEq)]
23pub enum Capacity {
24 Bounded(u32),
25 Unbounded,
26}
27
28struct InnerState<'a, T> {
29 inner: Option<BoxStream<'a, T>>,
30 buffer: VecDeque<T>,
31 polling: Option<Side>,
32 waker: Option<Waker>,
33 exhausted: bool,
34 left_buffered: u32,
35 right_buffered: u32,
36 available_buffer: Option<PollSemaphore>,
37}
38
39pub struct SharedStream<'a, T: Clone> {
41 state: Arc<Mutex<InnerState<'a, T>>>,
42 side: Side,
43}
44
45impl<'a, T: Clone> SharedStream<'a, T> {
46 pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
47 let available_buffer = match capacity {
48 Capacity::Unbounded => None,
49 Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
50 capacity as usize,
51 )))),
52 };
53 let state = InnerState {
54 inner: Some(inner),
55 buffer: VecDeque::new(),
56 polling: None,
57 waker: None,
58 exhausted: false,
59 left_buffered: 0,
60 right_buffered: 0,
61 available_buffer,
62 };
63
64 let state = Arc::new(Mutex::new(state));
65
66 let left = Self {
67 state: state.clone(),
68 side: Side::Left,
69 };
70 let right = Self {
71 state,
72 side: Side::Right,
73 };
74 (left, right)
75 }
76}
77
78impl<T: Clone> Stream for SharedStream<'_, T> {
79 type Item = T;
80
81 fn poll_next(
82 self: std::pin::Pin<&mut Self>,
83 cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<Option<Self::Item>> {
85 let mut inner_state = self.state.lock().unwrap();
86 let can_take_buffered = match self.side {
87 Side::Left => inner_state.left_buffered > 0,
88 Side::Right => inner_state.right_buffered > 0,
89 };
90 if can_take_buffered {
91 let item = inner_state.buffer.pop_front();
93 match self.side {
94 Side::Left => {
95 inner_state.left_buffered -= 1;
96 }
97 Side::Right => {
98 inner_state.right_buffered -= 1;
99 }
100 }
101 if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
102 available_buffer.add_permits(1);
103 }
104 std::task::Poll::Ready(item)
105 } else {
106 if inner_state.exhausted {
107 return std::task::Poll::Ready(None);
108 }
109 let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
111 match available_buffer.poll_acquire(cx) {
112 std::task::Poll::Ready(permit) => Some(permit.unwrap()),
115 std::task::Poll::Pending => {
116 return std::task::Poll::Pending;
117 }
118 }
119 } else {
120 None
121 };
122 if let Some(polling_side) = inner_state.polling.as_ref() {
123 if *polling_side != self.side {
124 inner_state.waker = Some(cx.waker().clone());
132 return std::task::Poll::Pending;
133 }
134 }
135 inner_state.polling = Some(self.side);
136 let mut to_poll = inner_state
138 .inner
139 .take()
140 .expect("Other half of shared stream panic'd while polling inner stream");
141 drop(inner_state);
142 let res = to_poll.poll_next_unpin(cx);
143 let mut inner_state = self.state.lock().unwrap();
144
145 let mut should_wake = true;
146 match &res {
147 std::task::Poll::Ready(None) => {
148 inner_state.exhausted = true;
149 inner_state.polling = None;
150 }
151 std::task::Poll::Ready(Some(item)) => {
152 if let Some(permit) = permit {
154 permit.forget();
155 }
156 inner_state.polling = None;
157 match self.side {
159 Side::Left => {
160 inner_state.right_buffered += 1;
161 }
162 Side::Right => {
163 inner_state.left_buffered += 1;
164 }
165 };
166 inner_state.buffer.push_back(item.clone());
167 }
168 std::task::Poll::Pending => {
169 should_wake = false;
170 }
171 };
172
173 inner_state.inner = Some(to_poll);
174
175 let to_wake = if should_wake {
177 inner_state.waker.take()
178 } else {
179 None
182 };
183 drop(inner_state);
184 if let Some(waker) = to_wake {
185 waker.wake();
186 }
187 res
188 }
189 }
190}
191
192pub trait SharedStreamExt<'a>: Stream + Send
193where
194 Self::Item: Clone,
195{
196 fn share(
209 self,
210 capacity: Capacity,
211 ) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
212}
213
214impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
215 fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
216 SharedStream::new(self, capacity)
217 }
218}
219
220#[pin_project]
221pub struct FinallyStream<S: Stream, F: FnOnce()> {
222 #[pin]
223 stream: S,
224 f: Option<F>,
225}
226
227impl<S: Stream, F: FnOnce()> FinallyStream<S, F> {
228 pub fn new(stream: S, f: F) -> Self {
229 Self { stream, f: Some(f) }
230 }
231}
232
233impl<S: Stream, F: FnOnce()> Stream for FinallyStream<S, F> {
234 type Item = S::Item;
235
236 fn poll_next(
237 self: std::pin::Pin<&mut Self>,
238 cx: &mut std::task::Context<'_>,
239 ) -> std::task::Poll<Option<Self::Item>> {
240 let this = self.project();
241 let res = this.stream.poll_next(cx);
242 if matches!(res, std::task::Poll::Ready(None)) {
243 if let Some(f) = this.f.take() {
245 f();
246 }
247 }
248 res
249 }
250}
251
252pub trait FinallyStreamExt<S: Stream>: Stream + Sized {
253 fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
254 FinallyStream {
255 stream: self,
256 f: Some(f),
257 }
258 }
259}
260
261impl<S: Stream> FinallyStreamExt<S> for S {
262 fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
263 FinallyStream::new(self, f)
264 }
265}
266
267#[cfg(test)]
268mod tests {
269
270 use futures::{FutureExt, StreamExt};
271 use tokio_stream::wrappers::ReceiverStream;
272
273 use crate::utils::futures::{Capacity, SharedStreamExt};
274
275 fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
276 let noop_waker = futures::task::noop_waker();
277 let mut context = std::task::Context::from_waker(&noop_waker);
278 fut.poll_unpin(&mut context).is_pending()
279 }
280
281 #[tokio::test]
282 async fn test_shared_stream() {
283 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
284 let inner_stream = ReceiverStream::new(rx);
285
286 for i in 0..3 {
288 tx.send(i).await.unwrap();
289 }
290
291 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
292
293 assert_eq!(left.next().await.unwrap(), 0);
295 assert_eq!(left.next().await.unwrap(), 1);
296
297 let mut left_fut = left.next();
299
300 assert!(is_pending(&mut left_fut));
301
302 assert_eq!(right.next().await.unwrap(), 0);
304 assert_eq!(left_fut.await.unwrap(), 2);
305
306 assert_eq!(right.next().await.unwrap(), 1);
308 assert_eq!(right.next().await.unwrap(), 2);
309
310 let mut right_fut = right.next();
312 let mut left_fut = left.next();
313 assert!(is_pending(&mut right_fut));
314 assert!(is_pending(&mut left_fut));
315
316 tx.send(3).await.unwrap();
318
319 assert_eq!(right_fut.await.unwrap(), 3);
321 assert_eq!(left_fut.await.unwrap(), 3);
322
323 drop(tx);
324
325 assert_eq!(left.next().await, None);
327 assert_eq!(right.next().await, None);
328
329 assert_eq!(left.next().await, None);
331 assert_eq!(right.next().await, None);
332 }
333
334 #[tokio::test]
335 async fn test_unbounded_shared_stream() {
336 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
337 let inner_stream = ReceiverStream::new(rx);
338
339 for i in 0..10 {
341 tx.send(i).await.unwrap();
342 }
343 drop(tx);
344
345 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
346
347 for i in 0..10 {
349 assert_eq!(left.next().await.unwrap(), i);
350 }
351 assert_eq!(left.next().await, None);
352
353 for i in 0..10 {
355 assert_eq!(right.next().await.unwrap(), i);
356 }
357 assert_eq!(right.next().await, None);
358 }
359
360 #[tokio::test(flavor = "multi_thread")]
361 async fn stress_shared_stream() {
362 for _ in 0..100 {
363 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
364 let inner_stream = ReceiverStream::new(rx);
365 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
366
367 let left_handle = tokio::spawn(async move {
368 let mut counter = 0;
369 while let Some(item) = left.next().await {
370 assert_eq!(item, counter);
371 counter += 1;
372 }
373 });
374
375 let right_handle = tokio::spawn(async move {
376 let mut counter = 0;
377 while let Some(item) = right.next().await {
378 assert_eq!(item, counter);
379 counter += 1;
380 }
381 });
382
383 for i in 0..1000 {
384 tx.send(i).await.unwrap();
385 }
386 drop(tx);
387 left_handle.await.unwrap();
388 right_handle.await.unwrap();
389 }
390 }
391}