forge_core/workflow/
parallel.rs1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Serialize, de::DeserializeOwned};
7
8use super::CompensationHandler;
9use super::context::WorkflowContext;
10use crate::{ForgeError, Result};
11
12type ParallelStepHandler =
14 Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + 'static>>;
15
16struct ParallelStep {
18 name: String,
19 handler: ParallelStepHandler,
20 compensate: Option<CompensationHandler>,
21}
22
23pub struct ParallelBuilder<'a> {
25 ctx: &'a WorkflowContext,
26 steps: Vec<ParallelStep>,
27}
28
29impl<'a> ParallelBuilder<'a> {
30 pub fn new(ctx: &'a WorkflowContext) -> Self {
32 Self {
33 ctx,
34 steps: Vec::new(),
35 }
36 }
37
38 pub fn step<T, F, Fut>(mut self, name: &str, handler: F) -> Self
40 where
41 T: Serialize + Send + 'static,
42 F: FnOnce() -> Fut + Send + 'static,
43 Fut: Future<Output = Result<T>> + Send + 'static,
44 {
45 let step_handler: ParallelStepHandler = Box::pin(async move {
46 let result = handler().await?;
47 serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
48 });
49
50 self.steps.push(ParallelStep {
51 name: name.to_string(),
52 handler: step_handler,
53 compensate: None,
54 });
55
56 self
57 }
58
59 pub fn step_with_compensate<T, F, Fut, C, CFut>(
61 mut self,
62 name: &str,
63 handler: F,
64 compensate: C,
65 ) -> Self
66 where
67 T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
68 F: FnOnce() -> Fut + Send + 'static,
69 Fut: Future<Output = Result<T>> + Send + 'static,
70 C: Fn(T) -> CFut + Send + Sync + 'static,
71 CFut: Future<Output = Result<()>> + Send + 'static,
72 {
73 let step_handler: ParallelStepHandler = Box::pin(async move {
74 let result = handler().await?;
75 serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
76 });
77
78 let compensation: CompensationHandler = Arc::new(move |value: serde_json::Value| {
79 let result: std::result::Result<T, _> = serde_json::from_value(value);
80 match result {
81 Ok(typed_value) => Box::pin(compensate(typed_value))
82 as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
83 Err(e) => Box::pin(async move {
84 Err(ForgeError::Deserialization(format!(
85 "Failed to deserialize compensation value: {}",
86 e
87 )))
88 }) as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
89 }
90 });
91
92 self.steps.push(ParallelStep {
93 name: name.to_string(),
94 handler: step_handler,
95 compensate: Some(compensation),
96 });
97
98 self
99 }
100
101 pub async fn run(self) -> Result<ParallelResults> {
103 let mut results = ParallelResults::new();
104 let mut compensation_handlers: Vec<(String, CompensationHandler)> = Vec::new();
105 let mut pending_steps = Vec::new();
106
107 for step in self.steps {
109 if let Some(cached) = self.ctx.get_step_result::<serde_json::Value>(&step.name) {
110 results.insert(step.name.clone(), cached);
111 } else {
112 pending_steps.push(step);
113 }
114 }
115
116 if pending_steps.is_empty() {
118 return Ok(results);
119 }
120
121 for step in &pending_steps {
123 self.ctx.record_step_start(&step.name);
124 }
125
126 type StepResult = (
128 String,
129 Result<serde_json::Value>,
130 Option<CompensationHandler>,
131 );
132
133 let handles: Vec<tokio::task::JoinHandle<StepResult>> = pending_steps
134 .into_iter()
135 .map(|step| {
136 let name = step.name;
137 let handler = step.handler;
138 let compensate = step.compensate;
139 tokio::spawn(async move {
140 let result = handler.await;
141 (name, result, compensate)
142 })
143 })
144 .collect();
145
146 let step_results = futures::future::join_all(handles).await;
148 let mut failed = false;
149 let mut first_error: Option<ForgeError> = None;
150
151 for join_result in step_results {
152 let (name, result, compensate): StepResult =
153 join_result.map_err(|e| ForgeError::Internal(format!("Task join error: {}", e)))?;
154
155 match result {
156 Ok(value) => {
157 self.ctx.record_step_complete(&name, value.clone());
158 results.insert(name.clone(), value);
159 if let Some(comp) = compensate {
160 compensation_handlers.push((name, comp));
161 }
162 }
163 Err(e) => {
164 self.ctx.record_step_failure(&name, e.to_string());
165 failed = true;
166 if first_error.is_none() {
167 first_error = Some(e);
168 }
169 }
170 }
171 }
172
173 if failed {
175 for (name, handler) in compensation_handlers.into_iter().rev() {
176 self.ctx.register_compensation(&name, handler);
177 }
178 self.ctx.run_compensation().await;
179 return Err(first_error.unwrap());
180 }
181
182 Ok(results)
183 }
184}
185
186#[derive(Debug, Clone, Default)]
188pub struct ParallelResults {
189 inner: HashMap<String, serde_json::Value>,
190}
191
192impl ParallelResults {
193 pub fn new() -> Self {
195 Self {
196 inner: HashMap::new(),
197 }
198 }
199
200 pub fn insert(&mut self, step_name: String, value: serde_json::Value) {
202 self.inner.insert(step_name, value);
203 }
204
205 pub fn get<T: DeserializeOwned>(&self, step_name: &str) -> Result<T> {
207 let value = self
208 .inner
209 .get(step_name)
210 .ok_or_else(|| ForgeError::NotFound(format!("Step '{}' not found", step_name)))?;
211 serde_json::from_value(value.clone())
212 .map_err(|e| ForgeError::Deserialization(e.to_string()))
213 }
214
215 pub fn contains(&self, step_name: &str) -> bool {
217 self.inner.contains_key(step_name)
218 }
219
220 pub fn len(&self) -> usize {
222 self.inner.len()
223 }
224
225 pub fn is_empty(&self) -> bool {
227 self.inner.is_empty()
228 }
229
230 pub fn iter(&self) -> impl Iterator<Item = (&String, &serde_json::Value)> {
232 self.inner.iter()
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_parallel_results() {
242 let mut results = ParallelResults::new();
243 results.insert("step1".to_string(), serde_json::json!({"value": 42}));
244 results.insert("step2".to_string(), serde_json::json!("hello"));
245
246 assert!(results.contains("step1"));
247 assert!(results.contains("step2"));
248 assert!(!results.contains("step3"));
249 assert_eq!(results.len(), 2);
250
251 #[derive(Debug, serde::Deserialize, PartialEq)]
252 struct StepResult {
253 value: i32,
254 }
255
256 let step1: StepResult = results.get("step1").unwrap();
257 assert_eq!(step1.value, 42);
258
259 let step2: String = results.get("step2").unwrap();
260 assert_eq!(step2, "hello");
261 }
262}