Skip to main content

forge_core/workflow/
parallel.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Serialize, de::DeserializeOwned};
7
8use super::CompensationHandler;
9use super::context::WorkflowContext;
10use crate::{ForgeError, Result};
11
12/// Type alias for parallel step handler.
13type ParallelStepHandler =
14    Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + 'static>>;
15
16/// A step to be executed in parallel.
17struct ParallelStep {
18    name: String,
19    handler: ParallelStepHandler,
20    compensate: Option<CompensationHandler>,
21}
22
23/// Builder for executing workflow steps in parallel.
24pub struct ParallelBuilder<'a> {
25    ctx: &'a WorkflowContext,
26    steps: Vec<ParallelStep>,
27}
28
29impl<'a> ParallelBuilder<'a> {
30    /// Create a new parallel builder.
31    pub fn new(ctx: &'a WorkflowContext) -> Self {
32        Self {
33            ctx,
34            steps: Vec::new(),
35        }
36    }
37
38    /// Add a step to be executed in parallel.
39    pub fn step<T, F, Fut>(mut self, name: &str, handler: F) -> Self
40    where
41        T: Serialize + Send + 'static,
42        F: FnOnce() -> Fut + Send + 'static,
43        Fut: Future<Output = Result<T>> + Send + 'static,
44    {
45        let step_handler: ParallelStepHandler = Box::pin(async move {
46            let result = handler().await?;
47            serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
48        });
49
50        self.steps.push(ParallelStep {
51            name: name.to_string(),
52            handler: step_handler,
53            compensate: None,
54        });
55
56        self
57    }
58
59    /// Add a step with compensation handler.
60    pub fn step_with_compensate<T, F, Fut, C, CFut>(
61        mut self,
62        name: &str,
63        handler: F,
64        compensate: C,
65    ) -> Self
66    where
67        T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
68        F: FnOnce() -> Fut + Send + 'static,
69        Fut: Future<Output = Result<T>> + Send + 'static,
70        C: Fn(T) -> CFut + Send + Sync + 'static,
71        CFut: Future<Output = Result<()>> + Send + 'static,
72    {
73        let step_handler: ParallelStepHandler = Box::pin(async move {
74            let result = handler().await?;
75            serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
76        });
77
78        let compensation: CompensationHandler = Arc::new(move |value: serde_json::Value| {
79            let result: std::result::Result<T, _> = serde_json::from_value(value);
80            match result {
81                Ok(typed_value) => Box::pin(compensate(typed_value))
82                    as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
83                Err(e) => Box::pin(async move {
84                    Err(ForgeError::Deserialization(format!(
85                        "Failed to deserialize compensation value: {}",
86                        e
87                    )))
88                }) as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
89            }
90        });
91
92        self.steps.push(ParallelStep {
93            name: name.to_string(),
94            handler: step_handler,
95            compensate: Some(compensation),
96        });
97
98        self
99    }
100
101    /// Execute all steps in parallel.
102    pub async fn run(self) -> Result<ParallelResults> {
103        let mut results = ParallelResults::new();
104        let mut compensation_handlers: Vec<(String, CompensationHandler)> = Vec::new();
105        let mut pending_steps = Vec::new();
106
107        // Check for cached results
108        for step in self.steps {
109            if let Some(cached) = self.ctx.get_step_result::<serde_json::Value>(&step.name) {
110                results.insert(step.name.clone(), cached);
111            } else {
112                pending_steps.push(step);
113            }
114        }
115
116        // If all steps are cached, return early
117        if pending_steps.is_empty() {
118            return Ok(results);
119        }
120
121        // Record step starts
122        for step in &pending_steps {
123            self.ctx.record_step_start(&step.name);
124        }
125
126        // Execute steps in parallel
127        type StepResult = (
128            String,
129            Result<serde_json::Value>,
130            Option<CompensationHandler>,
131        );
132
133        let handles: Vec<tokio::task::JoinHandle<StepResult>> = pending_steps
134            .into_iter()
135            .map(|step| {
136                let name = step.name;
137                let handler = step.handler;
138                let compensate = step.compensate;
139                tokio::spawn(async move {
140                    let result = handler.await;
141                    (name, result, compensate)
142                })
143            })
144            .collect();
145
146        // Collect results
147        let step_results = futures::future::join_all(handles).await;
148        let mut failed = false;
149        let mut first_error: Option<ForgeError> = None;
150
151        for join_result in step_results {
152            let (name, result, compensate): StepResult =
153                join_result.map_err(|e| ForgeError::Internal(format!("Task join error: {}", e)))?;
154
155            match result {
156                Ok(value) => {
157                    self.ctx.record_step_complete(&name, value.clone());
158                    results.insert(name.clone(), value);
159                    if let Some(comp) = compensate {
160                        compensation_handlers.push((name, comp));
161                    }
162                }
163                Err(e) => {
164                    self.ctx.record_step_failure(&name, e.to_string());
165                    failed = true;
166                    if first_error.is_none() {
167                        first_error = Some(e);
168                    }
169                }
170            }
171        }
172
173        // If any step failed, run compensation in reverse order
174        if failed {
175            for (name, handler) in compensation_handlers.into_iter().rev() {
176                self.ctx.register_compensation(&name, handler);
177            }
178            self.ctx.run_compensation().await;
179            return Err(first_error.expect("failed flag set implies at least one error"));
180        }
181
182        Ok(results)
183    }
184}
185
186/// Results from parallel step execution.
187#[derive(Debug, Clone, Default)]
188pub struct ParallelResults {
189    inner: HashMap<String, serde_json::Value>,
190}
191
192impl ParallelResults {
193    /// Create empty results.
194    pub fn new() -> Self {
195        Self {
196            inner: HashMap::new(),
197        }
198    }
199
200    /// Insert a result.
201    pub fn insert(&mut self, step_name: String, value: serde_json::Value) {
202        self.inner.insert(step_name, value);
203    }
204
205    /// Get a typed result by step name.
206    pub fn get<T: DeserializeOwned>(&self, step_name: &str) -> Result<T> {
207        let value = self
208            .inner
209            .get(step_name)
210            .ok_or_else(|| ForgeError::NotFound(format!("Step '{}' not found", step_name)))?;
211        serde_json::from_value(value.clone())
212            .map_err(|e| ForgeError::Deserialization(e.to_string()))
213    }
214
215    /// Check if a step result exists.
216    pub fn contains(&self, step_name: &str) -> bool {
217        self.inner.contains_key(step_name)
218    }
219
220    /// Get the number of results.
221    pub fn len(&self) -> usize {
222        self.inner.len()
223    }
224
225    /// Check if empty.
226    pub fn is_empty(&self) -> bool {
227        self.inner.is_empty()
228    }
229
230    /// Iterate over results.
231    pub fn iter(&self) -> impl Iterator<Item = (&String, &serde_json::Value)> {
232        self.inner.iter()
233    }
234}
235
236#[cfg(test)]
237#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_parallel_results() {
243        let mut results = ParallelResults::new();
244        results.insert("step1".to_string(), serde_json::json!({"value": 42}));
245        results.insert("step2".to_string(), serde_json::json!("hello"));
246
247        assert!(results.contains("step1"));
248        assert!(results.contains("step2"));
249        assert!(!results.contains("step3"));
250        assert_eq!(results.len(), 2);
251
252        #[derive(Debug, serde::Deserialize, PartialEq)]
253        struct StepResult {
254            value: i32,
255        }
256
257        let step1: StepResult = results.get("step1").unwrap();
258        assert_eq!(step1.value, 42);
259
260        let step2: String = results.get("step2").unwrap();
261        assert_eq!(step2, "hello");
262    }
263}