1use std::sync::Arc;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::time::Duration;
29
30use crate::error::{GraphError, Result};
31use crate::node::{Node, NodeContext, NodeOutput};
32
33#[derive(Debug, Clone, Default)]
35pub enum OnTimeout {
36 #[default]
38 Fail,
39 Retry { max_attempts: usize },
41 Skip,
43}
44
45#[derive(Debug, Clone, Default)]
60pub struct TimeoutPolicy {
61 pub run_timeout: Option<Duration>,
63 pub idle_timeout: Option<Duration>,
65 pub on_timeout: OnTimeout,
67}
68
69#[derive(Debug, Clone)]
75pub struct ProgressHandle {
76 last_progress_ms: Arc<AtomicU64>,
77}
78
79impl ProgressHandle {
80 pub fn new() -> Self {
82 let now_ms = current_time_ms();
83 Self { last_progress_ms: Arc::new(AtomicU64::new(now_ms)) }
84 }
85
86 pub fn report_progress(&self) {
88 let now_ms = current_time_ms();
89 self.last_progress_ms.store(now_ms, Ordering::Release);
90 }
91
92 pub(crate) fn last_progress_ms(&self) -> u64 {
94 self.last_progress_ms.load(Ordering::Acquire)
95 }
96}
97
98impl Default for ProgressHandle {
99 fn default() -> Self {
100 Self::new()
101 }
102}
103
104pub async fn execute_with_timeout(
136 node: &dyn Node,
137 ctx: &NodeContext,
138 policy: &TimeoutPolicy,
139) -> Result<NodeOutput> {
140 if policy.run_timeout.is_none() && policy.idle_timeout.is_none() {
142 return node.execute(ctx).await;
143 }
144
145 let mut attempts = 0;
146
147 loop {
148 attempts += 1;
149 let result = execute_once_with_timeout(node, ctx, policy).await;
150
151 match result {
152 Ok(output) => return Ok(output),
153 Err(GraphError::NodeTimedOut { ref node, ref elapsed }) => {
154 match &policy.on_timeout {
155 OnTimeout::Fail => {
156 tracing::warn!(
157 node = %node,
158 elapsed_ms = elapsed.as_millis(),
159 action = "fail",
160 "node timed out, failing execution"
161 );
162 return result;
163 }
164 OnTimeout::Retry { max_attempts } => {
165 if attempts >= *max_attempts {
166 tracing::warn!(
167 node = %node,
168 elapsed_ms = elapsed.as_millis(),
169 attempts = attempts,
170 action = "fail_after_retries",
171 "node timed out after all retry attempts exhausted"
172 );
173 return result;
174 }
175 tracing::warn!(
176 node = %node,
177 elapsed_ms = elapsed.as_millis(),
178 attempt = attempts,
179 max_attempts = *max_attempts,
180 action = "retry",
181 "node timed out, retrying"
182 );
183 }
185 OnTimeout::Skip => {
186 tracing::warn!(
187 node = %node,
188 elapsed_ms = elapsed.as_millis(),
189 action = "skip",
190 "node timed out, skipping with empty output"
191 );
192 return Ok(NodeOutput::new());
193 }
194 }
195 }
196 Err(other) => return Err(other),
197 }
198 }
199}
200
201async fn execute_once_with_timeout(
203 node: &dyn Node,
204 ctx: &NodeContext,
205 policy: &TimeoutPolicy,
206) -> Result<NodeOutput> {
207 let node_name = node.name().to_string();
208 let progress_handle = ProgressHandle::new();
209
210 let mut timeout_ctx = NodeContext::new(ctx.state.clone(), ctx.config.clone(), ctx.step);
213 timeout_ctx.set_progress_handle(progress_handle.clone());
214
215 tokio::select! {
216 result = node.execute(&timeout_ctx) => {
217 result
218 }
219 elapsed = wait_for_run_timeout(policy.run_timeout) => {
220 Err(GraphError::NodeTimedOut {
221 node: node_name,
222 elapsed,
223 })
224 }
225 elapsed = wait_for_idle_timeout(policy.idle_timeout, &progress_handle) => {
226 Err(GraphError::NodeTimedOut {
227 node: node_name,
228 elapsed,
229 })
230 }
231 }
232}
233
234async fn wait_for_run_timeout(run_timeout: Option<Duration>) -> Duration {
237 match run_timeout {
238 Some(duration) => {
239 tokio::time::sleep(duration).await;
240 duration
241 }
242 None => {
243 std::future::pending::<()>().await;
245 unreachable!()
246 }
247 }
248}
249
250async fn wait_for_idle_timeout(
254 idle_timeout: Option<Duration>,
255 progress_handle: &ProgressHandle,
256) -> Duration {
257 match idle_timeout {
258 Some(idle_duration) => {
259 let start_ms = current_time_ms();
260 let idle_ms = idle_duration.as_millis() as u64;
261 let poll_interval = Duration::from_millis(100);
262
263 loop {
264 tokio::time::sleep(poll_interval).await;
265 let now_ms = current_time_ms();
266 let last_progress = progress_handle.last_progress_ms();
267 let idle_elapsed = now_ms.saturating_sub(last_progress);
268
269 if idle_elapsed >= idle_ms {
270 let total_elapsed_ms = now_ms.saturating_sub(start_ms);
271 return Duration::from_millis(total_elapsed_ms);
272 }
273 }
274 }
275 None => {
276 std::future::pending::<()>().await;
278 unreachable!()
279 }
280 }
281}
282
283fn current_time_ms() -> u64 {
285 std::time::SystemTime::now()
286 .duration_since(std::time::UNIX_EPOCH)
287 .unwrap_or_default()
288 .as_millis() as u64
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::node::{ExecutionConfig, FunctionNode, NodeContext, NodeOutput};
295 use crate::state::State;
296
297 #[tokio::test]
298 async fn test_no_timeout_executes_normally() {
299 let node = FunctionNode::new("fast", |_ctx| async {
300 Ok(NodeOutput::new().with_update("done", serde_json::json!(true)))
301 });
302
303 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
304 let policy = TimeoutPolicy::default();
305
306 let result = execute_with_timeout(&node, &ctx, &policy).await;
307 assert!(result.is_ok());
308 let output = result.unwrap();
309 assert_eq!(output.updates.get("done"), Some(&serde_json::json!(true)));
310 }
311
312 #[tokio::test]
313 async fn test_run_timeout_fires_on_slow_node() {
314 let node = FunctionNode::new("slow", |_ctx| async {
315 tokio::time::sleep(Duration::from_secs(10)).await;
316 Ok(NodeOutput::new())
317 });
318
319 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
320 let policy = TimeoutPolicy {
321 run_timeout: Some(Duration::from_millis(100)),
322 idle_timeout: None,
323 on_timeout: OnTimeout::Fail,
324 };
325
326 let result = execute_with_timeout(&node, &ctx, &policy).await;
327 assert!(result.is_err());
328 match result {
329 Err(GraphError::NodeTimedOut { node, .. }) => {
330 assert_eq!(node, "slow");
331 }
332 Err(other) => panic!("expected NodeTimedOut, got: {other:?}"),
333 Ok(_) => panic!("expected error, got Ok"),
334 }
335 }
336
337 #[tokio::test]
338 async fn test_skip_returns_empty_output() {
339 let node = FunctionNode::new("slow", |_ctx| async {
340 tokio::time::sleep(Duration::from_secs(10)).await;
341 Ok(NodeOutput::new().with_update("should_not_appear", serde_json::json!(true)))
342 });
343
344 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
345 let policy = TimeoutPolicy {
346 run_timeout: Some(Duration::from_millis(50)),
347 idle_timeout: None,
348 on_timeout: OnTimeout::Skip,
349 };
350
351 let result = execute_with_timeout(&node, &ctx, &policy).await;
352 assert!(result.is_ok());
353 let output = result.unwrap();
354 assert!(output.updates.is_empty());
355 }
356
357 #[tokio::test]
358 async fn test_retry_retries_up_to_max_attempts() {
359 use std::sync::atomic::AtomicUsize;
360
361 let attempt_count = Arc::new(AtomicUsize::new(0));
362 let count_clone = attempt_count.clone();
363
364 let node = FunctionNode::new("flaky", move |_ctx| {
365 let count = count_clone.clone();
366 async move {
367 count.fetch_add(1, Ordering::SeqCst);
368 tokio::time::sleep(Duration::from_secs(10)).await;
369 Ok(NodeOutput::new())
370 }
371 });
372
373 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
374 let policy = TimeoutPolicy {
375 run_timeout: Some(Duration::from_millis(50)),
376 idle_timeout: None,
377 on_timeout: OnTimeout::Retry { max_attempts: 3 },
378 };
379
380 let result = execute_with_timeout(&node, &ctx, &policy).await;
381 assert!(result.is_err());
382 assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
383 }
384
385 #[tokio::test]
386 async fn test_fast_node_with_timeout_succeeds() {
387 let node = FunctionNode::new("fast", |_ctx| async {
388 Ok(NodeOutput::new().with_update("value", serde_json::json!(42)))
389 });
390
391 let ctx = NodeContext::new(State::new(), ExecutionConfig::default(), 0);
392 let policy = TimeoutPolicy {
393 run_timeout: Some(Duration::from_secs(5)),
394 idle_timeout: None,
395 on_timeout: OnTimeout::Fail,
396 };
397
398 let result = execute_with_timeout(&node, &ctx, &policy).await;
399 assert!(result.is_ok());
400 let output = result.unwrap();
401 assert_eq!(output.updates.get("value"), Some(&serde_json::json!(42)));
402 }
403
404 #[test]
405 fn test_progress_handle_updates_timestamp() {
406 let handle = ProgressHandle::new();
407 let initial = handle.last_progress_ms();
408
409 std::thread::sleep(Duration::from_millis(10));
411 handle.report_progress();
412
413 let updated = handle.last_progress_ms();
414 assert!(updated >= initial);
415 }
416}