1use crate::loop_detect::{normalize_signature, LoopDetector, LoopStatus};
2use crate::session::{AgentMessage, MessageRole, Session};
3use std::fmt;
4use std::future::Future;
5
6pub struct ActionResult {
8 pub output: String,
10 pub done: bool,
12}
13
14pub struct StepDecision<A> {
16 pub situation: String,
18 pub task: Vec<String>,
20 pub completed: bool,
22 pub actions: Vec<A>,
24}
25
26pub enum LoopEvent<'a, A> {
28 StepStart(usize),
30 Decision {
32 situation: &'a str,
33 task: &'a [String],
34 },
35 Completed,
37 ActionStart(&'a A),
39 ActionDone(&'a ActionResult),
41 LoopWarning(usize),
43 LoopAbort(usize),
45 Trimmed(usize),
47 MaxStepsReached(usize),
49 StreamToken(&'a str),
51}
52
53#[derive(Clone)]
55pub struct LoopConfig {
56 pub max_steps: usize,
57 pub loop_abort_threshold: usize,
58}
59
60impl Default for LoopConfig {
61 fn default() -> Self {
62 Self {
63 max_steps: 50,
64 loop_abort_threshold: 6,
65 }
66 }
67}
68
69pub trait SgrAgent {
80 type Action: Send + Sync;
82 type Msg: AgentMessage + Send + Sync;
84 type Error: fmt::Display + Send;
86
87 fn decide(
89 &self,
90 messages: &[Self::Msg],
91 ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send;
92
93 fn execute(
96 &self,
97 action: &Self::Action,
98 ) -> impl Future<Output = Result<ActionResult, Self::Error>> + Send;
99
100 fn action_signature(action: &Self::Action) -> String;
102
103 fn action_category(action: &Self::Action) -> String {
108 normalize_signature(&Self::action_signature(action))
109 }
110}
111
112pub trait SgrAgentStream: SgrAgent {
135 fn decide_stream<T>(
137 &self,
138 messages: &[Self::Msg],
139 on_token: T,
140 ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send
141 where
142 T: FnMut(&str) + Send;
143}
144
145pub async fn process_step<A, F>(
152 agent: &A,
153 session: &mut Session<A::Msg>,
154 decision: StepDecision<A::Action>,
155 step_num: usize,
156 detector: &mut LoopDetector,
157 on_event: &mut F,
158) -> Result<Option<usize>, A::Error>
159where
160 A: SgrAgent,
161 F: FnMut(LoopEvent<'_, A::Action>) + Send,
162{
163 on_event(LoopEvent::Decision {
164 situation: &decision.situation,
165 task: &decision.task,
166 });
167
168 if decision.completed {
169 on_event(LoopEvent::Completed);
170 return Ok(Some(step_num));
171 }
172
173 let sig = decision
175 .actions
176 .iter()
177 .map(A::action_signature)
178 .collect::<Vec<_>>()
179 .join("|");
180
181 let category = decision
182 .actions
183 .iter()
184 .map(A::action_category)
185 .collect::<Vec<_>>()
186 .join("|");
187
188 if decision.actions.is_empty() {
190 match detector.check(&sig) {
191 LoopStatus::Abort(n) => {
192 on_event(LoopEvent::LoopAbort(n));
193 session.push(
194 <<A::Msg as AgentMessage>::Role>::system(),
195 "SYSTEM: Repeatedly returning empty actions. Session terminated.".into(),
196 );
197 return Ok(Some(step_num));
198 }
199 _ => {
200 session.push(
201 <<A::Msg as AgentMessage>::Role>::system(),
202 "SYSTEM: You returned empty next_actions. You MUST emit at least one tool call \
203 in next_actions array. Look at the TOOLS section and pick the right tool for \
204 your current phase.".into(),
205 );
206 return Ok(None);
207 }
208 }
209 }
210
211 match detector.check_with_category(&sig, &category) {
213 LoopStatus::Abort(n) => {
214 on_event(LoopEvent::LoopAbort(n));
215 session.push(
216 <<A::Msg as AgentMessage>::Role>::system(),
217 format!(
218 "SYSTEM: Detected {} repetitions of the same action (category: {}). \
219 The result will not change. Session terminated.",
220 n, category
221 ),
222 );
223 return Ok(Some(step_num));
224 }
225 LoopStatus::Warning(n) => {
226 on_event(LoopEvent::LoopWarning(n));
227 session.push(
228 <<A::Msg as AgentMessage>::Role>::system(),
229 format!(
230 "SYSTEM: You have repeated the same action {} times (category: {}). \
231 The result is DEFINITIVE. Do NOT retry — either proceed to the next \
232 step or use FinishTaskTool to report completion.",
233 n, category
234 ),
235 );
236 }
237 LoopStatus::Ok => {}
238 }
239
240 for action in &decision.actions {
242 on_event(LoopEvent::ActionStart(action));
243
244 match agent.execute(action).await {
245 Ok(result) => {
246 session.push(
247 <<A::Msg as AgentMessage>::Role>::tool(),
248 result.output.clone(),
249 );
250
251 let done = result.done;
252 on_event(LoopEvent::ActionDone(&result));
253
254 match detector.record_output(&result.output) {
256 LoopStatus::Abort(n) => {
257 on_event(LoopEvent::LoopAbort(n));
258 session.push(
259 <<A::Msg as AgentMessage>::Role>::system(),
260 format!(
261 "SYSTEM: Tool returned identical output {} times. The result is \
262 DEFINITIVE and will not change. If searching found nothing, \
263 nothing exists. Accept the result and proceed to the next task \
264 step or use FinishTaskTool.",
265 n
266 ),
267 );
268 return Ok(Some(step_num));
269 }
270 LoopStatus::Warning(n) => {
271 on_event(LoopEvent::LoopWarning(n));
272 session.push(
273 <<A::Msg as AgentMessage>::Role>::system(),
274 format!(
275 "SYSTEM: Same tool output {} times in a row. The result will \
276 not change — accept it and move forward. Do NOT retry the \
277 same operation.",
278 n
279 ),
280 );
281 }
282 LoopStatus::Ok => {}
283 }
284
285 if done {
286 return Ok(Some(step_num));
287 }
288 }
289 Err(e) => {
290 session.push(
291 <<A::Msg as AgentMessage>::Role>::tool(),
292 format!("Tool error: {}", e),
293 );
294 }
295 }
296 }
297
298 Ok(None) }
300
301pub async fn run_loop<A, F>(
307 agent: &A,
308 session: &mut Session<A::Msg>,
309 config: &LoopConfig,
310 mut on_event: F,
311) -> Result<usize, A::Error>
312where
313 A: SgrAgent,
314 F: FnMut(LoopEvent<'_, A::Action>) + Send,
315{
316 let mut detector = LoopDetector::new(config.loop_abort_threshold);
317
318 for step_num in 1..=config.max_steps {
319 let trimmed = session.trim();
320 if trimmed > 0 {
321 on_event(LoopEvent::Trimmed(trimmed));
322 }
323
324 on_event(LoopEvent::StepStart(step_num));
325
326 let decision = agent.decide(session.messages()).await?;
327
328 if let Some(final_step) = process_step(
329 agent,
330 session,
331 decision,
332 step_num,
333 &mut detector,
334 &mut on_event,
335 )
336 .await?
337 {
338 return Ok(final_step);
339 }
340 }
341
342 on_event(LoopEvent::MaxStepsReached(config.max_steps));
343 Ok(config.max_steps)
344}
345
346pub async fn run_loop_stream<A, F>(
353 agent: &A,
354 session: &mut Session<A::Msg>,
355 config: &LoopConfig,
356 mut on_event: F,
357) -> Result<usize, A::Error>
358where
359 A: SgrAgentStream,
360 F: FnMut(LoopEvent<'_, A::Action>) + Send,
361{
362 let mut detector = LoopDetector::new(config.loop_abort_threshold);
363
364 for step_num in 1..=config.max_steps {
365 let trimmed = session.trim();
366 if trimmed > 0 {
367 on_event(LoopEvent::Trimmed(trimmed));
368 }
369
370 on_event(LoopEvent::StepStart(step_num));
371
372 let decision = agent
373 .decide_stream(session.messages(), |token| {
374 on_event(LoopEvent::StreamToken(token));
375 })
376 .await?;
377
378 if let Some(final_step) = process_step(
379 agent,
380 session,
381 decision,
382 step_num,
383 &mut detector,
384 &mut on_event,
385 )
386 .await?
387 {
388 return Ok(final_step);
389 }
390 }
391
392 on_event(LoopEvent::MaxStepsReached(config.max_steps));
393 Ok(config.max_steps)
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::session::tests::{TestMsg, TestRole};
400 use std::sync::atomic::{AtomicUsize, Ordering};
401
402 struct MockAgent {
403 steps_before_done: AtomicUsize,
404 }
405
406 impl SgrAgent for MockAgent {
407 type Action = String;
408 type Msg = TestMsg;
409 type Error = String;
410
411 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
412 let remaining = self.steps_before_done.fetch_sub(1, Ordering::SeqCst);
413 if remaining <= 1 {
414 Ok(StepDecision {
415 situation: "done".into(),
416 task: vec![],
417 completed: true,
418 actions: vec![],
419 })
420 } else {
421 Ok(StepDecision {
422 situation: format!("{} steps left", remaining - 1),
423 task: vec!["do something".into()],
424 completed: false,
425 actions: vec![format!("action_{}", remaining)],
426 })
427 }
428 }
429
430 async fn execute(&self, action: &String) -> Result<ActionResult, String> {
431 Ok(ActionResult {
432 output: format!("result of {}", action),
433 done: false,
434 })
435 }
436
437 fn action_signature(action: &String) -> String {
438 action.clone()
439 }
440 }
441
442 #[tokio::test]
443 async fn loop_completes_after_n_steps() {
444 let dir = std::env::temp_dir().join("baml_loop_test_complete");
445 let _ = std::fs::remove_dir_all(&dir);
446 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
447 session.push(TestRole::User, "do something".into());
448
449 let agent = MockAgent {
450 steps_before_done: AtomicUsize::new(3),
451 };
452 let config = LoopConfig {
453 max_steps: 10,
454 loop_abort_threshold: 6,
455 };
456
457 let mut events = vec![];
458 let steps = run_loop(&agent, &mut session, &config, |event| match &event {
459 LoopEvent::StepStart(n) => events.push(format!("step:{}", n)),
460 LoopEvent::Completed => events.push("completed".into()),
461 LoopEvent::ActionDone(r) => events.push(format!("done:{}", r.output)),
462 _ => {}
463 })
464 .await
465 .unwrap();
466
467 assert_eq!(steps, 3);
468 assert!(events.contains(&"completed".to_string()));
469 assert!(session.len() > 1);
470
471 let _ = std::fs::remove_dir_all(&dir);
472 }
473
474 struct LoopyAgent;
475
476 impl SgrAgent for LoopyAgent {
477 type Action = String;
478 type Msg = TestMsg;
479 type Error = String;
480
481 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
482 Ok(StepDecision {
483 situation: "stuck".into(),
484 task: vec!["same thing again".into()],
485 completed: false,
486 actions: vec!["same_action".into()],
487 })
488 }
489
490 async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
491 Ok(ActionResult {
492 output: "same result".into(),
493 done: false,
494 })
495 }
496
497 fn action_signature(action: &String) -> String {
498 action.clone()
499 }
500 }
501
502 #[tokio::test]
503 async fn loop_detects_and_aborts() {
504 let dir = std::env::temp_dir().join("baml_loop_test_abort");
505 let _ = std::fs::remove_dir_all(&dir);
506 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
507 session.push(TestRole::User, "do something".into());
508
509 let config = LoopConfig {
510 max_steps: 20,
511 loop_abort_threshold: 4,
512 };
513
514 let mut got_warning = false;
515 let mut got_abort = false;
516 let steps = run_loop(&LoopyAgent, &mut session, &config, |event| match event {
517 LoopEvent::LoopWarning(_) => got_warning = true,
518 LoopEvent::LoopAbort(_) => got_abort = true,
519 _ => {}
520 })
521 .await
522 .unwrap();
523
524 assert!(got_warning);
525 assert!(got_abort);
526 assert!(steps <= 4);
527
528 let _ = std::fs::remove_dir_all(&dir);
529 }
530
531 struct StreamingAgent;
534
535 impl SgrAgent for StreamingAgent {
536 type Action = String;
537 type Msg = TestMsg;
538 type Error = String;
539
540 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
541 Ok(StepDecision {
542 situation: "done".into(),
543 task: vec![],
544 completed: true,
545 actions: vec![],
546 })
547 }
548
549 async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
550 Ok(ActionResult {
551 output: "ok".into(),
552 done: false,
553 })
554 }
555
556 fn action_signature(action: &String) -> String {
557 action.clone()
558 }
559 }
560
561 impl SgrAgentStream for StreamingAgent {
562 #[allow(clippy::manual_async_fn)]
563 fn decide_stream<T>(
564 &self,
565 _messages: &[TestMsg],
566 mut on_token: T,
567 ) -> impl Future<Output = Result<StepDecision<String>, String>> + Send
568 where
569 T: FnMut(&str) + Send,
570 {
571 async move {
572 on_token("Thin");
573 on_token("king");
574 on_token("...");
575 Ok(StepDecision {
576 situation: "done".into(),
577 task: vec![],
578 completed: true,
579 actions: vec![],
580 })
581 }
582 }
583 }
584
585 #[tokio::test]
586 async fn streaming_tokens_emitted() {
587 let dir = std::env::temp_dir().join("baml_loop_test_stream");
588 let _ = std::fs::remove_dir_all(&dir);
589 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
590 session.push(TestRole::User, "hello".into());
591
592 let config = LoopConfig {
593 max_steps: 5,
594 loop_abort_threshold: 6,
595 };
596
597 let mut tokens = vec![];
598 let mut completed = false;
599 run_loop_stream(
600 &StreamingAgent,
601 &mut session,
602 &config,
603 |event| match event {
604 LoopEvent::StreamToken(t) => tokens.push(t.to_string()),
605 LoopEvent::Completed => completed = true,
606 _ => {}
607 },
608 )
609 .await
610 .unwrap();
611
612 assert!(completed);
613 assert_eq!(tokens, vec!["Thin", "king", "..."]);
614
615 let _ = std::fs::remove_dir_all(&dir);
616 }
617
618 struct EmptyActionsAgent {
621 call_count: AtomicUsize,
622 }
623
624 impl SgrAgent for EmptyActionsAgent {
625 type Action = String;
626 type Msg = TestMsg;
627 type Error = String;
628
629 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
630 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
631 if n < 2 {
632 Ok(StepDecision {
634 situation: "thinking...".into(),
635 task: vec!["do something".into()],
636 completed: false,
637 actions: vec![],
638 })
639 } else {
640 Ok(StepDecision {
642 situation: "done".into(),
643 task: vec![],
644 completed: true,
645 actions: vec![],
646 })
647 }
648 }
649
650 async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
651 Ok(ActionResult {
652 output: "ok".into(),
653 done: false,
654 })
655 }
656
657 fn action_signature(action: &String) -> String {
658 action.clone()
659 }
660 }
661
662 #[tokio::test]
663 async fn empty_actions_nudges_model() {
664 let dir = std::env::temp_dir().join("baml_loop_test_empty_actions");
665 let _ = std::fs::remove_dir_all(&dir);
666 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
667 session.push(TestRole::User, "do something".into());
668
669 let agent = EmptyActionsAgent {
670 call_count: AtomicUsize::new(0),
671 };
672 let config = LoopConfig {
673 max_steps: 10,
674 loop_abort_threshold: 6,
675 };
676
677 let mut completed = false;
678 let steps = run_loop(&agent, &mut session, &config, |event| {
679 if matches!(event, LoopEvent::Completed) {
680 completed = true;
681 }
682 })
683 .await
684 .unwrap();
685
686 assert!(completed, "agent should recover after nudge");
687 assert_eq!(steps, 3);
690
691 let messages: Vec<&str> = session.messages().iter().map(|m| m.content()).collect();
693 let nudges = messages
694 .iter()
695 .filter(|m| m.contains("empty next_actions"))
696 .count();
697 assert_eq!(
698 nudges, 2,
699 "should have 2 nudge messages for 2 empty action steps"
700 );
701
702 let _ = std::fs::remove_dir_all(&dir);
703 }
704
705 #[tokio::test]
706 async fn empty_actions_aborts_after_threshold() {
707 let dir = std::env::temp_dir().join("baml_loop_test_empty_abort");
708 let _ = std::fs::remove_dir_all(&dir);
709 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
710 session.push(TestRole::User, "do something".into());
711
712 let config = LoopConfig {
715 max_steps: 20,
716 loop_abort_threshold: 4,
717 };
718
719 struct NeverRecoverAgent;
721 impl SgrAgent for NeverRecoverAgent {
722 type Action = String;
723 type Msg = TestMsg;
724 type Error = String;
725 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
726 Ok(StepDecision {
727 situation: "stuck".into(),
728 task: vec!["try again".into()],
729 completed: false,
730 actions: vec![],
731 })
732 }
733 async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
734 Ok(ActionResult {
735 output: "ok".into(),
736 done: false,
737 })
738 }
739 fn action_signature(action: &String) -> String {
740 action.clone()
741 }
742 }
743
744 let mut got_abort = false;
745 let _steps = run_loop(&NeverRecoverAgent, &mut session, &config, |event| {
746 if matches!(event, LoopEvent::LoopAbort(_)) {
747 got_abort = true;
748 }
749 })
750 .await
751 .unwrap();
752
753 assert!(got_abort, "should abort after repeated empty actions");
754
755 let _ = std::fs::remove_dir_all(&dir);
756 }
757
758 #[tokio::test]
760 async fn non_streaming_agent_works() {
761 let dir = std::env::temp_dir().join("baml_loop_test_nostream");
762 let _ = std::fs::remove_dir_all(&dir);
763 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60).unwrap();
764 session.push(TestRole::User, "hello".into());
765
766 let config = LoopConfig {
767 max_steps: 5,
768 loop_abort_threshold: 6,
769 };
770
771 let mut completed = false;
773 run_loop(&StreamingAgent, &mut session, &config, |event| {
774 if matches!(event, LoopEvent::Completed) {
775 completed = true;
776 }
777 })
778 .await
779 .unwrap();
780
781 assert!(completed);
782 let _ = std::fs::remove_dir_all(&dir);
783 }
784}