1use alloc::boxed::Box;
2use core::{future::Future, time::Duration};
3
4use futures_util::{
5 FutureExt as _, TryFutureExt as _,
6 future::{self, Either},
7};
8
9#[cfg(feature = "std")]
10use crate::sleep::sleep_until;
11use crate::{Sleepble, sleep::sleep};
12
13pub fn internal_timeout<SLEEP, T>(
15 dur: Duration,
16 future: T,
17) -> impl Future<Output = Result<T::Output, (Duration, T)>>
18where
19 SLEEP: Sleepble,
20 T: Future + Unpin,
21{
22 future::select(future, Box::pin(sleep::<SLEEP>(dur))).map(move |either| match either {
23 Either::Left((output, _)) => Ok(output),
24 Either::Right((_, future)) => Err((dur, future)),
25 })
26}
27
28pub fn timeout<SLEEP, T>(dur: Duration, future: T) -> impl Future<Output = Result<T::Output, Error>>
29where
30 SLEEP: Sleepble,
31 T: Future + Unpin,
32{
33 internal_timeout::<SLEEP, _>(dur, future).map_err(|(dur, _)| Error::Timeout(dur))
34}
35
36#[cfg(feature = "std")]
37pub fn internal_timeout_at<SLEEP, T>(
38 deadline: std::time::Instant,
39 future: T,
40) -> impl Future<Output = Result<T::Output, (std::time::Instant, T)>>
41where
42 SLEEP: Sleepble,
43 T: Future + Unpin,
44{
45 future::select(future, Box::pin(sleep_until::<SLEEP>(deadline))).map(move |either| match either
46 {
47 Either::Left((output, _)) => Ok(output),
48 Either::Right((_, future)) => Err((deadline, future)),
49 })
50}
51
52#[cfg(feature = "std")]
53pub fn timeout_at<SLEEP, T>(
54 deadline: std::time::Instant,
55 future: T,
56) -> impl Future<Output = Result<T::Output, Error>>
57where
58 SLEEP: Sleepble,
59 T: Future + Unpin,
60{
61 internal_timeout_at::<SLEEP, _>(deadline, future)
62 .map_err(|(instant, _)| Error::TimeoutAt(instant))
63}
64
65#[derive(Debug, PartialEq)]
67pub enum Error {
68 Timeout(Duration),
69 #[cfg(feature = "std")]
70 TimeoutAt(std::time::Instant),
71}
72impl core::fmt::Display for Error {
73 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74 write!(f, "{self:?}")
75 }
76}
77impl core::error::Error for Error {}
78
79#[cfg(feature = "std")]
80impl From<Error> for std::io::Error {
81 fn from(_err: Error) -> std::io::Error {
82 std::io::ErrorKind::TimedOut.into()
83 }
84}
85
86#[cfg(feature = "impl_tokio")]
87#[cfg(test)]
88mod tests {
89 #[allow(unused_imports)]
90 use super::*;
91
92 #[allow(dead_code)]
93 async fn foo() -> usize {
94 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
95 0
96 }
97
98 #[cfg(feature = "std")]
99 #[tokio::test]
100 async fn test_timeout() {
101 #[cfg(feature = "std")]
103 let now = std::time::Instant::now();
104
105 let (_tx, rx) = tokio::sync::oneshot::channel::<()>();
106 match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(50), rx).await {
107 Ok(v) => panic!("{v:?}"),
108 Err(err) => assert_eq!(err, Error::Timeout(Duration::from_millis(50))),
109 }
110
111 #[cfg(feature = "std")]
112 {
113 let elapsed_dur = now.elapsed();
114 assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
115 }
116
117 #[cfg(feature = "std")]
119 let now = std::time::Instant::now();
120
121 match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(50), Box::pin(foo()))
122 .await
123 {
124 Ok(v) => panic!("{v:?}"),
125 Err(err) => assert_eq!(err, Error::Timeout(Duration::from_millis(50))),
126 }
127
128 #[cfg(feature = "std")]
129 {
130 let elapsed_dur = now.elapsed();
131 assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
132 }
133
134 #[cfg(feature = "std")]
136 let now = std::time::Instant::now();
137
138 match timeout::<crate::impl_tokio::Sleep, _>(Duration::from_millis(150), Box::pin(foo()))
139 .await
140 {
141 Ok(v) => assert_eq!(v, 0),
142 Err(err) => panic!("{err:?}"),
143 }
144
145 #[cfg(feature = "std")]
146 {
147 let elapsed_dur = now.elapsed();
148 assert!(elapsed_dur.as_millis() >= 100 && elapsed_dur.as_millis() <= 105);
149 }
150 }
151
152 #[cfg(feature = "std")]
153 #[tokio::test]
154 async fn test_timeout_at() {
155 let now = std::time::Instant::now();
157
158 match timeout_at::<crate::impl_tokio::Sleep, _>(
159 std::time::Instant::now() + Duration::from_millis(50),
160 Box::pin(foo()),
161 )
162 .await
163 {
164 Ok(v) => panic!("{v:?}"),
165 Err(Error::Timeout(dur)) => panic!("{dur:?}"),
166 Err(Error::TimeoutAt(instant)) => {
167 let elapsed_dur = instant.elapsed();
168 assert!(elapsed_dur.as_millis() <= 5);
169 }
170 }
171
172 let elapsed_dur = now.elapsed();
173 assert!(elapsed_dur.as_millis() >= 50 && elapsed_dur.as_millis() <= 55);
174 }
175}