1use crate::loop_detect::{LoopDetector, LoopStatus};
2use crate::session::{AgentMessage, MessageRole, Session};
3use std::fmt;
4use std::future::Future;
5
6pub struct ActionResult {
8 pub output: String,
10 pub done: bool,
12}
13
14pub struct StepDecision<A> {
16 pub state: String,
18 pub plan: Vec<String>,
20 pub completed: bool,
22 pub actions: Vec<A>,
24}
25
26pub enum LoopEvent<'a, A> {
28 StepStart(usize),
30 Decision { state: &'a str, plan: &'a [String] },
32 Completed,
34 ActionStart(&'a A),
36 ActionDone(&'a ActionResult),
38 LoopWarning(usize),
40 LoopAbort(usize),
42 Trimmed(usize),
44 MaxStepsReached(usize),
46 StreamToken(&'a str),
48}
49
50pub struct LoopConfig {
52 pub max_steps: usize,
53 pub loop_abort_threshold: usize,
54}
55
56impl Default for LoopConfig {
57 fn default() -> Self {
58 Self {
59 max_steps: 50,
60 loop_abort_threshold: 6,
61 }
62 }
63}
64
65pub trait SgrAgent {
76 type Action;
78 type Msg: AgentMessage;
80 type Error: fmt::Display;
82
83 fn decide(
85 &self,
86 messages: &[Self::Msg],
87 ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send;
88
89 fn execute(
92 &self,
93 action: &Self::Action,
94 ) -> impl Future<Output = Result<ActionResult, Self::Error>> + Send;
95
96 fn action_signature(action: &Self::Action) -> String;
98}
99
100pub trait SgrAgentStream: SgrAgent {
123 fn decide_stream<T>(
125 &self,
126 messages: &[Self::Msg],
127 on_token: T,
128 ) -> impl Future<Output = Result<StepDecision<Self::Action>, Self::Error>> + Send
129 where
130 T: FnMut(&str) + Send;
131}
132
133pub async fn process_step<A, F>(
140 agent: &A,
141 session: &mut Session<A::Msg>,
142 decision: StepDecision<A::Action>,
143 step_num: usize,
144 detector: &mut LoopDetector,
145 on_event: &mut F,
146) -> Result<Option<usize>, A::Error>
147where
148 A: SgrAgent,
149 F: FnMut(LoopEvent<'_, A::Action>) + Send,
150{
151 on_event(LoopEvent::Decision {
152 state: &decision.state,
153 plan: &decision.plan,
154 });
155
156 if decision.completed {
157 on_event(LoopEvent::Completed);
158 return Ok(Some(step_num));
159 }
160
161 let sig = decision
163 .actions
164 .iter()
165 .map(A::action_signature)
166 .collect::<Vec<_>>()
167 .join("|");
168
169 match detector.check(&sig) {
170 LoopStatus::Abort(n) => {
171 on_event(LoopEvent::LoopAbort(n));
172 session.push(
173 <<A::Msg as AgentMessage>::Role>::system(),
174 "SYSTEM: You have been repeating the same action. Session terminated.".into(),
175 );
176 return Ok(Some(step_num));
177 }
178 LoopStatus::Warning(n) => {
179 on_event(LoopEvent::LoopWarning(n));
180 session.push(
181 <<A::Msg as AgentMessage>::Role>::system(),
182 "SYSTEM: You are repeating the same action. Try a different approach or report completion.".into(),
183 );
184 }
185 LoopStatus::Ok => {}
186 }
187
188 for action in &decision.actions {
190 on_event(LoopEvent::ActionStart(action));
191
192 match agent.execute(action).await {
193 Ok(result) => {
194 session.push(
195 <<A::Msg as AgentMessage>::Role>::tool(),
196 result.output.clone(),
197 );
198 let done = result.done;
199 on_event(LoopEvent::ActionDone(&result));
200 if done {
201 return Ok(Some(step_num));
202 }
203 }
204 Err(e) => {
205 session.push(
206 <<A::Msg as AgentMessage>::Role>::tool(),
207 format!("Tool error: {}", e),
208 );
209 }
210 }
211 }
212
213 Ok(None) }
215
216pub async fn run_loop<A, F>(
222 agent: &A,
223 session: &mut Session<A::Msg>,
224 config: &LoopConfig,
225 mut on_event: F,
226) -> Result<usize, A::Error>
227where
228 A: SgrAgent,
229 F: FnMut(LoopEvent<'_, A::Action>) + Send,
230{
231 let mut detector = LoopDetector::new(config.loop_abort_threshold);
232
233 for step_num in 1..=config.max_steps {
234 let trimmed = session.trim();
235 if trimmed > 0 {
236 on_event(LoopEvent::Trimmed(trimmed));
237 }
238
239 on_event(LoopEvent::StepStart(step_num));
240
241 let decision = agent.decide(session.messages()).await?;
242
243 if let Some(final_step) = process_step(agent, session, decision, step_num, &mut detector, &mut on_event).await? {
244 return Ok(final_step);
245 }
246 }
247
248 on_event(LoopEvent::MaxStepsReached(config.max_steps));
249 Ok(config.max_steps)
250}
251
252pub async fn run_loop_stream<A, F>(
259 agent: &A,
260 session: &mut Session<A::Msg>,
261 config: &LoopConfig,
262 mut on_event: F,
263) -> Result<usize, A::Error>
264where
265 A: SgrAgentStream,
266 F: FnMut(LoopEvent<'_, A::Action>) + Send,
267{
268 let mut detector = LoopDetector::new(config.loop_abort_threshold);
269
270 for step_num in 1..=config.max_steps {
271 let trimmed = session.trim();
272 if trimmed > 0 {
273 on_event(LoopEvent::Trimmed(trimmed));
274 }
275
276 on_event(LoopEvent::StepStart(step_num));
277
278 let decision = agent.decide_stream(session.messages(), |token| {
279 on_event(LoopEvent::StreamToken(token));
280 }).await?;
281
282 if let Some(final_step) = process_step(agent, session, decision, step_num, &mut detector, &mut on_event).await? {
283 return Ok(final_step);
284 }
285 }
286
287 on_event(LoopEvent::MaxStepsReached(config.max_steps));
288 Ok(config.max_steps)
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::session::tests::{TestMsg, TestRole};
295 use std::sync::atomic::{AtomicUsize, Ordering};
296
297 struct MockAgent {
298 steps_before_done: AtomicUsize,
299 }
300
301 impl SgrAgent for MockAgent {
302 type Action = String;
303 type Msg = TestMsg;
304 type Error = String;
305
306 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
307 let remaining = self.steps_before_done.fetch_sub(1, Ordering::SeqCst);
308 if remaining <= 1 {
309 Ok(StepDecision {
310 state: "done".into(),
311 plan: vec![],
312 completed: true,
313 actions: vec![],
314 })
315 } else {
316 Ok(StepDecision {
317 state: format!("{} steps left", remaining - 1),
318 plan: vec!["do something".into()],
319 completed: false,
320 actions: vec![format!("action_{}", remaining)],
321 })
322 }
323 }
324
325 async fn execute(&self, action: &String) -> Result<ActionResult, String> {
326 Ok(ActionResult {
327 output: format!("result of {}", action),
328 done: false,
329 })
330 }
331
332 fn action_signature(action: &String) -> String {
333 action.clone()
334 }
335 }
336
337 #[tokio::test]
338 async fn loop_completes_after_n_steps() {
339 let dir = std::env::temp_dir().join("baml_loop_test_complete");
340 let _ = std::fs::remove_dir_all(&dir);
341 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
342 session.push(TestRole::User, "do something".into());
343
344 let agent = MockAgent {
345 steps_before_done: AtomicUsize::new(3),
346 };
347 let config = LoopConfig { max_steps: 10, loop_abort_threshold: 6 };
348
349 let mut events = vec![];
350 let steps = run_loop(&agent, &mut session, &config, |event| {
351 match &event {
352 LoopEvent::StepStart(n) => events.push(format!("step:{}", n)),
353 LoopEvent::Completed => events.push("completed".into()),
354 LoopEvent::ActionDone(r) => events.push(format!("done:{}", r.output)),
355 _ => {}
356 }
357 }).await.unwrap();
358
359 assert_eq!(steps, 3);
360 assert!(events.contains(&"completed".to_string()));
361 assert!(session.len() > 1);
362
363 let _ = std::fs::remove_dir_all(&dir);
364 }
365
366 struct LoopyAgent;
367
368 impl SgrAgent for LoopyAgent {
369 type Action = String;
370 type Msg = TestMsg;
371 type Error = String;
372
373 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
374 Ok(StepDecision {
375 state: "stuck".into(),
376 plan: vec!["same thing again".into()],
377 completed: false,
378 actions: vec!["same_action".into()],
379 })
380 }
381
382 async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
383 Ok(ActionResult {
384 output: "same result".into(),
385 done: false,
386 })
387 }
388
389 fn action_signature(action: &String) -> String {
390 action.clone()
391 }
392 }
393
394 #[tokio::test]
395 async fn loop_detects_and_aborts() {
396 let dir = std::env::temp_dir().join("baml_loop_test_abort");
397 let _ = std::fs::remove_dir_all(&dir);
398 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
399 session.push(TestRole::User, "do something".into());
400
401 let config = LoopConfig { max_steps: 20, loop_abort_threshold: 4 };
402
403 let mut got_warning = false;
404 let mut got_abort = false;
405 let steps = run_loop(&LoopyAgent, &mut session, &config, |event| {
406 match event {
407 LoopEvent::LoopWarning(_) => got_warning = true,
408 LoopEvent::LoopAbort(_) => got_abort = true,
409 _ => {}
410 }
411 }).await.unwrap();
412
413 assert!(got_warning);
414 assert!(got_abort);
415 assert!(steps <= 4);
416
417 let _ = std::fs::remove_dir_all(&dir);
418 }
419
420 struct StreamingAgent;
423
424 impl SgrAgent for StreamingAgent {
425 type Action = String;
426 type Msg = TestMsg;
427 type Error = String;
428
429 async fn decide(&self, _messages: &[TestMsg]) -> Result<StepDecision<String>, String> {
430 Ok(StepDecision {
431 state: "done".into(),
432 plan: vec![],
433 completed: true,
434 actions: vec![],
435 })
436 }
437
438 async fn execute(&self, _action: &String) -> Result<ActionResult, String> {
439 Ok(ActionResult { output: "ok".into(), done: false })
440 }
441
442 fn action_signature(action: &String) -> String {
443 action.clone()
444 }
445 }
446
447 impl SgrAgentStream for StreamingAgent {
448 fn decide_stream<T>(
449 &self,
450 _messages: &[TestMsg],
451 mut on_token: T,
452 ) -> impl Future<Output = Result<StepDecision<String>, String>> + Send
453 where
454 T: FnMut(&str) + Send,
455 {
456 async move {
457 on_token("Thin");
458 on_token("king");
459 on_token("...");
460 Ok(StepDecision {
461 state: "done".into(),
462 plan: vec![],
463 completed: true,
464 actions: vec![],
465 })
466 }
467 }
468 }
469
470 #[tokio::test]
471 async fn streaming_tokens_emitted() {
472 let dir = std::env::temp_dir().join("baml_loop_test_stream");
473 let _ = std::fs::remove_dir_all(&dir);
474 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
475 session.push(TestRole::User, "hello".into());
476
477 let config = LoopConfig { max_steps: 5, loop_abort_threshold: 6 };
478
479 let mut tokens = vec![];
480 let mut completed = false;
481 run_loop_stream(&StreamingAgent, &mut session, &config, |event| {
482 match event {
483 LoopEvent::StreamToken(t) => tokens.push(t.to_string()),
484 LoopEvent::Completed => completed = true,
485 _ => {}
486 }
487 }).await.unwrap();
488
489 assert!(completed);
490 assert_eq!(tokens, vec!["Thin", "king", "..."]);
491
492 let _ = std::fs::remove_dir_all(&dir);
493 }
494
495 #[tokio::test]
497 async fn non_streaming_agent_works() {
498 let dir = std::env::temp_dir().join("baml_loop_test_nostream");
499 let _ = std::fs::remove_dir_all(&dir);
500 let mut session = Session::<TestMsg>::new(dir.to_str().unwrap(), 60);
501 session.push(TestRole::User, "hello".into());
502
503 let config = LoopConfig { max_steps: 5, loop_abort_threshold: 6 };
504
505 let mut completed = false;
507 run_loop(&StreamingAgent, &mut session, &config, |event| {
508 if matches!(event, LoopEvent::Completed) { completed = true; }
509 }).await.unwrap();
510
511 assert!(completed);
512 let _ = std::fs::remove_dir_all(&dir);
513 }
514}