aster/agents/error_handling/
timeout_handler.rs1use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{broadcast, RwLock};
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "snake_case")]
18pub enum TimeoutStatus {
19 #[default]
21 Running,
22 Warning,
24 TimedOut,
26 Completed,
28 Cancelled,
30}
31
32impl std::fmt::Display for TimeoutStatus {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 match self {
35 TimeoutStatus::Running => write!(f, "running"),
36 TimeoutStatus::Warning => write!(f, "warning"),
37 TimeoutStatus::TimedOut => write!(f, "timed_out"),
38 TimeoutStatus::Completed => write!(f, "completed"),
39 TimeoutStatus::Cancelled => write!(f, "cancelled"),
40 }
41 }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct TimeoutConfig {
48 pub timeout: Duration,
50 pub warning_threshold: f64,
52 pub emit_events: bool,
54 pub grace_period: Option<Duration>,
56}
57
58impl Default for TimeoutConfig {
59 fn default() -> Self {
60 Self {
61 timeout: Duration::from_secs(300), warning_threshold: 0.8,
63 emit_events: true,
64 grace_period: Some(Duration::from_secs(10)),
65 }
66 }
67}
68
69impl TimeoutConfig {
70 pub fn new(timeout: Duration) -> Self {
72 Self {
73 timeout,
74 ..Default::default()
75 }
76 }
77
78 pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
80 self.warning_threshold = threshold.clamp(0.0, 1.0);
81 self
82 }
83
84 pub fn with_emit_events(mut self, emit: bool) -> Self {
86 self.emit_events = emit;
87 self
88 }
89
90 pub fn with_grace_period(mut self, grace: Duration) -> Self {
92 self.grace_period = Some(grace);
93 self
94 }
95
96 pub fn warning_duration(&self) -> Duration {
98 Duration::from_secs_f64(self.timeout.as_secs_f64() * self.warning_threshold)
99 }
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(rename_all = "camelCase")]
105pub struct TimeoutEvent {
106 pub agent_id: String,
108 pub previous_status: TimeoutStatus,
110 pub new_status: TimeoutStatus,
112 pub elapsed: Duration,
114 pub timeout: Duration,
116 pub timestamp: DateTime<Utc>,
118 pub message: Option<String>,
120}
121
122impl TimeoutEvent {
123 pub fn new(
125 agent_id: impl Into<String>,
126 previous_status: TimeoutStatus,
127 new_status: TimeoutStatus,
128 elapsed: Duration,
129 timeout: Duration,
130 ) -> Self {
131 Self {
132 agent_id: agent_id.into(),
133 previous_status,
134 new_status,
135 elapsed,
136 timeout,
137 timestamp: Utc::now(),
138 message: None,
139 }
140 }
141
142 pub fn with_message(mut self, message: impl Into<String>) -> Self {
144 self.message = Some(message.into());
145 self
146 }
147
148 pub fn is_timeout(&self) -> bool {
150 self.new_status == TimeoutStatus::TimedOut
151 }
152
153 pub fn is_warning(&self) -> bool {
155 self.new_status == TimeoutStatus::Warning
156 }
157}
158
159#[derive(Debug, Clone)]
161#[allow(dead_code)]
162struct TrackedAgent {
163 agent_id: String,
164 config: TimeoutConfig,
165 start_time: DateTime<Utc>,
166 status: TimeoutStatus,
167 warning_emitted: bool,
168}
169
170impl TrackedAgent {
171 fn new(agent_id: impl Into<String>, config: TimeoutConfig) -> Self {
172 Self {
173 agent_id: agent_id.into(),
174 config,
175 start_time: Utc::now(),
176 status: TimeoutStatus::Running,
177 warning_emitted: false,
178 }
179 }
180
181 fn elapsed(&self) -> Duration {
182 let elapsed = Utc::now().signed_duration_since(self.start_time);
183 elapsed.to_std().unwrap_or(Duration::ZERO)
184 }
185
186 fn is_timed_out(&self) -> bool {
187 self.elapsed() > self.config.timeout
188 }
189
190 fn is_warning(&self) -> bool {
191 let elapsed = self.elapsed();
192 elapsed > self.config.warning_duration() && elapsed <= self.config.timeout
193 }
194}
195
196#[derive(Debug)]
198pub struct TimeoutHandler {
199 agents: HashMap<String, TrackedAgent>,
201 event_sender: broadcast::Sender<TimeoutEvent>,
203 default_config: TimeoutConfig,
205}
206
207impl Default for TimeoutHandler {
208 fn default() -> Self {
209 Self::new()
210 }
211}
212
213impl TimeoutHandler {
214 pub fn new() -> Self {
216 let (event_sender, _) = broadcast::channel(100);
217 Self {
218 agents: HashMap::new(),
219 event_sender,
220 default_config: TimeoutConfig::default(),
221 }
222 }
223
224 pub fn with_default_config(config: TimeoutConfig) -> Self {
226 let (event_sender, _) = broadcast::channel(100);
227 Self {
228 agents: HashMap::new(),
229 event_sender,
230 default_config: config,
231 }
232 }
233
234 pub fn start_tracking(&mut self, agent_id: &str) {
236 self.start_tracking_with_config(agent_id, self.default_config.clone());
237 }
238
239 pub fn start_tracking_with_config(&mut self, agent_id: &str, config: TimeoutConfig) {
241 let agent = TrackedAgent::new(agent_id, config);
242 self.agents.insert(agent_id.to_string(), agent);
243 }
244
245 pub fn stop_tracking(&mut self, agent_id: &str, completed: bool) -> Option<TimeoutEvent> {
247 if let Some(agent) = self.agents.remove(agent_id) {
248 let previous_status = agent.status;
249 let new_status = if completed {
250 TimeoutStatus::Completed
251 } else {
252 TimeoutStatus::Cancelled
253 };
254
255 if agent.config.emit_events && previous_status != new_status {
256 let event = TimeoutEvent::new(
257 agent_id,
258 previous_status,
259 new_status,
260 agent.elapsed(),
261 agent.config.timeout,
262 );
263 let _ = self.event_sender.send(event.clone());
264 return Some(event);
265 }
266 }
267 None
268 }
269
270 pub fn check_status(&mut self, agent_id: &str) -> Option<TimeoutStatus> {
272 let agent = self.agents.get_mut(agent_id)?;
273
274 let previous_status = agent.status;
275
276 if agent.is_timed_out() {
277 agent.status = TimeoutStatus::TimedOut;
278 } else if agent.is_warning() && !agent.warning_emitted {
279 agent.status = TimeoutStatus::Warning;
280 agent.warning_emitted = true;
281 }
282
283 if agent.config.emit_events && agent.status != previous_status {
285 let event = TimeoutEvent::new(
286 agent_id,
287 previous_status,
288 agent.status,
289 agent.elapsed(),
290 agent.config.timeout,
291 );
292 let _ = self.event_sender.send(event);
293 }
294
295 Some(agent.status)
296 }
297
298 pub fn check_all(&mut self) -> Vec<TimeoutEvent> {
300 let mut events = Vec::new();
301 let agent_ids: Vec<_> = self.agents.keys().cloned().collect();
302
303 for agent_id in agent_ids {
304 if let Some(agent) = self.agents.get_mut(&agent_id) {
305 let previous_status = agent.status;
306
307 if agent.is_timed_out() && agent.status != TimeoutStatus::TimedOut {
308 agent.status = TimeoutStatus::TimedOut;
309
310 if agent.config.emit_events {
311 let event = TimeoutEvent::new(
312 &agent_id,
313 previous_status,
314 TimeoutStatus::TimedOut,
315 agent.elapsed(),
316 agent.config.timeout,
317 )
318 .with_message(format!(
319 "Agent {} timed out after {:?}",
320 agent_id,
321 agent.elapsed()
322 ));
323 let _ = self.event_sender.send(event.clone());
324 events.push(event);
325 }
326 } else if agent.is_warning()
327 && !agent.warning_emitted
328 && agent.status == TimeoutStatus::Running
329 {
330 agent.status = TimeoutStatus::Warning;
331 agent.warning_emitted = true;
332
333 if agent.config.emit_events {
334 let event = TimeoutEvent::new(
335 &agent_id,
336 previous_status,
337 TimeoutStatus::Warning,
338 agent.elapsed(),
339 agent.config.timeout,
340 )
341 .with_message(format!(
342 "Agent {} approaching timeout ({:?} / {:?})",
343 agent_id,
344 agent.elapsed(),
345 agent.config.timeout
346 ));
347 let _ = self.event_sender.send(event.clone());
348 events.push(event);
349 }
350 }
351 }
352 }
353
354 events
355 }
356
357 pub fn mark_timed_out(&mut self, agent_id: &str) -> Option<TimeoutEvent> {
359 let agent = self.agents.get_mut(agent_id)?;
360
361 if agent.status == TimeoutStatus::TimedOut {
362 return None;
363 }
364
365 let previous_status = agent.status;
366 agent.status = TimeoutStatus::TimedOut;
367
368 if agent.config.emit_events {
369 let event = TimeoutEvent::new(
370 agent_id,
371 previous_status,
372 TimeoutStatus::TimedOut,
373 agent.elapsed(),
374 agent.config.timeout,
375 )
376 .with_message(format!("Agent {} manually marked as timed out", agent_id));
377 let _ = self.event_sender.send(event.clone());
378 return Some(event);
379 }
380
381 None
382 }
383
384 pub fn get_status(&self, agent_id: &str) -> Option<TimeoutStatus> {
386 self.agents.get(agent_id).map(|a| a.status)
387 }
388
389 pub fn get_elapsed(&self, agent_id: &str) -> Option<Duration> {
391 self.agents.get(agent_id).map(|a| a.elapsed())
392 }
393
394 pub fn get_remaining(&self, agent_id: &str) -> Option<Duration> {
396 self.agents.get(agent_id).map(|a| {
397 let elapsed = a.elapsed();
398 if elapsed >= a.config.timeout {
399 Duration::ZERO
400 } else {
401 a.config.timeout - elapsed
402 }
403 })
404 }
405
406 pub fn is_timed_out(&self, agent_id: &str) -> bool {
408 self.agents
409 .get(agent_id)
410 .map(|a| a.status == TimeoutStatus::TimedOut || a.is_timed_out())
411 .unwrap_or(false)
412 }
413
414 pub fn subscribe(&self) -> broadcast::Receiver<TimeoutEvent> {
416 self.event_sender.subscribe()
417 }
418
419 pub fn tracked_count(&self) -> usize {
421 self.agents.len()
422 }
423
424 pub fn get_timed_out_agents(&self) -> Vec<&str> {
426 self.agents
427 .iter()
428 .filter(|(_, a)| a.status == TimeoutStatus::TimedOut || a.is_timed_out())
429 .map(|(id, _)| id.as_str())
430 .collect()
431 }
432
433 pub fn clear(&mut self) {
435 self.agents.clear();
436 }
437
438 pub fn set_default_config(&mut self, config: TimeoutConfig) {
440 self.default_config = config;
441 }
442}
443
444#[allow(dead_code)]
446pub type SharedTimeoutHandler = Arc<RwLock<TimeoutHandler>>;
447
448#[allow(dead_code)]
450pub fn new_shared_timeout_handler() -> SharedTimeoutHandler {
451 Arc::new(RwLock::new(TimeoutHandler::new()))
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_timeout_config_default() {
460 let config = TimeoutConfig::default();
461 assert_eq!(config.timeout, Duration::from_secs(300));
462 assert!((config.warning_threshold - 0.8).abs() < 0.001);
463 assert!(config.emit_events);
464 }
465
466 #[test]
467 fn test_timeout_config_warning_duration() {
468 let config = TimeoutConfig::new(Duration::from_secs(100)).with_warning_threshold(0.8);
469 assert_eq!(config.warning_duration(), Duration::from_secs(80));
470 }
471
472 #[test]
473 fn test_timeout_event_creation() {
474 let event = TimeoutEvent::new(
475 "agent-1",
476 TimeoutStatus::Running,
477 TimeoutStatus::TimedOut,
478 Duration::from_secs(100),
479 Duration::from_secs(60),
480 );
481
482 assert_eq!(event.agent_id, "agent-1");
483 assert!(event.is_timeout());
484 assert!(!event.is_warning());
485 }
486
487 #[test]
488 fn test_timeout_handler_start_tracking() {
489 let mut handler = TimeoutHandler::new();
490 handler.start_tracking("agent-1");
491
492 assert_eq!(handler.tracked_count(), 1);
493 assert_eq!(handler.get_status("agent-1"), Some(TimeoutStatus::Running));
494 }
495
496 #[test]
497 fn test_timeout_handler_stop_tracking() {
498 let mut handler = TimeoutHandler::new();
499 handler.start_tracking("agent-1");
500
501 let event = handler.stop_tracking("agent-1", true);
502 assert!(event.is_some());
503 assert_eq!(handler.tracked_count(), 0);
504 }
505
506 #[test]
507 fn test_timeout_handler_mark_timed_out() {
508 let mut handler = TimeoutHandler::new();
509 handler.start_tracking("agent-1");
510
511 let event = handler.mark_timed_out("agent-1");
512 assert!(event.is_some());
513 assert!(handler.is_timed_out("agent-1"));
514 }
515
516 #[test]
517 fn test_timeout_handler_get_remaining() {
518 let mut handler = TimeoutHandler::new();
519 let config = TimeoutConfig::new(Duration::from_secs(100));
520 handler.start_tracking_with_config("agent-1", config);
521
522 let remaining = handler.get_remaining("agent-1");
523 assert!(remaining.is_some());
524 assert!(remaining.unwrap() > Duration::from_secs(99));
526 }
527
528 #[test]
529 fn test_timeout_status_display() {
530 assert_eq!(format!("{}", TimeoutStatus::Running), "running");
531 assert_eq!(format!("{}", TimeoutStatus::TimedOut), "timed_out");
532 assert_eq!(format!("{}", TimeoutStatus::Warning), "warning");
533 }
534}