1use std::future::Future;
2use std::pin::Pin;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::Result;
10
11type CompensateFn<'a, T, C> = Arc<dyn Fn(T) -> Pin<Box<C>> + Send + Sync + 'a>;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16#[non_exhaustive]
17pub enum StepStatus {
18 Pending,
20 Running,
22 Completed,
24 Failed,
26 Compensated,
28 Skipped,
30 Waiting,
32}
33
34impl StepStatus {
35 pub fn as_str(&self) -> &'static str {
37 match self {
38 Self::Pending => "pending",
39 Self::Running => "running",
40 Self::Completed => "completed",
41 Self::Failed => "failed",
42 Self::Compensated => "compensated",
43 Self::Skipped => "skipped",
44 Self::Waiting => "waiting",
45 }
46 }
47}
48
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub struct ParseStepStatusError(pub String);
51
52impl std::fmt::Display for ParseStepStatusError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "invalid step status: '{}'", self.0)
55 }
56}
57
58impl std::error::Error for ParseStepStatusError {}
59
60impl FromStr for StepStatus {
61 type Err = ParseStepStatusError;
62
63 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
64 match s {
65 "pending" => Ok(Self::Pending),
66 "running" => Ok(Self::Running),
67 "completed" => Ok(Self::Completed),
68 "failed" => Ok(Self::Failed),
69 "compensated" => Ok(Self::Compensated),
70 "skipped" => Ok(Self::Skipped),
71 "waiting" => Ok(Self::Waiting),
72 _ => Err(ParseStepStatusError(s.to_string())),
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct StepResult<T> {
80 pub name: String,
82 pub status: StepStatus,
84 pub value: Option<T>,
86 pub error: Option<String>,
88}
89
90pub struct Step<T> {
92 pub name: String,
94 _marker: std::marker::PhantomData<T>,
96}
97
98impl<T> Step<T> {
99 pub fn new(name: impl Into<String>) -> Self {
101 Self {
102 name: name.into(),
103 _marker: std::marker::PhantomData,
104 }
105 }
106}
107
108pub struct StepBuilder<'a, T, F, C>
110where
111 T: Serialize + DeserializeOwned + Send + 'static,
112 F: Future<Output = Result<T>> + Send + 'a,
113 C: Future<Output = Result<()>> + Send + 'a,
114{
115 name: String,
116 run_fn: Option<Pin<Box<dyn FnOnce() -> F + Send + 'a>>>,
117 compensate_fn: Option<CompensateFn<'a, T, C>>,
118 timeout: Option<Duration>,
119 retry_count: u32,
120 retry_delay: Duration,
121 optional: bool,
122 _marker: std::marker::PhantomData<(T, F, C)>,
123}
124
125impl<'a, T, F, C> StepBuilder<'a, T, F, C>
126where
127 T: Serialize + DeserializeOwned + Send + Clone + 'static,
128 F: Future<Output = Result<T>> + Send + 'a,
129 C: Future<Output = Result<()>> + Send + 'a,
130{
131 pub fn new(name: impl Into<String>) -> Self {
133 Self {
134 name: name.into(),
135 run_fn: None,
136 compensate_fn: None,
137 timeout: None,
138 retry_count: 0,
139 retry_delay: Duration::from_secs(1),
140 optional: false,
141 _marker: std::marker::PhantomData,
142 }
143 }
144
145 pub fn run<RF>(mut self, f: RF) -> Self
147 where
148 RF: FnOnce() -> F + Send + 'a,
149 {
150 self.run_fn = Some(Box::pin(f));
151 self
152 }
153
154 pub fn compensate<CF>(mut self, f: CF) -> Self
165 where
166 CF: Fn(T) -> Pin<Box<C>> + Send + Sync + 'a,
167 {
168 self.compensate_fn = Some(Arc::new(f));
169 self
170 }
171
172 pub fn timeout(mut self, duration: Duration) -> Self {
174 self.timeout = Some(duration);
175 self
176 }
177
178 pub fn retry(mut self, count: u32, delay: Duration) -> Self {
180 self.retry_count = count;
181 self.retry_delay = delay;
182 self
183 }
184
185 pub fn optional(mut self) -> Self {
187 self.optional = true;
188 self
189 }
190
191 pub fn name(&self) -> &str {
193 &self.name
194 }
195
196 pub fn is_optional(&self) -> bool {
198 self.optional
199 }
200
201 pub fn retry_count(&self) -> u32 {
203 self.retry_count
204 }
205
206 pub fn retry_delay(&self) -> Duration {
208 self.retry_delay
209 }
210
211 pub fn get_timeout(&self) -> Option<Duration> {
213 self.timeout
214 }
215}
216
217#[derive(Debug, Clone)]
219pub struct StepConfig {
220 pub name: String,
222 pub timeout: Option<Duration>,
224 pub retry_count: u32,
226 pub retry_delay: Duration,
228 pub optional: bool,
230 pub has_compensation: bool,
232}
233
234impl Default for StepConfig {
235 fn default() -> Self {
236 Self {
237 name: String::new(),
238 timeout: None,
239 retry_count: 0,
240 retry_delay: Duration::from_secs(1),
241 optional: false,
242 has_compensation: false,
243 }
244 }
245}
246
247#[cfg(test)]
248#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_step_status_conversion() {
254 assert_eq!(StepStatus::Pending.as_str(), "pending");
255 assert_eq!(StepStatus::Running.as_str(), "running");
256 assert_eq!(StepStatus::Completed.as_str(), "completed");
257 assert_eq!(StepStatus::Failed.as_str(), "failed");
258 assert_eq!(StepStatus::Compensated.as_str(), "compensated");
259
260 assert_eq!("pending".parse::<StepStatus>(), Ok(StepStatus::Pending));
261 assert_eq!("completed".parse::<StepStatus>(), Ok(StepStatus::Completed));
262 }
263
264 #[test]
265 fn test_step_config_default() {
266 let config = StepConfig::default();
267 assert!(config.name.is_empty());
268 assert!(!config.optional);
269 assert_eq!(config.retry_count, 0);
270 }
271
272 #[test]
273 fn step_status_as_str_covers_all_variants() {
274 assert_eq!(StepStatus::Pending.as_str(), "pending");
275 assert_eq!(StepStatus::Running.as_str(), "running");
276 assert_eq!(StepStatus::Completed.as_str(), "completed");
277 assert_eq!(StepStatus::Failed.as_str(), "failed");
278 assert_eq!(StepStatus::Compensated.as_str(), "compensated");
279 assert_eq!(StepStatus::Skipped.as_str(), "skipped");
280 assert_eq!(StepStatus::Waiting.as_str(), "waiting");
281 }
282
283 #[test]
284 fn step_status_parse_roundtrips_every_variant() {
285 for status in [
286 StepStatus::Pending,
287 StepStatus::Running,
288 StepStatus::Completed,
289 StepStatus::Failed,
290 StepStatus::Compensated,
291 StepStatus::Skipped,
292 StepStatus::Waiting,
293 ] {
294 let s = status.as_str();
295 let parsed: StepStatus = s.parse().unwrap();
296 assert_eq!(parsed, status, "{s} did not round-trip");
297 }
298 }
299
300 #[test]
301 fn step_status_parse_rejects_unknown() {
302 let err = "garbage".parse::<StepStatus>().unwrap_err();
303 assert_eq!(err.0, "garbage");
304 assert!(err.to_string().contains("garbage"));
306 }
307
308 #[test]
309 fn step_constructor_records_name() {
310 let s: Step<String> = Step::new("send_email");
311 assert_eq!(s.name, "send_email");
312 }
313
314 type NoFut = Pin<Box<dyn Future<Output = Result<u32>> + Send + 'static>>;
315 type NoComp = Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>;
316
317 fn fresh_builder<'a>() -> StepBuilder<'a, u32, NoFut, NoComp> {
318 StepBuilder::new("noop")
319 }
320
321 #[test]
322 fn step_builder_defaults() {
323 let b = fresh_builder();
324 assert_eq!(b.name(), "noop");
325 assert!(!b.is_optional());
326 assert_eq!(b.retry_count(), 0);
327 assert_eq!(b.retry_delay(), Duration::from_secs(1));
328 assert!(b.get_timeout().is_none());
329 }
330
331 #[test]
332 fn step_builder_optional_flag_flips() {
333 let b = fresh_builder().optional();
334 assert!(b.is_optional());
335 }
336
337 #[test]
338 fn step_builder_retry_sets_count_and_delay() {
339 let b = fresh_builder().retry(3, Duration::from_millis(250));
340 assert_eq!(b.retry_count(), 3);
341 assert_eq!(b.retry_delay(), Duration::from_millis(250));
342 }
343
344 #[test]
345 fn step_builder_timeout_setter() {
346 let b = fresh_builder().timeout(Duration::from_secs(5));
347 assert_eq!(b.get_timeout(), Some(Duration::from_secs(5)));
348 }
349}