Skip to main content

forge_core/workflow/
step.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::Result;
10
11/// Type alias for compensation function to reduce complexity.
12type CompensateFn<'a, T, C> = Arc<dyn Fn(T) -> Pin<Box<C>> + Send + Sync + 'a>;
13
14/// Step execution status.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[non_exhaustive]
17pub enum StepStatus {
18    /// Step not yet started.
19    Pending,
20    /// Step currently running.
21    Running,
22    /// Step completed successfully.
23    Completed,
24    /// Step failed.
25    Failed,
26    /// Step compensation ran.
27    Compensated,
28    /// Step was skipped.
29    Skipped,
30    /// Step is waiting (suspended).
31    Waiting,
32}
33
34impl StepStatus {
35    /// Convert to string for database storage.
36    pub fn as_str(&self) -> &'static str {
37        match self {
38            Self::Pending => "pending",
39            Self::Running => "running",
40            Self::Completed => "completed",
41            Self::Failed => "failed",
42            Self::Compensated => "compensated",
43            Self::Skipped => "skipped",
44            Self::Waiting => "waiting",
45        }
46    }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct ParseStepStatusError(pub String);
51
52impl std::fmt::Display for ParseStepStatusError {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        write!(f, "invalid step status: '{}'", self.0)
55    }
56}
57
58impl std::error::Error for ParseStepStatusError {}
59
60impl FromStr for StepStatus {
61    type Err = ParseStepStatusError;
62
63    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
64        match s {
65            "pending" => Ok(Self::Pending),
66            "running" => Ok(Self::Running),
67            "completed" => Ok(Self::Completed),
68            "failed" => Ok(Self::Failed),
69            "compensated" => Ok(Self::Compensated),
70            "skipped" => Ok(Self::Skipped),
71            "waiting" => Ok(Self::Waiting),
72            _ => Err(ParseStepStatusError(s.to_string())),
73        }
74    }
75}
76
77/// Result of a step execution.
78#[derive(Debug, Clone)]
79pub struct StepResult<T> {
80    /// Step name.
81    pub name: String,
82    /// Step status.
83    pub status: StepStatus,
84    /// Step result (if completed).
85    pub value: Option<T>,
86    /// Error message (if failed).
87    pub error: Option<String>,
88}
89
90/// A workflow step definition.
91pub struct Step<T> {
92    /// Step name.
93    pub name: String,
94    /// Step result type.
95    _marker: std::marker::PhantomData<T>,
96}
97
98impl<T> Step<T> {
99    /// Create a new step.
100    pub fn new(name: impl Into<String>) -> Self {
101        Self {
102            name: name.into(),
103            _marker: std::marker::PhantomData,
104        }
105    }
106}
107
108/// Builder for configuring and executing a step.
109pub struct StepBuilder<'a, T, F, C>
110where
111    T: Serialize + DeserializeOwned + Send + 'static,
112    F: Future<Output = Result<T>> + Send + 'a,
113    C: Future<Output = Result<()>> + Send + 'a,
114{
115    name: String,
116    run_fn: Option<Pin<Box<dyn FnOnce() -> F + Send + 'a>>>,
117    compensate_fn: Option<CompensateFn<'a, T, C>>,
118    timeout: Option<Duration>,
119    retry_count: u32,
120    retry_delay: Duration,
121    optional: bool,
122    _marker: std::marker::PhantomData<(T, F, C)>,
123}
124
125impl<'a, T, F, C> StepBuilder<'a, T, F, C>
126where
127    T: Serialize + DeserializeOwned + Send + Clone + 'static,
128    F: Future<Output = Result<T>> + Send + 'a,
129    C: Future<Output = Result<()>> + Send + 'a,
130{
131    /// Create a new step builder.
132    pub fn new(name: impl Into<String>) -> Self {
133        Self {
134            name: name.into(),
135            run_fn: None,
136            compensate_fn: None,
137            timeout: None,
138            retry_count: 0,
139            retry_delay: Duration::from_secs(1),
140            optional: false,
141            _marker: std::marker::PhantomData,
142        }
143    }
144
145    /// Set the step execution function.
146    pub fn run<RF>(mut self, f: RF) -> Self
147    where
148        RF: FnOnce() -> F + Send + 'a,
149    {
150        self.run_fn = Some(Box::pin(f));
151        self
152    }
153
154    /// Set the compensation function.
155    ///
156    /// # Warning
157    ///
158    /// Compensation handlers are in-memory closures. They do **not** survive
159    /// process restarts. If the workflow suspends (via `ctx.sleep()` or
160    /// `ctx.wait_for_event()`) and the process restarts before the workflow
161    /// completes, registered compensation handlers are lost. The executor
162    /// detects this and fails the workflow with a message requiring manual
163    /// remediation.
164    pub fn compensate<CF>(mut self, f: CF) -> Self
165    where
166        CF: Fn(T) -> Pin<Box<C>> + Send + Sync + 'a,
167    {
168        self.compensate_fn = Some(Arc::new(f));
169        self
170    }
171
172    /// Set step timeout.
173    pub fn timeout(mut self, duration: Duration) -> Self {
174        self.timeout = Some(duration);
175        self
176    }
177
178    /// Configure retry behavior.
179    pub fn retry(mut self, count: u32, delay: Duration) -> Self {
180        self.retry_count = count;
181        self.retry_delay = delay;
182        self
183    }
184
185    /// Mark the step as optional (failure won't trigger compensation).
186    pub fn optional(mut self) -> Self {
187        self.optional = true;
188        self
189    }
190
191    /// Get step name.
192    pub fn name(&self) -> &str {
193        &self.name
194    }
195
196    /// Check if step is optional.
197    pub fn is_optional(&self) -> bool {
198        self.optional
199    }
200
201    /// Get retry count.
202    pub fn retry_count(&self) -> u32 {
203        self.retry_count
204    }
205
206    /// Get retry delay.
207    pub fn retry_delay(&self) -> Duration {
208        self.retry_delay
209    }
210
211    /// Get timeout.
212    pub fn get_timeout(&self) -> Option<Duration> {
213        self.timeout
214    }
215}
216
217/// Configuration for a step (without closures, for storage).
218#[derive(Debug, Clone)]
219pub struct StepConfig {
220    /// Step name.
221    pub name: String,
222    /// Step timeout.
223    pub timeout: Option<Duration>,
224    /// Retry count.
225    pub retry_count: u32,
226    /// Retry delay.
227    pub retry_delay: Duration,
228    /// Whether the step is optional.
229    pub optional: bool,
230    /// Whether the step has a compensation function.
231    pub has_compensation: bool,
232}
233
234impl Default for StepConfig {
235    fn default() -> Self {
236        Self {
237            name: String::new(),
238            timeout: None,
239            retry_count: 0,
240            retry_delay: Duration::from_secs(1),
241            optional: false,
242            has_compensation: false,
243        }
244    }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_step_status_conversion() {
254        assert_eq!(StepStatus::Pending.as_str(), "pending");
255        assert_eq!(StepStatus::Running.as_str(), "running");
256        assert_eq!(StepStatus::Completed.as_str(), "completed");
257        assert_eq!(StepStatus::Failed.as_str(), "failed");
258        assert_eq!(StepStatus::Compensated.as_str(), "compensated");
259
260        assert_eq!("pending".parse::<StepStatus>(), Ok(StepStatus::Pending));
261        assert_eq!("completed".parse::<StepStatus>(), Ok(StepStatus::Completed));
262    }
263
264    #[test]
265    fn test_step_config_default() {
266        let config = StepConfig::default();
267        assert!(config.name.is_empty());
268        assert!(!config.optional);
269        assert_eq!(config.retry_count, 0);
270    }
271
272    #[test]
273    fn step_status_as_str_covers_all_variants() {
274        assert_eq!(StepStatus::Pending.as_str(), "pending");
275        assert_eq!(StepStatus::Running.as_str(), "running");
276        assert_eq!(StepStatus::Completed.as_str(), "completed");
277        assert_eq!(StepStatus::Failed.as_str(), "failed");
278        assert_eq!(StepStatus::Compensated.as_str(), "compensated");
279        assert_eq!(StepStatus::Skipped.as_str(), "skipped");
280        assert_eq!(StepStatus::Waiting.as_str(), "waiting");
281    }
282
283    #[test]
284    fn step_status_parse_roundtrips_every_variant() {
285        for status in [
286            StepStatus::Pending,
287            StepStatus::Running,
288            StepStatus::Completed,
289            StepStatus::Failed,
290            StepStatus::Compensated,
291            StepStatus::Skipped,
292            StepStatus::Waiting,
293        ] {
294            let s = status.as_str();
295            let parsed: StepStatus = s.parse().unwrap();
296            assert_eq!(parsed, status, "{s} did not round-trip");
297        }
298    }
299
300    #[test]
301    fn step_status_parse_rejects_unknown() {
302        let err = "garbage".parse::<StepStatus>().unwrap_err();
303        assert_eq!(err.0, "garbage");
304        // Display must echo the bad value so logs pinpoint the typo.
305        assert!(err.to_string().contains("garbage"));
306    }
307
308    #[test]
309    fn step_constructor_records_name() {
310        let s: Step<String> = Step::new("send_email");
311        assert_eq!(s.name, "send_email");
312    }
313
314    type NoFut = Pin<Box<dyn Future<Output = Result<u32>> + Send + 'static>>;
315    type NoComp = Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>;
316
317    fn fresh_builder<'a>() -> StepBuilder<'a, u32, NoFut, NoComp> {
318        StepBuilder::new("noop")
319    }
320
321    #[test]
322    fn step_builder_defaults() {
323        let b = fresh_builder();
324        assert_eq!(b.name(), "noop");
325        assert!(!b.is_optional());
326        assert_eq!(b.retry_count(), 0);
327        assert_eq!(b.retry_delay(), Duration::from_secs(1));
328        assert!(b.get_timeout().is_none());
329    }
330
331    #[test]
332    fn step_builder_optional_flag_flips() {
333        let b = fresh_builder().optional();
334        assert!(b.is_optional());
335    }
336
337    #[test]
338    fn step_builder_retry_sets_count_and_delay() {
339        let b = fresh_builder().retry(3, Duration::from_millis(250));
340        assert_eq!(b.retry_count(), 3);
341        assert_eq!(b.retry_delay(), Duration::from_millis(250));
342    }
343
344    #[test]
345    fn step_builder_timeout_setter() {
346        let b = fresh_builder().timeout(Duration::from_secs(5));
347        assert_eq!(b.get_timeout(), Some(Duration::from_secs(5)));
348    }
349}