Skip to main content

enact_core/flow/
repeat.rs

1//! Loop Flow - Iteration with exit condition
2//!
3//! Repeatedly execute a callable until an exit condition is met.
4
5use crate::callable::Callable;
6use std::sync::Arc;
7
8/// Loop exit condition
9pub enum LoopCondition {
10    /// Fixed number of iterations
11    MaxIterations(usize),
12    /// Output matches predicate
13    OutputMatches(Box<dyn Fn(&str) -> bool + Send + Sync>),
14    /// Output contains string
15    OutputContains(String),
16    /// Combined: max iterations OR output matches
17    Either {
18        max_iterations: usize,
19        predicate: Box<dyn Fn(&str) -> bool + Send + Sync>,
20    },
21}
22
23impl LoopCondition {
24    /// Check if loop should exit
25    pub fn should_exit(&self, iteration: usize, output: &str) -> bool {
26        match self {
27            LoopCondition::MaxIterations(max) => iteration >= *max,
28            LoopCondition::OutputMatches(pred) => pred(output),
29            LoopCondition::OutputContains(needle) => output.contains(needle),
30            LoopCondition::Either {
31                max_iterations,
32                predicate,
33            } => iteration >= *max_iterations || predicate(output),
34        }
35    }
36
37    /// Create a max iterations condition
38    pub fn max(n: usize) -> Self {
39        LoopCondition::MaxIterations(n)
40    }
41
42    /// Create an output contains condition
43    pub fn until_contains(s: impl Into<String>) -> Self {
44        LoopCondition::OutputContains(s.into())
45    }
46
47    /// Create a predicate condition
48    pub fn until(pred: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
49        LoopCondition::OutputMatches(Box::new(pred))
50    }
51
52    /// Create a combined condition
53    pub fn max_or_until(max: usize, pred: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
54        LoopCondition::Either {
55            max_iterations: max,
56            predicate: Box::new(pred),
57        }
58    }
59}
60
61/// Loop execution flow
62pub struct LoopFlow<C: Callable> {
63    /// The callable to execute repeatedly
64    callable: Arc<C>,
65    /// Exit condition
66    condition: LoopCondition,
67    /// Flow name
68    name: String,
69    /// Whether to pass previous output as input (feedback loop)
70    feedback: bool,
71}
72
73impl<C: Callable> LoopFlow<C> {
74    /// Create a new loop flow
75    pub fn new(name: impl Into<String>, callable: Arc<C>, condition: LoopCondition) -> Self {
76        Self {
77            callable,
78            condition,
79            name: name.into(),
80            feedback: true, // Default: use output as next input
81        }
82    }
83
84    /// Create a loop that runs N times
85    pub fn times(name: impl Into<String>, n: usize, callable: Arc<C>) -> Self {
86        Self::new(name, callable, LoopCondition::MaxIterations(n))
87    }
88
89    /// Create a loop that runs until output contains string
90    pub fn until_contains(name: impl Into<String>, s: impl Into<String>, callable: Arc<C>) -> Self {
91        Self::new(name, callable, LoopCondition::OutputContains(s.into()))
92    }
93
94    /// Set whether to use output as next input
95    pub fn with_feedback(mut self, feedback: bool) -> Self {
96        self.feedback = feedback;
97        self
98    }
99
100    /// Execute the loop
101    pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
102        let mut current_input = input.to_string();
103        let mut iteration = 0;
104
105        loop {
106            let output = self.callable.run(&current_input).await?;
107
108            if self.condition.should_exit(iteration, &output) {
109                return Ok(output);
110            }
111
112            // Prepare next iteration
113            if self.feedback {
114                current_input = output;
115            }
116            iteration += 1;
117        }
118    }
119
120    /// Execute with iteration tracking
121    pub async fn execute_with_history(&self, input: &str) -> anyhow::Result<LoopHistory> {
122        let mut current_input = input.to_string();
123        let mut iteration = 0;
124        let mut outputs = Vec::new();
125
126        loop {
127            let output = self.callable.run(&current_input).await?;
128            outputs.push(output.clone());
129
130            if self.condition.should_exit(iteration, &output) {
131                return Ok(LoopHistory {
132                    iterations: iteration + 1,
133                    outputs,
134                    final_output: output,
135                });
136            }
137
138            if self.feedback {
139                current_input = output;
140            }
141            iteration += 1;
142        }
143    }
144
145    /// Get the flow name
146    pub fn name(&self) -> &str {
147        &self.name
148    }
149}
150
151/// History of a loop execution
152#[derive(Debug)]
153pub struct LoopHistory {
154    /// Number of iterations executed
155    pub iterations: usize,
156    /// Output from each iteration
157    pub outputs: Vec<String>,
158    /// Final output
159    pub final_output: String,
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use async_trait::async_trait;
166    use std::sync::atomic::{AtomicUsize, Ordering};
167
168    /// Mock callable that tracks calls and transforms input
169    #[allow(clippy::type_complexity)]
170    struct MockCallable {
171        name: String,
172        call_count: Arc<AtomicUsize>,
173        transform: Box<dyn Fn(&str, usize) -> String + Send + Sync>,
174    }
175
176    impl MockCallable {
177        fn new(
178            name: &str,
179            transform: impl Fn(&str, usize) -> String + Send + Sync + 'static,
180        ) -> Self {
181            Self {
182                name: name.to_string(),
183                call_count: Arc::new(AtomicUsize::new(0)),
184                transform: Box::new(transform),
185            }
186        }
187
188        /// Simple incrementing callable - appends iteration count
189        fn incrementing(name: &str) -> Self {
190            Self::new(name, |input, n| format!("{}:{}", input, n))
191        }
192
193        /// Callable that emits "DONE" on the Nth call
194        fn done_on_call(name: &str, n: usize) -> Self {
195            Self::new(name, move |input, call| {
196                if call >= n - 1 {
197                    "DONE".to_string()
198                } else {
199                    format!("{}:{}", input, call)
200                }
201            })
202        }
203
204        fn get_call_count(&self) -> usize {
205            self.call_count.load(Ordering::SeqCst)
206        }
207    }
208
209    #[async_trait]
210    impl Callable for MockCallable {
211        fn name(&self) -> &str {
212            &self.name
213        }
214
215        async fn run(&self, input: &str) -> anyhow::Result<String> {
216            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
217            Ok((self.transform)(input, n))
218        }
219    }
220
221    // ============ LoopCondition Tests ============
222
223    #[test]
224    fn test_condition_max_iterations() {
225        let cond = LoopCondition::MaxIterations(3);
226        assert!(!cond.should_exit(0, "any"));
227        assert!(!cond.should_exit(1, "any"));
228        assert!(!cond.should_exit(2, "any"));
229        assert!(cond.should_exit(3, "any")); // Exit at iteration 3
230        assert!(cond.should_exit(5, "any")); // Also exit if past
231    }
232
233    #[test]
234    fn test_condition_output_matches() {
235        let cond = LoopCondition::OutputMatches(Box::new(|s| s.len() > 5));
236        assert!(!cond.should_exit(0, "hi"));
237        assert!(!cond.should_exit(10, "short"));
238        assert!(cond.should_exit(0, "longer"));
239        assert!(cond.should_exit(0, "this is long enough"));
240    }
241
242    #[test]
243    fn test_condition_output_contains() {
244        let cond = LoopCondition::OutputContains("DONE".to_string());
245        assert!(!cond.should_exit(0, "not yet"));
246        assert!(!cond.should_exit(5, "still working"));
247        assert!(cond.should_exit(0, "DONE"));
248        assert!(cond.should_exit(0, "task DONE here"));
249    }
250
251    #[test]
252    fn test_condition_either() {
253        let cond = LoopCondition::Either {
254            max_iterations: 5,
255            predicate: Box::new(|s| s.contains("STOP")),
256        };
257
258        // Not exit before max or predicate
259        assert!(!cond.should_exit(0, "working"));
260        assert!(!cond.should_exit(4, "still going"));
261
262        // Exit at max iterations
263        assert!(cond.should_exit(5, "working"));
264
265        // Exit when predicate matches, even before max
266        assert!(cond.should_exit(2, "STOP now"));
267    }
268
269    #[test]
270    fn test_condition_helpers() {
271        // max()
272        let cond = LoopCondition::max(2);
273        assert!(!cond.should_exit(1, "x"));
274        assert!(cond.should_exit(2, "x"));
275
276        // until_contains()
277        let cond = LoopCondition::until_contains("END");
278        assert!(!cond.should_exit(0, "not"));
279        assert!(cond.should_exit(0, "END"));
280
281        // until()
282        let cond = LoopCondition::until(|s| s == "target");
283        assert!(!cond.should_exit(0, "other"));
284        assert!(cond.should_exit(0, "target"));
285
286        // max_or_until()
287        let cond = LoopCondition::max_or_until(3, |s| s.starts_with("!"));
288        assert!(!cond.should_exit(0, "a"));
289        assert!(cond.should_exit(3, "a")); // max hit
290        assert!(cond.should_exit(0, "!bang")); // predicate hit
291    }
292
293    // ============ LoopFlow Construction Tests ============
294
295    #[tokio::test]
296    async fn test_loop_flow_new() {
297        let callable = Arc::new(MockCallable::incrementing("inc"));
298        let flow = LoopFlow::new("test_loop", callable, LoopCondition::MaxIterations(2));
299        assert_eq!(flow.name(), "test_loop");
300    }
301
302    #[tokio::test]
303    async fn test_loop_flow_times() {
304        let callable = Arc::new(MockCallable::incrementing("inc"));
305        let flow = LoopFlow::times("timer", 3, callable);
306        assert_eq!(flow.name(), "timer");
307    }
308
309    #[tokio::test]
310    async fn test_loop_flow_until_contains() {
311        let callable = Arc::new(MockCallable::done_on_call("done", 2));
312        let flow = LoopFlow::until_contains("stopper", "DONE", callable);
313        assert_eq!(flow.name(), "stopper");
314    }
315
316    // ============ LoopFlow Execute Tests ============
317
318    #[tokio::test]
319    async fn test_loop_execute_max_iterations() {
320        let callable = Arc::new(MockCallable::incrementing("inc"));
321        let flow = LoopFlow::times("loop", 3, callable.clone());
322
323        let result = flow.execute("start").await.unwrap();
324        // With MaxIterations(3), exit condition is iteration >= 3
325        // iteration 0: run, check 0>=3=false, inc to 1
326        // iteration 1: run, check 1>=3=false, inc to 2
327        // iteration 2: run, check 2>=3=false, inc to 3
328        // iteration 3: run, check 3>=3=true, exit
329        // So 4 total calls
330        assert_eq!(callable.get_call_count(), 4);
331        assert!(result.contains("start"));
332    }
333
334    #[tokio::test]
335    async fn test_loop_execute_until_contains() {
336        let callable = Arc::new(MockCallable::done_on_call("done", 3));
337        let flow = LoopFlow::until_contains("wait_done", "DONE", callable.clone());
338
339        let result = flow.execute("input").await.unwrap();
340        assert_eq!(result, "DONE");
341        assert_eq!(callable.get_call_count(), 3);
342    }
343
344    #[tokio::test]
345    async fn test_loop_execute_with_predicate() {
346        let callable = Arc::new(MockCallable::new("counter", |_, n| format!("count:{}", n)));
347        let flow = LoopFlow::new(
348            "until_five",
349            callable.clone(),
350            LoopCondition::until(|s| s == "count:5"),
351        );
352
353        let result = flow.execute("x").await.unwrap();
354        assert_eq!(result, "count:5");
355        assert_eq!(callable.get_call_count(), 6); // 0,1,2,3,4,5
356    }
357
358    #[tokio::test]
359    async fn test_loop_execute_either_max_first() {
360        let callable = Arc::new(MockCallable::new("counter", |_, n| format!("v{}", n)));
361        let flow = LoopFlow::new(
362            "either",
363            callable.clone(),
364            LoopCondition::max_or_until(3, |s| s == "never"),
365        );
366
367        let result = flow.execute("x").await.unwrap();
368        // Exit at iteration >= 3, so 4 calls total
369        assert_eq!(callable.get_call_count(), 4);
370        assert_eq!(result, "v3");
371    }
372
373    #[tokio::test]
374    async fn test_loop_execute_either_predicate_first() {
375        let callable = Arc::new(MockCallable::done_on_call("done", 2));
376        let flow = LoopFlow::new(
377            "either",
378            callable.clone(),
379            LoopCondition::max_or_until(10, |s| s == "DONE"),
380        );
381
382        let result = flow.execute("x").await.unwrap();
383        assert_eq!(result, "DONE");
384        assert_eq!(callable.get_call_count(), 2); // Exit before max
385    }
386
387    // ============ Feedback Mode Tests ============
388
389    #[tokio::test]
390    async fn test_loop_with_feedback_enabled() {
391        // Track inputs to verify feedback
392        let inputs: Arc<std::sync::Mutex<Vec<String>>> =
393            Arc::new(std::sync::Mutex::new(Vec::new()));
394        let inputs_clone = inputs.clone();
395
396        let callable = Arc::new(MockCallable::new("fb", move |input, n| {
397            inputs_clone.lock().unwrap().push(input.to_string());
398            format!("out{}", n)
399        }));
400
401        let flow = LoopFlow::times("feedback_on", 3, callable).with_feedback(true);
402        flow.execute("start").await.unwrap();
403
404        let recorded = inputs.lock().unwrap().clone();
405        // MaxIterations(3) runs 4 times (exits when iteration >= 3)
406        assert_eq!(recorded, vec!["start", "out0", "out1", "out2"]);
407    }
408
409    #[tokio::test]
410    async fn test_loop_with_feedback_disabled() {
411        let inputs: Arc<std::sync::Mutex<Vec<String>>> =
412            Arc::new(std::sync::Mutex::new(Vec::new()));
413        let inputs_clone = inputs.clone();
414
415        let callable = Arc::new(MockCallable::new("no_fb", move |input, n| {
416            inputs_clone.lock().unwrap().push(input.to_string());
417            format!("out{}", n)
418        }));
419
420        let flow = LoopFlow::times("feedback_off", 3, callable).with_feedback(false);
421        flow.execute("same").await.unwrap();
422
423        let recorded = inputs.lock().unwrap().clone();
424        // Without feedback, input stays the same; MaxIterations(3) runs 4 times
425        assert_eq!(recorded, vec!["same", "same", "same", "same"]);
426    }
427
428    // ============ Execute with History Tests ============
429
430    #[tokio::test]
431    async fn test_loop_execute_with_history() {
432        let callable = Arc::new(MockCallable::new("hist", |_, n| format!("iter{}", n)));
433        let flow = LoopFlow::times("history_test", 4, callable);
434
435        let history = flow.execute_with_history("start").await.unwrap();
436
437        // MaxIterations(4) exits when iteration >= 4, so runs 5 times
438        assert_eq!(history.iterations, 5);
439        assert_eq!(history.outputs.len(), 5);
440        assert_eq!(
441            history.outputs,
442            vec!["iter0", "iter1", "iter2", "iter3", "iter4"]
443        );
444        assert_eq!(history.final_output, "iter4");
445    }
446
447    #[tokio::test]
448    async fn test_loop_execute_with_history_early_exit() {
449        let callable = Arc::new(MockCallable::done_on_call("early", 2));
450        let flow = LoopFlow::until_contains("early_exit", "DONE", callable);
451
452        let history = flow.execute_with_history("x").await.unwrap();
453
454        assert_eq!(history.iterations, 2);
455        assert_eq!(history.outputs.len(), 2);
456        assert_eq!(history.final_output, "DONE");
457    }
458
459    // ============ Error Handling Tests ============
460
461    #[tokio::test]
462    async fn test_loop_error_propagation() {
463        struct FailingCallable {
464            fail_on: usize,
465            call_count: Arc<AtomicUsize>,
466        }
467
468        #[async_trait]
469        impl Callable for FailingCallable {
470            fn name(&self) -> &str {
471                "failing"
472            }
473
474            async fn run(&self, _input: &str) -> anyhow::Result<String> {
475                let n = self.call_count.fetch_add(1, Ordering::SeqCst);
476                if n >= self.fail_on {
477                    anyhow::bail!("Intentional failure at iteration {}", n)
478                }
479                Ok(format!("ok{}", n))
480            }
481        }
482
483        let callable = Arc::new(FailingCallable {
484            fail_on: 2,
485            call_count: Arc::new(AtomicUsize::new(0)),
486        });
487
488        let flow = LoopFlow::times("fail_loop", 5, callable);
489        let result = flow.execute("start").await;
490
491        assert!(result.is_err());
492        assert!(result
493            .unwrap_err()
494            .to_string()
495            .contains("Intentional failure"));
496    }
497
498    // ============ Edge Cases ============
499
500    #[tokio::test]
501    async fn test_loop_zero_iterations() {
502        let callable = Arc::new(MockCallable::incrementing("zero"));
503        let flow = LoopFlow::times("zero_loop", 0, callable.clone());
504
505        // With max 0, should_exit returns true on iteration 0
506        // So we run once and immediately exit
507        let result = flow.execute("input").await.unwrap();
508        assert_eq!(callable.get_call_count(), 1);
509        assert!(result.contains("input"));
510    }
511
512    #[tokio::test]
513    async fn test_loop_immediate_exit_predicate() {
514        let callable = Arc::new(MockCallable::new("imm", |_, _| "STOP".to_string()));
515        let flow = LoopFlow::new(
516            "immediate",
517            callable.clone(),
518            LoopCondition::until_contains("STOP"),
519        );
520
521        let result = flow.execute("x").await.unwrap();
522        assert_eq!(result, "STOP");
523        assert_eq!(callable.get_call_count(), 1);
524    }
525
526    #[tokio::test]
527    async fn test_loop_single_iteration() {
528        let callable = Arc::new(MockCallable::incrementing("single"));
529        let flow = LoopFlow::times("one", 1, callable.clone());
530
531        let history = flow.execute_with_history("x").await.unwrap();
532        // MaxIterations(1) exits when iteration >= 1, so runs 2 times
533        assert_eq!(history.iterations, 2);
534        assert_eq!(callable.get_call_count(), 2);
535    }
536}