use crate::callable::Callable;
use std::sync::Arc;
pub enum LoopCondition {
MaxIterations(usize),
OutputMatches(Box<dyn Fn(&str) -> bool + Send + Sync>),
OutputContains(String),
Either {
max_iterations: usize,
predicate: Box<dyn Fn(&str) -> bool + Send + Sync>,
},
}
impl LoopCondition {
pub fn should_exit(&self, iteration: usize, output: &str) -> bool {
match self {
LoopCondition::MaxIterations(max) => iteration >= *max,
LoopCondition::OutputMatches(pred) => pred(output),
LoopCondition::OutputContains(needle) => output.contains(needle),
LoopCondition::Either {
max_iterations,
predicate,
} => iteration >= *max_iterations || predicate(output),
}
}
pub fn max(n: usize) -> Self {
LoopCondition::MaxIterations(n)
}
pub fn until_contains(s: impl Into<String>) -> Self {
LoopCondition::OutputContains(s.into())
}
pub fn until(pred: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
LoopCondition::OutputMatches(Box::new(pred))
}
pub fn max_or_until(max: usize, pred: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
LoopCondition::Either {
max_iterations: max,
predicate: Box::new(pred),
}
}
}
pub struct LoopFlow<C: Callable> {
callable: Arc<C>,
condition: LoopCondition,
name: String,
feedback: bool,
}
impl<C: Callable> LoopFlow<C> {
pub fn new(name: impl Into<String>, callable: Arc<C>, condition: LoopCondition) -> Self {
Self {
callable,
condition,
name: name.into(),
feedback: true, }
}
pub fn times(name: impl Into<String>, n: usize, callable: Arc<C>) -> Self {
Self::new(name, callable, LoopCondition::MaxIterations(n))
}
pub fn until_contains(name: impl Into<String>, s: impl Into<String>, callable: Arc<C>) -> Self {
Self::new(name, callable, LoopCondition::OutputContains(s.into()))
}
pub fn with_feedback(mut self, feedback: bool) -> Self {
self.feedback = feedback;
self
}
pub async fn execute(&self, input: &str) -> anyhow::Result<String> {
let mut current_input = input.to_string();
let mut iteration = 0;
loop {
let output = self.callable.run(¤t_input).await?;
if self.condition.should_exit(iteration, &output) {
return Ok(output);
}
if self.feedback {
current_input = output;
}
iteration += 1;
}
}
pub async fn execute_with_history(&self, input: &str) -> anyhow::Result<LoopHistory> {
let mut current_input = input.to_string();
let mut iteration = 0;
let mut outputs = Vec::new();
loop {
let output = self.callable.run(¤t_input).await?;
outputs.push(output.clone());
if self.condition.should_exit(iteration, &output) {
return Ok(LoopHistory {
iterations: iteration + 1,
outputs,
final_output: output,
});
}
if self.feedback {
current_input = output;
}
iteration += 1;
}
}
pub fn name(&self) -> &str {
&self.name
}
}
#[derive(Debug)]
pub struct LoopHistory {
pub iterations: usize,
pub outputs: Vec<String>,
pub final_output: String,
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::atomic::{AtomicUsize, Ordering};
#[allow(clippy::type_complexity)]
struct MockCallable {
name: String,
call_count: Arc<AtomicUsize>,
transform: Box<dyn Fn(&str, usize) -> String + Send + Sync>,
}
impl MockCallable {
fn new(
name: &str,
transform: impl Fn(&str, usize) -> String + Send + Sync + 'static,
) -> Self {
Self {
name: name.to_string(),
call_count: Arc::new(AtomicUsize::new(0)),
transform: Box::new(transform),
}
}
fn incrementing(name: &str) -> Self {
Self::new(name, |input, n| format!("{}:{}", input, n))
}
fn done_on_call(name: &str, n: usize) -> Self {
Self::new(name, move |input, call| {
if call >= n - 1 {
"DONE".to_string()
} else {
format!("{}:{}", input, call)
}
})
}
fn get_call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Callable for MockCallable {
fn name(&self) -> &str {
&self.name
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
let n = self.call_count.fetch_add(1, Ordering::SeqCst);
Ok((self.transform)(input, n))
}
}
#[test]
fn test_condition_max_iterations() {
let cond = LoopCondition::MaxIterations(3);
assert!(!cond.should_exit(0, "any"));
assert!(!cond.should_exit(1, "any"));
assert!(!cond.should_exit(2, "any"));
assert!(cond.should_exit(3, "any")); assert!(cond.should_exit(5, "any")); }
#[test]
fn test_condition_output_matches() {
let cond = LoopCondition::OutputMatches(Box::new(|s| s.len() > 5));
assert!(!cond.should_exit(0, "hi"));
assert!(!cond.should_exit(10, "short"));
assert!(cond.should_exit(0, "longer"));
assert!(cond.should_exit(0, "this is long enough"));
}
#[test]
fn test_condition_output_contains() {
let cond = LoopCondition::OutputContains("DONE".to_string());
assert!(!cond.should_exit(0, "not yet"));
assert!(!cond.should_exit(5, "still working"));
assert!(cond.should_exit(0, "DONE"));
assert!(cond.should_exit(0, "task DONE here"));
}
#[test]
fn test_condition_either() {
let cond = LoopCondition::Either {
max_iterations: 5,
predicate: Box::new(|s| s.contains("STOP")),
};
assert!(!cond.should_exit(0, "working"));
assert!(!cond.should_exit(4, "still going"));
assert!(cond.should_exit(5, "working"));
assert!(cond.should_exit(2, "STOP now"));
}
#[test]
fn test_condition_helpers() {
let cond = LoopCondition::max(2);
assert!(!cond.should_exit(1, "x"));
assert!(cond.should_exit(2, "x"));
let cond = LoopCondition::until_contains("END");
assert!(!cond.should_exit(0, "not"));
assert!(cond.should_exit(0, "END"));
let cond = LoopCondition::until(|s| s == "target");
assert!(!cond.should_exit(0, "other"));
assert!(cond.should_exit(0, "target"));
let cond = LoopCondition::max_or_until(3, |s| s.starts_with("!"));
assert!(!cond.should_exit(0, "a"));
assert!(cond.should_exit(3, "a")); assert!(cond.should_exit(0, "!bang")); }
#[tokio::test]
async fn test_loop_flow_new() {
let callable = Arc::new(MockCallable::incrementing("inc"));
let flow = LoopFlow::new("test_loop", callable, LoopCondition::MaxIterations(2));
assert_eq!(flow.name(), "test_loop");
}
#[tokio::test]
async fn test_loop_flow_times() {
let callable = Arc::new(MockCallable::incrementing("inc"));
let flow = LoopFlow::times("timer", 3, callable);
assert_eq!(flow.name(), "timer");
}
#[tokio::test]
async fn test_loop_flow_until_contains() {
let callable = Arc::new(MockCallable::done_on_call("done", 2));
let flow = LoopFlow::until_contains("stopper", "DONE", callable);
assert_eq!(flow.name(), "stopper");
}
#[tokio::test]
async fn test_loop_execute_max_iterations() {
let callable = Arc::new(MockCallable::incrementing("inc"));
let flow = LoopFlow::times("loop", 3, callable.clone());
let result = flow.execute("start").await.unwrap();
assert_eq!(callable.get_call_count(), 4);
assert!(result.contains("start"));
}
#[tokio::test]
async fn test_loop_execute_until_contains() {
let callable = Arc::new(MockCallable::done_on_call("done", 3));
let flow = LoopFlow::until_contains("wait_done", "DONE", callable.clone());
let result = flow.execute("input").await.unwrap();
assert_eq!(result, "DONE");
assert_eq!(callable.get_call_count(), 3);
}
#[tokio::test]
async fn test_loop_execute_with_predicate() {
let callable = Arc::new(MockCallable::new("counter", |_, n| format!("count:{}", n)));
let flow = LoopFlow::new(
"until_five",
callable.clone(),
LoopCondition::until(|s| s == "count:5"),
);
let result = flow.execute("x").await.unwrap();
assert_eq!(result, "count:5");
assert_eq!(callable.get_call_count(), 6); }
#[tokio::test]
async fn test_loop_execute_either_max_first() {
let callable = Arc::new(MockCallable::new("counter", |_, n| format!("v{}", n)));
let flow = LoopFlow::new(
"either",
callable.clone(),
LoopCondition::max_or_until(3, |s| s == "never"),
);
let result = flow.execute("x").await.unwrap();
assert_eq!(callable.get_call_count(), 4);
assert_eq!(result, "v3");
}
#[tokio::test]
async fn test_loop_execute_either_predicate_first() {
let callable = Arc::new(MockCallable::done_on_call("done", 2));
let flow = LoopFlow::new(
"either",
callable.clone(),
LoopCondition::max_or_until(10, |s| s == "DONE"),
);
let result = flow.execute("x").await.unwrap();
assert_eq!(result, "DONE");
assert_eq!(callable.get_call_count(), 2); }
#[tokio::test]
async fn test_loop_with_feedback_enabled() {
let inputs: Arc<std::sync::Mutex<Vec<String>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let inputs_clone = inputs.clone();
let callable = Arc::new(MockCallable::new("fb", move |input, n| {
inputs_clone.lock().unwrap().push(input.to_string());
format!("out{}", n)
}));
let flow = LoopFlow::times("feedback_on", 3, callable).with_feedback(true);
flow.execute("start").await.unwrap();
let recorded = inputs.lock().unwrap().clone();
assert_eq!(recorded, vec!["start", "out0", "out1", "out2"]);
}
#[tokio::test]
async fn test_loop_with_feedback_disabled() {
let inputs: Arc<std::sync::Mutex<Vec<String>>> =
Arc::new(std::sync::Mutex::new(Vec::new()));
let inputs_clone = inputs.clone();
let callable = Arc::new(MockCallable::new("no_fb", move |input, n| {
inputs_clone.lock().unwrap().push(input.to_string());
format!("out{}", n)
}));
let flow = LoopFlow::times("feedback_off", 3, callable).with_feedback(false);
flow.execute("same").await.unwrap();
let recorded = inputs.lock().unwrap().clone();
assert_eq!(recorded, vec!["same", "same", "same", "same"]);
}
#[tokio::test]
async fn test_loop_execute_with_history() {
let callable = Arc::new(MockCallable::new("hist", |_, n| format!("iter{}", n)));
let flow = LoopFlow::times("history_test", 4, callable);
let history = flow.execute_with_history("start").await.unwrap();
assert_eq!(history.iterations, 5);
assert_eq!(history.outputs.len(), 5);
assert_eq!(
history.outputs,
vec!["iter0", "iter1", "iter2", "iter3", "iter4"]
);
assert_eq!(history.final_output, "iter4");
}
#[tokio::test]
async fn test_loop_execute_with_history_early_exit() {
let callable = Arc::new(MockCallable::done_on_call("early", 2));
let flow = LoopFlow::until_contains("early_exit", "DONE", callable);
let history = flow.execute_with_history("x").await.unwrap();
assert_eq!(history.iterations, 2);
assert_eq!(history.outputs.len(), 2);
assert_eq!(history.final_output, "DONE");
}
#[tokio::test]
async fn test_loop_error_propagation() {
struct FailingCallable {
fail_on: usize,
call_count: Arc<AtomicUsize>,
}
#[async_trait]
impl Callable for FailingCallable {
fn name(&self) -> &str {
"failing"
}
async fn run(&self, _input: &str) -> anyhow::Result<String> {
let n = self.call_count.fetch_add(1, Ordering::SeqCst);
if n >= self.fail_on {
anyhow::bail!("Intentional failure at iteration {}", n)
}
Ok(format!("ok{}", n))
}
}
let callable = Arc::new(FailingCallable {
fail_on: 2,
call_count: Arc::new(AtomicUsize::new(0)),
});
let flow = LoopFlow::times("fail_loop", 5, callable);
let result = flow.execute("start").await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Intentional failure"));
}
#[tokio::test]
async fn test_loop_zero_iterations() {
let callable = Arc::new(MockCallable::incrementing("zero"));
let flow = LoopFlow::times("zero_loop", 0, callable.clone());
let result = flow.execute("input").await.unwrap();
assert_eq!(callable.get_call_count(), 1);
assert!(result.contains("input"));
}
#[tokio::test]
async fn test_loop_immediate_exit_predicate() {
let callable = Arc::new(MockCallable::new("imm", |_, _| "STOP".to_string()));
let flow = LoopFlow::new(
"immediate",
callable.clone(),
LoopCondition::until_contains("STOP"),
);
let result = flow.execute("x").await.unwrap();
assert_eq!(result, "STOP");
assert_eq!(callable.get_call_count(), 1);
}
#[tokio::test]
async fn test_loop_single_iteration() {
let callable = Arc::new(MockCallable::incrementing("single"));
let flow = LoopFlow::times("one", 1, callable.clone());
let history = flow.execute_with_history("x").await.unwrap();
assert_eq!(history.iterations, 2);
assert_eq!(callable.get_call_count(), 2);
}
}