aster/background/
timeout.rs1use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tokio::time::{sleep, Duration};
14
15use super::types::TimeoutStats;
16
17pub(crate) type TimeoutCallback = Arc<dyn Fn(&str) + Send + Sync>;
19
20#[derive(Debug, Clone)]
22pub struct TimeoutConfig {
23 pub default_timeout_ms: u64,
24 pub max_timeout_ms: u64,
25 pub graceful_shutdown_timeout_ms: u64,
26}
27
28impl Default for TimeoutConfig {
29 fn default() -> Self {
30 Self {
31 default_timeout_ms: 120_000, max_timeout_ms: 600_000, graceful_shutdown_timeout_ms: 5_000, }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct TimeoutHandle {
41 pub id: String,
42 pub start_time: i64,
43 pub duration_ms: u64,
44 pub cancelled: bool,
45}
46
47pub struct TimeoutManager {
49 timeouts: Arc<RwLock<HashMap<String, TimeoutHandle>>>,
50 config: TimeoutConfig,
51 on_timeout: Option<TimeoutCallback>,
52}
53
54impl TimeoutManager {
55 pub fn new(config: TimeoutConfig) -> Self {
57 Self {
58 timeouts: Arc::new(RwLock::new(HashMap::new())),
59 config,
60 on_timeout: None,
61 }
62 }
63
64 pub fn set_on_timeout<F>(&mut self, callback: F)
66 where
67 F: Fn(&str) + Send + Sync + 'static,
68 {
69 self.on_timeout = Some(Arc::new(callback));
70 }
71
72 pub async fn set_timeout<F>(
74 &self,
75 id: &str,
76 callback: F,
77 duration_ms: Option<u64>,
78 ) -> TimeoutHandle
79 where
80 F: FnOnce() + Send + 'static,
81 {
82 self.clear_timeout(id).await;
84
85 let actual_duration = duration_ms
86 .unwrap_or(self.config.default_timeout_ms)
87 .min(self.config.max_timeout_ms);
88
89 let handle = TimeoutHandle {
90 id: id.to_string(),
91 start_time: chrono::Utc::now().timestamp_millis(),
92 duration_ms: actual_duration,
93 cancelled: false,
94 };
95
96 self.timeouts
97 .write()
98 .await
99 .insert(id.to_string(), handle.clone());
100
101 let timeouts = Arc::clone(&self.timeouts);
103 let id_clone = id.to_string();
104 let on_timeout = self.on_timeout.clone();
105
106 tokio::spawn(async move {
107 sleep(Duration::from_millis(actual_duration)).await;
108
109 let mut guard = timeouts.write().await;
110 if let Some(h) = guard.get(&id_clone) {
111 if !h.cancelled {
112 if let Some(cb) = on_timeout {
113 cb(&id_clone);
114 }
115 callback();
116 guard.remove(&id_clone);
117 }
118 }
119 });
120
121 handle
122 }
123
124 pub async fn clear_timeout(&self, id: &str) -> bool {
126 let mut timeouts = self.timeouts.write().await;
127 if let Some(handle) = timeouts.get_mut(id) {
128 handle.cancelled = true;
129 timeouts.remove(id);
130 true
131 } else {
132 false
133 }
134 }
135
136 pub async fn get_remaining_time(&self, id: &str) -> Option<u64> {
138 let timeouts = self.timeouts.read().await;
139 if let Some(handle) = timeouts.get(id) {
140 let elapsed = (chrono::Utc::now().timestamp_millis() - handle.start_time) as u64;
141 Some(handle.duration_ms.saturating_sub(elapsed))
142 } else {
143 None
144 }
145 }
146
147 pub async fn is_timed_out(&self, id: &str) -> bool {
149 !self.timeouts.read().await.contains_key(id)
150 }
151
152 pub async fn reset_timeout(&self, id: &str) -> bool {
154 let mut timeouts = self.timeouts.write().await;
155 if let Some(handle) = timeouts.get_mut(id) {
156 handle.start_time = chrono::Utc::now().timestamp_millis();
157 true
158 } else {
159 false
160 }
161 }
162
163 pub async fn extend_timeout(&self, id: &str, additional_ms: u64) -> bool {
165 let mut timeouts = self.timeouts.write().await;
166 if let Some(handle) = timeouts.get_mut(id) {
167 let new_duration = (handle.duration_ms + additional_ms).min(self.config.max_timeout_ms);
168 handle.duration_ms = new_duration;
169 true
170 } else {
171 false
172 }
173 }
174
175 pub async fn get_all_timeouts(&self) -> Vec<TimeoutHandle> {
177 self.timeouts.read().await.values().cloned().collect()
178 }
179
180 pub async fn clear_all(&self) -> usize {
182 let mut timeouts = self.timeouts.write().await;
183 let count = timeouts.len();
184 for handle in timeouts.values_mut() {
185 handle.cancelled = true;
186 }
187 timeouts.clear();
188 count
189 }
190
191 pub async fn get_stats(&self) -> TimeoutStats {
193 TimeoutStats {
194 total: self.timeouts.read().await.len(),
195 default_timeout_ms: self.config.default_timeout_ms,
196 max_timeout_ms: self.config.max_timeout_ms,
197 graceful_shutdown_timeout_ms: self.config.graceful_shutdown_timeout_ms,
198 }
199 }
200}
201
202pub async fn promise_with_timeout<T, F>(
204 future: F,
205 timeout_ms: u64,
206 timeout_error: Option<&str>,
207) -> Result<T, String>
208where
209 F: std::future::Future<Output = T>,
210{
211 match tokio::time::timeout(Duration::from_millis(timeout_ms), future).await {
212 Ok(result) => Ok(result),
213 Err(_) => Err(timeout_error.unwrap_or("Operation timed out").to_string()),
214 }
215}
216
217pub struct CancellableDelay {
219 duration_ms: u64,
220 cancelled: Arc<RwLock<bool>>,
221}
222
223impl CancellableDelay {
224 pub fn new(duration_ms: u64) -> Self {
226 Self {
227 duration_ms,
228 cancelled: Arc::new(RwLock::new(false)),
229 }
230 }
231
232 pub async fn start(&self) -> Result<(), ()> {
234 let cancelled = Arc::clone(&self.cancelled);
235 let duration = Duration::from_millis(self.duration_ms);
236
237 tokio::select! {
238 _ = sleep(duration) => {
239 if *cancelled.read().await {
240 Err(())
241 } else {
242 Ok(())
243 }
244 }
245 }
246 }
247
248 pub async fn cancel(&self) {
250 *self.cancelled.write().await = true;
251 }
252}