Skip to main content

oximedia_graph/
async_exec.rs

1//! Async graph execution mode using tokio tasks for I/O-bound source/sink nodes.
2//!
3//! This module provides an async executor that dispatches each graph node as an
4//! independent tokio task.  I/O-bound source and sink nodes benefit the most
5//! because they can yield the thread while waiting for data, allowing other
6//! nodes to make progress concurrently.
7//!
8//! # Design
9//!
10//! The [`AsyncExecutor`] maps an [`ExecutionPlan`] onto tokio tasks:
11//!
12//! - Each **stage** becomes a `tokio::task::JoinSet` that runs all nodes in
13//!   the stage concurrently.
14//! - Stages are still sequentially ordered, so data-flow dependencies are
15//!   respected.
16//! - An optional timeout per stage prevents I/O-bound nodes from stalling the
17//!   whole graph.
18
19#![allow(dead_code)]
20
21use std::sync::Arc;
22use std::time::Duration;
23
24use tokio::task::JoinSet;
25use tokio::time::timeout;
26
27use crate::scheduler::{ExecutionPlan, ExecutionStage, ParallelNodeResult, ParallelRunStats};
28
29/// Configuration for the async graph executor.
30#[derive(Debug, Clone)]
31pub struct AsyncExecutorConfig {
32    /// Optional per-stage timeout.  A stage that does not complete within this
33    /// duration is cancelled and all remaining nodes are marked as failed.
34    pub stage_timeout: Option<Duration>,
35    /// If `true`, a stage failure (even a partial one) aborts all subsequent
36    /// stages.
37    pub fail_on_stage_error: bool,
38}
39
40impl Default for AsyncExecutorConfig {
41    fn default() -> Self {
42        Self {
43            stage_timeout: Some(Duration::from_secs(30)),
44            fail_on_stage_error: false,
45        }
46    }
47}
48
49/// Async graph executor using tokio tasks.
50///
51/// # Example
52///
53/// ```rust,no_run
54/// use oximedia_graph::async_exec::{AsyncExecutor, AsyncExecutorConfig};
55/// use oximedia_graph::scheduler::ExecutionPlan;
56///
57/// async fn run_graph(plan: ExecutionPlan) {
58///     let config = AsyncExecutorConfig::default();
59///     let (results, stats) = AsyncExecutor::run(&plan, config, |node_id| async move {
60///         // Perform async I/O for this node.
61///         Ok::<(), String>(())
62///     }).await;
63/// }
64/// ```
65pub struct AsyncExecutor;
66
67impl AsyncExecutor {
68    /// Execute the plan asynchronously.
69    ///
70    /// `executor` is an async factory that receives a node ID (owned `String`)
71    /// and returns a future resolving to `Result<(), String>`.
72    ///
73    /// Returns the collected per-node results and aggregate statistics.
74    pub async fn run<F, Fut>(
75        plan: &ExecutionPlan,
76        config: AsyncExecutorConfig,
77        executor: F,
78    ) -> (Vec<ParallelNodeResult>, ParallelRunStats)
79    where
80        F: Fn(String) -> Fut + Send + Sync + 'static,
81        Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
82    {
83        let mut all_results: Vec<ParallelNodeResult> = Vec::new();
84        let mut stats = ParallelRunStats::default();
85        let executor = Arc::new(executor);
86
87        'stages: for stage in &plan.stages {
88            if stage.nodes.is_empty() {
89                continue;
90            }
91            stats.stages_executed += 1;
92            stats.max_concurrency = stats.max_concurrency.max(stage.nodes.len());
93
94            let stage_results = Self::run_stage(stage, &config, Arc::clone(&executor)).await;
95            let failures = stage_results.iter().filter(|r| !r.success).count();
96            stats.nodes_executed += stage_results.len();
97            stats.failures += failures;
98
99            let abort = config.fail_on_stage_error && failures > 0;
100            all_results.extend(stage_results);
101
102            if abort {
103                break 'stages;
104            }
105        }
106
107        (all_results, stats)
108    }
109
110    /// Execute a single stage: spawn all nodes as tokio tasks, with optional timeout.
111    async fn run_stage<F, Fut>(
112        stage: &ExecutionStage,
113        config: &AsyncExecutorConfig,
114        executor: Arc<F>,
115    ) -> Vec<ParallelNodeResult>
116    where
117        F: Fn(String) -> Fut + Send + Sync + 'static,
118        Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
119    {
120        let mut set: JoinSet<ParallelNodeResult> = JoinSet::new();
121
122        for node_id in &stage.nodes {
123            let exec = Arc::clone(&executor);
124            let nid = node_id.clone();
125            set.spawn(async move {
126                let result = exec(nid.clone()).await;
127                match result {
128                    Ok(()) => ParallelNodeResult {
129                        node_id: nid,
130                        success: true,
131                        elapsed: Duration::ZERO,
132                        error: None,
133                    },
134                    Err(e) => ParallelNodeResult {
135                        node_id: nid,
136                        success: false,
137                        elapsed: Duration::ZERO,
138                        error: Some(e),
139                    },
140                }
141            });
142        }
143
144        let collect_future = async {
145            let mut results = Vec::new();
146            while let Some(join_result) = set.join_next().await {
147                match join_result {
148                    Ok(node_result) => results.push(node_result),
149                    Err(e) => results.push(ParallelNodeResult {
150                        node_id: "unknown".to_string(),
151                        success: false,
152                        elapsed: Duration::ZERO,
153                        error: Some(format!("task panic: {e}")),
154                    }),
155                }
156            }
157            results
158        };
159
160        match config.stage_timeout {
161            Some(dur) => timeout(dur, collect_future).await.unwrap_or_else(|_| {
162                // Stage timed out: mark remaining nodes as failed.
163                stage
164                    .nodes
165                    .iter()
166                    .map(|id| ParallelNodeResult {
167                        node_id: id.clone(),
168                        success: false,
169                        elapsed: Duration::ZERO,
170                        error: Some("stage timeout".to_string()),
171                    })
172                    .collect()
173            }),
174            None => collect_future.await,
175        }
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::scheduler::ExecutionStage;
183
184    fn make_plan(stages: Vec<Vec<&str>>) -> ExecutionPlan {
185        ExecutionPlan {
186            stages: stages
187                .into_iter()
188                .map(|nodes| ExecutionStage {
189                    nodes: nodes.iter().map(|s| s.to_string()).collect(),
190                    estimated_cpu_threads: nodes.len() as u32,
191                    estimated_memory_mb: nodes.len() as u64 * 64,
192                })
193                .collect(),
194        }
195    }
196
197    #[tokio::test]
198    async fn test_async_all_succeed() {
199        let plan = make_plan(vec![vec!["a", "b"], vec!["c"]]);
200        let config = AsyncExecutorConfig::default();
201        let (results, stats) =
202            AsyncExecutor::run(&plan, config, |_node_id| async { Ok::<(), String>(()) }).await;
203        assert_eq!(results.len(), 3);
204        assert!(results.iter().all(|r| r.success));
205        assert_eq!(stats.nodes_executed, 3);
206        assert_eq!(stats.failures, 0);
207    }
208
209    #[tokio::test]
210    async fn test_async_partial_failure() {
211        let plan = make_plan(vec![vec!["ok1", "fail1", "ok2"]]);
212        let config = AsyncExecutorConfig::default();
213        let (results, stats) = AsyncExecutor::run(&plan, config, |node_id| async move {
214            if node_id == "fail1" {
215                Err("simulated failure".to_string())
216            } else {
217                Ok(())
218            }
219        })
220        .await;
221        assert_eq!(stats.failures, 1);
222        assert_eq!(results.iter().filter(|r| !r.success).count(), 1);
223    }
224
225    #[tokio::test]
226    async fn test_async_fail_on_stage_error_aborts_remaining() {
227        let plan = make_plan(vec![vec!["fail-node"], vec!["should-not-run"]]);
228        let config = AsyncExecutorConfig {
229            fail_on_stage_error: true,
230            stage_timeout: None,
231        };
232        let (results, stats) = AsyncExecutor::run(&plan, config, |node_id| async move {
233            if node_id == "fail-node" {
234                Err("stage error".to_string())
235            } else {
236                Ok(())
237            }
238        })
239        .await;
240        // Second stage should be skipped.
241        assert!(!results.iter().any(|r| r.node_id == "should-not-run"));
242        assert_eq!(stats.stages_executed, 1);
243    }
244
245    #[tokio::test]
246    async fn test_async_empty_plan() {
247        let plan = ExecutionPlan { stages: vec![] };
248        let config = AsyncExecutorConfig::default();
249        let (results, stats) =
250            AsyncExecutor::run(&plan, config, |_| async { Ok::<(), String>(()) }).await;
251        assert!(results.is_empty());
252        assert_eq!(stats.stages_executed, 0);
253    }
254}