1use std::future::Future;
6use std::time::Duration;
7
8#[derive(Debug, thiserror::Error)]
10pub enum TimeoutError<E> {
11 #[error("operation timed out after {duration:?}: {context}")]
13 Timeout {
14 duration: Duration,
16 context: String,
18 },
19
20 #[error("execution error: {0}")]
22 ExecutionError(#[source] E),
23}
24
25impl<E> TimeoutError<E> {
26 pub fn is_timeout(&self) -> bool {
28 matches!(self, TimeoutError::Timeout { .. })
29 }
30
31 pub fn is_execution_error(&self) -> bool {
33 matches!(self, TimeoutError::ExecutionError(_))
34 }
35
36 pub fn timeout_duration(&self) -> Option<Duration> {
38 match self {
39 TimeoutError::Timeout { duration, .. } => Some(*duration),
40 _ => None,
41 }
42 }
43}
44
45pub async fn with_timeout<F, T, E>(
76 duration: Duration,
77 context: impl Into<String>,
78 future: F,
79) -> Result<T, TimeoutError<E>>
80where
81 F: Future<Output = Result<T, E>>,
82{
83 let context = context.into();
84
85 match tokio::time::timeout(duration, future).await {
86 Ok(Ok(result)) => Ok(result),
87 Ok(Err(e)) => Err(TimeoutError::ExecutionError(e)),
88 Err(_elapsed) => Err(TimeoutError::Timeout { duration, context }),
89 }
90}
91
92pub async fn with_timeout_infallible<F, T>(duration: Duration, future: F) -> Option<T>
101where
102 F: Future<Output = T>,
103{
104 tokio::time::timeout(duration, future).await.ok()
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use std::io;
111
112 #[tokio::test]
113 async fn test_timeout_success() {
114 let result: Result<&str, TimeoutError<io::Error>> =
115 with_timeout(Duration::from_secs(1), "test operation", async {
116 Ok("success")
117 })
118 .await;
119
120 assert!(result.is_ok());
121 assert_eq!(result.unwrap(), "success");
122 }
123
124 #[tokio::test]
125 async fn test_timeout_fires() {
126 let result: Result<(), TimeoutError<io::Error>> =
127 with_timeout(Duration::from_millis(50), "slow operation", async {
128 tokio::time::sleep(Duration::from_secs(10)).await;
129 Ok(())
130 })
131 .await;
132
133 assert!(result.is_err());
134 let err = result.unwrap_err();
135 assert!(err.is_timeout());
136 assert!(err.timeout_duration().is_some());
137 assert!(err.to_string().contains("slow operation"));
138 }
139
140 #[tokio::test]
141 async fn test_timeout_execution_error() {
142 let result: Result<(), TimeoutError<io::Error>> =
143 with_timeout(Duration::from_secs(1), "failing operation", async {
144 Err(io::Error::new(io::ErrorKind::NotFound, "not found"))
145 })
146 .await;
147
148 assert!(result.is_err());
149 let err = result.unwrap_err();
150 assert!(err.is_execution_error());
151 assert!(err.timeout_duration().is_none());
152 }
153
154 #[tokio::test]
155 async fn test_timeout_infallible_success() {
156 let result = with_timeout_infallible(Duration::from_secs(1), async { 42 }).await;
157
158 assert_eq!(result, Some(42));
159 }
160
161 #[tokio::test]
162 async fn test_timeout_infallible_timeout() {
163 let result = with_timeout_infallible(Duration::from_millis(50), async {
164 tokio::time::sleep(Duration::from_secs(10)).await;
165 42
166 })
167 .await;
168
169 assert_eq!(result, None);
170 }
171}