Skip to main content

rs_adk/text/
loop_agent.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4
5use super::TextAgent;
6use crate::error::AgentError;
7use crate::state::State;
8
9/// Runs a text agent repeatedly until max iterations or a state predicate.
10pub struct LoopTextAgent {
11    name: String,
12    body: Arc<dyn TextAgent>,
13    max: u32,
14    until: Option<Arc<dyn Fn(&State) -> bool + Send + Sync>>,
15}
16
17impl LoopTextAgent {
18    /// Create a new loop agent that repeats up to `max` iterations.
19    pub fn new(name: impl Into<String>, body: Arc<dyn TextAgent>, max: u32) -> Self {
20        Self {
21            name: name.into(),
22            body,
23            max,
24            until: None,
25        }
26    }
27
28    /// Add a predicate — loop breaks when predicate returns true.
29    pub fn until(mut self, pred: impl Fn(&State) -> bool + Send + Sync + 'static) -> Self {
30        self.until = Some(Arc::new(pred));
31        self
32    }
33}
34
35#[async_trait]
36impl TextAgent for LoopTextAgent {
37    fn name(&self) -> &str {
38        &self.name
39    }
40
41    async fn run(&self, state: &State) -> Result<String, AgentError> {
42        let mut last_output = String::new();
43
44        for _iter in 0..self.max {
45            last_output = self.body.run(state).await?;
46
47            if let Some(pred) = &self.until {
48                if pred(state) {
49                    break;
50                }
51            }
52        }
53
54        Ok(last_output)
55    }
56}