1use crate::callable::Callable;
6use std::sync::Arc;
7
8pub enum LoopCondition {
10 MaxIterations(usize),
12 OutputMatches(Box<dyn Fn(&str) -> bool + Send + Sync>),
14 OutputContains(String),
16 Either {
18 max_iterations: usize,
19 predicate: Box<dyn Fn(&str) -> bool + Send + Sync>,
20 },
21}
22
23impl LoopCondition {
24 pub fn should_exit(&self, iteration: usize, output: &str) -> bool {
26 match self {
27 LoopCondition::MaxIterations(max) => iteration >= *max,
28 LoopCondition::OutputMatches(pred) => pred(output),
29 LoopCondition::OutputContains(needle) => output.contains(needle),
30 LoopCondition::Either {
31 max_iterations,
32 predicate,
33 } => iteration >= *max_iterations || predicate(output),
34 }
35 }
36
37 pub fn max(n: usize) -> Self {
39 LoopCondition::MaxIterations(n)
40 }
41
42 pub fn until_contains(s: impl Into<String>) -> Self {
44 LoopCondition::OutputContains(s.into())
45 }
46
47 pub fn until(pred: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
49 LoopCondition::OutputMatches(Box::new(pred))
50 }
51
52 pub fn max_or_until(max: usize, pred: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
54 LoopCondition::Either {
55 max_iterations: max,
56 predicate: Box::new(pred),
57 }
58 }
59}
60
61pub struct LoopFlow<C: Callable> {
63 callable: Arc<C>,
65 condition: LoopCondition,
67 name: String,
69 feedback: bool,
71}
72
73impl<C: Callable> LoopFlow<C> {
74 pub fn new(name: impl Into<String>, callable: Arc<C>, condition: LoopCondition) -> Self {
76 Self {
77 callable,
78 condition,
79 name: name.into(),
80 feedback: true, }
82 }
83
84 pub fn times(name: impl Into<String>, n: usize, callable: Arc<C>) -> Self {
86 Self::new(name, callable, LoopCondition::MaxIterations(n))
87 }
88
89 pub fn until_contains(name: impl Into<String>, s: impl Into<String>, callable: Arc<C>) -> Self {
91 Self::new(name, callable, LoopCondition::OutputContains(s.into()))
92 }
93
94 pub fn with_feedback(mut self, feedback: bool) -> Self {
96 self.feedback = feedback;
97 self
98 }
99
100 pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
102 let mut current_input = input.to_string();
103 let mut iteration = 0;
104
105 loop {
106 let output = self.callable.run(¤t_input).await?;
107
108 if self.condition.should_exit(iteration, &output) {
109 return Ok(output);
110 }
111
112 if self.feedback {
114 current_input = output;
115 }
116 iteration += 1;
117 }
118 }
119
120 pub async fn execute_with_history(&self, input: &str) -> anyhow::Result<LoopHistory> {
122 let mut current_input = input.to_string();
123 let mut iteration = 0;
124 let mut outputs = Vec::new();
125
126 loop {
127 let output = self.callable.run(¤t_input).await?;
128 outputs.push(output.clone());
129
130 if self.condition.should_exit(iteration, &output) {
131 return Ok(LoopHistory {
132 iterations: iteration + 1,
133 outputs,
134 final_output: output,
135 });
136 }
137
138 if self.feedback {
139 current_input = output;
140 }
141 iteration += 1;
142 }
143 }
144
145 pub fn name(&self) -> &str {
147 &self.name
148 }
149}
150
151#[derive(Debug)]
153pub struct LoopHistory {
154 pub iterations: usize,
156 pub outputs: Vec<String>,
158 pub final_output: String,
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use async_trait::async_trait;
166 use std::sync::atomic::{AtomicUsize, Ordering};
167
168 #[allow(clippy::type_complexity)]
170 struct MockCallable {
171 name: String,
172 call_count: Arc<AtomicUsize>,
173 transform: Box<dyn Fn(&str, usize) -> String + Send + Sync>,
174 }
175
176 impl MockCallable {
177 fn new(
178 name: &str,
179 transform: impl Fn(&str, usize) -> String + Send + Sync + 'static,
180 ) -> Self {
181 Self {
182 name: name.to_string(),
183 call_count: Arc::new(AtomicUsize::new(0)),
184 transform: Box::new(transform),
185 }
186 }
187
188 fn incrementing(name: &str) -> Self {
190 Self::new(name, |input, n| format!("{}:{}", input, n))
191 }
192
193 fn done_on_call(name: &str, n: usize) -> Self {
195 Self::new(name, move |input, call| {
196 if call >= n - 1 {
197 "DONE".to_string()
198 } else {
199 format!("{}:{}", input, call)
200 }
201 })
202 }
203
204 fn get_call_count(&self) -> usize {
205 self.call_count.load(Ordering::SeqCst)
206 }
207 }
208
209 #[async_trait]
210 impl Callable for MockCallable {
211 fn name(&self) -> &str {
212 &self.name
213 }
214
215 async fn run(&self, input: &str) -> anyhow::Result<String> {
216 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
217 Ok((self.transform)(input, n))
218 }
219 }
220
221 #[test]
224 fn test_condition_max_iterations() {
225 let cond = LoopCondition::MaxIterations(3);
226 assert!(!cond.should_exit(0, "any"));
227 assert!(!cond.should_exit(1, "any"));
228 assert!(!cond.should_exit(2, "any"));
229 assert!(cond.should_exit(3, "any")); assert!(cond.should_exit(5, "any")); }
232
233 #[test]
234 fn test_condition_output_matches() {
235 let cond = LoopCondition::OutputMatches(Box::new(|s| s.len() > 5));
236 assert!(!cond.should_exit(0, "hi"));
237 assert!(!cond.should_exit(10, "short"));
238 assert!(cond.should_exit(0, "longer"));
239 assert!(cond.should_exit(0, "this is long enough"));
240 }
241
242 #[test]
243 fn test_condition_output_contains() {
244 let cond = LoopCondition::OutputContains("DONE".to_string());
245 assert!(!cond.should_exit(0, "not yet"));
246 assert!(!cond.should_exit(5, "still working"));
247 assert!(cond.should_exit(0, "DONE"));
248 assert!(cond.should_exit(0, "task DONE here"));
249 }
250
251 #[test]
252 fn test_condition_either() {
253 let cond = LoopCondition::Either {
254 max_iterations: 5,
255 predicate: Box::new(|s| s.contains("STOP")),
256 };
257
258 assert!(!cond.should_exit(0, "working"));
260 assert!(!cond.should_exit(4, "still going"));
261
262 assert!(cond.should_exit(5, "working"));
264
265 assert!(cond.should_exit(2, "STOP now"));
267 }
268
269 #[test]
270 fn test_condition_helpers() {
271 let cond = LoopCondition::max(2);
273 assert!(!cond.should_exit(1, "x"));
274 assert!(cond.should_exit(2, "x"));
275
276 let cond = LoopCondition::until_contains("END");
278 assert!(!cond.should_exit(0, "not"));
279 assert!(cond.should_exit(0, "END"));
280
281 let cond = LoopCondition::until(|s| s == "target");
283 assert!(!cond.should_exit(0, "other"));
284 assert!(cond.should_exit(0, "target"));
285
286 let cond = LoopCondition::max_or_until(3, |s| s.starts_with("!"));
288 assert!(!cond.should_exit(0, "a"));
289 assert!(cond.should_exit(3, "a")); assert!(cond.should_exit(0, "!bang")); }
292
293 #[tokio::test]
296 async fn test_loop_flow_new() {
297 let callable = Arc::new(MockCallable::incrementing("inc"));
298 let flow = LoopFlow::new("test_loop", callable, LoopCondition::MaxIterations(2));
299 assert_eq!(flow.name(), "test_loop");
300 }
301
302 #[tokio::test]
303 async fn test_loop_flow_times() {
304 let callable = Arc::new(MockCallable::incrementing("inc"));
305 let flow = LoopFlow::times("timer", 3, callable);
306 assert_eq!(flow.name(), "timer");
307 }
308
309 #[tokio::test]
310 async fn test_loop_flow_until_contains() {
311 let callable = Arc::new(MockCallable::done_on_call("done", 2));
312 let flow = LoopFlow::until_contains("stopper", "DONE", callable);
313 assert_eq!(flow.name(), "stopper");
314 }
315
316 #[tokio::test]
319 async fn test_loop_execute_max_iterations() {
320 let callable = Arc::new(MockCallable::incrementing("inc"));
321 let flow = LoopFlow::times("loop", 3, callable.clone());
322
323 let result = flow.execute("start").await.unwrap();
324 assert_eq!(callable.get_call_count(), 4);
331 assert!(result.contains("start"));
332 }
333
334 #[tokio::test]
335 async fn test_loop_execute_until_contains() {
336 let callable = Arc::new(MockCallable::done_on_call("done", 3));
337 let flow = LoopFlow::until_contains("wait_done", "DONE", callable.clone());
338
339 let result = flow.execute("input").await.unwrap();
340 assert_eq!(result, "DONE");
341 assert_eq!(callable.get_call_count(), 3);
342 }
343
344 #[tokio::test]
345 async fn test_loop_execute_with_predicate() {
346 let callable = Arc::new(MockCallable::new("counter", |_, n| format!("count:{}", n)));
347 let flow = LoopFlow::new(
348 "until_five",
349 callable.clone(),
350 LoopCondition::until(|s| s == "count:5"),
351 );
352
353 let result = flow.execute("x").await.unwrap();
354 assert_eq!(result, "count:5");
355 assert_eq!(callable.get_call_count(), 6); }
357
358 #[tokio::test]
359 async fn test_loop_execute_either_max_first() {
360 let callable = Arc::new(MockCallable::new("counter", |_, n| format!("v{}", n)));
361 let flow = LoopFlow::new(
362 "either",
363 callable.clone(),
364 LoopCondition::max_or_until(3, |s| s == "never"),
365 );
366
367 let result = flow.execute("x").await.unwrap();
368 assert_eq!(callable.get_call_count(), 4);
370 assert_eq!(result, "v3");
371 }
372
373 #[tokio::test]
374 async fn test_loop_execute_either_predicate_first() {
375 let callable = Arc::new(MockCallable::done_on_call("done", 2));
376 let flow = LoopFlow::new(
377 "either",
378 callable.clone(),
379 LoopCondition::max_or_until(10, |s| s == "DONE"),
380 );
381
382 let result = flow.execute("x").await.unwrap();
383 assert_eq!(result, "DONE");
384 assert_eq!(callable.get_call_count(), 2); }
386
387 #[tokio::test]
390 async fn test_loop_with_feedback_enabled() {
391 let inputs: Arc<std::sync::Mutex<Vec<String>>> =
393 Arc::new(std::sync::Mutex::new(Vec::new()));
394 let inputs_clone = inputs.clone();
395
396 let callable = Arc::new(MockCallable::new("fb", move |input, n| {
397 inputs_clone.lock().unwrap().push(input.to_string());
398 format!("out{}", n)
399 }));
400
401 let flow = LoopFlow::times("feedback_on", 3, callable).with_feedback(true);
402 flow.execute("start").await.unwrap();
403
404 let recorded = inputs.lock().unwrap().clone();
405 assert_eq!(recorded, vec!["start", "out0", "out1", "out2"]);
407 }
408
409 #[tokio::test]
410 async fn test_loop_with_feedback_disabled() {
411 let inputs: Arc<std::sync::Mutex<Vec<String>>> =
412 Arc::new(std::sync::Mutex::new(Vec::new()));
413 let inputs_clone = inputs.clone();
414
415 let callable = Arc::new(MockCallable::new("no_fb", move |input, n| {
416 inputs_clone.lock().unwrap().push(input.to_string());
417 format!("out{}", n)
418 }));
419
420 let flow = LoopFlow::times("feedback_off", 3, callable).with_feedback(false);
421 flow.execute("same").await.unwrap();
422
423 let recorded = inputs.lock().unwrap().clone();
424 assert_eq!(recorded, vec!["same", "same", "same", "same"]);
426 }
427
428 #[tokio::test]
431 async fn test_loop_execute_with_history() {
432 let callable = Arc::new(MockCallable::new("hist", |_, n| format!("iter{}", n)));
433 let flow = LoopFlow::times("history_test", 4, callable);
434
435 let history = flow.execute_with_history("start").await.unwrap();
436
437 assert_eq!(history.iterations, 5);
439 assert_eq!(history.outputs.len(), 5);
440 assert_eq!(
441 history.outputs,
442 vec!["iter0", "iter1", "iter2", "iter3", "iter4"]
443 );
444 assert_eq!(history.final_output, "iter4");
445 }
446
447 #[tokio::test]
448 async fn test_loop_execute_with_history_early_exit() {
449 let callable = Arc::new(MockCallable::done_on_call("early", 2));
450 let flow = LoopFlow::until_contains("early_exit", "DONE", callable);
451
452 let history = flow.execute_with_history("x").await.unwrap();
453
454 assert_eq!(history.iterations, 2);
455 assert_eq!(history.outputs.len(), 2);
456 assert_eq!(history.final_output, "DONE");
457 }
458
459 #[tokio::test]
462 async fn test_loop_error_propagation() {
463 struct FailingCallable {
464 fail_on: usize,
465 call_count: Arc<AtomicUsize>,
466 }
467
468 #[async_trait]
469 impl Callable for FailingCallable {
470 fn name(&self) -> &str {
471 "failing"
472 }
473
474 async fn run(&self, _input: &str) -> anyhow::Result<String> {
475 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
476 if n >= self.fail_on {
477 anyhow::bail!("Intentional failure at iteration {}", n)
478 }
479 Ok(format!("ok{}", n))
480 }
481 }
482
483 let callable = Arc::new(FailingCallable {
484 fail_on: 2,
485 call_count: Arc::new(AtomicUsize::new(0)),
486 });
487
488 let flow = LoopFlow::times("fail_loop", 5, callable);
489 let result = flow.execute("start").await;
490
491 assert!(result.is_err());
492 assert!(result
493 .unwrap_err()
494 .to_string()
495 .contains("Intentional failure"));
496 }
497
498 #[tokio::test]
501 async fn test_loop_zero_iterations() {
502 let callable = Arc::new(MockCallable::incrementing("zero"));
503 let flow = LoopFlow::times("zero_loop", 0, callable.clone());
504
505 let result = flow.execute("input").await.unwrap();
508 assert_eq!(callable.get_call_count(), 1);
509 assert!(result.contains("input"));
510 }
511
512 #[tokio::test]
513 async fn test_loop_immediate_exit_predicate() {
514 let callable = Arc::new(MockCallable::new("imm", |_, _| "STOP".to_string()));
515 let flow = LoopFlow::new(
516 "immediate",
517 callable.clone(),
518 LoopCondition::until_contains("STOP"),
519 );
520
521 let result = flow.execute("x").await.unwrap();
522 assert_eq!(result, "STOP");
523 assert_eq!(callable.get_call_count(), 1);
524 }
525
526 #[tokio::test]
527 async fn test_loop_single_iteration() {
528 let callable = Arc::new(MockCallable::incrementing("single"));
529 let flow = LoopFlow::times("one", 1, callable.clone());
530
531 let history = flow.execute_with_history("x").await.unwrap();
532 assert_eq!(history.iterations, 2);
534 assert_eq!(callable.get_call_count(), 2);
535 }
536}