Skip to main content

baml_agent/
agent_loop.rs

1use crate::loop_detect::{normalize_signature, LoopDetector, LoopStatus};
2use crate::session::{AgentMessage, MessageRole, Session};
3use std::fmt;
4use std::future::Future;
5
6/// Result of executing a single action.
7pub struct ActionResult {
8    /// Text output from the tool (goes into history as tool message).
9    pub output: String,
10    /// Whether this action signals task completion (e.g. FinishTask, ReportCompletion).
11    pub done: bool,
12}
13
14/// Result of one LLM decision step (STAR: Situation → Task → Action).
15pub struct StepDecision<A> {
16    /// S — Situation: current state assessment.
17    pub situation: String,
18    /// T — Task: remaining steps (first = current).
19    pub task: Vec<String>,
20    /// Whether the overall task is complete (R — Result).
21    pub completed: bool,
22    /// A — Action: tools to execute.
23    pub actions: Vec<A>,
24}
25
26/// Events emitted by the agent loop (print, TUI, log).
27pub enum LoopEvent<'a, A> {
28    /// Step started (step number, 1-based).
29    StepStart(usize),
30    /// LLM returned a decision.
31    Decision {
32        situation: &'a str,
33        task: &'a [String],
34    },
35    /// Task completed by LLM (task_completed=true).
36    Completed,
37    /// About to execute an action.
38    ActionStart(&'a A),
39    /// Action executed, result available.
40    ActionDone(&'a ActionResult),
41    /// Loop warning (repeated actions).
42    LoopWarning(usize),
43    /// Loop abort (too many repeats).
44    LoopAbort(usize),
45    /// Context trimmed.
46    Trimmed(usize),
47    /// Max steps reached.
48    MaxStepsReached(usize),
49    /// Streaming token from LLM (only emitted by `run_loop_stream`).
50    StreamToken(&'a str),
51}
52
53/// Configuration for the agent loop.
54#[derive(Clone)]
55pub struct LoopConfig {
56    pub max_steps: usize,
57    pub loop_abort_threshold: usize,
58}
59
60impl Default for LoopConfig {
61    fn default() -> Self {
62        Self {
63            max_steps: 50,
64            loop_abort_threshold: 6,
65        }
66    }
67}
68
69/// Base SGR Agent trait — implement per project.
70///
71/// Covers the non-streaming case (va-agent, simple CLIs).
72/// For streaming, implement [`SgrAgentStream`] on top.
73///
74/// # Stateful executors
75///
76/// If `execute` needs mutable state (e.g. MCP connections), use interior
77/// mutability (`Mutex`, `RwLock`). The trait takes `&self` to allow
78/// concurrent action execution in the future.
79pub trait SgrAgent {
80    /// The action union type (BAML-generated, project-specific).
81    type Action: Send + Sync;
82    /// The message type (implements AgentMessage).
83    type Msg: AgentMessage + Send + Sync;
84    /// Error type.
85    type Error: fmt::Display + Send;
86
87    /// Call LLM to decide next actions.
88    fn decide(
89        &self,
90        messages: &[Self::Msg],
91    ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send;
92
93    /// Execute a single action. Returns tool output + done flag.
94    /// Does NOT push to session — the loop handles that.
95    fn execute(
96        &self,
97        action: &Self::Action,
98    ) -> impl Future<Output = Result<ActionResult, Self::Error>> + Send;
99
100    /// String signature for loop detection (exact match).
101    fn action_signature(action: &Self::Action) -> String;
102
103    /// Coarse category for semantic loop detection.
104    ///
105    /// Default: normalize the signature (strips bash flags, quotes, fallbacks).
106    /// Override for project-specific normalization.
107    fn action_category(action: &Self::Action) -> String {
108        normalize_signature(&Self::action_signature(action))
109    }
110}
111
112/// Streaming extension for SGR agents.
113///
114/// Implement this alongside [`SgrAgent`] to get streaming tokens
115/// during the decision phase. Use with [`run_loop_stream`].
116///
117/// ```ignore
118/// impl SgrAgentStream for MyAgent {
119///     fn decide_stream<T>(&self, messages: &[Msg], mut on_token: T)
120///         -> impl Future<Output = Result<StepDecision<Action>, Error>> + Send
121///     where T: FnMut(&str) + Send
122///     {
123///         async move {
124///             let stream = B.MyFunction.stream(&messages).await?;
125///             while let Some(partial) = stream.next().await {
126///                 on_token(&partial.raw_text);
127///             }
128///             let result = stream.get_final_response().await?;
129///             Ok(StepDecision { ... })
130///         }
131///     }
132/// }
133/// ```
134pub trait SgrAgentStream: SgrAgent {
135    /// Call LLM with streaming — emits tokens via `on_token` callback.
136    fn decide_stream<T>(
137        &self,
138        messages: &[Self::Msg],
139        on_token: T,
140    ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send
141    where
142        T: FnMut(&str) + Send;
143}
144
145// --- Shared loop internals ---
146
147/// Post-decision processing: loop detection, action execution, session updates.
148///
149/// Shared between `run_loop`, `run_loop_stream`, and custom loops (e.g. TUI).
150/// Returns `Some(step_num)` if the loop should stop, `None` to continue.
151pub async fn process_step<A, F>(
152    agent: &A,
153    session: &mut Session<A::Msg>,
154    decision: StepDecision<A::Action>,
155    step_num: usize,
156    detector: &mut LoopDetector,
157    on_event: &mut F,
158) -> Result<Option<usize>, A::Error>
159where
160    A: SgrAgent,
161    F: FnMut(LoopEvent<'_, A::Action>) + Send,
162{
163    on_event(LoopEvent::Decision {
164        situation: &decision.situation,
165        task: &decision.task,
166    });
167
168    if decision.completed {
169        on_event(LoopEvent::Completed);
170        return Ok(Some(step_num));
171    }
172
173    // --- Signatures: exact + normalized category ---
174    let sig = decision
175        .actions
176        .iter()
177        .map(A::action_signature)
178        .collect::<Vec<_>>()
179        .join("|");
180
181    let category = decision
182        .actions
183        .iter()
184        .map(A::action_category)
185        .collect::<Vec<_>>()
186        .join("|");
187
188    // --- Empty actions guard ---
189    if decision.actions.is_empty() {
190        match detector.check(&sig) {
191            LoopStatus::Abort(n) => {
192                on_event(LoopEvent::LoopAbort(n));
193                session.push(
194                    <<A::Msg as AgentMessage>::Role>::system(),
195                    "SYSTEM: Repeatedly returning empty actions. Session terminated.".into(),
196                );
197                return Ok(Some(step_num));
198            }
199            _ => {
200                session.push(
201                    <<A::Msg as AgentMessage>::Role>::system(),
202                    "SYSTEM: You returned empty next_actions. You MUST emit at least one tool call \
203                     in next_actions array. Look at the TOOLS section and pick the right tool for \
204                     your current phase.".into(),
205                );
206                return Ok(None);
207            }
208        }
209    }
210
211    // --- Tier 1+2: exact + category loop detection ---
212    match detector.check_with_category(&sig, &category) {
213        LoopStatus::Abort(n) => {
214            on_event(LoopEvent::LoopAbort(n));
215            session.push(
216                <<A::Msg as AgentMessage>::Role>::system(),
217                format!(
218                    "SYSTEM: Detected {} repetitions of the same action (category: {}). \
219                     The result will not change. Session terminated.",
220                    n, category
221                ),
222            );
223            return Ok(Some(step_num));
224        }
225        LoopStatus::Warning(n) => {
226            on_event(LoopEvent::LoopWarning(n));
227            session.push(
228                <<A::Msg as AgentMessage>::Role>::system(),
229                format!(
230                    "SYSTEM: You have repeated the same action {} times (category: {}). \
231                     The result is DEFINITIVE. Do NOT retry — either proceed to the next \
232                     step or use FinishTaskTool to report completion.",
233                    n, category
234                ),
235            );
236        }
237        LoopStatus::Ok => {}
238    }
239
240    // --- Execute actions ---
241    for action in &decision.actions {
242        on_event(LoopEvent::ActionStart(action));
243
244        match agent.execute(action).await {
245            Ok(result) => {
246                session.push(
247                    <<A::Msg as AgentMessage>::Role>::tool(),
248                    result.output.clone(),
249                );
250
251                let done = result.done;
252                on_event(LoopEvent::ActionDone(&result));
253
254                // --- Tier 3: output stagnation ---
255                match detector.record_output(&result.output) {
256                    LoopStatus::Abort(n) => {
257                        on_event(LoopEvent::LoopAbort(n));
258                        session.push(
259                            <<A::Msg as AgentMessage>::Role>::system(),
260                            format!(
261                                "SYSTEM: Tool returned identical output {} times. The result is \
262                                 DEFINITIVE and will not change. If searching found nothing, \
263                                 nothing exists. Accept the result and proceed to the next task \
264                                 step or use FinishTaskTool.",
265                                n
266                            ),
267                        );
268                        return Ok(Some(step_num));
269                    }
270                    LoopStatus::Warning(n) => {
271                        on_event(LoopEvent::LoopWarning(n));
272                        session.push(
273                            <<A::Msg as AgentMessage>::Role>::system(),
274                            format!(
275                                "SYSTEM: Same tool output {} times in a row. The result will \
276                                 not change — accept it and move forward. Do NOT retry the \
277                                 same operation.",
278                                n
279                            ),
280                        );
281                    }
282                    LoopStatus::Ok => {}
283                }
284
285                if done {
286                    return Ok(Some(step_num));
287                }
288            }
289            Err(e) => {
290                session.push(
291                    <<A::Msg as AgentMessage>::Role>::tool(),
292                    format!("Tool error: {}", e),
293                );
294            }
295        }
296    }
297
298    Ok(None) // continue
299}
300
301/// Run the SGR agent loop (non-streaming).
302///
303/// `trim → decide → check loop → execute → push results → repeat`
304///
305/// Returns the number of steps executed.
306pub async fn run_loop<A, F>(
307    agent: &A,
308    session: &mut Session<A::Msg>,
309    config: &LoopConfig,
310    mut on_event: F,
311) -> Result<usize, A::Error>
312where
313    A: SgrAgent,
314    F: FnMut(LoopEvent<'_, A::Action>) + Send,
315{
316    let mut detector = LoopDetector::new(config.loop_abort_threshold);
317
318    for step_num in 1..=config.max_steps {
319        let trimmed = session.trim();
320        if trimmed > 0 {
321            on_event(LoopEvent::Trimmed(trimmed));
322        }
323
324        on_event(LoopEvent::StepStart(step_num));
325
326        let decision = agent.decide(session.messages()).await?;
327
328        if let Some(final_step) = process_step(
329            agent,
330            session,
331            decision,
332            step_num,
333            &mut detector,
334            &mut on_event,
335        )
336        .await?
337        {
338            return Ok(final_step);
339        }
340    }
341
342    on_event(LoopEvent::MaxStepsReached(config.max_steps));
343    Ok(config.max_steps)
344}
345
346/// Run the SGR agent loop with streaming tokens.
347///
348/// Same as [`run_loop`] but calls `decide_stream` instead of `decide`,
349/// emitting `LoopEvent::StreamToken` during the decision phase.
350///
351/// Requires the agent to implement [`SgrAgentStream`].
352pub async fn run_loop_stream<A, F>(
353    agent: &A,
354    session: &mut Session<A::Msg>,
355    config: &LoopConfig,
356    mut on_event: F,
357) -> Result<usize, A::Error>
358where
359    A: SgrAgentStream,
360    F: FnMut(LoopEvent<'_, A::Action>) + Send,
361{
362    let mut detector = LoopDetector::new(config.loop_abort_threshold);
363
364    for step_num in 1..=config.max_steps {
365        let trimmed = session.trim();
366        if trimmed > 0 {
367            on_event(LoopEvent::Trimmed(trimmed));
368        }
369
370        on_event(LoopEvent::StepStart(step_num));
371
372        let decision = agent
373            .decide_stream(session.messages(), |token| {
374                on_event(LoopEvent::StreamToken(token));
375            })
376            .await?;
377
378        if let Some(final_step) = process_step(
379            agent,
380            session,
381            decision,
382            step_num,
383            &mut detector,
384            &mut on_event,
385        )
386        .await?
387        {
388            return Ok(final_step);
389        }
390    }
391
392    on_event(LoopEvent::MaxStepsReached(config.max_steps));
393    Ok(config.max_steps)
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::session::tests::{TestMsg, TestRole};
400    use std::sync::atomic::{AtomicUsize, Ordering};
401
402    struct MockAgent {
403        steps_before_done: AtomicUsize,
404    }
405
406    impl SgrAgent for MockAgent {
407        type Action = String;
408        type Msg = TestMsg;
409        type Error = String;
410
411        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
412            let remaining = self.steps_before_done.fetch_sub(1, Ordering::SeqCst);
413            if remaining <= 1 {
414                Ok(StepDecision {
415                    situation: "done".into(),
416                    task: vec![],
417                    completed: true,
418                    actions: vec![],
419                })
420            } else {
421                Ok(StepDecision {
422                    situation: format!("{} steps left", remaining - 1),
423                    task: vec!["do something".into()],
424                    completed: false,
425                    actions: vec![format!("action_{}", remaining)],
426                })
427            }
428        }
429
430        async fn execute(&self, action: &String) -> Result<ActionResult, String> {
431            Ok(ActionResult {
432                output: format!("result of {}", action),
433                done: false,
434            })
435        }
436
437        fn action_signature(action: &String) -> String {
438            action.clone()
439        }
440    }
441
442    #[tokio::test]
443    async fn loop_completes_after_n_steps() {
444        let dir = std::env::temp_dir().join("baml_loop_test_complete");
445        let _ = std::fs::remove_dir_all(&dir);
446        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
447        session.push(TestRole::User, "do something".into());
448
449        let agent = MockAgent {
450            steps_before_done: AtomicUsize::new(3),
451        };
452        let config = LoopConfig {
453            max_steps: 10,
454            loop_abort_threshold: 6,
455        };
456
457        let mut events = vec![];
458        let steps = run_loop(&agent, &mut session, &config, |event| match &event {
459            LoopEvent::StepStart(n) => events.push(format!("step:{}", n)),
460            LoopEvent::Completed => events.push("completed".into()),
461            LoopEvent::ActionDone(r) => events.push(format!("done:{}", r.output)),
462            _ => {}
463        })
464        .await
465        .unwrap();
466
467        assert_eq!(steps, 3);
468        assert!(events.contains(&"completed".to_string()));
469        assert!(session.len() > 1);
470
471        let _ = std::fs::remove_dir_all(&dir);
472    }
473
474    struct LoopyAgent;
475
476    impl SgrAgent for LoopyAgent {
477        type Action = String;
478        type Msg = TestMsg;
479        type Error = String;
480
481        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
482            Ok(StepDecision {
483                situation: "stuck".into(),
484                task: vec!["same thing again".into()],
485                completed: false,
486                actions: vec!["same_action".into()],
487            })
488        }
489
490        async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
491            Ok(ActionResult {
492                output: "same result".into(),
493                done: false,
494            })
495        }
496
497        fn action_signature(action: &String) -> String {
498            action.clone()
499        }
500    }
501
502    #[tokio::test]
503    async fn loop_detects_and_aborts() {
504        let dir = std::env::temp_dir().join("baml_loop_test_abort");
505        let _ = std::fs::remove_dir_all(&dir);
506        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
507        session.push(TestRole::User, "do something".into());
508
509        let config = LoopConfig {
510            max_steps: 20,
511            loop_abort_threshold: 4,
512        };
513
514        let mut got_warning = false;
515        let mut got_abort = false;
516        let steps = run_loop(&LoopyAgent, &mut session, &config, |event| match event {
517            LoopEvent::LoopWarning(_) => got_warning = true,
518            LoopEvent::LoopAbort(_) => got_abort = true,
519            _ => {}
520        })
521        .await
522        .unwrap();
523
524        assert!(got_warning);
525        assert!(got_abort);
526        assert!(steps <= 4);
527
528        let _ = std::fs::remove_dir_all(&dir);
529    }
530
531    // --- Streaming trait test ---
532
533    struct StreamingAgent;
534
535    impl SgrAgent for StreamingAgent {
536        type Action = String;
537        type Msg = TestMsg;
538        type Error = String;
539
540        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
541            Ok(StepDecision {
542                situation: "done".into(),
543                task: vec![],
544                completed: true,
545                actions: vec![],
546            })
547        }
548
549        async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
550            Ok(ActionResult {
551                output: "ok".into(),
552                done: false,
553            })
554        }
555
556        fn action_signature(action: &String) -> String {
557            action.clone()
558        }
559    }
560
561    impl SgrAgentStream for StreamingAgent {
562        #[allow(clippy::manual_async_fn)]
563        fn decide_stream<T>(
564            &self,
565            _messages: &[TestMsg],
566            mut on_token: T,
567        ) -> impl Future<Output = Result<StepDecision<String>, String>> + Send
568        where
569            T: FnMut(&str) + Send,
570        {
571            async move {
572                on_token("Thin");
573                on_token("king");
574                on_token("...");
575                Ok(StepDecision {
576                    situation: "done".into(),
577                    task: vec![],
578                    completed: true,
579                    actions: vec![],
580                })
581            }
582        }
583    }
584
585    #[tokio::test]
586    async fn streaming_tokens_emitted() {
587        let dir = std::env::temp_dir().join("baml_loop_test_stream");
588        let _ = std::fs::remove_dir_all(&dir);
589        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
590        session.push(TestRole::User, "hello".into());
591
592        let config = LoopConfig {
593            max_steps: 5,
594            loop_abort_threshold: 6,
595        };
596
597        let mut tokens = vec![];
598        let mut completed = false;
599        run_loop_stream(
600            &StreamingAgent,
601            &mut session,
602            &config,
603            |event| match event {
604                LoopEvent::StreamToken(t) => tokens.push(t.to_string()),
605                LoopEvent::Completed => completed = true,
606                _ => {}
607            },
608        )
609        .await
610        .unwrap();
611
612        assert!(completed);
613        assert_eq!(tokens, vec!["Thin", "king", "..."]);
614
615        let _ = std::fs::remove_dir_all(&dir);
616    }
617
618    // --- Empty actions guard test ---
619
620    struct EmptyActionsAgent {
621        call_count: AtomicUsize,
622    }
623
624    impl SgrAgent for EmptyActionsAgent {
625        type Action = String;
626        type Msg = TestMsg;
627        type Error = String;
628
629        async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
630            let n = self.call_count.fetch_add(1, Ordering::SeqCst);
631            if n < 2 {
632                // First 2 calls: empty actions (model forgot tool calls)
633                Ok(StepDecision {
634                    situation: "thinking...".into(),
635                    task: vec!["do something".into()],
636                    completed: false,
637                    actions: vec![],
638                })
639            } else {
640                // After nudge, model recovers
641                Ok(StepDecision {
642                    situation: "done".into(),
643                    task: vec![],
644                    completed: true,
645                    actions: vec![],
646                })
647            }
648        }
649
650        async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
651            Ok(ActionResult {
652                output: "ok".into(),
653                done: false,
654            })
655        }
656
657        fn action_signature(action: &String) -> String {
658            action.clone()
659        }
660    }
661
662    #[tokio::test]
663    async fn empty_actions_nudges_model() {
664        let dir = std::env::temp_dir().join("baml_loop_test_empty_actions");
665        let _ = std::fs::remove_dir_all(&dir);
666        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
667        session.push(TestRole::User, "do something".into());
668
669        let agent = EmptyActionsAgent {
670            call_count: AtomicUsize::new(0),
671        };
672        let config = LoopConfig {
673            max_steps: 10,
674            loop_abort_threshold: 6,
675        };
676
677        let mut completed = false;
678        let steps = run_loop(&agent, &mut session, &config, |event| {
679            if matches!(event, LoopEvent::Completed) {
680                completed = true;
681            }
682        })
683        .await
684        .unwrap();
685
686        assert!(completed, "agent should recover after nudge");
687        // 2 empty steps + 1 completed = 3 decide calls, but empty steps return Ok(None)
688        // so step counter advances: step 1 (empty), step 2 (empty), step 3 (completed)
689        assert_eq!(steps, 3);
690
691        // Session should contain nudge messages
692        let messages: Vec<&str> = session.messages().iter().map(|m| m.content()).collect();
693        let nudges = messages
694            .iter()
695            .filter(|m| m.contains("empty next_actions"))
696            .count();
697        assert_eq!(
698            nudges, 2,
699            "should have 2 nudge messages for 2 empty action steps"
700        );
701
702        let _ = std::fs::remove_dir_all(&dir);
703    }
704
705    #[tokio::test]
706    async fn empty_actions_aborts_after_threshold() {
707        let dir = std::env::temp_dir().join("baml_loop_test_empty_abort");
708        let _ = std::fs::remove_dir_all(&dir);
709        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
710        session.push(TestRole::User, "do something".into());
711
712        // Agent that always returns empty actions — never recovers
713        // Set threshold low so it aborts quickly
714        let config = LoopConfig {
715            max_steps: 20,
716            loop_abort_threshold: 4,
717        };
718
719        // Use a separate agent that never completes
720        struct NeverRecoverAgent;
721        impl SgrAgent for NeverRecoverAgent {
722            type Action = String;
723            type Msg = TestMsg;
724            type Error = String;
725            async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
726                Ok(StepDecision {
727                    situation: "stuck".into(),
728                    task: vec!["try again".into()],
729                    completed: false,
730                    actions: vec![],
731                })
732            }
733            async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
734                Ok(ActionResult {
735                    output: "ok".into(),
736                    done: false,
737                })
738            }
739            fn action_signature(action: &String) -> String {
740                action.clone()
741            }
742        }
743
744        let mut got_abort = false;
745        let _steps = run_loop(&NeverRecoverAgent, &mut session, &config, |event| {
746            if matches!(event, LoopEvent::LoopAbort(_)) {
747                got_abort = true;
748            }
749        })
750        .await
751        .unwrap();
752
753        assert!(got_abort, "should abort after repeated empty actions");
754
755        let _ = std::fs::remove_dir_all(&dir);
756    }
757
758    /// Non-streaming agent can also use run_loop (base trait only).
759    #[tokio::test]
760    async fn non_streaming_agent_works() {
761        let dir = std::env::temp_dir().join("baml_loop_test_nostream");
762        let _ = std::fs::remove_dir_all(&dir);
763        let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
764        session.push(TestRole::User, "hello".into());
765
766        let config = LoopConfig {
767            max_steps: 5,
768            loop_abort_threshold: 6,
769        };
770
771        // StreamingAgent also implements SgrAgent, so run_loop works
772        let mut completed = false;
773        run_loop(&StreamingAgent, &mut session, &config, |event| {
774            if matches!(event, LoopEvent::Completed) {
775                completed = true;
776            }
777        })
778        .await
779        .unwrap();
780
781        assert!(completed);
782        let _ = std::fs::remove_dir_all(&dir);
783    }
784}