pollable_map/stream/
timeout_set.rs1use crate::common::Timed;
2use crate::stream::set::StreamSet;
3use futures::stream::FusedStream;
4use futures::{Stream, StreamExt};
5use std::ops::{Deref, DerefMut};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10pub struct TimeoutStreamSet<S> {
11 duration: Duration,
12 set: StreamSet<Timed<S>>,
13}
14
15impl<S> Deref for TimeoutStreamSet<S> {
16 type Target = StreamSet<Timed<S>>;
17 fn deref(&self) -> &Self::Target {
18 &self.set
19 }
20}
21
22impl<S> DerefMut for TimeoutStreamSet<S> {
23 fn deref_mut(&mut self) -> &mut Self::Target {
24 &mut self.set
25 }
26}
27
28impl<S> TimeoutStreamSet<S>
29where
30 S: Stream + Send + Unpin + 'static,
31{
32 pub fn new(duration: Duration) -> Self {
34 Self {
35 duration,
36 set: StreamSet::new(),
37 }
38 }
39
40 pub fn insert(&mut self, stream: S) -> bool {
42 self.set.insert(Timed::new(stream, self.duration))
43 }
44}
45
46impl<S> Stream for TimeoutStreamSet<S>
47where
48 S: Stream + Send + Unpin + 'static,
49{
50 type Item = std::io::Result<S::Item>;
51 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
52 self.set.poll_next_unpin(cx)
53 }
54
55 fn size_hint(&self) -> (usize, Option<usize>) {
56 self.set.size_hint()
57 }
58}
59
60impl<S> FusedStream for TimeoutStreamSet<S>
61where
62 S: Stream + Send + Unpin + 'static,
63{
64 fn is_terminated(&self) -> bool {
65 self.set.is_terminated()
66 }
67}
68
69#[cfg(test)]
70mod test {
71 use crate::stream::timeout_set::TimeoutStreamSet;
72 use futures::StreamExt;
73 use std::time::Duration;
74
75 #[test]
76 fn timeout_set() {
77 let mut list = TimeoutStreamSet::new(Duration::from_millis(100));
78 assert!(list.insert(futures::stream::pending::<()>()));
79
80 futures::executor::block_on(async move {
81 let result = list.next().await;
82 let Some(Err(e)) = result else {
83 unreachable!("result is err");
84 };
85
86 assert_eq!(e.kind(), std::io::ErrorKind::TimedOut);
87 });
88 }
89
90 #[test]
91 fn valid_stream() {
92 let mut list = TimeoutStreamSet::new(Duration::from_secs(10));
93 assert!(list.insert(futures::stream::once(async { 0 }).boxed()));
94
95 futures::executor::block_on(async move {
96 let result = list.next().await;
97 let Some(Ok(val)) = result else {
98 unreachable!("result is err");
99 };
100
101 assert_eq!(val, 0);
102 });
103 }
104}