Skip to main content

aster/agents/error_handling/
timeout_handler.rs

1//! Timeout Handler
2//!
3//! Provides timeout handling for agent execution.
4//! Marks agents as timed out and emits timeout events.
5//!
6//! **Validates: Requirements 15.2**
7
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::sync::{broadcast, RwLock};
14
15/// Timeout status for an agent
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
17#[serde(rename_all = "snake_case")]
18pub enum TimeoutStatus {
19    /// Agent is running normally
20    #[default]
21    Running,
22    /// Agent is approaching timeout (warning)
23    Warning,
24    /// Agent has timed out
25    TimedOut,
26    /// Agent completed before timeout
27    Completed,
28    /// Agent was cancelled
29    Cancelled,
30}
31
32impl std::fmt::Display for TimeoutStatus {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        match self {
35            TimeoutStatus::Running => write!(f, "running"),
36            TimeoutStatus::Warning => write!(f, "warning"),
37            TimeoutStatus::TimedOut => write!(f, "timed_out"),
38            TimeoutStatus::Completed => write!(f, "completed"),
39            TimeoutStatus::Cancelled => write!(f, "cancelled"),
40        }
41    }
42}
43
44/// Timeout configuration
45#[derive(Debug, Clone, Serialize, Deserialize)]
46#[serde(rename_all = "camelCase")]
47pub struct TimeoutConfig {
48    /// Maximum execution time
49    pub timeout: Duration,
50    /// Warning threshold (percentage of timeout, e.g., 0.8 = 80%)
51    pub warning_threshold: f64,
52    /// Whether to emit events
53    pub emit_events: bool,
54    /// Grace period after timeout before forced termination
55    pub grace_period: Option<Duration>,
56}
57
58impl Default for TimeoutConfig {
59    fn default() -> Self {
60        Self {
61            timeout: Duration::from_secs(300), // 5 minutes
62            warning_threshold: 0.8,
63            emit_events: true,
64            grace_period: Some(Duration::from_secs(10)),
65        }
66    }
67}
68
69impl TimeoutConfig {
70    /// Create a new timeout config
71    pub fn new(timeout: Duration) -> Self {
72        Self {
73            timeout,
74            ..Default::default()
75        }
76    }
77
78    /// Set the warning threshold
79    pub fn with_warning_threshold(mut self, threshold: f64) -> Self {
80        self.warning_threshold = threshold.clamp(0.0, 1.0);
81        self
82    }
83
84    /// Set whether to emit events
85    pub fn with_emit_events(mut self, emit: bool) -> Self {
86        self.emit_events = emit;
87        self
88    }
89
90    /// Set the grace period
91    pub fn with_grace_period(mut self, grace: Duration) -> Self {
92        self.grace_period = Some(grace);
93        self
94    }
95
96    /// Get the warning duration
97    pub fn warning_duration(&self) -> Duration {
98        Duration::from_secs_f64(self.timeout.as_secs_f64() * self.warning_threshold)
99    }
100}
101
102/// Timeout event emitted when timeout status changes
103#[derive(Debug, Clone, Serialize, Deserialize)]
104#[serde(rename_all = "camelCase")]
105pub struct TimeoutEvent {
106    /// Agent ID
107    pub agent_id: String,
108    /// Previous status
109    pub previous_status: TimeoutStatus,
110    /// New status
111    pub new_status: TimeoutStatus,
112    /// Elapsed time
113    pub elapsed: Duration,
114    /// Configured timeout
115    pub timeout: Duration,
116    /// Event timestamp
117    pub timestamp: DateTime<Utc>,
118    /// Additional message
119    pub message: Option<String>,
120}
121
122impl TimeoutEvent {
123    /// Create a new timeout event
124    pub fn new(
125        agent_id: impl Into<String>,
126        previous_status: TimeoutStatus,
127        new_status: TimeoutStatus,
128        elapsed: Duration,
129        timeout: Duration,
130    ) -> Self {
131        Self {
132            agent_id: agent_id.into(),
133            previous_status,
134            new_status,
135            elapsed,
136            timeout,
137            timestamp: Utc::now(),
138            message: None,
139        }
140    }
141
142    /// Set the message
143    pub fn with_message(mut self, message: impl Into<String>) -> Self {
144        self.message = Some(message.into());
145        self
146    }
147
148    /// Check if this is a timeout event
149    pub fn is_timeout(&self) -> bool {
150        self.new_status == TimeoutStatus::TimedOut
151    }
152
153    /// Check if this is a warning event
154    pub fn is_warning(&self) -> bool {
155        self.new_status == TimeoutStatus::Warning
156    }
157}
158
159/// Tracked agent information
160#[derive(Debug, Clone)]
161#[allow(dead_code)]
162struct TrackedAgent {
163    agent_id: String,
164    config: TimeoutConfig,
165    start_time: DateTime<Utc>,
166    status: TimeoutStatus,
167    warning_emitted: bool,
168}
169
170impl TrackedAgent {
171    fn new(agent_id: impl Into<String>, config: TimeoutConfig) -> Self {
172        Self {
173            agent_id: agent_id.into(),
174            config,
175            start_time: Utc::now(),
176            status: TimeoutStatus::Running,
177            warning_emitted: false,
178        }
179    }
180
181    fn elapsed(&self) -> Duration {
182        let elapsed = Utc::now().signed_duration_since(self.start_time);
183        elapsed.to_std().unwrap_or(Duration::ZERO)
184    }
185
186    fn is_timed_out(&self) -> bool {
187        self.elapsed() > self.config.timeout
188    }
189
190    fn is_warning(&self) -> bool {
191        let elapsed = self.elapsed();
192        elapsed > self.config.warning_duration() && elapsed <= self.config.timeout
193    }
194}
195
196/// Timeout handler for managing agent timeouts
197#[derive(Debug)]
198pub struct TimeoutHandler {
199    /// Tracked agents
200    agents: HashMap<String, TrackedAgent>,
201    /// Event sender
202    event_sender: broadcast::Sender<TimeoutEvent>,
203    /// Default configuration
204    default_config: TimeoutConfig,
205}
206
207impl Default for TimeoutHandler {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213impl TimeoutHandler {
214    /// Create a new timeout handler
215    pub fn new() -> Self {
216        let (event_sender, _) = broadcast::channel(100);
217        Self {
218            agents: HashMap::new(),
219            event_sender,
220            default_config: TimeoutConfig::default(),
221        }
222    }
223
224    /// Create with custom default configuration
225    pub fn with_default_config(config: TimeoutConfig) -> Self {
226        let (event_sender, _) = broadcast::channel(100);
227        Self {
228            agents: HashMap::new(),
229            event_sender,
230            default_config: config,
231        }
232    }
233
234    /// Start tracking an agent with default config
235    pub fn start_tracking(&mut self, agent_id: &str) {
236        self.start_tracking_with_config(agent_id, self.default_config.clone());
237    }
238
239    /// Start tracking an agent with custom config
240    pub fn start_tracking_with_config(&mut self, agent_id: &str, config: TimeoutConfig) {
241        let agent = TrackedAgent::new(agent_id, config);
242        self.agents.insert(agent_id.to_string(), agent);
243    }
244
245    /// Stop tracking an agent
246    pub fn stop_tracking(&mut self, agent_id: &str, completed: bool) -> Option<TimeoutEvent> {
247        if let Some(agent) = self.agents.remove(agent_id) {
248            let previous_status = agent.status;
249            let new_status = if completed {
250                TimeoutStatus::Completed
251            } else {
252                TimeoutStatus::Cancelled
253            };
254
255            if agent.config.emit_events && previous_status != new_status {
256                let event = TimeoutEvent::new(
257                    agent_id,
258                    previous_status,
259                    new_status,
260                    agent.elapsed(),
261                    agent.config.timeout,
262                );
263                let _ = self.event_sender.send(event.clone());
264                return Some(event);
265            }
266        }
267        None
268    }
269
270    /// Check timeout status for an agent
271    pub fn check_status(&mut self, agent_id: &str) -> Option<TimeoutStatus> {
272        let agent = self.agents.get_mut(agent_id)?;
273
274        let previous_status = agent.status;
275
276        if agent.is_timed_out() {
277            agent.status = TimeoutStatus::TimedOut;
278        } else if agent.is_warning() && !agent.warning_emitted {
279            agent.status = TimeoutStatus::Warning;
280            agent.warning_emitted = true;
281        }
282
283        // Emit event if status changed
284        if agent.config.emit_events && agent.status != previous_status {
285            let event = TimeoutEvent::new(
286                agent_id,
287                previous_status,
288                agent.status,
289                agent.elapsed(),
290                agent.config.timeout,
291            );
292            let _ = self.event_sender.send(event);
293        }
294
295        Some(agent.status)
296    }
297
298    /// Check all agents and return timed out ones
299    pub fn check_all(&mut self) -> Vec<TimeoutEvent> {
300        let mut events = Vec::new();
301        let agent_ids: Vec<_> = self.agents.keys().cloned().collect();
302
303        for agent_id in agent_ids {
304            if let Some(agent) = self.agents.get_mut(&agent_id) {
305                let previous_status = agent.status;
306
307                if agent.is_timed_out() && agent.status != TimeoutStatus::TimedOut {
308                    agent.status = TimeoutStatus::TimedOut;
309
310                    if agent.config.emit_events {
311                        let event = TimeoutEvent::new(
312                            &agent_id,
313                            previous_status,
314                            TimeoutStatus::TimedOut,
315                            agent.elapsed(),
316                            agent.config.timeout,
317                        )
318                        .with_message(format!(
319                            "Agent {} timed out after {:?}",
320                            agent_id,
321                            agent.elapsed()
322                        ));
323                        let _ = self.event_sender.send(event.clone());
324                        events.push(event);
325                    }
326                } else if agent.is_warning()
327                    && !agent.warning_emitted
328                    && agent.status == TimeoutStatus::Running
329                {
330                    agent.status = TimeoutStatus::Warning;
331                    agent.warning_emitted = true;
332
333                    if agent.config.emit_events {
334                        let event = TimeoutEvent::new(
335                            &agent_id,
336                            previous_status,
337                            TimeoutStatus::Warning,
338                            agent.elapsed(),
339                            agent.config.timeout,
340                        )
341                        .with_message(format!(
342                            "Agent {} approaching timeout ({:?} / {:?})",
343                            agent_id,
344                            agent.elapsed(),
345                            agent.config.timeout
346                        ));
347                        let _ = self.event_sender.send(event.clone());
348                        events.push(event);
349                    }
350                }
351            }
352        }
353
354        events
355    }
356
357    /// Mark an agent as timed out
358    pub fn mark_timed_out(&mut self, agent_id: &str) -> Option<TimeoutEvent> {
359        let agent = self.agents.get_mut(agent_id)?;
360
361        if agent.status == TimeoutStatus::TimedOut {
362            return None;
363        }
364
365        let previous_status = agent.status;
366        agent.status = TimeoutStatus::TimedOut;
367
368        if agent.config.emit_events {
369            let event = TimeoutEvent::new(
370                agent_id,
371                previous_status,
372                TimeoutStatus::TimedOut,
373                agent.elapsed(),
374                agent.config.timeout,
375            )
376            .with_message(format!("Agent {} manually marked as timed out", agent_id));
377            let _ = self.event_sender.send(event.clone());
378            return Some(event);
379        }
380
381        None
382    }
383
384    /// Get the status of an agent
385    pub fn get_status(&self, agent_id: &str) -> Option<TimeoutStatus> {
386        self.agents.get(agent_id).map(|a| a.status)
387    }
388
389    /// Get elapsed time for an agent
390    pub fn get_elapsed(&self, agent_id: &str) -> Option<Duration> {
391        self.agents.get(agent_id).map(|a| a.elapsed())
392    }
393
394    /// Get remaining time for an agent
395    pub fn get_remaining(&self, agent_id: &str) -> Option<Duration> {
396        self.agents.get(agent_id).map(|a| {
397            let elapsed = a.elapsed();
398            if elapsed >= a.config.timeout {
399                Duration::ZERO
400            } else {
401                a.config.timeout - elapsed
402            }
403        })
404    }
405
406    /// Check if an agent is timed out
407    pub fn is_timed_out(&self, agent_id: &str) -> bool {
408        self.agents
409            .get(agent_id)
410            .map(|a| a.status == TimeoutStatus::TimedOut || a.is_timed_out())
411            .unwrap_or(false)
412    }
413
414    /// Subscribe to timeout events
415    pub fn subscribe(&self) -> broadcast::Receiver<TimeoutEvent> {
416        self.event_sender.subscribe()
417    }
418
419    /// Get the number of tracked agents
420    pub fn tracked_count(&self) -> usize {
421        self.agents.len()
422    }
423
424    /// Get all timed out agents
425    pub fn get_timed_out_agents(&self) -> Vec<&str> {
426        self.agents
427            .iter()
428            .filter(|(_, a)| a.status == TimeoutStatus::TimedOut || a.is_timed_out())
429            .map(|(id, _)| id.as_str())
430            .collect()
431    }
432
433    /// Clear all tracked agents
434    pub fn clear(&mut self) {
435        self.agents.clear();
436    }
437
438    /// Set default configuration
439    pub fn set_default_config(&mut self, config: TimeoutConfig) {
440        self.default_config = config;
441    }
442}
443
444/// Thread-safe timeout handler wrapper
445#[allow(dead_code)]
446pub type SharedTimeoutHandler = Arc<RwLock<TimeoutHandler>>;
447
448/// Create a new shared timeout handler
449#[allow(dead_code)]
450pub fn new_shared_timeout_handler() -> SharedTimeoutHandler {
451    Arc::new(RwLock::new(TimeoutHandler::new()))
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_timeout_config_default() {
460        let config = TimeoutConfig::default();
461        assert_eq!(config.timeout, Duration::from_secs(300));
462        assert!((config.warning_threshold - 0.8).abs() < 0.001);
463        assert!(config.emit_events);
464    }
465
466    #[test]
467    fn test_timeout_config_warning_duration() {
468        let config = TimeoutConfig::new(Duration::from_secs(100)).with_warning_threshold(0.8);
469        assert_eq!(config.warning_duration(), Duration::from_secs(80));
470    }
471
472    #[test]
473    fn test_timeout_event_creation() {
474        let event = TimeoutEvent::new(
475            "agent-1",
476            TimeoutStatus::Running,
477            TimeoutStatus::TimedOut,
478            Duration::from_secs(100),
479            Duration::from_secs(60),
480        );
481
482        assert_eq!(event.agent_id, "agent-1");
483        assert!(event.is_timeout());
484        assert!(!event.is_warning());
485    }
486
487    #[test]
488    fn test_timeout_handler_start_tracking() {
489        let mut handler = TimeoutHandler::new();
490        handler.start_tracking("agent-1");
491
492        assert_eq!(handler.tracked_count(), 1);
493        assert_eq!(handler.get_status("agent-1"), Some(TimeoutStatus::Running));
494    }
495
496    #[test]
497    fn test_timeout_handler_stop_tracking() {
498        let mut handler = TimeoutHandler::new();
499        handler.start_tracking("agent-1");
500
501        let event = handler.stop_tracking("agent-1", true);
502        assert!(event.is_some());
503        assert_eq!(handler.tracked_count(), 0);
504    }
505
506    #[test]
507    fn test_timeout_handler_mark_timed_out() {
508        let mut handler = TimeoutHandler::new();
509        handler.start_tracking("agent-1");
510
511        let event = handler.mark_timed_out("agent-1");
512        assert!(event.is_some());
513        assert!(handler.is_timed_out("agent-1"));
514    }
515
516    #[test]
517    fn test_timeout_handler_get_remaining() {
518        let mut handler = TimeoutHandler::new();
519        let config = TimeoutConfig::new(Duration::from_secs(100));
520        handler.start_tracking_with_config("agent-1", config);
521
522        let remaining = handler.get_remaining("agent-1");
523        assert!(remaining.is_some());
524        // Should be close to 100 seconds (minus small elapsed time)
525        assert!(remaining.unwrap() > Duration::from_secs(99));
526    }
527
528    #[test]
529    fn test_timeout_status_display() {
530        assert_eq!(format!("{}", TimeoutStatus::Running), "running");
531        assert_eq!(format!("{}", TimeoutStatus::TimedOut), "timed_out");
532        assert_eq!(format!("{}", TimeoutStatus::Warning), "warning");
533    }
534}