Skip to main content

aster/background/
timeout.rs

1//! 超时处理模块
2//!
3//! 提供任务超时管理、进程终止策略和超时配置
4//!
5//! # 功能
6//! - 超时时间管理
7//! - 优雅终止策略
8//! - 超时延长和重置
9
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tokio::time::{sleep, Duration};
14
15use super::types::TimeoutStats;
16
17/// 超时回调函数类型
18pub(crate) type TimeoutCallback = Arc<dyn Fn(&str) + Send + Sync>;
19
20/// 超时配置
21#[derive(Debug, Clone)]
22pub struct TimeoutConfig {
23    pub default_timeout_ms: u64,
24    pub max_timeout_ms: u64,
25    pub graceful_shutdown_timeout_ms: u64,
26}
27
28impl Default for TimeoutConfig {
29    fn default() -> Self {
30        Self {
31            default_timeout_ms: 120_000,         // 2 分钟
32            max_timeout_ms: 600_000,             // 10 分钟
33            graceful_shutdown_timeout_ms: 5_000, // 5 秒
34        }
35    }
36}
37
38/// 超时句柄
39#[derive(Debug, Clone)]
40pub struct TimeoutHandle {
41    pub id: String,
42    pub start_time: i64,
43    pub duration_ms: u64,
44    pub cancelled: bool,
45}
46
47/// 超时管理器
48pub struct TimeoutManager {
49    timeouts: Arc<RwLock<HashMap<String, TimeoutHandle>>>,
50    config: TimeoutConfig,
51    on_timeout: Option<TimeoutCallback>,
52}
53
54impl TimeoutManager {
55    /// 创建新的超时管理器
56    pub fn new(config: TimeoutConfig) -> Self {
57        Self {
58            timeouts: Arc::new(RwLock::new(HashMap::new())),
59            config,
60            on_timeout: None,
61        }
62    }
63
64    /// 设置超时回调
65    pub fn set_on_timeout<F>(&mut self, callback: F)
66    where
67        F: Fn(&str) + Send + Sync + 'static,
68    {
69        self.on_timeout = Some(Arc::new(callback));
70    }
71
72    /// 设置超时
73    pub async fn set_timeout<F>(
74        &self,
75        id: &str,
76        callback: F,
77        duration_ms: Option<u64>,
78    ) -> TimeoutHandle
79    where
80        F: FnOnce() + Send + 'static,
81    {
82        // 清除已存在的超时
83        self.clear_timeout(id).await;
84
85        let actual_duration = duration_ms
86            .unwrap_or(self.config.default_timeout_ms)
87            .min(self.config.max_timeout_ms);
88
89        let handle = TimeoutHandle {
90            id: id.to_string(),
91            start_time: chrono::Utc::now().timestamp_millis(),
92            duration_ms: actual_duration,
93            cancelled: false,
94        };
95
96        self.timeouts
97            .write()
98            .await
99            .insert(id.to_string(), handle.clone());
100
101        // 启动超时任务
102        let timeouts = Arc::clone(&self.timeouts);
103        let id_clone = id.to_string();
104        let on_timeout = self.on_timeout.clone();
105
106        tokio::spawn(async move {
107            sleep(Duration::from_millis(actual_duration)).await;
108
109            let mut guard = timeouts.write().await;
110            if let Some(h) = guard.get(&id_clone) {
111                if !h.cancelled {
112                    if let Some(cb) = on_timeout {
113                        cb(&id_clone);
114                    }
115                    callback();
116                    guard.remove(&id_clone);
117                }
118            }
119        });
120
121        handle
122    }
123
124    /// 清除超时
125    pub async fn clear_timeout(&self, id: &str) -> bool {
126        let mut timeouts = self.timeouts.write().await;
127        if let Some(handle) = timeouts.get_mut(id) {
128            handle.cancelled = true;
129            timeouts.remove(id);
130            true
131        } else {
132            false
133        }
134    }
135
136    /// 获取剩余时间
137    pub async fn get_remaining_time(&self, id: &str) -> Option<u64> {
138        let timeouts = self.timeouts.read().await;
139        if let Some(handle) = timeouts.get(id) {
140            let elapsed = (chrono::Utc::now().timestamp_millis() - handle.start_time) as u64;
141            Some(handle.duration_ms.saturating_sub(elapsed))
142        } else {
143            None
144        }
145    }
146
147    /// 检查是否已超时
148    pub async fn is_timed_out(&self, id: &str) -> bool {
149        !self.timeouts.read().await.contains_key(id)
150    }
151
152    /// 重置超时
153    pub async fn reset_timeout(&self, id: &str) -> bool {
154        let mut timeouts = self.timeouts.write().await;
155        if let Some(handle) = timeouts.get_mut(id) {
156            handle.start_time = chrono::Utc::now().timestamp_millis();
157            true
158        } else {
159            false
160        }
161    }
162
163    /// 延长超时时间
164    pub async fn extend_timeout(&self, id: &str, additional_ms: u64) -> bool {
165        let mut timeouts = self.timeouts.write().await;
166        if let Some(handle) = timeouts.get_mut(id) {
167            let new_duration = (handle.duration_ms + additional_ms).min(self.config.max_timeout_ms);
168            handle.duration_ms = new_duration;
169            true
170        } else {
171            false
172        }
173    }
174
175    /// 获取所有超时信息
176    pub async fn get_all_timeouts(&self) -> Vec<TimeoutHandle> {
177        self.timeouts.read().await.values().cloned().collect()
178    }
179
180    /// 清除所有超时
181    pub async fn clear_all(&self) -> usize {
182        let mut timeouts = self.timeouts.write().await;
183        let count = timeouts.len();
184        for handle in timeouts.values_mut() {
185            handle.cancelled = true;
186        }
187        timeouts.clear();
188        count
189    }
190
191    /// 获取统计信息
192    pub async fn get_stats(&self) -> TimeoutStats {
193        TimeoutStats {
194            total: self.timeouts.read().await.len(),
195            default_timeout_ms: self.config.default_timeout_ms,
196            max_timeout_ms: self.config.max_timeout_ms,
197            graceful_shutdown_timeout_ms: self.config.graceful_shutdown_timeout_ms,
198        }
199    }
200}
201
202/// 带超时的 Promise
203pub async fn promise_with_timeout<T, F>(
204    future: F,
205    timeout_ms: u64,
206    timeout_error: Option<&str>,
207) -> Result<T, String>
208where
209    F: std::future::Future<Output = T>,
210{
211    match tokio::time::timeout(Duration::from_millis(timeout_ms), future).await {
212        Ok(result) => Ok(result),
213        Err(_) => Err(timeout_error.unwrap_or("Operation timed out").to_string()),
214    }
215}
216
217/// 可取消的延迟
218pub struct CancellableDelay {
219    duration_ms: u64,
220    cancelled: Arc<RwLock<bool>>,
221}
222
223impl CancellableDelay {
224    /// 创建新的可取消延迟
225    pub fn new(duration_ms: u64) -> Self {
226        Self {
227            duration_ms,
228            cancelled: Arc::new(RwLock::new(false)),
229        }
230    }
231
232    /// 开始延迟
233    pub async fn start(&self) -> Result<(), ()> {
234        let cancelled = Arc::clone(&self.cancelled);
235        let duration = Duration::from_millis(self.duration_ms);
236
237        tokio::select! {
238            _ = sleep(duration) => {
239                if *cancelled.read().await {
240                    Err(())
241                } else {
242                    Ok(())
243                }
244            }
245        }
246    }
247
248    /// 取消延迟
249    pub async fn cancel(&self) {
250        *self.cancelled.write().await = true;
251    }
252}