oximedia_graph/
async_exec.rs1#![allow(dead_code)]
20
21use std::sync::Arc;
22use std::time::Duration;
23
24use tokio::task::JoinSet;
25use tokio::time::timeout;
26
27use crate::scheduler::{ExecutionPlan, ExecutionStage, ParallelNodeResult, ParallelRunStats};
28
29#[derive(Debug, Clone)]
31pub struct AsyncExecutorConfig {
32 pub stage_timeout: Option<Duration>,
35 pub fail_on_stage_error: bool,
38}
39
40impl Default for AsyncExecutorConfig {
41 fn default() -> Self {
42 Self {
43 stage_timeout: Some(Duration::from_secs(30)),
44 fail_on_stage_error: false,
45 }
46 }
47}
48
49pub struct AsyncExecutor;
66
67impl AsyncExecutor {
68 pub async fn run<F, Fut>(
75 plan: &ExecutionPlan,
76 config: AsyncExecutorConfig,
77 executor: F,
78 ) -> (Vec<ParallelNodeResult>, ParallelRunStats)
79 where
80 F: Fn(String) -> Fut + Send + Sync + 'static,
81 Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
82 {
83 let mut all_results: Vec<ParallelNodeResult> = Vec::new();
84 let mut stats = ParallelRunStats::default();
85 let executor = Arc::new(executor);
86
87 'stages: for stage in &plan.stages {
88 if stage.nodes.is_empty() {
89 continue;
90 }
91 stats.stages_executed += 1;
92 stats.max_concurrency = stats.max_concurrency.max(stage.nodes.len());
93
94 let stage_results = Self::run_stage(stage, &config, Arc::clone(&executor)).await;
95 let failures = stage_results.iter().filter(|r| !r.success).count();
96 stats.nodes_executed += stage_results.len();
97 stats.failures += failures;
98
99 let abort = config.fail_on_stage_error && failures > 0;
100 all_results.extend(stage_results);
101
102 if abort {
103 break 'stages;
104 }
105 }
106
107 (all_results, stats)
108 }
109
110 async fn run_stage<F, Fut>(
112 stage: &ExecutionStage,
113 config: &AsyncExecutorConfig,
114 executor: Arc<F>,
115 ) -> Vec<ParallelNodeResult>
116 where
117 F: Fn(String) -> Fut + Send + Sync + 'static,
118 Fut: std::future::Future<Output = Result<(), String>> + Send + 'static,
119 {
120 let mut set: JoinSet<ParallelNodeResult> = JoinSet::new();
121
122 for node_id in &stage.nodes {
123 let exec = Arc::clone(&executor);
124 let nid = node_id.clone();
125 set.spawn(async move {
126 let result = exec(nid.clone()).await;
127 match result {
128 Ok(()) => ParallelNodeResult {
129 node_id: nid,
130 success: true,
131 elapsed: Duration::ZERO,
132 error: None,
133 },
134 Err(e) => ParallelNodeResult {
135 node_id: nid,
136 success: false,
137 elapsed: Duration::ZERO,
138 error: Some(e),
139 },
140 }
141 });
142 }
143
144 let collect_future = async {
145 let mut results = Vec::new();
146 while let Some(join_result) = set.join_next().await {
147 match join_result {
148 Ok(node_result) => results.push(node_result),
149 Err(e) => results.push(ParallelNodeResult {
150 node_id: "unknown".to_string(),
151 success: false,
152 elapsed: Duration::ZERO,
153 error: Some(format!("task panic: {e}")),
154 }),
155 }
156 }
157 results
158 };
159
160 match config.stage_timeout {
161 Some(dur) => timeout(dur, collect_future).await.unwrap_or_else(|_| {
162 stage
164 .nodes
165 .iter()
166 .map(|id| ParallelNodeResult {
167 node_id: id.clone(),
168 success: false,
169 elapsed: Duration::ZERO,
170 error: Some("stage timeout".to_string()),
171 })
172 .collect()
173 }),
174 None => collect_future.await,
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::scheduler::ExecutionStage;
183
184 fn make_plan(stages: Vec<Vec<&str>>) -> ExecutionPlan {
185 ExecutionPlan {
186 stages: stages
187 .into_iter()
188 .map(|nodes| ExecutionStage {
189 nodes: nodes.iter().map(|s| s.to_string()).collect(),
190 estimated_cpu_threads: nodes.len() as u32,
191 estimated_memory_mb: nodes.len() as u64 * 64,
192 })
193 .collect(),
194 }
195 }
196
197 #[tokio::test]
198 async fn test_async_all_succeed() {
199 let plan = make_plan(vec![vec!["a", "b"], vec!["c"]]);
200 let config = AsyncExecutorConfig::default();
201 let (results, stats) =
202 AsyncExecutor::run(&plan, config, |_node_id| async { Ok::<(), String>(()) }).await;
203 assert_eq!(results.len(), 3);
204 assert!(results.iter().all(|r| r.success));
205 assert_eq!(stats.nodes_executed, 3);
206 assert_eq!(stats.failures, 0);
207 }
208
209 #[tokio::test]
210 async fn test_async_partial_failure() {
211 let plan = make_plan(vec![vec!["ok1", "fail1", "ok2"]]);
212 let config = AsyncExecutorConfig::default();
213 let (results, stats) = AsyncExecutor::run(&plan, config, |node_id| async move {
214 if node_id == "fail1" {
215 Err("simulated failure".to_string())
216 } else {
217 Ok(())
218 }
219 })
220 .await;
221 assert_eq!(stats.failures, 1);
222 assert_eq!(results.iter().filter(|r| !r.success).count(), 1);
223 }
224
225 #[tokio::test]
226 async fn test_async_fail_on_stage_error_aborts_remaining() {
227 let plan = make_plan(vec![vec!["fail-node"], vec!["should-not-run"]]);
228 let config = AsyncExecutorConfig {
229 fail_on_stage_error: true,
230 stage_timeout: None,
231 };
232 let (results, stats) = AsyncExecutor::run(&plan, config, |node_id| async move {
233 if node_id == "fail-node" {
234 Err("stage error".to_string())
235 } else {
236 Ok(())
237 }
238 })
239 .await;
240 assert!(!results.iter().any(|r| r.node_id == "should-not-run"));
242 assert_eq!(stats.stages_executed, 1);
243 }
244
245 #[tokio::test]
246 async fn test_async_empty_plan() {
247 let plan = ExecutionPlan { stages: vec![] };
248 let config = AsyncExecutorConfig::default();
249 let (results, stats) =
250 AsyncExecutor::run(&plan, config, |_| async { Ok::<(), String>(()) }).await;
251 assert!(results.is_empty());
252 assert_eq!(stats.stages_executed, 0);
253 }
254}