use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use crate::error::{GraphError, Result};
use crate::node::{Node, NodeContext, NodeOutput};
#[derive(Debug, Clone, Default)]
pub enum OnTimeout {
#[default]
Fail,
Retry { max_attempts: usize },
Skip,
}
#[derive(Debug, Clone, Default)]
pub struct TimeoutPolicy {
pub run_timeout: Option<Duration>,
pub idle_timeout: Option<Duration>,
pub on_timeout: OnTimeout,
}
#[derive(Debug, Clone)]
pub struct ProgressHandle {
last_progress_ms: Arc<AtomicU64>,
}
impl ProgressHandle {
pub fn new() -> Self {
let now_ms = current_time_ms();
Self { last_progress_ms: Arc::new(AtomicU64::new(now_ms)) }
}
pub fn report_progress(&self) {
let now_ms = current_time_ms();
self.last_progress_ms.store(now_ms, Ordering::Release);
}
pub(crate) fn last_progress_ms(&self) -> u64 {
self.last_progress_ms.load(Ordering::Acquire)
}
}
impl Default for ProgressHandle {
fn default() -> Self {
Self::new()
}
}
pub async fn execute_with_timeout(
node: &dyn Node,
ctx: &NodeContext,
policy: &TimeoutPolicy,
) -> Result<NodeOutput> {
if policy.run_timeout.is_none() && policy.idle_timeout.is_none() {
return node.execute(ctx).await;
}
let mut attempts = 0;
loop {
attempts += 1;
let result = execute_once_with_timeout(node, ctx, policy).await;
match result {
Ok(output) => return Ok(output),
Err(GraphError::NodeTimedOut { ref node, ref elapsed }) => {
match &policy.on_timeout {
OnTimeout::Fail => {
tracing::warn!(
node = %node,
elapsed_ms = elapsed.as_millis(),
action = "fail",
"node timed out, failing execution"
);
return result;
}
OnTimeout::Retry { max_attempts } => {
if attempts >= *max_attempts {
tracing::warn!(
node = %node,
elapsed_ms = elapsed.as_millis(),
attempts = attempts,
action = "fail_after_retries",
"node timed out after all retry attempts exhausted"
);
return result;
}
tracing::warn!(
node = %node,
elapsed_ms = elapsed.as_millis(),
attempt = attempts,
max_attempts = *max_attempts,
action = "retry",
"node timed out, retrying"
);
}
OnTimeout::Skip => {
tracing::warn!(
node = %node,
elapsed_ms = elapsed.as_millis(),
action = "skip",
"node timed out, skipping with empty output"
);
return Ok(NodeOutput::new());
}
}
}
Err(other) => return Err(other),
}
}
}
async fn execute_once_with_timeout(
node: &dyn Node,
ctx: &NodeContext,
policy: &TimeoutPolicy,
) -> Result<NodeOutput> {
let node_name = node.name().to_string();
let progress_handle = ProgressHandle::new();
let mut timeout_ctx = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
timeout_ctx.set_progress_handle(progress_handle.clone());
tokio::select! {
result = node.execute(&timeout_ctx) => {
result
}
elapsed = wait_for_run_timeout(policy.run_timeout) => {
Err(GraphError::NodeTimedOut {
node: node_name,
elapsed,
})
}
elapsed = wait_for_idle_timeout(policy.idle_timeout, &progress_handle) => {
Err(GraphError::NodeTimedOut {
node: node_name,
elapsed,
})
}
}
}
async fn wait_for_run_timeout(run_timeout: Option<Duration>) -> Duration {
match run_timeout {
Some(duration) => {
tokio::time::sleep(duration).await;
duration
}
None => {
std::future::pending::<()>().await;
unreachable!()
}
}
}
async fn wait_for_idle_timeout(
idle_timeout: Option<Duration>,
progress_handle: &ProgressHandle,
) -> Duration {
match idle_timeout {
Some(idle_duration) => {
let start_ms = current_time_ms();
let idle_ms = idle_duration.as_millis() as u64;
let poll_interval = Duration::from_millis(100);
loop {
tokio::time::sleep(poll_interval).await;
let now_ms = current_time_ms();
let last_progress = progress_handle.last_progress_ms();
let idle_elapsed = now_ms.saturating_sub(last_progress);
if idle_elapsed >= idle_ms {
let total_elapsed_ms = now_ms.saturating_sub(start_ms);
return Duration::from_millis(total_elapsed_ms);
}
}
}
None => {
std::future::pending::<()>().await;
unreachable!()
}
}
}
fn current_time_ms() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
#[cfg(test)]
mod tests {
use super::*;
use crate::node::{ExecutionConfig, FunctionNode, NodeContext, NodeOutput};
use crate::state::State;
#[tokio::test]
async fn test_no_timeout_executes_normally() {
let node = FunctionNode::new("fast", |_ctx| async {
Ok(NodeOutput::new().with_update("done", serde_json::json!(true)))
});
let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
let policy = TimeoutPolicy::default();
let result = execute_with_timeout(&node, &ctx, &policy).await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.updates.get("done"), Some(&serde_json::json!(true)));
}
#[tokio::test]
async fn test_run_timeout_fires_on_slow_node() {
let node = FunctionNode::new("slow", |_ctx| async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(NodeOutput::new())
});
let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
let policy = TimeoutPolicy {
run_timeout: Some(Duration::from_millis(100)),
idle_timeout: None,
on_timeout: OnTimeout::Fail,
};
let result = execute_with_timeout(&node, &ctx, &policy).await;
assert!(result.is_err());
match result {
Err(GraphError::NodeTimedOut { node, .. }) => {
assert_eq!(node, "slow");
}
Err(other) => panic!("expected NodeTimedOut, got: {other:?}"),
Ok(_) => panic!("expected error, got Ok"),
}
}
#[tokio::test]
async fn test_skip_returns_empty_output() {
let node = FunctionNode::new("slow", |_ctx| async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(NodeOutput::new().with_update("should_not_appear", serde_json::json!(true)))
});
let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
let policy = TimeoutPolicy {
run_timeout: Some(Duration::from_millis(50)),
idle_timeout: None,
on_timeout: OnTimeout::Skip,
};
let result = execute_with_timeout(&node, &ctx, &policy).await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.updates.is_empty());
}
#[tokio::test]
async fn test_retry_retries_up_to_max_attempts() {
use std::sync::atomic::AtomicUsize;
let attempt_count = Arc::new(AtomicUsize::new(0));
let count_clone = attempt_count.clone();
let node = FunctionNode::new("flaky", move |_ctx| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(NodeOutput::new())
}
});
let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
let policy = TimeoutPolicy {
run_timeout: Some(Duration::from_millis(50)),
idle_timeout: None,
on_timeout: OnTimeout::Retry { max_attempts: 3 },
};
let result = execute_with_timeout(&node, &ctx, &policy).await;
assert!(result.is_err());
assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_fast_node_with_timeout_succeeds() {
let node = FunctionNode::new("fast", |_ctx| async {
Ok(NodeOutput::new().with_update("value", serde_json::json!(42)))
});
let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
let policy = TimeoutPolicy {
run_timeout: Some(Duration::from_secs(5)),
idle_timeout: None,
on_timeout: OnTimeout::Fail,
};
let result = execute_with_timeout(&node, &ctx, &policy).await;
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.updates.get("value"), Some(&serde_json::json!(42)));
}
#[test]
fn test_progress_handle_updates_timestamp() {
let handle = ProgressHandle::new();
let initial = handle.last_progress_ms();
std::thread::sleep(Duration::from_millis(10));
handle.report_progress();
let updated = handle.last_progress_ms();
assert!(updated >= initial);
}
}