Skip to main content

forge_core/workflow/
parallel.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use serde::{Serialize, de::DeserializeOwned};
7
8use super::CompensationHandler;
9use super::context::WorkflowContext;
10use crate::{ForgeError, Result};
11
12/// Type alias for parallel step handler.
13type ParallelStepHandler =
14    Pin<Box<dyn Future<Output = Result<serde_json::Value>> + Send + 'static>>;
15
16/// A step to be executed in parallel.
17struct ParallelStep {
18    name: String,
19    handler: ParallelStepHandler,
20    compensate: Option<CompensationHandler>,
21}
22
23/// Builder for executing workflow steps in parallel.
24pub struct ParallelBuilder<'a> {
25    ctx: &'a WorkflowContext,
26    steps: Vec<ParallelStep>,
27}
28
29impl<'a> ParallelBuilder<'a> {
30    /// Create a new parallel builder.
31    pub fn new(ctx: &'a WorkflowContext) -> Self {
32        Self {
33            ctx,
34            steps: Vec::new(),
35        }
36    }
37
38    /// Add a step to be executed in parallel.
39    pub fn step<T, F, Fut>(mut self, name: &str, handler: F) -> Self
40    where
41        T: Serialize + Send + 'static,
42        F: FnOnce() -> Fut + Send + 'static,
43        Fut: Future<Output = Result<T>> + Send + 'static,
44    {
45        let step_handler: ParallelStepHandler = Box::pin(async move {
46            let result = handler().await?;
47            serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
48        });
49
50        self.steps.push(ParallelStep {
51            name: name.to_string(),
52            handler: step_handler,
53            compensate: None,
54        });
55
56        self
57    }
58
59    /// Add a step with compensation handler.
60    pub fn step_with_compensate<T, F, Fut, C, CFut>(
61        mut self,
62        name: &str,
63        handler: F,
64        compensate: C,
65    ) -> Self
66    where
67        T: Serialize + DeserializeOwned + Clone + Send + Sync + 'static,
68        F: FnOnce() -> Fut + Send + 'static,
69        Fut: Future<Output = Result<T>> + Send + 'static,
70        C: Fn(T) -> CFut + Send + Sync + 'static,
71        CFut: Future<Output = Result<()>> + Send + 'static,
72    {
73        let step_handler: ParallelStepHandler = Box::pin(async move {
74            let result = handler().await?;
75            serde_json::to_value(result).map_err(|e| ForgeError::Serialization(e.to_string()))
76        });
77
78        let compensation: CompensationHandler = Arc::new(move |value: serde_json::Value| {
79            let result: std::result::Result<T, _> = serde_json::from_value(value);
80            match result {
81                Ok(typed_value) => Box::pin(compensate(typed_value))
82                    as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
83                Err(e) => Box::pin(async move {
84                    Err(ForgeError::Deserialization(format!(
85                        "Failed to deserialize compensation value: {}",
86                        e
87                    )))
88                }) as Pin<Box<dyn Future<Output = Result<()>> + Send>>,
89            }
90        });
91
92        self.steps.push(ParallelStep {
93            name: name.to_string(),
94            handler: step_handler,
95            compensate: Some(compensation),
96        });
97
98        self
99    }
100
101    /// Execute all steps in parallel.
102    pub async fn run(self) -> Result<ParallelResults> {
103        let mut results = ParallelResults::new();
104        let mut compensation_handlers: Vec<(String, CompensationHandler)> = Vec::new();
105        let mut pending_steps = Vec::new();
106
107        // Check for cached results
108        for step in self.steps {
109            if let Some(cached) = self.ctx.get_step_result::<serde_json::Value>(&step.name) {
110                results.insert(step.name.clone(), cached);
111            } else {
112                pending_steps.push(step);
113            }
114        }
115
116        // If all steps are cached, return early
117        if pending_steps.is_empty() {
118            return Ok(results);
119        }
120
121        // Record step starts
122        for step in &pending_steps {
123            self.ctx.record_step_start(&step.name);
124        }
125
126        // Execute steps in parallel
127        type StepResult = (
128            String,
129            Result<serde_json::Value>,
130            Option<CompensationHandler>,
131        );
132
133        let handles: Vec<tokio::task::JoinHandle<StepResult>> = pending_steps
134            .into_iter()
135            .map(|step| {
136                let name = step.name;
137                let handler = step.handler;
138                let compensate = step.compensate;
139                tokio::spawn(async move {
140                    let result = handler.await;
141                    (name, result, compensate)
142                })
143            })
144            .collect();
145
146        // Collect results
147        let mut step_results = Vec::with_capacity(handles.len());
148        for handle in handles {
149            step_results.push(handle.await);
150        }
151        let mut failed = false;
152        let mut first_error: Option<ForgeError> = None;
153
154        for join_result in step_results {
155            let (name, result, compensate): StepResult =
156                join_result.map_err(|e| ForgeError::Internal(format!("Task join error: {}", e)))?;
157
158            match result {
159                Ok(value) => {
160                    self.ctx.record_step_complete(&name, value.clone());
161                    results.insert(name.clone(), value);
162                    if let Some(comp) = compensate {
163                        compensation_handlers.push((name, comp));
164                    }
165                }
166                Err(e) => {
167                    self.ctx.record_step_failure(&name, e.to_string());
168                    failed = true;
169                    if first_error.is_none() {
170                        first_error = Some(e);
171                    }
172                }
173            }
174        }
175
176        // If any step failed, run compensation in reverse order
177        if failed {
178            for (name, handler) in compensation_handlers.into_iter().rev() {
179                self.ctx.register_compensation(&name, handler);
180            }
181            self.ctx.run_compensation().await;
182            return Err(first_error.expect("failed flag set implies at least one error"));
183        }
184
185        Ok(results)
186    }
187}
188
189/// Results from parallel step execution.
190#[derive(Debug, Clone, Default)]
191pub struct ParallelResults {
192    inner: HashMap<String, serde_json::Value>,
193}
194
195impl ParallelResults {
196    /// Create empty results.
197    pub fn new() -> Self {
198        Self {
199            inner: HashMap::new(),
200        }
201    }
202
203    /// Insert a result.
204    pub fn insert(&mut self, step_name: String, value: serde_json::Value) {
205        self.inner.insert(step_name, value);
206    }
207
208    /// Get a typed result by step name.
209    pub fn get<T: DeserializeOwned>(&self, step_name: &str) -> Result<T> {
210        let value = self
211            .inner
212            .get(step_name)
213            .ok_or_else(|| ForgeError::NotFound(format!("Step '{}' not found", step_name)))?;
214        serde_json::from_value(value.clone())
215            .map_err(|e| ForgeError::Deserialization(e.to_string()))
216    }
217
218    /// Check if a step result exists.
219    pub fn contains(&self, step_name: &str) -> bool {
220        self.inner.contains_key(step_name)
221    }
222
223    /// Get the number of results.
224    pub fn len(&self) -> usize {
225        self.inner.len()
226    }
227
228    /// Check if empty.
229    pub fn is_empty(&self) -> bool {
230        self.inner.is_empty()
231    }
232
233    /// Iterate over results.
234    pub fn iter(&self) -> impl Iterator<Item = (&String, &serde_json::Value)> {
235        self.inner.iter()
236    }
237}
238
239#[cfg(test)]
240#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_parallel_results() {
246        let mut results = ParallelResults::new();
247        results.insert("step1".to_string(), serde_json::json!({"value": 42}));
248        results.insert("step2".to_string(), serde_json::json!("hello"));
249
250        assert!(results.contains("step1"));
251        assert!(results.contains("step2"));
252        assert!(!results.contains("step3"));
253        assert_eq!(results.len(), 2);
254
255        #[derive(Debug, serde::Deserialize, PartialEq)]
256        struct StepResult {
257            value: i32,
258        }
259
260        let step1: StepResult = results.get("step1").unwrap();
261        assert_eq!(step1.value, 42);
262
263        let step2: String = results.get("step2").unwrap();
264        assert_eq!(step2, "hello");
265    }
266}