1use crate::agent::{Agent, AgentError, Decision};
6use crate::context::{AgentContext, AgentState};
7use crate::registry::ToolRegistry;
8use crate::types::{Message, SgrError};
9use futures::future::join_all;
10use std::collections::HashMap;
11
12const MAX_PARSE_RETRIES: usize = 3;
14
15fn is_recoverable_error(e: &AgentError) -> bool {
17 matches!(
18 e,
19 AgentError::Llm(SgrError::Json(_))
20 | AgentError::Llm(SgrError::EmptyResponse)
21 | AgentError::Llm(SgrError::Schema(_))
22 )
23}
24
25#[derive(Debug, Clone)]
27pub struct LoopConfig {
28 pub max_steps: usize,
30 pub loop_abort_threshold: usize,
32 pub max_messages: usize,
35 pub auto_complete_threshold: usize,
37}
38
39impl Default for LoopConfig {
40 fn default() -> Self {
41 Self {
42 max_steps: 50,
43 loop_abort_threshold: 6,
44 max_messages: 80,
45 auto_complete_threshold: 3,
46 }
47 }
48}
49
50#[derive(Debug)]
52pub enum LoopEvent {
53 StepStart {
54 step: usize,
55 },
56 Decision(Decision),
57 ToolResult {
58 name: String,
59 output: String,
60 },
61 Completed {
62 steps: usize,
63 },
64 LoopDetected {
65 count: usize,
66 },
67 Error(AgentError),
68 WaitingForInput {
70 question: String,
71 tool_call_id: String,
72 },
73}
74
75pub async fn run_loop(
79 agent: &dyn Agent,
80 tools: &ToolRegistry,
81 ctx: &mut AgentContext,
82 messages: &mut Vec<Message>,
83 config: &LoopConfig,
84 mut on_event: impl FnMut(LoopEvent),
85) -> Result<usize, AgentError> {
86 let mut detector = LoopDetector::new(config.loop_abort_threshold);
87 let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
88 let mut parse_retries: usize = 0;
89
90 for step in 1..=config.max_steps {
91 if config.max_messages > 0 && messages.len() > config.max_messages {
93 trim_messages(messages, config.max_messages);
94 }
95 ctx.iteration = step;
96 on_event(LoopEvent::StepStart { step });
97
98 agent.prepare_context(ctx, messages);
100
101 let active_tool_names = agent.prepare_tools(ctx, tools);
103 let filtered_tools = if active_tool_names.len() == tools.list().len() {
104 None } else {
106 Some(active_tool_names)
107 };
108
109 let effective_tools = if let Some(ref names) = filtered_tools {
111 &tools.filter(names)
112 } else {
113 tools
114 };
115
116 let decision = match agent.decide(messages, effective_tools).await {
117 Ok(d) => {
118 parse_retries = 0;
119 d
120 }
121 Err(e) if is_recoverable_error(&e) => {
122 parse_retries += 1;
123 if parse_retries > MAX_PARSE_RETRIES {
124 return Err(e);
125 }
126 let err_msg = format!(
127 "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
128 parse_retries, MAX_PARSE_RETRIES, e
129 );
130 on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
131 err_msg.clone(),
132 ))));
133 messages.push(Message::user(&err_msg));
134 continue;
135 }
136 Err(e) => return Err(e),
137 };
138 on_event(LoopEvent::Decision(decision.clone()));
139
140 if completion_detector.check(&decision) {
142 ctx.state = AgentState::Completed;
143 if !decision.situation.is_empty() {
144 messages.push(Message::assistant(&decision.situation));
145 }
146 on_event(LoopEvent::Completed { steps: step });
147 return Ok(step);
148 }
149
150 if decision.completed || decision.tool_calls.is_empty() {
151 ctx.state = AgentState::Completed;
152 if !decision.situation.is_empty() {
154 messages.push(Message::assistant(&decision.situation));
155 }
156 on_event(LoopEvent::Completed { steps: step });
157 return Ok(step);
158 }
159
160 let sig: Vec<String> = decision
162 .tool_calls
163 .iter()
164 .map(|tc| tc.name.clone())
165 .collect();
166 match detector.check(&sig) {
167 LoopCheckResult::Abort => {
168 ctx.state = AgentState::Failed;
169 on_event(LoopEvent::LoopDetected {
170 count: detector.consecutive,
171 });
172 return Err(AgentError::LoopDetected(detector.consecutive));
173 }
174 LoopCheckResult::Tier2Warning(dominant_tool) => {
175 let hint = format!(
177 "LOOP WARNING: You are repeatedly using '{}' without making progress. \
178 Try a different approach: re-read the file with read_file to see current contents, \
179 use write_file instead of edit_file, or break the problem into smaller steps.",
180 dominant_tool
181 );
182 messages.push(Message::system(&hint));
183 }
184 LoopCheckResult::Ok => {}
185 }
186
187 messages.push(Message::assistant_with_tool_calls(
189 &decision.situation,
190 decision.tool_calls.clone(),
191 ));
192
193 let mut step_outputs: Vec<String> = Vec::new();
195 let mut early_done = false;
196
197 let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
199 .tool_calls
200 .iter()
201 .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
202
203 if !ro_calls.is_empty() {
205 let futs: Vec<_> = ro_calls
206 .iter()
207 .map(|tc| {
208 let tool = tools.get(&tc.name).unwrap();
209 let args = tc.arguments.clone();
210 let name = tc.name.clone();
211 let id = tc.id.clone();
212 async move { (id, name, tool.execute_readonly(args).await) }
213 })
214 .collect();
215
216 for (id, name, result) in join_all(futs).await {
217 match result {
218 Ok(output) => {
219 on_event(LoopEvent::ToolResult {
220 name: name.clone(),
221 output: output.content.clone(),
222 });
223 step_outputs.push(output.content.clone());
224 agent.after_action(ctx, &name, &output.content);
225 if output.waiting {
226 ctx.state = AgentState::WaitingInput;
227 on_event(LoopEvent::WaitingForInput {
228 question: output.content.clone(),
229 tool_call_id: id.clone(),
230 });
231 messages.push(Message::tool(&id, "[waiting for user input]"));
232 ctx.state = AgentState::Running;
233 } else {
234 messages.push(Message::tool(&id, &output.content));
235 }
236 if output.done {
237 early_done = true;
238 }
239 }
240 Err(e) => {
241 let err_msg = format!("Tool error: {}", e);
242 step_outputs.push(err_msg.clone());
243 messages.push(Message::tool(&id, &err_msg));
244 agent.after_action(ctx, &name, &err_msg);
245 on_event(LoopEvent::ToolResult {
246 name,
247 output: err_msg,
248 });
249 }
250 }
251 }
252 if early_done && rw_calls.is_empty() {
253 ctx.state = AgentState::Completed;
255 on_event(LoopEvent::Completed { steps: step });
256 return Ok(step);
257 }
258 }
259
260 for tc in &rw_calls {
262 if let Some(tool) = tools.get(&tc.name) {
263 match tool.execute(tc.arguments.clone(), ctx).await {
264 Ok(output) => {
265 on_event(LoopEvent::ToolResult {
266 name: tc.name.clone(),
267 output: output.content.clone(),
268 });
269 step_outputs.push(output.content.clone());
270 agent.after_action(ctx, &tc.name, &output.content);
271 if output.waiting {
272 ctx.state = AgentState::WaitingInput;
273 on_event(LoopEvent::WaitingForInput {
274 question: output.content.clone(),
275 tool_call_id: tc.id.clone(),
276 });
277 messages.push(Message::tool(&tc.id, "[waiting for user input]"));
278 ctx.state = AgentState::Running;
279 } else {
280 messages.push(Message::tool(&tc.id, &output.content));
281 }
282 if output.done {
283 ctx.state = AgentState::Completed;
284 on_event(LoopEvent::Completed { steps: step });
285 return Ok(step);
286 }
287 }
288 Err(e) => {
289 let err_msg = format!("Tool error: {}", e);
290 step_outputs.push(err_msg.clone());
291 messages.push(Message::tool(&tc.id, &err_msg));
292 agent.after_action(ctx, &tc.name, &err_msg);
293 on_event(LoopEvent::ToolResult {
294 name: tc.name.clone(),
295 output: err_msg,
296 });
297 }
298 }
299 } else {
300 let err_msg = format!("Unknown tool: {}", tc.name);
301 step_outputs.push(err_msg.clone());
302 messages.push(Message::tool(&tc.id, &err_msg));
303 on_event(LoopEvent::ToolResult {
304 name: tc.name.clone(),
305 output: err_msg,
306 });
307 }
308 }
309
310 if detector.check_outputs(&step_outputs) {
312 ctx.state = AgentState::Failed;
313 on_event(LoopEvent::LoopDetected {
314 count: detector.output_repeat_count,
315 });
316 return Err(AgentError::LoopDetected(detector.output_repeat_count));
317 }
318 }
319
320 ctx.state = AgentState::Failed;
321 Err(AgentError::MaxSteps(config.max_steps))
322}
323
324pub async fn run_loop_interactive<F, Fut>(
332 agent: &dyn Agent,
333 tools: &ToolRegistry,
334 ctx: &mut AgentContext,
335 messages: &mut Vec<Message>,
336 config: &LoopConfig,
337 mut on_event: impl FnMut(LoopEvent),
338 mut on_input: F,
339) -> Result<usize, AgentError>
340where
341 F: FnMut(String) -> Fut,
342 Fut: std::future::Future<Output = String>,
343{
344 let mut detector = LoopDetector::new(config.loop_abort_threshold);
345 let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
346 let mut parse_retries: usize = 0;
347
348 for step in 1..=config.max_steps {
349 if config.max_messages > 0 && messages.len() > config.max_messages {
350 trim_messages(messages, config.max_messages);
351 }
352 ctx.iteration = step;
353 on_event(LoopEvent::StepStart { step });
354
355 agent.prepare_context(ctx, messages);
356
357 let active_tool_names = agent.prepare_tools(ctx, tools);
358 let filtered_tools = if active_tool_names.len() == tools.list().len() {
359 None
360 } else {
361 Some(active_tool_names)
362 };
363 let effective_tools = if let Some(ref names) = filtered_tools {
364 &tools.filter(names)
365 } else {
366 tools
367 };
368
369 let decision = match agent.decide(messages, effective_tools).await {
370 Ok(d) => {
371 parse_retries = 0;
372 d
373 }
374 Err(e) if is_recoverable_error(&e) => {
375 parse_retries += 1;
376 if parse_retries > MAX_PARSE_RETRIES {
377 return Err(e);
378 }
379 let err_msg = format!(
380 "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
381 parse_retries, MAX_PARSE_RETRIES, e
382 );
383 on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
384 err_msg.clone(),
385 ))));
386 messages.push(Message::user(&err_msg));
387 continue;
388 }
389 Err(e) => return Err(e),
390 };
391 on_event(LoopEvent::Decision(decision.clone()));
392
393 if completion_detector.check(&decision) {
394 ctx.state = AgentState::Completed;
395 if !decision.situation.is_empty() {
396 messages.push(Message::assistant(&decision.situation));
397 }
398 on_event(LoopEvent::Completed { steps: step });
399 return Ok(step);
400 }
401
402 if decision.completed || decision.tool_calls.is_empty() {
403 ctx.state = AgentState::Completed;
404 if !decision.situation.is_empty() {
405 messages.push(Message::assistant(&decision.situation));
406 }
407 on_event(LoopEvent::Completed { steps: step });
408 return Ok(step);
409 }
410
411 let sig: Vec<String> = decision
412 .tool_calls
413 .iter()
414 .map(|tc| tc.name.clone())
415 .collect();
416 match detector.check(&sig) {
417 LoopCheckResult::Abort => {
418 ctx.state = AgentState::Failed;
419 on_event(LoopEvent::LoopDetected {
420 count: detector.consecutive,
421 });
422 return Err(AgentError::LoopDetected(detector.consecutive));
423 }
424 LoopCheckResult::Tier2Warning(dominant_tool) => {
425 let hint = format!(
426 "LOOP WARNING: You are repeatedly using '{}' without making progress. \
427 Try a different approach: re-read the file with read_file to see current contents, \
428 use write_file instead of edit_file, or break the problem into smaller steps.",
429 dominant_tool
430 );
431 messages.push(Message::system(&hint));
432 }
433 LoopCheckResult::Ok => {}
434 }
435
436 messages.push(Message::assistant_with_tool_calls(
438 &decision.situation,
439 decision.tool_calls.clone(),
440 ));
441
442 let mut step_outputs: Vec<String> = Vec::new();
443 let mut early_done = false;
444
445 let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
447 .tool_calls
448 .iter()
449 .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
450
451 if !ro_calls.is_empty() {
453 let futs: Vec<_> = ro_calls
454 .iter()
455 .map(|tc| {
456 let tool = tools.get(&tc.name).unwrap();
457 let args = tc.arguments.clone();
458 let name = tc.name.clone();
459 let id = tc.id.clone();
460 async move { (id, name, tool.execute_readonly(args).await) }
461 })
462 .collect();
463
464 for (id, name, result) in join_all(futs).await {
465 match result {
466 Ok(output) => {
467 on_event(LoopEvent::ToolResult {
468 name: name.clone(),
469 output: output.content.clone(),
470 });
471 step_outputs.push(output.content.clone());
472 agent.after_action(ctx, &name, &output.content);
473 if output.waiting {
474 ctx.state = AgentState::WaitingInput;
475 on_event(LoopEvent::WaitingForInput {
476 question: output.content.clone(),
477 tool_call_id: id.clone(),
478 });
479 let response = on_input(output.content).await;
480 ctx.state = AgentState::Running;
481 messages.push(Message::tool(&id, &response));
482 } else {
483 messages.push(Message::tool(&id, &output.content));
484 }
485 if output.done {
486 early_done = true;
487 }
488 }
489 Err(e) => {
490 let err_msg = format!("Tool error: {}", e);
491 step_outputs.push(err_msg.clone());
492 messages.push(Message::tool(&id, &err_msg));
493 agent.after_action(ctx, &name, &err_msg);
494 on_event(LoopEvent::ToolResult {
495 name,
496 output: err_msg,
497 });
498 }
499 }
500 }
501 if early_done && rw_calls.is_empty() {
502 ctx.state = AgentState::Completed;
504 on_event(LoopEvent::Completed { steps: step });
505 return Ok(step);
506 }
507 }
508
509 for tc in &rw_calls {
511 if let Some(tool) = tools.get(&tc.name) {
512 match tool.execute(tc.arguments.clone(), ctx).await {
513 Ok(output) => {
514 on_event(LoopEvent::ToolResult {
515 name: tc.name.clone(),
516 output: output.content.clone(),
517 });
518 step_outputs.push(output.content.clone());
519 agent.after_action(ctx, &tc.name, &output.content);
520 if output.waiting {
521 ctx.state = AgentState::WaitingInput;
522 on_event(LoopEvent::WaitingForInput {
523 question: output.content.clone(),
524 tool_call_id: tc.id.clone(),
525 });
526 let response = on_input(output.content.clone()).await;
527 ctx.state = AgentState::Running;
528 messages.push(Message::tool(&tc.id, &response));
529 } else {
530 messages.push(Message::tool(&tc.id, &output.content));
531 }
532 if output.done {
533 ctx.state = AgentState::Completed;
534 on_event(LoopEvent::Completed { steps: step });
535 return Ok(step);
536 }
537 }
538 Err(e) => {
539 let err_msg = format!("Tool error: {}", e);
540 step_outputs.push(err_msg.clone());
541 messages.push(Message::tool(&tc.id, &err_msg));
542 agent.after_action(ctx, &tc.name, &err_msg);
543 on_event(LoopEvent::ToolResult {
544 name: tc.name.clone(),
545 output: err_msg,
546 });
547 }
548 }
549 } else {
550 let err_msg = format!("Unknown tool: {}", tc.name);
551 step_outputs.push(err_msg.clone());
552 messages.push(Message::tool(&tc.id, &err_msg));
553 on_event(LoopEvent::ToolResult {
554 name: tc.name.clone(),
555 output: err_msg,
556 });
557 }
558 }
559
560 if detector.check_outputs(&step_outputs) {
561 ctx.state = AgentState::Failed;
562 on_event(LoopEvent::LoopDetected {
563 count: detector.output_repeat_count,
564 });
565 return Err(AgentError::LoopDetected(detector.output_repeat_count));
566 }
567 }
568
569 ctx.state = AgentState::Failed;
570 Err(AgentError::MaxSteps(config.max_steps))
571}
572
573#[derive(Debug, PartialEq)]
575enum LoopCheckResult {
576 Ok,
578 Tier2Warning(String),
581 Abort,
583}
584
585struct LoopDetector {
590 threshold: usize,
591 consecutive: usize,
592 last_sig: Vec<String>,
593 tool_freq: HashMap<String, usize>,
594 total_calls: usize,
595 last_output_hash: u64,
597 output_repeat_count: usize,
598 tier2_warned: bool,
600}
601
602impl LoopDetector {
603 fn new(threshold: usize) -> Self {
604 Self {
605 threshold,
606 consecutive: 0,
607 last_sig: vec![],
608 tool_freq: HashMap::new(),
609 total_calls: 0,
610 last_output_hash: 0,
611 output_repeat_count: 0,
612 tier2_warned: false,
613 }
614 }
615
616 fn check(&mut self, sig: &[String]) -> LoopCheckResult {
620 self.total_calls += 1;
621
622 if sig == self.last_sig {
624 self.consecutive += 1;
625 } else {
626 self.consecutive = 1;
627 self.last_sig = sig.to_vec();
628 }
629 if self.consecutive >= self.threshold {
630 return LoopCheckResult::Abort;
631 }
632
633 for name in sig {
635 *self.tool_freq.entry(name.clone()).or_insert(0) += 1;
636 }
637 if self.total_calls >= self.threshold {
638 for (name, count) in &self.tool_freq {
639 if *count >= self.threshold && *count as f64 / self.total_calls as f64 > 0.9 {
640 if self.tier2_warned {
641 return LoopCheckResult::Abort;
642 }
643 self.tier2_warned = true;
644 return LoopCheckResult::Tier2Warning(name.clone());
645 }
646 }
647 }
648
649 LoopCheckResult::Ok
650 }
651
652 fn check_outputs(&mut self, outputs: &[String]) -> bool {
654 use std::collections::hash_map::DefaultHasher;
655 use std::hash::{Hash, Hasher};
656
657 let mut hasher = DefaultHasher::new();
658 outputs.hash(&mut hasher);
659 let hash = hasher.finish();
660
661 if hash == self.last_output_hash && self.last_output_hash != 0 {
662 self.output_repeat_count += 1;
663 } else {
664 self.output_repeat_count = 1;
665 self.last_output_hash = hash;
666 }
667
668 self.output_repeat_count >= self.threshold
669 }
670}
671
672struct CompletionDetector {
678 threshold: usize,
679 last_situation: String,
680 repeat_count: usize,
681}
682
683const COMPLETION_KEYWORDS: &[&str] = &[
685 "task is complete",
686 "task is done",
687 "task is finished",
688 "all done",
689 "successfully completed",
690 "nothing more",
691 "no further action",
692 "no more steps",
693];
694
695impl CompletionDetector {
696 fn new(threshold: usize) -> Self {
697 Self {
698 threshold: threshold.max(2),
699 last_situation: String::new(),
700 repeat_count: 0,
701 }
702 }
703
704 fn check(&mut self, decision: &Decision) -> bool {
706 if decision.completed || decision.tool_calls.is_empty() {
708 return false;
709 }
710
711 let sit_lower = decision.situation.to_lowercase();
713 for keyword in COMPLETION_KEYWORDS {
714 if sit_lower.contains(keyword) {
715 return true;
716 }
717 }
718
719 if !decision.situation.is_empty() && decision.situation == self.last_situation {
721 self.repeat_count += 1;
722 } else {
723 self.repeat_count = 1;
724 self.last_situation = decision.situation.clone();
725 }
726
727 self.repeat_count >= self.threshold
728 }
729}
730
731fn trim_messages(messages: &mut Vec<Message>, max: usize) {
734 use crate::types::Role;
735
736 if messages.len() <= max || max < 4 {
737 return;
738 }
739 let keep_start = 2; let remove_count = messages.len() - max + 1;
741 let mut trim_end = keep_start + remove_count;
742
743 while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
749 trim_end += 1;
750 }
751 if trim_end > keep_start && trim_end < messages.len() {
758 let last_removed = trim_end - 1;
759 if messages[last_removed].role == Role::Assistant
760 && !messages[last_removed].tool_calls.is_empty()
761 {
762 while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
765 trim_end += 1;
766 }
767 }
768 }
769
770 let removed_range = keep_start..trim_end;
771
772 let summary = format!(
773 "[{} messages trimmed from context to stay within {} message limit]",
774 trim_end - keep_start,
775 max
776 );
777
778 messages.drain(removed_range);
779 messages.insert(keep_start, Message::system(&summary));
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785 use crate::agent::{Agent, AgentError, Decision};
786 use crate::agent_tool::{Tool, ToolError, ToolOutput};
787 use crate::context::AgentContext;
788 use crate::registry::ToolRegistry;
789 use crate::types::{Message, SgrError, ToolCall};
790 use serde_json::Value;
791 use std::sync::Arc;
792 use std::sync::atomic::{AtomicUsize, Ordering};
793
794 struct CountingAgent {
795 max_calls: usize,
796 call_count: Arc<AtomicUsize>,
797 }
798
799 #[async_trait::async_trait]
800 impl Agent for CountingAgent {
801 async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
802 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
803 if n >= self.max_calls {
804 Ok(Decision {
805 situation: "done".into(),
806 task: vec![],
807 tool_calls: vec![],
808 completed: true,
809 })
810 } else {
811 Ok(Decision {
812 situation: format!("step {}", n),
813 task: vec![],
814 tool_calls: vec![ToolCall {
815 id: format!("call_{}", n),
816 name: "echo".into(),
817 arguments: serde_json::json!({"msg": "hi"}),
818 }],
819 completed: false,
820 })
821 }
822 }
823 }
824
825 struct EchoTool;
826
827 #[async_trait::async_trait]
828 impl Tool for EchoTool {
829 fn name(&self) -> &str {
830 "echo"
831 }
832 fn description(&self) -> &str {
833 "echo"
834 }
835 fn parameters_schema(&self) -> Value {
836 serde_json::json!({"type": "object"})
837 }
838 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
839 Ok(ToolOutput::text("echoed"))
840 }
841 }
842
843 #[tokio::test]
844 async fn loop_runs_and_completes() {
845 let agent = CountingAgent {
846 max_calls: 3,
847 call_count: Arc::new(AtomicUsize::new(0)),
848 };
849 let tools = ToolRegistry::new().register(EchoTool);
850 let mut ctx = AgentContext::new();
851 let mut messages = vec![Message::user("go")];
852 let config = LoopConfig::default();
853
854 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
855 .await
856 .unwrap();
857 assert_eq!(steps, 4); assert_eq!(ctx.state, AgentState::Completed);
859 }
860
861 #[tokio::test]
862 async fn loop_detects_repetition() {
863 struct LoopingAgent;
865 #[async_trait::async_trait]
866 impl Agent for LoopingAgent {
867 async fn decide(
868 &self,
869 _: &[Message],
870 _: &ToolRegistry,
871 ) -> Result<Decision, AgentError> {
872 Ok(Decision {
873 situation: "stuck".into(),
874 task: vec![],
875 tool_calls: vec![ToolCall {
876 id: "1".into(),
877 name: "echo".into(),
878 arguments: serde_json::json!({}),
879 }],
880 completed: false,
881 })
882 }
883 }
884
885 let tools = ToolRegistry::new().register(EchoTool);
886 let mut ctx = AgentContext::new();
887 let mut messages = vec![Message::user("go")];
888 let config = LoopConfig {
889 max_steps: 50,
890 loop_abort_threshold: 3,
891 auto_complete_threshold: 100, ..Default::default()
893 };
894
895 let result = run_loop(
896 &LoopingAgent,
897 &tools,
898 &mut ctx,
899 &mut messages,
900 &config,
901 |_| {},
902 )
903 .await;
904 assert!(matches!(result, Err(AgentError::LoopDetected(3))));
905 assert_eq!(ctx.state, AgentState::Failed);
906 }
907
908 #[tokio::test]
909 async fn loop_max_steps() {
910 struct NeverDoneAgent;
912 #[async_trait::async_trait]
913 impl Agent for NeverDoneAgent {
914 async fn decide(
915 &self,
916 _: &[Message],
917 _: &ToolRegistry,
918 ) -> Result<Decision, AgentError> {
919 static COUNTER: AtomicUsize = AtomicUsize::new(0);
921 let n = COUNTER.fetch_add(1, Ordering::SeqCst);
922 Ok(Decision {
923 situation: String::new(),
924 task: vec![],
925 tool_calls: vec![ToolCall {
926 id: format!("{}", n),
927 name: format!("tool_{}", n),
928 arguments: serde_json::json!({}),
929 }],
930 completed: false,
931 })
932 }
933 }
934
935 let tools = ToolRegistry::new().register(EchoTool);
936 let mut ctx = AgentContext::new();
937 let mut messages = vec![Message::user("go")];
938 let config = LoopConfig {
939 max_steps: 5,
940 loop_abort_threshold: 100,
941 ..Default::default()
942 };
943
944 let result = run_loop(
945 &NeverDoneAgent,
946 &tools,
947 &mut ctx,
948 &mut messages,
949 &config,
950 |_| {},
951 )
952 .await;
953 assert!(matches!(result, Err(AgentError::MaxSteps(5))));
954 }
955
956 #[test]
957 fn loop_detector_exact_sig() {
958 let mut d = LoopDetector::new(3);
959 let sig = vec!["bash".to_string()];
960 assert_eq!(d.check(&sig), LoopCheckResult::Ok);
961 assert_eq!(d.check(&sig), LoopCheckResult::Ok);
962 assert_eq!(d.check(&sig), LoopCheckResult::Abort); }
964
965 #[test]
966 fn loop_detector_different_sigs_reset() {
967 let mut d = LoopDetector::new(3);
968 assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
969 assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
970 assert_eq!(d.check(&["read".into()]), LoopCheckResult::Ok); assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
972 }
973
974 #[test]
975 fn loop_detector_tier2_warning_then_abort() {
976 let mut d = LoopDetector::new(3);
979 assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); assert_eq!(
985 d.check(&["edit_file".into(), "read_file".into()]),
986 LoopCheckResult::Tier2Warning("edit_file".into())
987 );
988 assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Abort);
990 }
991
992 #[test]
993 fn loop_config_default() {
994 let c = LoopConfig::default();
995 assert_eq!(c.max_steps, 50);
996 assert_eq!(c.loop_abort_threshold, 6);
997 }
998
999 #[test]
1000 fn loop_detector_output_stagnation() {
1001 let mut d = LoopDetector::new(3);
1002 let outputs = vec!["same result".to_string()];
1003 assert!(!d.check_outputs(&outputs));
1004 assert!(!d.check_outputs(&outputs));
1005 assert!(d.check_outputs(&outputs)); }
1007
1008 #[test]
1009 fn completion_detector_keyword() {
1010 let mut cd = CompletionDetector::new(3);
1011 let d = Decision {
1012 situation: "The task is complete, all files written.".into(),
1013 task: vec![],
1014 tool_calls: vec![ToolCall {
1015 id: "1".into(),
1016 name: "echo".into(),
1017 arguments: serde_json::json!({}),
1018 }],
1019 completed: false,
1020 };
1021 assert!(cd.check(&d));
1022 }
1023
1024 #[test]
1025 fn completion_detector_repeated_situation() {
1026 let mut cd = CompletionDetector::new(3);
1027 let d = Decision {
1028 situation: "working on it".into(),
1029 task: vec![],
1030 tool_calls: vec![ToolCall {
1031 id: "1".into(),
1032 name: "echo".into(),
1033 arguments: serde_json::json!({}),
1034 }],
1035 completed: false,
1036 };
1037 assert!(!cd.check(&d));
1038 assert!(!cd.check(&d));
1039 assert!(cd.check(&d)); }
1041
1042 #[test]
1043 fn completion_detector_ignores_explicit_completion() {
1044 let mut cd = CompletionDetector::new(2);
1045 let d = Decision {
1046 situation: "task is complete".into(),
1047 task: vec![],
1048 tool_calls: vec![],
1049 completed: true,
1050 };
1051 assert!(!cd.check(&d));
1053 }
1054
1055 #[test]
1056 fn trim_messages_basic() {
1057 let mut msgs: Vec<Message> = (0..10).map(|i| Message::user(format!("msg {i}"))).collect();
1058 trim_messages(&mut msgs, 6);
1059 assert_eq!(msgs.len(), 6);
1061 assert!(msgs[2].content.contains("trimmed"));
1062 }
1063
1064 #[test]
1065 fn trim_messages_no_op_when_under_limit() {
1066 let mut msgs = vec![Message::user("a"), Message::user("b")];
1067 trim_messages(&mut msgs, 10);
1068 assert_eq!(msgs.len(), 2);
1069 }
1070
1071 #[test]
1072 fn trim_messages_preserves_assistant_tool_call_pair() {
1073 use crate::types::Role;
1074 let mut msgs = vec![
1076 Message::system("sys"),
1077 Message::user("prompt"),
1078 Message::assistant_with_tool_calls(
1079 "calling",
1080 vec![
1081 ToolCall {
1082 id: "c1".into(),
1083 name: "read".into(),
1084 arguments: serde_json::json!({}),
1085 },
1086 ToolCall {
1087 id: "c2".into(),
1088 name: "read".into(),
1089 arguments: serde_json::json!({}),
1090 },
1091 ],
1092 ),
1093 Message::tool("c1", "result1"),
1094 Message::tool("c2", "result2"),
1095 Message::user("next"),
1096 Message::assistant("done"),
1097 ];
1098 trim_messages(&mut msgs, 5);
1100 for (i, msg) in msgs.iter().enumerate() {
1102 if msg.role == Role::Tool {
1103 assert!(i > 0, "Tool message at start");
1105 assert!(
1106 msgs[i - 1].role == Role::Assistant && !msgs[i - 1].tool_calls.is_empty()
1107 || msgs[i - 1].role == Role::Tool,
1108 "Orphaned Tool at position {i}"
1109 );
1110 }
1111 }
1112 }
1113
1114 #[test]
1115 fn loop_detector_output_stagnation_resets_on_change() {
1116 let mut d = LoopDetector::new(3);
1117 let a = vec!["result A".to_string()];
1118 let b = vec!["result B".to_string()];
1119 assert!(!d.check_outputs(&a));
1120 assert!(!d.check_outputs(&a));
1121 assert!(!d.check_outputs(&b)); assert!(!d.check_outputs(&a));
1123 }
1124
1125 #[tokio::test]
1126 async fn loop_handles_non_recoverable_llm_error() {
1127 struct FailingAgent;
1128 #[async_trait::async_trait]
1129 impl Agent for FailingAgent {
1130 async fn decide(
1131 &self,
1132 _: &[Message],
1133 _: &ToolRegistry,
1134 ) -> Result<Decision, AgentError> {
1135 Err(AgentError::Llm(SgrError::Api {
1136 status: 500,
1137 body: "internal server error".into(),
1138 }))
1139 }
1140 }
1141
1142 let tools = ToolRegistry::new().register(EchoTool);
1143 let mut ctx = AgentContext::new();
1144 let mut messages = vec![Message::user("go")];
1145 let config = LoopConfig::default();
1146
1147 let result = run_loop(
1148 &FailingAgent,
1149 &tools,
1150 &mut ctx,
1151 &mut messages,
1152 &config,
1153 |_| {},
1154 )
1155 .await;
1156 assert!(result.is_err());
1158 assert_eq!(messages.len(), 1); }
1160
1161 #[tokio::test]
1162 async fn loop_recovers_from_parse_error() {
1163 struct ParseRetryAgent {
1165 call_count: Arc<AtomicUsize>,
1166 }
1167 #[async_trait::async_trait]
1168 impl Agent for ParseRetryAgent {
1169 async fn decide(
1170 &self,
1171 msgs: &[Message],
1172 _: &ToolRegistry,
1173 ) -> Result<Decision, AgentError> {
1174 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1175 if n == 0 {
1176 Err(AgentError::Llm(SgrError::Schema(
1178 "Missing required field: situation".into(),
1179 )))
1180 } else {
1181 let last = msgs.last().unwrap();
1183 assert!(
1184 last.content.contains("Parse error"),
1185 "expected parse error feedback, got: {}",
1186 last.content
1187 );
1188 Ok(Decision {
1189 situation: "recovered from parse error".into(),
1190 task: vec![],
1191 tool_calls: vec![],
1192 completed: true,
1193 })
1194 }
1195 }
1196 }
1197
1198 let tools = ToolRegistry::new().register(EchoTool);
1199 let mut ctx = AgentContext::new();
1200 let mut messages = vec![Message::user("go")];
1201 let config = LoopConfig::default();
1202 let agent = ParseRetryAgent {
1203 call_count: Arc::new(AtomicUsize::new(0)),
1204 };
1205
1206 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1207 .await
1208 .unwrap();
1209 assert_eq!(steps, 2); assert_eq!(ctx.state, AgentState::Completed);
1211 }
1212
1213 #[tokio::test]
1214 async fn loop_aborts_after_max_parse_retries() {
1215 struct AlwaysFailParseAgent;
1216 #[async_trait::async_trait]
1217 impl Agent for AlwaysFailParseAgent {
1218 async fn decide(
1219 &self,
1220 _: &[Message],
1221 _: &ToolRegistry,
1222 ) -> Result<Decision, AgentError> {
1223 Err(AgentError::Llm(SgrError::Schema("bad json".into())))
1224 }
1225 }
1226
1227 let tools = ToolRegistry::new().register(EchoTool);
1228 let mut ctx = AgentContext::new();
1229 let mut messages = vec![Message::user("go")];
1230 let config = LoopConfig::default();
1231
1232 let result = run_loop(
1233 &AlwaysFailParseAgent,
1234 &tools,
1235 &mut ctx,
1236 &mut messages,
1237 &config,
1238 |_| {},
1239 )
1240 .await;
1241 assert!(result.is_err());
1242 let feedback_count = messages
1244 .iter()
1245 .filter(|m| m.content.contains("Parse error"))
1246 .count();
1247 assert_eq!(feedback_count, MAX_PARSE_RETRIES);
1248 }
1249
1250 #[tokio::test]
1251 async fn loop_feeds_tool_errors_back() {
1252 struct ErrorRecoveryAgent {
1254 call_count: Arc<AtomicUsize>,
1255 }
1256 #[async_trait::async_trait]
1257 impl Agent for ErrorRecoveryAgent {
1258 async fn decide(
1259 &self,
1260 msgs: &[Message],
1261 _: &ToolRegistry,
1262 ) -> Result<Decision, AgentError> {
1263 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1264 if n == 0 {
1265 Ok(Decision {
1267 situation: "trying".into(),
1268 task: vec![],
1269 tool_calls: vec![ToolCall {
1270 id: "1".into(),
1271 name: "nonexistent_tool".into(),
1272 arguments: serde_json::json!({}),
1273 }],
1274 completed: false,
1275 })
1276 } else {
1277 let last = msgs.last().unwrap();
1279 assert!(last.content.contains("Unknown tool"));
1280 Ok(Decision {
1281 situation: "recovered".into(),
1282 task: vec![],
1283 tool_calls: vec![],
1284 completed: true,
1285 })
1286 }
1287 }
1288 }
1289
1290 let tools = ToolRegistry::new().register(EchoTool);
1291 let mut ctx = AgentContext::new();
1292 let mut messages = vec![Message::user("go")];
1293 let config = LoopConfig::default();
1294 let agent = ErrorRecoveryAgent {
1295 call_count: Arc::new(AtomicUsize::new(0)),
1296 };
1297
1298 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1299 .await
1300 .unwrap();
1301 assert_eq!(steps, 2);
1302 assert_eq!(ctx.state, AgentState::Completed);
1303 }
1304
1305 #[tokio::test]
1306 async fn parallel_readonly_tools() {
1307 struct ReadOnlyTool {
1308 name: &'static str,
1309 }
1310
1311 #[async_trait::async_trait]
1312 impl Tool for ReadOnlyTool {
1313 fn name(&self) -> &str {
1314 self.name
1315 }
1316 fn description(&self) -> &str {
1317 "read-only tool"
1318 }
1319 fn is_read_only(&self) -> bool {
1320 true
1321 }
1322 fn parameters_schema(&self) -> Value {
1323 serde_json::json!({"type": "object"})
1324 }
1325 async fn execute(
1326 &self,
1327 _: Value,
1328 _: &mut AgentContext,
1329 ) -> Result<ToolOutput, ToolError> {
1330 Ok(ToolOutput::text(format!("{} result", self.name)))
1331 }
1332 async fn execute_readonly(&self, _: Value) -> Result<ToolOutput, ToolError> {
1333 Ok(ToolOutput::text(format!("{} result", self.name)))
1334 }
1335 }
1336
1337 struct ParallelAgent;
1338 #[async_trait::async_trait]
1339 impl Agent for ParallelAgent {
1340 async fn decide(
1341 &self,
1342 msgs: &[Message],
1343 _: &ToolRegistry,
1344 ) -> Result<Decision, AgentError> {
1345 if msgs.len() > 3 {
1346 return Ok(Decision {
1347 situation: "done".into(),
1348 task: vec![],
1349 tool_calls: vec![],
1350 completed: true,
1351 });
1352 }
1353 Ok(Decision {
1354 situation: "reading".into(),
1355 task: vec![],
1356 tool_calls: vec![
1357 ToolCall {
1358 id: "1".into(),
1359 name: "reader_a".into(),
1360 arguments: serde_json::json!({}),
1361 },
1362 ToolCall {
1363 id: "2".into(),
1364 name: "reader_b".into(),
1365 arguments: serde_json::json!({}),
1366 },
1367 ],
1368 completed: false,
1369 })
1370 }
1371 }
1372
1373 let tools = ToolRegistry::new()
1374 .register(ReadOnlyTool { name: "reader_a" })
1375 .register(ReadOnlyTool { name: "reader_b" });
1376 let mut ctx = AgentContext::new();
1377 let mut messages = vec![Message::user("read stuff")];
1378 let config = LoopConfig::default();
1379
1380 let steps = run_loop(
1381 &ParallelAgent,
1382 &tools,
1383 &mut ctx,
1384 &mut messages,
1385 &config,
1386 |_| {},
1387 )
1388 .await
1389 .unwrap();
1390 assert!(steps > 0);
1391 assert_eq!(ctx.state, AgentState::Completed);
1392 }
1393
1394 #[tokio::test]
1395 async fn loop_events_are_emitted() {
1396 let agent = CountingAgent {
1397 max_calls: 1,
1398 call_count: Arc::new(AtomicUsize::new(0)),
1399 };
1400 let tools = ToolRegistry::new().register(EchoTool);
1401 let mut ctx = AgentContext::new();
1402 let mut messages = vec![Message::user("go")];
1403 let config = LoopConfig::default();
1404
1405 let mut events = Vec::new();
1406 run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |e| {
1407 events.push(format!("{:?}", std::mem::discriminant(&e)));
1408 })
1409 .await
1410 .unwrap();
1411
1412 assert!(events.len() >= 4);
1414 }
1415}