1use std::{
5 collections::VecDeque,
6 sync::{Arc, Mutex},
7 task::Waker,
8};
9
10use futures::{Stream, StreamExt, stream::BoxStream};
11use pin_project::{pin_project, pinned_drop};
12use tokio::sync::Semaphore;
13use tokio_util::sync::PollSemaphore;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
16enum Side {
17 Left,
18 Right,
19}
20
21#[derive(Clone, Copy, Debug, PartialEq)]
23pub enum Capacity {
24 Bounded(u32),
25 Unbounded,
26}
27
28struct InnerState<'a, T> {
29 inner: Option<BoxStream<'a, T>>,
30 buffer: VecDeque<T>,
31 polling: Option<Side>,
32 waker: Option<Waker>,
33 exhausted: bool,
34 left_buffered: u32,
35 right_buffered: u32,
36 available_buffer: Option<PollSemaphore>,
37}
38
39pub struct SharedStream<'a, T: Clone> {
41 state: Arc<Mutex<InnerState<'a, T>>>,
42 side: Side,
43}
44
45impl<'a, T: Clone> SharedStream<'a, T> {
46 pub fn new(inner: BoxStream<'a, T>, capacity: Capacity) -> (Self, Self) {
47 let available_buffer = match capacity {
48 Capacity::Unbounded => None,
49 Capacity::Bounded(capacity) => Some(PollSemaphore::new(Arc::new(Semaphore::new(
50 capacity as usize,
51 )))),
52 };
53 let state = InnerState {
54 inner: Some(inner),
55 buffer: VecDeque::new(),
56 polling: None,
57 waker: None,
58 exhausted: false,
59 left_buffered: 0,
60 right_buffered: 0,
61 available_buffer,
62 };
63
64 let state = Arc::new(Mutex::new(state));
65
66 let left = Self {
67 state: state.clone(),
68 side: Side::Left,
69 };
70 let right = Self {
71 state,
72 side: Side::Right,
73 };
74 (left, right)
75 }
76}
77
78impl<T: Clone> Stream for SharedStream<'_, T> {
79 type Item = T;
80
81 fn poll_next(
82 self: std::pin::Pin<&mut Self>,
83 cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<Option<Self::Item>> {
85 let mut inner_state = self.state.lock().unwrap();
86 let can_take_buffered = match self.side {
87 Side::Left => inner_state.left_buffered > 0,
88 Side::Right => inner_state.right_buffered > 0,
89 };
90 if can_take_buffered {
91 let item = inner_state.buffer.pop_front();
93 match self.side {
94 Side::Left => {
95 inner_state.left_buffered -= 1;
96 }
97 Side::Right => {
98 inner_state.right_buffered -= 1;
99 }
100 }
101 if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
102 available_buffer.add_permits(1);
103 }
104 std::task::Poll::Ready(item)
105 } else {
106 if inner_state.exhausted {
107 return std::task::Poll::Ready(None);
108 }
109 let permit = if let Some(available_buffer) = inner_state.available_buffer.as_mut() {
111 match available_buffer.poll_acquire(cx) {
112 std::task::Poll::Ready(permit) => Some(permit.unwrap()),
115 std::task::Poll::Pending => {
116 return std::task::Poll::Pending;
117 }
118 }
119 } else {
120 None
121 };
122 if let Some(polling_side) = inner_state.polling.as_ref()
123 && *polling_side != self.side
124 {
125 inner_state.waker = Some(cx.waker().clone());
133 return std::task::Poll::Pending;
134 }
135 inner_state.polling = Some(self.side);
136 let mut to_poll = inner_state
138 .inner
139 .take()
140 .expect("Other half of shared stream panic'd while polling inner stream");
141 drop(inner_state);
142 let res = to_poll.poll_next_unpin(cx);
143 let mut inner_state = self.state.lock().unwrap();
144
145 let mut should_wake = true;
146 match &res {
147 std::task::Poll::Ready(None) => {
148 inner_state.exhausted = true;
149 inner_state.polling = None;
150 }
151 std::task::Poll::Ready(Some(item)) => {
152 if let Some(permit) = permit {
154 permit.forget();
155 }
156 inner_state.polling = None;
157 match self.side {
159 Side::Left => {
160 inner_state.right_buffered += 1;
161 }
162 Side::Right => {
163 inner_state.left_buffered += 1;
164 }
165 };
166 inner_state.buffer.push_back(item.clone());
167 }
168 std::task::Poll::Pending => {
169 should_wake = false;
170 }
171 };
172
173 inner_state.inner = Some(to_poll);
174
175 let to_wake = if should_wake {
177 inner_state.waker.take()
178 } else {
179 None
182 };
183 drop(inner_state);
184 if let Some(waker) = to_wake {
185 waker.wake();
186 }
187 res
188 }
189 }
190}
191
192pub trait SharedStreamExt<'a>: Stream + Send
193where
194 Self::Item: Clone,
195{
196 fn share(
209 self,
210 capacity: Capacity,
211 ) -> (SharedStream<'a, Self::Item>, SharedStream<'a, Self::Item>);
212}
213
214impl<'a, T: Clone> SharedStreamExt<'a> for BoxStream<'a, T> {
215 fn share(self, capacity: Capacity) -> (SharedStream<'a, T>, SharedStream<'a, T>) {
216 SharedStream::new(self, capacity)
217 }
218}
219
220#[pin_project]
221pub struct FinallyStream<S: Stream, F: FnOnce()> {
222 #[pin]
223 stream: S,
224 f: Option<F>,
225}
226
227impl<S: Stream, F: FnOnce()> FinallyStream<S, F> {
228 pub fn new(stream: S, f: F) -> Self {
229 Self { stream, f: Some(f) }
230 }
231}
232
233impl<S: Stream, F: FnOnce()> Stream for FinallyStream<S, F> {
234 type Item = S::Item;
235
236 fn poll_next(
237 self: std::pin::Pin<&mut Self>,
238 cx: &mut std::task::Context<'_>,
239 ) -> std::task::Poll<Option<Self::Item>> {
240 let this = self.project();
241 let res = this.stream.poll_next(cx);
242 if matches!(res, std::task::Poll::Ready(None)) {
243 if let Some(f) = this.f.take() {
245 f();
246 }
247 }
248 res
249 }
250}
251
252pub trait FinallyStreamExt<S: Stream>: Stream + Sized {
253 fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
254 FinallyStream {
255 stream: self,
256 f: Some(f),
257 }
258 }
259}
260
261impl<S: Stream> FinallyStreamExt<S> for S {
262 fn finally<F: FnOnce()>(self, f: F) -> FinallyStream<Self, F> {
263 FinallyStream::new(self, f)
264 }
265}
266
267#[pin_project(PinnedDrop)]
273pub struct OnDropStream<S: Stream, F: FnOnce()> {
274 #[pin]
275 stream: S,
276 f: Option<F>,
277}
278
279impl<S: Stream, F: FnOnce()> OnDropStream<S, F> {
280 pub fn new(stream: S, f: F) -> Self {
281 Self { stream, f: Some(f) }
282 }
283}
284
285impl<S: Stream, F: FnOnce()> Stream for OnDropStream<S, F> {
286 type Item = S::Item;
287
288 fn poll_next(
289 self: std::pin::Pin<&mut Self>,
290 cx: &mut std::task::Context<'_>,
291 ) -> std::task::Poll<Option<Self::Item>> {
292 self.project().stream.poll_next(cx)
293 }
294}
295
296#[pinned_drop]
297impl<S: Stream, F: FnOnce()> PinnedDrop for OnDropStream<S, F> {
298 fn drop(self: std::pin::Pin<&mut Self>) {
299 let this = self.project();
300 if let Some(f) = this.f.take() {
301 f();
302 }
303 }
304}
305
306pub trait StreamOnDropExt: Stream + Sized {
307 fn on_drop<F: FnOnce()>(self, f: F) -> OnDropStream<Self, F> {
309 OnDropStream::new(self, f)
310 }
311}
312
313impl<S: Stream> StreamOnDropExt for S {}
314
315#[cfg(test)]
316mod tests {
317
318 use std::sync::Arc;
319 use std::sync::atomic::{AtomicBool, Ordering};
320
321 use futures::{FutureExt, StreamExt};
322 use tokio_stream::wrappers::ReceiverStream;
323
324 use crate::utils::futures::{Capacity, SharedStreamExt, StreamOnDropExt};
325
326 fn is_pending(fut: &mut (impl std::future::Future + Unpin)) -> bool {
327 let noop_waker = futures::task::noop_waker();
328 let mut context = std::task::Context::from_waker(&noop_waker);
329 fut.poll_unpin(&mut context).is_pending()
330 }
331
332 #[tokio::test]
333 async fn test_shared_stream() {
334 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
335 let inner_stream = ReceiverStream::new(rx);
336
337 for i in 0..3 {
339 tx.send(i).await.unwrap();
340 }
341
342 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
343
344 assert_eq!(left.next().await.unwrap(), 0);
346 assert_eq!(left.next().await.unwrap(), 1);
347
348 let mut left_fut = left.next();
350
351 assert!(is_pending(&mut left_fut));
352
353 assert_eq!(right.next().await.unwrap(), 0);
355 assert_eq!(left_fut.await.unwrap(), 2);
356
357 assert_eq!(right.next().await.unwrap(), 1);
359 assert_eq!(right.next().await.unwrap(), 2);
360
361 let mut right_fut = right.next();
363 let mut left_fut = left.next();
364 assert!(is_pending(&mut right_fut));
365 assert!(is_pending(&mut left_fut));
366
367 tx.send(3).await.unwrap();
369
370 assert_eq!(right_fut.await.unwrap(), 3);
372 assert_eq!(left_fut.await.unwrap(), 3);
373
374 drop(tx);
375
376 assert_eq!(left.next().await, None);
378 assert_eq!(right.next().await, None);
379
380 assert_eq!(left.next().await, None);
382 assert_eq!(right.next().await, None);
383 }
384
385 #[tokio::test]
386 async fn test_unbounded_shared_stream() {
387 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
388 let inner_stream = ReceiverStream::new(rx);
389
390 for i in 0..10 {
392 tx.send(i).await.unwrap();
393 }
394 drop(tx);
395
396 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Unbounded);
397
398 for i in 0..10 {
400 assert_eq!(left.next().await.unwrap(), i);
401 }
402 assert_eq!(left.next().await, None);
403
404 for i in 0..10 {
406 assert_eq!(right.next().await.unwrap(), i);
407 }
408 assert_eq!(right.next().await, None);
409 }
410
411 #[tokio::test(flavor = "multi_thread")]
412 async fn stress_shared_stream() {
413 for _ in 0..100 {
414 let (tx, rx) = tokio::sync::mpsc::channel::<u32>(10);
415 let inner_stream = ReceiverStream::new(rx);
416 let (mut left, mut right) = inner_stream.boxed().share(Capacity::Bounded(2));
417
418 let left_handle = tokio::spawn(async move {
419 let mut counter = 0;
420 while let Some(item) = left.next().await {
421 assert_eq!(item, counter);
422 counter += 1;
423 }
424 });
425
426 let right_handle = tokio::spawn(async move {
427 let mut counter = 0;
428 while let Some(item) = right.next().await {
429 assert_eq!(item, counter);
430 counter += 1;
431 }
432 });
433
434 for i in 0..1000 {
435 tx.send(i).await.unwrap();
436 }
437 drop(tx);
438 left_handle.await.unwrap();
439 right_handle.await.unwrap();
440 }
441 }
442
443 #[tokio::test]
444 async fn test_on_drop_fires_on_early_drop() {
445 let called = Arc::new(AtomicBool::new(false));
446 let called_clone = called.clone();
447
448 let stream = futures::stream::iter(vec![1, 2, 3]);
449 let mut stream = stream.on_drop(move || {
450 called_clone.store(true, Ordering::SeqCst);
451 });
452
453 assert_eq!(stream.next().await, Some(1));
455 assert!(!called.load(Ordering::SeqCst));
456 drop(stream);
457 assert!(called.load(Ordering::SeqCst));
458 }
459
460 #[tokio::test]
461 async fn test_on_drop_fires_after_exhaustion() {
462 let called = Arc::new(AtomicBool::new(false));
463 let called_clone = called.clone();
464
465 let stream = futures::stream::iter(vec![1]);
466 let mut stream = stream.on_drop(move || {
467 called_clone.store(true, Ordering::SeqCst);
468 });
469
470 assert_eq!(stream.next().await, Some(1));
471 assert_eq!(stream.next().await, None);
472 assert!(!called.load(Ordering::SeqCst));
473 drop(stream);
474 assert!(called.load(Ordering::SeqCst));
475 }
476
477 #[tokio::test]
478 async fn test_on_drop_fires_without_polling() {
479 let called = Arc::new(AtomicBool::new(false));
480 let called_clone = called.clone();
481
482 let stream = futures::stream::iter(vec![1, 2, 3]);
483 let stream = stream.on_drop(move || {
484 called_clone.store(true, Ordering::SeqCst);
485 });
486
487 assert!(!called.load(Ordering::SeqCst));
488 drop(stream);
489 assert!(called.load(Ordering::SeqCst));
490 }
491}