Skip to main content

adk_graph/
timeout.rs

1//! Timeout enforcement for graph node execution.
2//!
3//! Provides wall-clock and idle timeout enforcement for individual graph nodes,
4//! with configurable recovery actions (fail, retry, skip).
5//!
6//! # Overview
7//!
8//! The [`TimeoutPolicy`] struct configures timeout behavior for a node:
9//! - `run_timeout`: Hard wall-clock limit from when the node starts executing.
10//! - `idle_timeout`: Resets each time [`report_progress()`] is called on the progress handle.
11//! - `on_timeout`: What to do when a timeout fires ([`OnTimeout`]).
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use std::time::Duration;
17//! use adk_graph::timeout::{TimeoutPolicy, OnTimeout};
18//!
19//! let policy = TimeoutPolicy {
20//!     run_timeout: Some(Duration::from_secs(30)),
21//!     idle_timeout: Some(Duration::from_secs(5)),
22//!     on_timeout: OnTimeout::Retry { max_attempts: 3 },
23//! };
24//! ```
25
26use std::sync::Arc;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::time::Duration;
29
30use crate::error::{GraphError, Result};
31use crate::node::{Node, NodeContext, NodeOutput};
32
33/// Recovery action when a node times out.
34#[derive(Debug, Clone, Default)]
35pub enum OnTimeout {
36    /// Fail the graph with `GraphError::NodeTimedOut`.
37    #[default]
38    Fail,
39    /// Retry the node up to `max_attempts` times before failing.
40    Retry { max_attempts: usize },
41    /// Skip the node and proceed with an empty output.
42    Skip,
43}
44
45/// Timeout configuration for a graph node.
46///
47/// # Example
48///
49/// ```rust,ignore
50/// use std::time::Duration;
51/// use adk_graph::timeout::{TimeoutPolicy, OnTimeout};
52///
53/// let policy = TimeoutPolicy {
54///     run_timeout: Some(Duration::from_secs(10)),
55///     idle_timeout: None,
56///     on_timeout: OnTimeout::Fail,
57/// };
58/// ```
59#[derive(Debug, Clone, Default)]
60pub struct TimeoutPolicy {
61    /// Hard wall-clock limit. Timer starts when node begins execution.
62    pub run_timeout: Option<Duration>,
63    /// Idle timeout: resets each time `report_progress()` is called.
64    pub idle_timeout: Option<Duration>,
65    /// Recovery action on timeout.
66    pub on_timeout: OnTimeout,
67}
68
69/// A shared progress handle that nodes can use to report progress,
70/// resetting the idle timeout counter.
71///
72/// The handle stores the last progress timestamp as milliseconds since
73/// the UNIX epoch using an atomic u64 for lock-free updates.
74#[derive(Debug, Clone)]
75pub struct ProgressHandle {
76    last_progress_ms: Arc<AtomicU64>,
77}
78
79impl ProgressHandle {
80    /// Create a new progress handle initialized to the current time.
81    pub fn new() -> Self {
82        let now_ms = current_time_ms();
83        Self { last_progress_ms: Arc::new(AtomicU64::new(now_ms)) }
84    }
85
86    /// Report progress, resetting the idle timeout counter.
87    pub fn report_progress(&self) {
88        let now_ms = current_time_ms();
89        self.last_progress_ms.store(now_ms, Ordering::Release);
90    }
91
92    /// Get the last progress timestamp in milliseconds since epoch.
93    pub(crate) fn last_progress_ms(&self) -> u64 {
94        self.last_progress_ms.load(Ordering::Acquire)
95    }
96}
97
98impl Default for ProgressHandle {
99    fn default() -> Self {
100        Self::new()
101    }
102}
103
104/// Execute a node with timeout enforcement.
105///
106/// Uses `tokio::select!` to race node execution against configured timeouts.
107/// When a timeout fires, the configured [`OnTimeout`] recovery action is applied:
108///
109/// - [`OnTimeout::Fail`]: Returns `GraphError::NodeTimedOut`.
110/// - [`OnTimeout::Retry`]: Re-executes the node up to `max_attempts` times.
111/// - [`OnTimeout::Skip`]: Returns an empty [`NodeOutput`].
112///
113/// A `tracing::warn!` is emitted whenever a timeout triggers a recovery action.
114///
115/// # Arguments
116///
117/// * `node` - The node to execute.
118/// * `ctx` - The node execution context.
119/// * `policy` - The timeout policy to enforce.
120///
121/// # Example
122///
123/// ```rust,ignore
124/// use adk_graph::timeout::{execute_with_timeout, TimeoutPolicy, OnTimeout};
125/// use std::time::Duration;
126///
127/// let policy = TimeoutPolicy {
128///     run_timeout: Some(Duration::from_secs(5)),
129///     idle_timeout: None,
130///     on_timeout: OnTimeout::Fail,
131/// };
132///
133/// let result = execute_with_timeout(&my_node, &ctx, &policy).await;
134/// ```
135pub async fn execute_with_timeout(
136    node: &dyn Node,
137    ctx: &NodeContext,
138    policy: &TimeoutPolicy,
139) -> Result<NodeOutput> {
140    // If no timeouts are configured, execute directly
141    if policy.run_timeout.is_none() && policy.idle_timeout.is_none() {
142        return node.execute(ctx).await;
143    }
144
145    let mut attempts = 0;
146
147    loop {
148        attempts += 1;
149        let result = execute_once_with_timeout(node, ctx, policy).await;
150
151        match result {
152            Ok(output) => return Ok(output),
153            Err(GraphError::NodeTimedOut { ref node, ref elapsed }) => {
154                match &policy.on_timeout {
155                    OnTimeout::Fail => {
156                        tracing::warn!(
157                            node = %node,
158                            elapsed_ms = elapsed.as_millis(),
159                            action = "fail",
160                            "node timed out, failing execution"
161                        );
162                        return result;
163                    }
164                    OnTimeout::Retry { max_attempts } => {
165                        if attempts >= *max_attempts {
166                            tracing::warn!(
167                                node = %node,
168                                elapsed_ms = elapsed.as_millis(),
169                                attempts = attempts,
170                                action = "fail_after_retries",
171                                "node timed out after all retry attempts exhausted"
172                            );
173                            return result;
174                        }
175                        tracing::warn!(
176                            node = %node,
177                            elapsed_ms = elapsed.as_millis(),
178                            attempt = attempts,
179                            max_attempts = *max_attempts,
180                            action = "retry",
181                            "node timed out, retrying"
182                        );
183                        // Continue loop to retry
184                    }
185                    OnTimeout::Skip => {
186                        tracing::warn!(
187                            node = %node,
188                            elapsed_ms = elapsed.as_millis(),
189                            action = "skip",
190                            "node timed out, skipping with empty output"
191                        );
192                        return Ok(NodeOutput::new());
193                    }
194                }
195            }
196            Err(other) => return Err(other),
197        }
198    }
199}
200
201/// Execute a single attempt of a node with timeout enforcement.
202async fn execute_once_with_timeout(
203    node: &dyn Node,
204    ctx: &NodeContext,
205    policy: &TimeoutPolicy,
206) -> Result<NodeOutput> {
207    let node_name = node.name().to_string();
208    let progress_handle = ProgressHandle::new();
209
210    // Build a context with the progress handle attached so the node can
211    // call `report_progress()` to reset the idle timeout.
212    let mut timeout_ctx = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
213    timeout_ctx.set_progress_handle(progress_handle.clone());
214
215    tokio::select! {
216        result = node.execute(&timeout_ctx) => {
217            result
218        }
219        elapsed = wait_for_run_timeout(policy.run_timeout) => {
220            Err(GraphError::NodeTimedOut {
221                node: node_name,
222                elapsed,
223            })
224        }
225        elapsed = wait_for_idle_timeout(policy.idle_timeout, &progress_handle) => {
226            Err(GraphError::NodeTimedOut {
227                node: node_name,
228                elapsed,
229            })
230        }
231    }
232}
233
234/// Wait for the run timeout to expire. If no run timeout is configured,
235/// this future never completes (allowing the select to be driven by other branches).
236async fn wait_for_run_timeout(run_timeout: Option<Duration>) -> Duration {
237    match run_timeout {
238        Some(duration) => {
239            tokio::time::sleep(duration).await;
240            duration
241        }
242        None => {
243            // Never completes — effectively infinite timeout
244            std::future::pending::<()>().await;
245            unreachable!()
246        }
247    }
248}
249
250/// Poll for idle timeout expiry. Checks every 100ms whether the time since
251/// last progress exceeds the idle timeout. If no idle timeout is configured,
252/// this future never completes.
253async fn wait_for_idle_timeout(
254    idle_timeout: Option<Duration>,
255    progress_handle: &ProgressHandle,
256) -> Duration {
257    match idle_timeout {
258        Some(idle_duration) => {
259            let start_ms = current_time_ms();
260            let idle_ms = idle_duration.as_millis() as u64;
261            let poll_interval = Duration::from_millis(100);
262
263            loop {
264                tokio::time::sleep(poll_interval).await;
265                let now_ms = current_time_ms();
266                let last_progress = progress_handle.last_progress_ms();
267                let idle_elapsed = now_ms.saturating_sub(last_progress);
268
269                if idle_elapsed >= idle_ms {
270                    let total_elapsed_ms = now_ms.saturating_sub(start_ms);
271                    return Duration::from_millis(total_elapsed_ms);
272                }
273            }
274        }
275        None => {
276            // Never completes — effectively infinite timeout
277            std::future::pending::<()>().await;
278            unreachable!()
279        }
280    }
281}
282
283/// Get the current time in milliseconds since the UNIX epoch.
284fn current_time_ms() -> u64 {
285    std::time::SystemTime::now()
286        .duration_since(std::time::UNIX_EPOCH)
287        .unwrap_or_default()
288        .as_millis() as u64
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::node::{ExecutionConfig, FunctionNode, NodeContext, NodeOutput};
295    use crate::state::State;
296
297    #[tokio::test]
298    async fn test_no_timeout_executes_normally() {
299        let node = FunctionNode::new("fast", |_ctx| async {
300            Ok(NodeOutput::new().with_update("done", serde_json::json!(true)))
301        });
302
303        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
304        let policy = TimeoutPolicy::default();
305
306        let result = execute_with_timeout(&node, &ctx, &policy).await;
307        assert!(result.is_ok());
308        let output = result.unwrap();
309        assert_eq!(output.updates.get("done"), Some(&serde_json::json!(true)));
310    }
311
312    #[tokio::test]
313    async fn test_run_timeout_fires_on_slow_node() {
314        let node = FunctionNode::new("slow", |_ctx| async {
315            tokio::time::sleep(Duration::from_secs(10)).await;
316            Ok(NodeOutput::new())
317        });
318
319        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
320        let policy = TimeoutPolicy {
321            run_timeout: Some(Duration::from_millis(100)),
322            idle_timeout: None,
323            on_timeout: OnTimeout::Fail,
324        };
325
326        let result = execute_with_timeout(&node, &ctx, &policy).await;
327        assert!(result.is_err());
328        match result {
329            Err(GraphError::NodeTimedOut { node, .. }) => {
330                assert_eq!(node, "slow");
331            }
332            Err(other) => panic!("expected NodeTimedOut, got: {other:?}"),
333            Ok(_) => panic!("expected error, got Ok"),
334        }
335    }
336
337    #[tokio::test]
338    async fn test_skip_returns_empty_output() {
339        let node = FunctionNode::new("slow", |_ctx| async {
340            tokio::time::sleep(Duration::from_secs(10)).await;
341            Ok(NodeOutput::new().with_update("should_not_appear", serde_json::json!(true)))
342        });
343
344        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
345        let policy = TimeoutPolicy {
346            run_timeout: Some(Duration::from_millis(50)),
347            idle_timeout: None,
348            on_timeout: OnTimeout::Skip,
349        };
350
351        let result = execute_with_timeout(&node, &ctx, &policy).await;
352        assert!(result.is_ok());
353        let output = result.unwrap();
354        assert!(output.updates.is_empty());
355    }
356
357    #[tokio::test]
358    async fn test_retry_retries_up_to_max_attempts() {
359        use std::sync::atomic::AtomicUsize;
360
361        let attempt_count = Arc::new(AtomicUsize::new(0));
362        let count_clone = attempt_count.clone();
363
364        let node = FunctionNode::new("flaky", move |_ctx| {
365            let count = count_clone.clone();
366            async move {
367                count.fetch_add(1, Ordering::SeqCst);
368                tokio::time::sleep(Duration::from_secs(10)).await;
369                Ok(NodeOutput::new())
370            }
371        });
372
373        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
374        let policy = TimeoutPolicy {
375            run_timeout: Some(Duration::from_millis(50)),
376            idle_timeout: None,
377            on_timeout: OnTimeout::Retry { max_attempts: 3 },
378        };
379
380        let result = execute_with_timeout(&node, &ctx, &policy).await;
381        assert!(result.is_err());
382        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
383    }
384
385    #[tokio::test]
386    async fn test_fast_node_with_timeout_succeeds() {
387        let node = FunctionNode::new("fast", |_ctx| async {
388            Ok(NodeOutput::new().with_update("value", serde_json::json!(42)))
389        });
390
391        let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
392        let policy = TimeoutPolicy {
393            run_timeout: Some(Duration::from_secs(5)),
394            idle_timeout: None,
395            on_timeout: OnTimeout::Fail,
396        };
397
398        let result = execute_with_timeout(&node, &ctx, &policy).await;
399        assert!(result.is_ok());
400        let output = result.unwrap();
401        assert_eq!(output.updates.get("value"), Some(&serde_json::json!(42)));
402    }
403
404    #[test]
405    fn test_progress_handle_updates_timestamp() {
406        let handle = ProgressHandle::new();
407        let initial = handle.last_progress_ms();
408
409        // Small sleep to ensure time advances
410        std::thread::sleep(Duration::from_millis(10));
411        handle.report_progress();
412
413        let updated = handle.last_progress_ms();
414        assert!(updated >= initial);
415    }
416}