forge_core/workflow/
parallel.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Serialize, de::DeserializeOwned};
7
8use super::CompensationHandler;
9use super::context::WorkflowContext;
10use crate::{ForgeError, Result};
11
12type ParallelStepHandler =
14 Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + 'static>>;
15
16struct ParallelStep {
18 name: String,
19 handler: ParallelStepHandler,
20 compensate: Option<CompensationHandler>,
21}
22
23pub struct ParallelBuilder<'a> {
25 ctx: &'a WorkflowContext,
26 steps: Vec<ParallelStep>,
27}
28
29impl<'a> ParallelBuilder<'a> {
30 pub fn new(ctx: &'a WorkflowContext) -> Self {
32 Self {
33 ctx,
34 steps: Vec::new(),
35 }
36 }
37
38 pub fn step<T, F, Fut>(mut self, name: &str, handler: F) -> Self
40 where
41 T: Serialize + Send + 'static,
42 F: FnOnce() -> Fut + Send + 'static,
43 Fut: Future<Output = Result<T>> + Send + 'static,
44 {
45 let step_handler: ParallelStepHandler = Box::pin(async move {
46 let result = handler().await?;
47 serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
48 });
49
50 self.steps.push(ParallelStep {
51 name: name.to_string(),
52 handler: step_handler,
53 compensate: None,
54 });
55
56 self
57 }
58
59 pub fn step_with_compensate<T, F, Fut, C, CFut>(
61 mut self,
62 name: &str,
63 handler: F,
64 compensate: C,
65 ) -> Self
66 where
67 T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
68 F: FnOnce() -> Fut + Send + 'static,
69 Fut: Future<Output = Result<T>> + Send + 'static,
70 C: Fn(T) -> CFut + Send + Sync + 'static,
71 CFut: Future<Output = Result<()>> + Send + 'static,
72 {
73 let step_handler: ParallelStepHandler = Box::pin(async move {
74 let result = handler().await?;
75 serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
76 });
77
78 let compensation: CompensationHandler = Arc::new(move |value: serde_json::Value| {
79 let result: std::result::Result<T, _> = serde_json::from_value(value);
80 match result {
81 Ok(typed_value) => Box::pin(compensate(typed_value))
82 as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
83 Err(e) => Box::pin(async move {
84 Err(ForgeError::Deserialization(format!(
85 "Failed to deserialize compensation value: {}",
86 e
87 )))
88 }) as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
89 }
90 });
91
92 self.steps.push(ParallelStep {
93 name: name.to_string(),
94 handler: step_handler,
95 compensate: Some(compensation),
96 });
97
98 self
99 }
100
101 pub async fn run(self) -> Result<ParallelResults> {
103 let mut results = ParallelResults::new();
104 let mut compensation_handlers: Vec<(String, CompensationHandler)> = Vec::new();
105 let mut pending_steps = Vec::new();
106
107 for step in self.steps {
109 if let Some(cached) = self.ctx.get_step_result::<serde_json::Value>(&step.name) {
110 results.insert(step.name.clone(), cached);
111 } else {
112 pending_steps.push(step);
113 }
114 }
115
116 if pending_steps.is_empty() {
118 return Ok(results);
119 }
120
121 for step in &pending_steps {
123 self.ctx.record_step_start(&step.name);
124 }
125
126 type StepResult = (
128 String,
129 Result<serde_json::Value>,
130 Option<CompensationHandler>,
131 );
132
133 let handles: Vec<tokio::task::JoinHandle<StepResult>> = pending_steps
134 .into_iter()
135 .map(|step| {
136 let name = step.name;
137 let handler = step.handler;
138 let compensate = step.compensate;
139 tokio::spawn(async move {
140 let result = handler.await;
141 (name, result, compensate)
142 })
143 })
144 .collect();
145
146 let mut step_results = Vec::with_capacity(handles.len());
148 for handle in handles {
149 step_results.push(handle.await);
150 }
151 let mut failed = false;
152 let mut first_error: Option<ForgeError> = None;
153
154 for join_result in step_results {
155 let (name, result, compensate): StepResult =
156 join_result.map_err(|e| ForgeError::Internal(format!("Task join error: {}", e)))?;
157
158 match result {
159 Ok(value) => {
160 self.ctx.record_step_complete(&name, value.clone());
161 results.insert(name.clone(), value);
162 if let Some(comp) = compensate {
163 compensation_handlers.push((name, comp));
164 }
165 }
166 Err(e) => {
167 self.ctx.record_step_failure(&name, e.to_string());
168 failed = true;
169 if first_error.is_none() {
170 first_error = Some(e);
171 }
172 }
173 }
174 }
175
176 if failed {
178 for (name, handler) in compensation_handlers.into_iter().rev() {
179 self.ctx.register_compensation(&name, handler);
180 }
181 self.ctx.run_compensation().await;
182 return Err(first_error.expect("failed flag set implies at least one error"));
183 }
184
185 Ok(results)
186 }
187}
188
189#[derive(Debug, Clone, Default)]
191pub struct ParallelResults {
192 inner: HashMap<String, serde_json::Value>,
193}
194
195impl ParallelResults {
196 pub fn new() -> Self {
198 Self {
199 inner: HashMap::new(),
200 }
201 }
202
203 pub fn insert(&mut self, step_name: String, value: serde_json::Value) {
205 self.inner.insert(step_name, value);
206 }
207
208 pub fn get<T: DeserializeOwned>(&self, step_name: &str) -> Result<T> {
210 let value = self
211 .inner
212 .get(step_name)
213 .ok_or_else(|| ForgeError::NotFound(format!("Step '{}' not found", step_name)))?;
214 serde_json::from_value(value.clone())
215 .map_err(|e| ForgeError::Deserialization(e.to_string()))
216 }
217
218 pub fn contains(&self, step_name: &str) -> bool {
220 self.inner.contains_key(step_name)
221 }
222
223 pub fn len(&self) -> usize {
225 self.inner.len()
226 }
227
228 pub fn is_empty(&self) -> bool {
230 self.inner.is_empty()
231 }
232
233 pub fn iter(&self) -> impl Iterator<Item = (&String, &serde_json::Value)> {
235 self.inner.iter()
236 }
237}
238
239#[cfg(test)]
240#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_parallel_results() {
246 let mut results = ParallelResults::new();
247 results.insert("step1".to_string(), serde_json::json!({"value": 42}));
248 results.insert("step2".to_string(), serde_json::json!("hello"));
249
250 assert!(results.contains("step1"));
251 assert!(results.contains("step2"));
252 assert!(!results.contains("step3"));
253 assert_eq!(results.len(), 2);
254
255 #[derive(Debug, serde::Deserialize, PartialEq)]
256 struct StepResult {
257 value: i32,
258 }
259
260 let step1: StepResult = results.get("step1").unwrap();
261 assert_eq!(step1.value, 42);
262
263 let step2: String = results.get("step2").unwrap();
264 assert_eq!(step2, "hello");
265 }
266}