Skip to main content

baml_agent/
agent_loop.rs

1use crate::loop_detect::{LoopDetector, LoopStatus};
2use crate::session::{AgentMessage, MessageRole, Session};
3use std::fmt;
4use std::future::Future;
5
6/// Result of executing a single action.
7pub struct ActionResult {
8    /// Text output from the tool (goes into history as tool message).
9    pub output: String,
10    /// Whether this action signals task completion (e.g. FinishTask, ReportCompletion).
11    pub done: bool,
12}
13
14/// Result of one LLM decision step.
15pub struct StepDecision<A> {
16    /// Current state description (shown to user).
17    pub state: String,
18    /// Remaining plan steps (shown to user).
19    pub plan: Vec<String>,
20    /// Whether the overall task is complete.
21    pub completed: bool,
22    /// Actions to execute this step.
23    pub actions: Vec<A>,
24}
25
26/// Events emitted by the agent loop (print, TUI, log).
27pub enum LoopEvent<'a, A> {
28    /// Step started (step number, 1-based).
29    StepStart(usize),
30    /// LLM returned a decision.
31    Decision { state: &'a str, plan: &'a [String] },
32    /// Task completed by LLM (task_completed=true).
33    Completed,
34    /// About to execute an action.
35    ActionStart(&'a A),
36    /// Action executed, result available.
37    ActionDone(&'a ActionResult),
38    /// Loop warning (repeated actions).
39    LoopWarning(usize),
40    /// Loop abort (too many repeats).
41    LoopAbort(usize),
42    /// Context trimmed.
43    Trimmed(usize),
44    /// Max steps reached.
45    MaxStepsReached(usize),
46    /// Streaming token from LLM (only emitted by `run_loop_stream`).
47    StreamToken(&'a str),
48}
49
50/// Configuration for the agent loop.
51pub struct LoopConfig {
52    pub max_steps: usize,
53    pub loop_abort_threshold: usize,
54}
55
56impl Default for LoopConfig {
57    fn default() -> Self {
58        Self {
59            max_steps: 50,
60            loop_abort_threshold: 6,
61        }
62    }
63}
64
65/// Base SGR Agent trait — implement per project.
66///
67/// Covers the non-streaming case (va-agent, simple CLIs).
68/// For streaming, implement [`SgrAgentStream`] on top.
69///
70/// # Stateful executors
71///
72/// If `execute` needs mutable state (e.g. MCP connections), use interior
73/// mutability (`Mutex`, `RwLock`). The trait takes `&self` to allow
74/// concurrent action execution in the future.
75pub trait SgrAgent {
76    /// The action union type (BAML-generated, project-specific).
77    type Action;
78    /// The message type (implements AgentMessage).
79    type Msg: AgentMessage;
80    /// Error type.
81    type Error: fmt::Display;
82
83    /// Call LLM to decide next actions.
84    fn decide(
85        &self,
86        messages: &[Self::Msg],
87    ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send;
88
89    /// Execute a single action. Returns tool output + done flag.
90    /// Does NOT push to session — the loop handles that.
91    fn execute(
92        &self,
93        action: &Self::Action,
94    ) -> impl Future<Output = Result<ActionResult, Self::Error>> + Send;
95
96    /// String signature for loop detection.
97    fn action_signature(action: &Self::Action) -> String;
98}
99
100/// Streaming extension for SGR agents.
101///
102/// Implement this alongside [`SgrAgent`] to get streaming tokens
103/// during the decision phase. Use with [`run_loop_stream`].
104///
105/// ```ignore
106/// impl SgrAgentStream for MyAgent {
107///     fn decide_stream<T>(&self, messages: &[Msg], mut on_token: T)
108///         -> impl Future<Output = Result<StepDecision<Action>, Error>> + Send
109///     where T: FnMut(&str) + Send
110///     {
111///         async move {
112///             let stream = B.MyFunction.stream(&messages).await?;
113///             while let Some(partial) = stream.next().await {
114///                 on_token(&partial.raw_text);
115///             }
116///             let result = stream.get_final_response().await?;
117///             Ok(StepDecision { ... })
118///         }
119///     }
120/// }
121/// ```
122pub trait SgrAgentStream: SgrAgent {
123    /// Call LLM with streaming — emits tokens via `on_token` callback.
124    fn decide_stream<T>(
125        &self,
126        messages: &[Self::Msg],
127        on_token: T,
128    ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send
129    where
130        T: FnMut(&str) + Send;
131}
132
133// --- Shared loop internals ---
134
135/// Post-decision processing: loop detection, action execution, session updates.
136///
137/// Shared between `run_loop`, `run_loop_stream`, and custom loops (e.g. TUI).
138/// Returns `Some(step_num)` if the loop should stop, `None` to continue.
139pub async fn process_step<A, F>(
140    agent: &A,
141    session: &mut Session<A::Msg>,
142    decision: StepDecision<A::Action>,
143    step_num: usize,
144    detector: &mut LoopDetector,
145    on_event: &mut F,
146) -> Result<Option<usize>, A::Error>
147where
148    A: SgrAgent,
149    F: FnMut(LoopEvent<'_, A::Action>) + Send,
150{
151    on_event(LoopEvent::Decision {
152        state: &decision.state,
153        plan: &decision.plan,
154    });
155
156    if decision.completed {
157        on_event(LoopEvent::Completed);
158        return Ok(Some(step_num));
159    }
160
161    // Loop detection
162    let sig = decision
163        .actions
164        .iter()
165        .map(A::action_signature)
166        .collect::<Vec<_>>()
167        .join("|");
168
169    match detector.check(&sig) {
170        LoopStatus::Abort(n) => {
171            on_event(LoopEvent::LoopAbort(n));
172            session.push(
173                <<A::Msg as AgentMessage>::Role>::system(),
174                "SYSTEM: You have been repeating the same action. Session terminated.".into(),
175            );
176            return Ok(Some(step_num));
177        }
178        LoopStatus::Warning(n) => {
179            on_event(LoopEvent::LoopWarning(n));
180            session.push(
181                <<A::Msg as AgentMessage>::Role>::system(),
182                "SYSTEM: You are repeating the same action. Try a different approach or report completion.".into(),
183            );
184        }
185        LoopStatus::Ok => {}
186    }
187
188    // Execute actions
189    for action in &decision.actions {
190        on_event(LoopEvent::ActionStart(action));
191
192        match agent.execute(action).await {
193            Ok(result) => {
194                session.push(
195                    <<A::Msg as AgentMessage>::Role>::tool(),
196                    result.output.clone(),
197                );
198                let done = result.done;
199                on_event(LoopEvent::ActionDone(&result));
200                if done {
201                    return Ok(Some(step_num));
202                }
203            }
204            Err(e) => {
205                session.push(
206                    <<A::Msg as AgentMessage>::Role>::tool(),
207                    format!("Tool error: {}", e),
208                );
209            }
210        }
211    }
212
213    Ok(None) // continue
214}
215
216/// Run the SGR agent loop (non-streaming).
217///
218/// `trim → decide → check loop → execute → push results → repeat`
219///
220/// Returns the number of steps executed.
221pub async fn run_loop<A, F>(
222    agent: &A,
223    session: &mut Session<A::Msg>,
224    config: &LoopConfig,
225    mut on_event: F,
226) -> Result<usize, A::Error>
227where
228    A: SgrAgent,
229    F: FnMut(LoopEvent<'_, A::Action>) + Send,
230{
231    let mut detector = LoopDetector::new(config.loop_abort_threshold);
232
233    for step_num in 1..=config.max_steps {
234        let trimmed = session.trim();
235        if trimmed > 0 {
236            on_event(LoopEvent::Trimmed(trimmed));
237        }
238
239        on_event(LoopEvent::StepStart(step_num));
240
241        let decision = agent.decide(session.messages()).await?;
242
243        if let Some(final_step) = process_step(agent, session, decision, step_num, &mut detector, &mut on_event).await? {
244            return Ok(final_step);
245        }
246    }
247
248    on_event(LoopEvent::MaxStepsReached(config.max_steps));
249    Ok(config.max_steps)
250}
251
252/// Run the SGR agent loop with streaming tokens.
253///
254/// Same as [`run_loop`] but calls `decide_stream` instead of `decide`,
255/// emitting `LoopEvent::StreamToken` during the decision phase.
256///
257/// Requires the agent to implement [`SgrAgentStream`].
258pub async fn run_loop_stream<A, F>(
259    agent: &A,
260    session: &mut Session<A::Msg>,
261    config: &LoopConfig,
262    mut on_event: F,
263) -> Result<usize, A::Error>
264where
265    A: SgrAgentStream,
266    F: FnMut(LoopEvent<'_, A::Action>) + Send,
267{
268    let mut detector = LoopDetector::new(config.loop_abort_threshold);
269
270    for step_num in 1..=config.max_steps {
271        let trimmed = session.trim();
272        if trimmed > 0 {
273            on_event(LoopEvent::Trimmed(trimmed));
274        }
275
276        on_event(LoopEvent::StepStart(step_num));
277
278        let decision = agent.decide_stream(session.messages(), |token| {
279            on_event(LoopEvent::StreamToken(token));
280        }).await?;
281
282        if let Some(final_step) = process_step(agent, session, decision, step_num, &mut detector, &mut on_event).await? {
283            return Ok(final_step);
284        }
285    }
286
287    on_event(LoopEvent::MaxStepsReached(config.max_steps));
288    Ok(config.max_steps)
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::session::tests::{TestMsg, TestRole};
295    use std::sync::atomic::{AtomicUsize, Ordering};
296
297    struct MockAgent {
298        steps_before_done: AtomicUsize,
299    }
300
301    impl SgrAgent for MockAgent {
302        type Action = String;
303        type Msg = TestMsg;
304        type Error = String;
305
306        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
307            let remaining = self.steps_before_done.fetch_sub(1, Ordering::SeqCst);
308            if remaining <= 1 {
309                Ok(StepDecision {
310                    state: "done".into(),
311                    plan: vec![],
312                    completed: true,
313                    actions: vec![],
314                })
315            } else {
316                Ok(StepDecision {
317                    state: format!("{} steps left", remaining - 1),
318                    plan: vec!["do something".into()],
319                    completed: false,
320                    actions: vec![format!("action_{}", remaining)],
321                })
322            }
323        }
324
325        async fn execute(&self, action: &String) -> Result<ActionResult, String> {
326            Ok(ActionResult {
327                output: format!("result of {}", action),
328                done: false,
329            })
330        }
331
332        fn action_signature(action: &String) -> String {
333            action.clone()
334        }
335    }
336
337    #[tokio::test]
338    async fn loop_completes_after_n_steps() {
339        let dir = std::env::temp_dir().join("baml_loop_test_complete");
340        let _ = std::fs::remove_dir_all(&dir);
341        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
342        session.push(TestRole::User, "do something".into());
343
344        let agent = MockAgent {
345            steps_before_done: AtomicUsize::new(3),
346        };
347        let config = LoopConfig { max_steps: 10, loop_abort_threshold: 6 };
348
349        let mut events = vec![];
350        let steps = run_loop(&agent, &mut session, &config, |event| {
351            match &event {
352                LoopEvent::StepStart(n) => events.push(format!("step:{}", n)),
353                LoopEvent::Completed => events.push("completed".into()),
354                LoopEvent::ActionDone(r) => events.push(format!("done:{}", r.output)),
355                _ => {}
356            }
357        }).await.unwrap();
358
359        assert_eq!(steps, 3);
360        assert!(events.contains(&"completed".to_string()));
361        assert!(session.len() > 1);
362
363        let _ = std::fs::remove_dir_all(&dir);
364    }
365
366    struct LoopyAgent;
367
368    impl SgrAgent for LoopyAgent {
369        type Action = String;
370        type Msg = TestMsg;
371        type Error = String;
372
373        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
374            Ok(StepDecision {
375                state: "stuck".into(),
376                plan: vec!["same thing again".into()],
377                completed: false,
378                actions: vec!["same_action".into()],
379            })
380        }
381
382        async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
383            Ok(ActionResult {
384                output: "same result".into(),
385                done: false,
386            })
387        }
388
389        fn action_signature(action: &String) -> String {
390            action.clone()
391        }
392    }
393
394    #[tokio::test]
395    async fn loop_detects_and_aborts() {
396        let dir = std::env::temp_dir().join("baml_loop_test_abort");
397        let _ = std::fs::remove_dir_all(&dir);
398        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
399        session.push(TestRole::User, "do something".into());
400
401        let config = LoopConfig { max_steps: 20, loop_abort_threshold: 4 };
402
403        let mut got_warning = false;
404        let mut got_abort = false;
405        let steps = run_loop(&LoopyAgent, &mut session, &config, |event| {
406            match event {
407                LoopEvent::LoopWarning(_) => got_warning = true,
408                LoopEvent::LoopAbort(_) => got_abort = true,
409                _ => {}
410            }
411        }).await.unwrap();
412
413        assert!(got_warning);
414        assert!(got_abort);
415        assert!(steps <= 4);
416
417        let _ = std::fs::remove_dir_all(&dir);
418    }
419
420    // --- Streaming trait test ---
421
422    struct StreamingAgent;
423
424    impl SgrAgent for StreamingAgent {
425        type Action = String;
426        type Msg = TestMsg;
427        type Error = String;
428
429        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
430            Ok(StepDecision {
431                state: "done".into(),
432                plan: vec![],
433                completed: true,
434                actions: vec![],
435            })
436        }
437
438        async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
439            Ok(ActionResult { output: "ok".into(), done: false })
440        }
441
442        fn action_signature(action: &String) -> String {
443            action.clone()
444        }
445    }
446
447    impl SgrAgentStream for StreamingAgent {
448        fn decide_stream<T>(
449            &self,
450            _messages: &[TestMsg],
451            mut on_token: T,
452        ) -> impl Future<Output = Result<StepDecision<String>, String>> + Send
453        where
454            T: FnMut(&str) + Send,
455        {
456            async move {
457                on_token("Thin");
458                on_token("king");
459                on_token("...");
460                Ok(StepDecision {
461                    state: "done".into(),
462                    plan: vec![],
463                    completed: true,
464                    actions: vec![],
465                })
466            }
467        }
468    }
469
470    #[tokio::test]
471    async fn streaming_tokens_emitted() {
472        let dir = std::env::temp_dir().join("baml_loop_test_stream");
473        let _ = std::fs::remove_dir_all(&dir);
474        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
475        session.push(TestRole::User, "hello".into());
476
477        let config = LoopConfig { max_steps: 5, loop_abort_threshold: 6 };
478
479        let mut tokens = vec![];
480        let mut completed = false;
481        run_loop_stream(&StreamingAgent, &mut session, &config, |event| {
482            match event {
483                LoopEvent::StreamToken(t) => tokens.push(t.to_string()),
484                LoopEvent::Completed => completed = true,
485                _ => {}
486            }
487        }).await.unwrap();
488
489        assert!(completed);
490        assert_eq!(tokens, vec!["Thin", "king", "..."]);
491
492        let _ = std::fs::remove_dir_all(&dir);
493    }
494
495    /// Non-streaming agent can also use run_loop (base trait only).
496    #[tokio::test]
497    async fn non_streaming_agent_works() {
498        let dir = std::env::temp_dir().join("baml_loop_test_nostream");
499        let _ = std::fs::remove_dir_all(&dir);
500        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
501        session.push(TestRole::User, "hello".into());
502
503        let config = LoopConfig { max_steps: 5, loop_abort_threshold: 6 };
504
505        // StreamingAgent also implements SgrAgent, so run_loop works
506        let mut completed = false;
507        run_loop(&StreamingAgent, &mut session, &config, |event| {
508            if matches!(event, LoopEvent::Completed) { completed = true; }
509        }).await.unwrap();
510
511        assert!(completed);
512        let _ = std::fs::remove_dir_all(&dir);
513    }
514}