Skip to main content

punch_kernel/
workflow_loops.rs

1//! Loop constructs for workflow steps.
2//!
3//! Supports `ForEach` (iterate over JSON arrays), `While` (repeat while
4//! condition holds), and `Retry` (retry with backoff).
5
6use serde::{Deserialize, Serialize};
7
8use crate::workflow_conditions::Condition;
9
10/// A loop construct that can be attached to a workflow step.
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
12#[serde(rename_all = "snake_case")]
13pub enum LoopConfig {
14    /// Iterate over items from a JSON array in a previous step's output.
15    ForEach {
16        /// Name of the step whose output is a JSON array.
17        source_step: String,
18        /// Maximum number of iterations (safety limit).
19        max_iterations: usize,
20    },
21    /// Repeat while a condition is true.
22    While {
23        /// The condition to evaluate each iteration.
24        condition: Condition,
25        /// Maximum iterations (safety limit, prevents infinite loops).
26        max_iterations: usize,
27    },
28    /// Retry a step N times with configurable backoff.
29    Retry {
30        /// Maximum number of retry attempts.
31        max_retries: usize,
32        /// Initial backoff in milliseconds.
33        backoff_ms: u64,
34        /// Backoff multiplier (e.g. 2.0 for exponential backoff).
35        backoff_multiplier: f64,
36    },
37}
38
39/// Tracks the state of a loop during execution.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct LoopState {
42    /// Current iteration index (0-based).
43    pub index: usize,
44    /// Current item (for ForEach loops — the JSON string of the current element).
45    pub item: Option<String>,
46    /// Accumulated results from each iteration.
47    pub accumulated_results: Vec<String>,
48    /// Whether a break was requested.
49    pub should_break: bool,
50    /// Whether the current iteration should be skipped (continue).
51    pub should_continue: bool,
52}
53
54impl LoopState {
55    /// Create a new loop state.
56    pub fn new() -> Self {
57        Self {
58            index: 0,
59            item: None,
60            accumulated_results: Vec::new(),
61            should_break: false,
62            should_continue: false,
63        }
64    }
65
66    /// Request a break out of the loop.
67    pub fn request_break(&mut self) {
68        self.should_break = true;
69    }
70
71    /// Request a continue (skip rest of current iteration).
72    pub fn request_continue(&mut self) {
73        self.should_continue = true;
74    }
75
76    /// Advance to the next iteration, clearing per-iteration flags.
77    pub fn advance(&mut self) {
78        self.index += 1;
79        self.should_continue = false;
80    }
81
82    /// Record the result of the current iteration.
83    pub fn push_result(&mut self, result: String) {
84        self.accumulated_results.push(result);
85    }
86}
87
88impl Default for LoopState {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94/// Parse a JSON array string into a list of individual JSON value strings.
95///
96/// Returns an error string if parsing fails.
97pub fn parse_foreach_items(json_str: &str) -> Result<Vec<String>, String> {
98    let value: serde_json::Value =
99        serde_json::from_str(json_str).map_err(|e| format!("failed to parse JSON array: {e}"))?;
100
101    match value {
102        serde_json::Value::Array(arr) => Ok(arr
103            .into_iter()
104            .map(|v| match v {
105                serde_json::Value::String(s) => s,
106                other => other.to_string(),
107            })
108            .collect()),
109        _ => Err("expected a JSON array".to_string()),
110    }
111}
112
113/// Calculate the backoff duration for a given retry attempt.
114pub fn calculate_backoff(attempt: usize, base_ms: u64, multiplier: f64) -> u64 {
115    let factor = multiplier.powi(attempt as i32);
116    (base_ms as f64 * factor) as u64
117}
118
119// ---------------------------------------------------------------------------
120// Tests
121// ---------------------------------------------------------------------------
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn loop_state_new() {
129        let state = LoopState::new();
130        assert_eq!(state.index, 0);
131        assert!(state.item.is_none());
132        assert!(state.accumulated_results.is_empty());
133        assert!(!state.should_break);
134        assert!(!state.should_continue);
135    }
136
137    #[test]
138    fn loop_state_default() {
139        let state = LoopState::default();
140        assert_eq!(state.index, 0);
141    }
142
143    #[test]
144    fn loop_state_advance() {
145        let mut state = LoopState::new();
146        state.should_continue = true;
147        state.advance();
148        assert_eq!(state.index, 1);
149        assert!(!state.should_continue);
150    }
151
152    #[test]
153    fn loop_state_break() {
154        let mut state = LoopState::new();
155        state.request_break();
156        assert!(state.should_break);
157    }
158
159    #[test]
160    fn loop_state_continue() {
161        let mut state = LoopState::new();
162        state.request_continue();
163        assert!(state.should_continue);
164    }
165
166    #[test]
167    fn loop_state_push_result() {
168        let mut state = LoopState::new();
169        state.push_result("result1".to_string());
170        state.push_result("result2".to_string());
171        assert_eq!(state.accumulated_results.len(), 2);
172        assert_eq!(state.accumulated_results[0], "result1");
173        assert_eq!(state.accumulated_results[1], "result2");
174    }
175
176    #[test]
177    fn parse_foreach_items_string_array() {
178        let items = parse_foreach_items(r#"["a", "b", "c"]"#).expect("should parse");
179        assert_eq!(items, vec!["a", "b", "c"]);
180    }
181
182    #[test]
183    fn parse_foreach_items_number_array() {
184        let items = parse_foreach_items("[1, 2, 3]").expect("should parse");
185        assert_eq!(items, vec!["1", "2", "3"]);
186    }
187
188    #[test]
189    fn parse_foreach_items_object_array() {
190        let items = parse_foreach_items(r#"[{"name": "a"}, {"name": "b"}]"#).expect("should parse");
191        assert_eq!(items.len(), 2);
192        assert!(items[0].contains("name"));
193    }
194
195    #[test]
196    fn parse_foreach_items_empty_array() {
197        let items = parse_foreach_items("[]").expect("should parse");
198        assert!(items.is_empty());
199    }
200
201    #[test]
202    fn parse_foreach_items_not_array() {
203        let result = parse_foreach_items(r#"{"key": "value"}"#);
204        assert!(result.is_err());
205        assert!(result.expect_err("error").contains("expected a JSON array"));
206    }
207
208    #[test]
209    fn parse_foreach_items_invalid_json() {
210        let result = parse_foreach_items("not json at all");
211        assert!(result.is_err());
212    }
213
214    #[test]
215    fn calculate_backoff_first_attempt() {
216        let ms = calculate_backoff(0, 100, 2.0);
217        assert_eq!(ms, 100);
218    }
219
220    #[test]
221    fn calculate_backoff_exponential() {
222        assert_eq!(calculate_backoff(1, 100, 2.0), 200);
223        assert_eq!(calculate_backoff(2, 100, 2.0), 400);
224        assert_eq!(calculate_backoff(3, 100, 2.0), 800);
225    }
226
227    #[test]
228    fn calculate_backoff_no_multiplier() {
229        assert_eq!(calculate_backoff(0, 500, 1.0), 500);
230        assert_eq!(calculate_backoff(1, 500, 1.0), 500);
231        assert_eq!(calculate_backoff(5, 500, 1.0), 500);
232    }
233
234    #[test]
235    fn loop_config_foreach_serialization() {
236        let config = LoopConfig::ForEach {
237            source_step: "step1".to_string(),
238            max_iterations: 100,
239        };
240        let json = serde_json::to_string(&config).expect("serialize");
241        let deser: LoopConfig = serde_json::from_str(&json).expect("deserialize");
242        assert_eq!(config, deser);
243    }
244
245    #[test]
246    fn loop_config_while_serialization() {
247        let config = LoopConfig::While {
248            condition: Condition::Always,
249            max_iterations: 10,
250        };
251        let json = serde_json::to_string(&config).expect("serialize");
252        let deser: LoopConfig = serde_json::from_str(&json).expect("deserialize");
253        assert_eq!(config, deser);
254    }
255
256    #[test]
257    fn loop_config_retry_serialization() {
258        let config = LoopConfig::Retry {
259            max_retries: 3,
260            backoff_ms: 100,
261            backoff_multiplier: 2.0,
262        };
263        let json = serde_json::to_string(&config).expect("serialize");
264        let deser: LoopConfig = serde_json::from_str(&json).expect("deserialize");
265        assert_eq!(config, deser);
266    }
267
268    #[test]
269    fn loop_state_serialization() {
270        let mut state = LoopState::new();
271        state.index = 5;
272        state.item = Some("test_item".to_string());
273        state.push_result("r1".to_string());
274        let json = serde_json::to_string(&state).expect("serialize");
275        let deser: LoopState = serde_json::from_str(&json).expect("deserialize");
276        assert_eq!(deser.index, 5);
277        assert_eq!(deser.item.as_deref(), Some("test_item"));
278        assert_eq!(deser.accumulated_results.len(), 1);
279    }
280
281    #[test]
282    fn parse_foreach_items_mixed_types() {
283        let items = parse_foreach_items(r#"["hello", 42, true, null]"#).expect("should parse");
284        assert_eq!(items.len(), 4);
285        assert_eq!(items[0], "hello");
286        assert_eq!(items[1], "42");
287        assert_eq!(items[2], "true");
288        assert_eq!(items[3], "null");
289    }
290}