1use crate::agent::{Agent, AgentError, Decision};
10use crate::context::{AgentContext, AgentState};
11use crate::registry::ToolRegistry;
12use crate::retry::{RetryConfig, delay_for_attempt, is_retryable};
13use crate::types::{Message, Role, SgrError};
14use futures::future::join_all;
15use std::collections::HashMap;
16
17const MAX_PARSE_RETRIES: usize = 3;
19
20const MAX_TRANSIENT_RETRIES: usize = 3;
22
23const MAX_OUTPUT_TOKENS_RECOVERIES: usize = 3;
25
26fn is_recoverable_error(e: &AgentError) -> bool {
28 matches!(
29 e,
30 AgentError::Llm(SgrError::Json(_))
31 | AgentError::Llm(SgrError::EmptyResponse)
32 | AgentError::Llm(SgrError::Schema(_))
33 )
34}
35
36async fn decide_with_retry(
39 agent: &dyn Agent,
40 messages: &[Message],
41 tools: &ToolRegistry,
42 previous_response_id: Option<&str>,
43) -> Result<(Decision, Option<String>), AgentError> {
44 let retry_config = RetryConfig {
45 max_retries: MAX_TRANSIENT_RETRIES,
46 base_delay_ms: 500,
47 max_delay_ms: 30_000,
48 };
49
50 for attempt in 0..=retry_config.max_retries {
51 match agent
52 .decide_stateful(messages, tools, previous_response_id)
53 .await
54 {
55 Ok(d) => return Ok(d),
56 Err(AgentError::Llm(sgr_err))
57 if is_retryable(&sgr_err) && attempt < retry_config.max_retries =>
58 {
59 let delay = delay_for_attempt(attempt, &retry_config, &sgr_err);
60 tracing::warn!(
61 attempt = attempt + 1,
62 max = retry_config.max_retries,
63 delay_ms = delay.as_millis() as u64,
64 "Retrying agent.decide(): {}",
65 sgr_err
66 );
67 tokio::time::sleep(delay).await;
68 }
70 Err(e) => return Err(e),
71 }
72 }
73 agent
75 .decide_stateful(messages, tools, previous_response_id)
76 .await
77}
78
79pub fn ensure_tool_result_pairing(messages: &mut Vec<Message>) {
88 let mut expected_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
90 for msg in messages.iter() {
91 if msg.role == Role::Assistant {
92 for tc in &msg.tool_calls {
93 expected_ids.insert(tc.id.clone());
94 }
95 }
96 }
97
98 let mut seen_result_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
100 let mut to_remove: Vec<usize> = Vec::new();
102
103 for (i, msg) in messages.iter().enumerate() {
104 if msg.role == Role::Tool
105 && let Some(ref id) = msg.tool_call_id
106 {
107 if !seen_result_ids.insert(id.clone()) {
108 to_remove.push(i);
110 } else if !expected_ids.contains(id) {
111 to_remove.push(i);
113 }
114 }
115 }
116
117 for i in to_remove.into_iter().rev() {
119 tracing::debug!(
120 tool_call_id = messages[i].tool_call_id.as_deref().unwrap_or("?"),
121 "Removing orphaned/duplicate tool_result"
122 );
123 messages.remove(i);
124 }
125
126 let mut i = 0;
129 while i < messages.len() {
130 if messages[i].role == Role::Assistant && !messages[i].tool_calls.is_empty() {
131 let tool_call_ids: Vec<String> = messages[i]
132 .tool_calls
133 .iter()
134 .map(|tc| tc.id.clone())
135 .collect();
136
137 let mut insert_pos = i + 1;
139 let mut found_ids: std::collections::HashSet<String> = std::collections::HashSet::new();
140 while insert_pos < messages.len() && messages[insert_pos].role == Role::Tool {
141 if let Some(ref id) = messages[insert_pos].tool_call_id {
142 found_ids.insert(id.clone());
143 }
144 insert_pos += 1;
145 }
146
147 for id in &tool_call_ids {
149 if !found_ids.contains(id) {
150 tracing::debug!(
151 tool_call_id = id.as_str(),
152 "Inserting synthetic tool_result for orphaned tool_use"
153 );
154 messages.insert(
155 insert_pos,
156 Message::tool(id, "[Tool result missing due to internal error]"),
157 );
158 insert_pos += 1;
159 }
160 }
161 i = insert_pos;
162 } else {
163 i += 1;
164 }
165 }
166}
167
168fn apply_context_modifier(
170 modifier: &crate::agent_tool::ContextModifier,
171 ctx: &mut AgentContext,
172 messages: &mut Vec<Message>,
173 effective_max_steps: &mut usize,
174) {
175 if let Some(ref injection) = modifier.system_injection {
176 messages.push(Message::user(format!("[Context update]: {injection}")));
179 }
180 for (key, value) in &modifier.custom_context {
181 ctx.set(key.clone(), value.clone());
182 }
183 if let Some(delta) = modifier.max_steps_delta {
184 if delta > 0 {
185 *effective_max_steps = effective_max_steps.saturating_add(delta as usize);
186 } else {
187 *effective_max_steps =
188 effective_max_steps.saturating_sub(delta.unsigned_abs() as usize);
189 }
190 }
191 if let Some(tokens) = modifier.max_tokens_override {
192 ctx.set(
193 crate::agent_tool::MAX_TOKENS_OVERRIDE_KEY.to_string(),
194 serde_json::Value::Number(tokens.into()),
195 );
196 }
197}
198
199#[derive(Debug, Clone)]
201pub struct LoopConfig {
202 pub max_steps: usize,
204 pub loop_abort_threshold: usize,
206 pub max_messages: usize,
209 pub auto_complete_threshold: usize,
211}
212
213impl Default for LoopConfig {
214 fn default() -> Self {
215 Self {
216 max_steps: 50,
217 loop_abort_threshold: 6,
218 max_messages: 80,
219 auto_complete_threshold: 3,
220 }
221 }
222}
223
224#[derive(Debug)]
226pub enum LoopEvent {
227 StepStart {
228 step: usize,
229 },
230 Decision(Decision),
231 ToolResult {
232 name: String,
233 output: String,
234 },
235 Completed {
236 steps: usize,
237 },
238 LoopDetected {
239 count: usize,
240 },
241 Error(AgentError),
242 WaitingForInput {
244 question: String,
245 tool_call_id: String,
246 },
247 MaxOutputTokensRecovery {
249 attempt: usize,
250 },
251 PromptTooLong {
253 message: String,
254 },
255 ContextModified {
257 tool_name: String,
258 },
259}
260
261pub async fn run_loop(
268 agent: &dyn Agent,
269 tools: &ToolRegistry,
270 ctx: &mut AgentContext,
271 messages: &mut Vec<Message>,
272 config: &LoopConfig,
273 on_event: impl FnMut(LoopEvent),
274) -> Result<usize, AgentError> {
275 run_loop_interactive(
278 agent,
279 tools,
280 ctx,
281 messages,
282 config,
283 on_event,
284 |_question: String| async { "[waiting for user input]".to_string() },
285 )
286 .await
287}
288
289pub async fn run_loop_interactive<F, Fut>(
294 agent: &dyn Agent,
295 tools: &ToolRegistry,
296 ctx: &mut AgentContext,
297 messages: &mut Vec<Message>,
298 config: &LoopConfig,
299 mut on_event: impl FnMut(LoopEvent),
300 mut on_input: F,
301) -> Result<usize, AgentError>
302where
303 F: FnMut(String) -> Fut,
304 Fut: std::future::Future<Output = String>,
305{
306 let mut detector = LoopDetector::new(config.loop_abort_threshold);
307 let mut completion_detector = CompletionDetector::new(config.auto_complete_threshold);
308 let mut parse_retries: usize = 0;
309 let mut response_id: Option<String> = None;
310 let mut max_output_tokens_recoveries: usize = 0;
311 let mut effective_max_steps = config.max_steps;
312
313 let mut step = 0;
314 while {
315 step += 1;
316 step <= effective_max_steps
317 } {
318 if config.max_messages > 0 && messages.len() > config.max_messages {
319 trim_messages(messages, config.max_messages);
320 }
321
322 ensure_tool_result_pairing(messages);
324
325 ctx.iteration = step;
326 on_event(LoopEvent::StepStart { step });
327
328 agent.prepare_context(ctx, messages);
329
330 let active_tool_names = agent.prepare_tools(ctx, tools);
331 let filtered_tools = if active_tool_names.len() == tools.list().len() {
332 None
333 } else {
334 Some(active_tool_names)
335 };
336 let effective_tools = if let Some(ref names) = filtered_tools {
337 &tools.filter(names)
338 } else {
339 tools
340 };
341
342 let decision = match decide_with_retry(
343 agent,
344 messages,
345 effective_tools,
346 response_id.as_deref(),
347 )
348 .await
349 {
350 Ok((d, new_rid)) => {
351 parse_retries = 0;
352 max_output_tokens_recoveries = 0;
353 response_id = new_rid;
354 d
355 }
356 Err(AgentError::Llm(SgrError::MaxOutputTokens { partial_content })) => {
357 max_output_tokens_recoveries += 1;
358 if max_output_tokens_recoveries > MAX_OUTPUT_TOKENS_RECOVERIES {
359 return Err(AgentError::Llm(SgrError::MaxOutputTokens {
360 partial_content,
361 }));
362 }
363 if !partial_content.is_empty() {
364 messages.push(Message::assistant(&partial_content));
365 }
366 messages.push(Message::user(
367 "Your response was cut off. Resume directly from where you stopped. \
368 No apology, no recap — pick up mid-thought.",
369 ));
370 on_event(LoopEvent::MaxOutputTokensRecovery {
371 attempt: max_output_tokens_recoveries,
372 });
373 continue;
374 }
375 Err(AgentError::Llm(SgrError::PromptTooLong(msg))) => {
376 on_event(LoopEvent::PromptTooLong {
377 message: msg.clone(),
378 });
379 return Err(AgentError::Llm(SgrError::PromptTooLong(msg)));
380 }
381 Err(e) if is_recoverable_error(&e) => {
382 parse_retries += 1;
383 if parse_retries > MAX_PARSE_RETRIES {
384 return Err(e);
385 }
386 let err_msg = format!(
387 "Parse error (attempt {}/{}): {}. Please respond with valid JSON matching the schema.",
388 parse_retries, MAX_PARSE_RETRIES, e
389 );
390 on_event(LoopEvent::Error(AgentError::Llm(SgrError::Schema(
391 err_msg.clone(),
392 ))));
393 messages.push(Message::user(&err_msg));
394 continue;
395 }
396 Err(e) => return Err(e),
397 };
398 on_event(LoopEvent::Decision(decision.clone()));
399
400 if completion_detector.check(&decision) {
401 ctx.state = AgentState::Completed;
402 if !decision.situation.is_empty() {
403 messages.push(Message::assistant(&decision.situation));
404 }
405 on_event(LoopEvent::Completed { steps: step });
406 return Ok(step);
407 }
408
409 if decision.completed || decision.tool_calls.is_empty() {
410 ctx.state = AgentState::Completed;
411 if !decision.situation.is_empty() {
412 messages.push(Message::assistant(&decision.situation));
413 }
414 on_event(LoopEvent::Completed { steps: step });
415 return Ok(step);
416 }
417
418 let sig: Vec<String> = decision
419 .tool_calls
420 .iter()
421 .map(|tc| tc.name.clone())
422 .collect();
423 match detector.check(&sig) {
424 LoopCheckResult::Abort => {
425 ctx.state = AgentState::Failed;
426 on_event(LoopEvent::LoopDetected {
427 count: detector.consecutive,
428 });
429 return Err(AgentError::LoopDetected(detector.consecutive));
430 }
431 LoopCheckResult::Tier2Warning(dominant_tool) => {
432 let hint = format!(
433 "LOOP WARNING: You are repeatedly using '{}' without making progress. \
434 Try a different approach: re-read the file with read_file to see current contents, \
435 use write_file instead of edit_file, or break the problem into smaller steps.",
436 dominant_tool
437 );
438 messages.push(Message::system(&hint));
439 }
440 LoopCheckResult::Ok => {}
441 }
442
443 messages.push(Message::assistant_with_tool_calls(
445 &decision.situation,
446 decision.tool_calls.clone(),
447 ));
448
449 let mut step_outputs: Vec<String> = Vec::new();
450 let mut early_done = false;
451
452 let (ro_calls, rw_calls): (Vec<_>, Vec<_>) = decision
454 .tool_calls
455 .iter()
456 .partition(|tc| tools.get(&tc.name).is_some_and(|t| t.is_read_only()));
457
458 if !ro_calls.is_empty() {
460 let ctx_snapshot = ctx.clone(); let futs: Vec<_> = ro_calls
462 .iter()
463 .map(|tc| {
464 let tool = tools.get(&tc.name).unwrap();
465 let args = tc.arguments.clone();
466 let name = tc.name.clone();
467 let id = tc.id.clone();
468 let ctx_ref = &ctx_snapshot;
469 async move { (id, name, tool.execute_readonly(args, ctx_ref).await) }
470 })
471 .collect();
472
473 let mut pending_modifiers: Vec<(String, crate::agent_tool::ContextModifier)> =
474 Vec::new();
475
476 for (id, name, result) in join_all(futs).await {
477 match result {
478 Ok(output) => {
479 on_event(LoopEvent::ToolResult {
480 name: name.clone(),
481 output: output.content.clone(),
482 });
483 step_outputs.push(output.content.clone());
484 agent.after_action(ctx, &name, &output.content);
485 if let Some(modifier) = output.modifier.clone()
486 && !modifier.is_empty()
487 {
488 pending_modifiers.push((name.clone(), modifier));
489 }
490 if output.waiting {
491 ctx.state = AgentState::WaitingInput;
492 on_event(LoopEvent::WaitingForInput {
493 question: output.content.clone(),
494 tool_call_id: id.clone(),
495 });
496 let response = on_input(output.content).await;
497 ctx.state = AgentState::Running;
498 messages.push(Message::tool(&id, &response));
499 } else {
500 messages.push(Message::tool(&id, &output.content));
501 }
502 if output.done {
503 early_done = true;
504 }
505 }
506 Err(e) => {
507 let err_msg = format!("Tool error: {}", e);
508 step_outputs.push(err_msg.clone());
509 messages.push(Message::tool(&id, &err_msg));
510 agent.after_action(ctx, &name, &err_msg);
511 on_event(LoopEvent::ToolResult {
512 name,
513 output: err_msg,
514 });
515 }
516 }
517 }
518
519 for (name, modifier) in pending_modifiers {
520 apply_context_modifier(&modifier, ctx, messages, &mut effective_max_steps);
521 on_event(LoopEvent::ContextModified { tool_name: name });
522 }
523
524 if early_done && rw_calls.is_empty() {
525 ctx.state = AgentState::Completed;
526 on_event(LoopEvent::Completed { steps: step });
527 return Ok(step);
528 }
529 }
530
531 for tc in &rw_calls {
533 if let Some(tool) = tools.get(&tc.name) {
534 match tool.execute(tc.arguments.clone(), ctx).await {
535 Ok(output) => {
536 on_event(LoopEvent::ToolResult {
537 name: tc.name.clone(),
538 output: output.content.clone(),
539 });
540 step_outputs.push(output.content.clone());
541 agent.after_action(ctx, &tc.name, &output.content);
542 if let Some(ref modifier) = output.modifier
543 && !modifier.is_empty()
544 {
545 apply_context_modifier(
546 modifier,
547 ctx,
548 messages,
549 &mut effective_max_steps,
550 );
551 on_event(LoopEvent::ContextModified {
552 tool_name: tc.name.clone(),
553 });
554 }
555 if output.waiting {
556 ctx.state = AgentState::WaitingInput;
557 on_event(LoopEvent::WaitingForInput {
558 question: output.content.clone(),
559 tool_call_id: tc.id.clone(),
560 });
561 let response = on_input(output.content.clone()).await;
562 ctx.state = AgentState::Running;
563 messages.push(Message::tool(&tc.id, &response));
564 } else {
565 messages.push(Message::tool(&tc.id, &output.content));
566 }
567 if output.done {
568 ctx.state = AgentState::Completed;
569 on_event(LoopEvent::Completed { steps: step });
570 return Ok(step);
571 }
572 }
573 Err(e) => {
574 let err_msg = format!("Tool error: {}", e);
575 step_outputs.push(err_msg.clone());
576 messages.push(Message::tool(&tc.id, &err_msg));
577 agent.after_action(ctx, &tc.name, &err_msg);
578 on_event(LoopEvent::ToolResult {
579 name: tc.name.clone(),
580 output: err_msg,
581 });
582 }
583 }
584 } else {
585 let err_msg = format!("Unknown tool: {}", tc.name);
586 step_outputs.push(err_msg.clone());
587 messages.push(Message::tool(&tc.id, &err_msg));
588 on_event(LoopEvent::ToolResult {
589 name: tc.name.clone(),
590 output: err_msg,
591 });
592 }
593 }
594
595 if detector.check_outputs(&step_outputs) {
596 ctx.state = AgentState::Failed;
597 on_event(LoopEvent::LoopDetected {
598 count: detector.output_repeat_count,
599 });
600 return Err(AgentError::LoopDetected(detector.output_repeat_count));
601 }
602 }
603
604 ctx.state = AgentState::Failed;
605 Err(AgentError::MaxSteps(effective_max_steps))
606}
607
608#[derive(Debug, PartialEq)]
610enum LoopCheckResult {
611 Ok,
613 Tier2Warning(String),
616 Abort,
618}
619
620struct LoopDetector {
625 threshold: usize,
626 consecutive: usize,
627 last_sig: Vec<String>,
628 tool_freq: HashMap<String, usize>,
629 total_calls: usize,
630 last_output_hash: u64,
632 output_repeat_count: usize,
633 tier2_warned: bool,
635}
636
637impl LoopDetector {
638 fn new(threshold: usize) -> Self {
639 Self {
640 threshold,
641 consecutive: 0,
642 last_sig: vec![],
643 tool_freq: HashMap::new(),
644 total_calls: 0,
645 last_output_hash: 0,
646 output_repeat_count: 0,
647 tier2_warned: false,
648 }
649 }
650
651 fn check(&mut self, sig: &[String]) -> LoopCheckResult {
655 self.total_calls += 1;
656
657 if sig == self.last_sig {
659 self.consecutive += 1;
660 } else {
661 self.consecutive = 1;
662 self.last_sig = sig.to_vec();
663 }
664 if self.consecutive >= self.threshold {
665 return LoopCheckResult::Abort;
666 }
667
668 for name in sig {
670 *self.tool_freq.entry(name.clone()).or_insert(0) += 1;
671 }
672 if self.total_calls >= self.threshold {
673 for (name, count) in &self.tool_freq {
674 if *count >= self.threshold && *count as f64 / self.total_calls as f64 > 0.9 {
675 if self.tier2_warned {
676 return LoopCheckResult::Abort;
677 }
678 self.tier2_warned = true;
679 return LoopCheckResult::Tier2Warning(name.clone());
680 }
681 }
682 }
683
684 LoopCheckResult::Ok
685 }
686
687 fn check_outputs(&mut self, outputs: &[String]) -> bool {
689 use std::collections::hash_map::DefaultHasher;
690 use std::hash::{Hash, Hasher};
691
692 let mut hasher = DefaultHasher::new();
693 outputs.hash(&mut hasher);
694 let hash = hasher.finish();
695
696 if hash == self.last_output_hash && self.last_output_hash != 0 {
697 self.output_repeat_count += 1;
698 } else {
699 self.output_repeat_count = 1;
700 self.last_output_hash = hash;
701 }
702
703 self.output_repeat_count >= self.threshold
704 }
705}
706
707struct CompletionDetector {
713 threshold: usize,
714 last_situation: String,
715 repeat_count: usize,
716}
717
718const COMPLETION_KEYWORDS: &[&str] = &[
720 "task is complete",
721 "task is done",
722 "task is finished",
723 "all done",
724 "successfully completed",
725 "nothing more",
726 "no further action",
727 "no more steps",
728];
729
730impl CompletionDetector {
731 fn new(threshold: usize) -> Self {
732 Self {
733 threshold: threshold.max(2),
734 last_situation: String::new(),
735 repeat_count: 0,
736 }
737 }
738
739 fn check(&mut self, decision: &Decision) -> bool {
741 if decision.completed || decision.tool_calls.is_empty() {
743 return false;
744 }
745
746 let sit_lower = decision.situation.to_lowercase();
748 for keyword in COMPLETION_KEYWORDS {
749 if sit_lower.contains(keyword) {
750 return true;
751 }
752 }
753
754 if !decision.situation.is_empty() && decision.situation == self.last_situation {
756 self.repeat_count += 1;
757 } else {
758 self.repeat_count = 1;
759 self.last_situation = decision.situation.clone();
760 }
761
762 self.repeat_count >= self.threshold
763 }
764}
765
766fn trim_messages(messages: &mut Vec<Message>, max: usize) {
769 if messages.len() <= max || max < 4 {
770 return;
771 }
772 let keep_start = 2; let remove_count = messages.len() - max + 1;
774 let mut trim_end = keep_start + remove_count;
775
776 while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
782 trim_end += 1;
783 }
784 if trim_end > keep_start && trim_end < messages.len() {
791 let last_removed = trim_end - 1;
792 if messages[last_removed].role == Role::Assistant
793 && !messages[last_removed].tool_calls.is_empty()
794 {
795 while trim_end < messages.len() && messages[trim_end].role == Role::Tool {
798 trim_end += 1;
799 }
800 }
801 }
802
803 let removed_range = keep_start..trim_end;
804
805 let summary = format!(
806 "[{} messages trimmed from context to stay within {} message limit]",
807 trim_end - keep_start,
808 max
809 );
810
811 messages.drain(removed_range);
812 messages.insert(keep_start, Message::system(&summary));
813}
814
815#[cfg(test)]
816mod tests {
817 use super::*;
818 use crate::agent::{Agent, AgentError, Decision};
819 use crate::agent_tool::{Tool, ToolError, ToolOutput};
820 use crate::context::AgentContext;
821 use crate::registry::ToolRegistry;
822 use crate::types::{Message, SgrError, ToolCall};
823 use serde_json::Value;
824 use std::sync::Arc;
825 use std::sync::atomic::{AtomicUsize, Ordering};
826
827 struct CountingAgent {
828 max_calls: usize,
829 call_count: Arc<AtomicUsize>,
830 }
831
832 #[async_trait::async_trait]
833 impl Agent for CountingAgent {
834 async fn decide(&self, _: &[Message], _: &ToolRegistry) -> Result<Decision, AgentError> {
835 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
836 if n >= self.max_calls {
837 Ok(Decision {
838 situation: "done".into(),
839 task: vec![],
840 tool_calls: vec![],
841 completed: true,
842 })
843 } else {
844 Ok(Decision {
845 situation: format!("step {}", n),
846 task: vec![],
847 tool_calls: vec![ToolCall {
848 id: format!("call_{}", n),
849 name: "echo".into(),
850 arguments: serde_json::json!({"msg": "hi"}),
851 }],
852 completed: false,
853 })
854 }
855 }
856 }
857
858 struct EchoTool;
859
860 #[async_trait::async_trait]
861 impl Tool for EchoTool {
862 fn name(&self) -> &str {
863 "echo"
864 }
865 fn description(&self) -> &str {
866 "echo"
867 }
868 fn parameters_schema(&self) -> Value {
869 serde_json::json!({"type": "object"})
870 }
871 async fn execute(&self, _: Value, _: &mut AgentContext) -> Result<ToolOutput, ToolError> {
872 Ok(ToolOutput::text("echoed"))
873 }
874 }
875
876 #[tokio::test]
877 async fn loop_runs_and_completes() {
878 let agent = CountingAgent {
879 max_calls: 3,
880 call_count: Arc::new(AtomicUsize::new(0)),
881 };
882 let tools = ToolRegistry::new().register(EchoTool);
883 let mut ctx = AgentContext::new();
884 let mut messages = vec![Message::user("go")];
885 let config = LoopConfig::default();
886
887 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
888 .await
889 .unwrap();
890 assert_eq!(steps, 4); assert_eq!(ctx.state, AgentState::Completed);
892 }
893
894 #[tokio::test]
895 async fn loop_detects_repetition() {
896 struct LoopingAgent;
898 #[async_trait::async_trait]
899 impl Agent for LoopingAgent {
900 async fn decide(
901 &self,
902 _: &[Message],
903 _: &ToolRegistry,
904 ) -> Result<Decision, AgentError> {
905 Ok(Decision {
906 situation: "stuck".into(),
907 task: vec![],
908 tool_calls: vec![ToolCall {
909 id: "1".into(),
910 name: "echo".into(),
911 arguments: serde_json::json!({}),
912 }],
913 completed: false,
914 })
915 }
916 }
917
918 let tools = ToolRegistry::new().register(EchoTool);
919 let mut ctx = AgentContext::new();
920 let mut messages = vec![Message::user("go")];
921 let config = LoopConfig {
922 max_steps: 50,
923 loop_abort_threshold: 3,
924 auto_complete_threshold: 100, ..Default::default()
926 };
927
928 let result = run_loop(
929 &LoopingAgent,
930 &tools,
931 &mut ctx,
932 &mut messages,
933 &config,
934 |_| {},
935 )
936 .await;
937 assert!(matches!(result, Err(AgentError::LoopDetected(3))));
938 assert_eq!(ctx.state, AgentState::Failed);
939 }
940
941 #[tokio::test]
942 async fn loop_max_steps() {
943 struct NeverDoneAgent;
945 #[async_trait::async_trait]
946 impl Agent for NeverDoneAgent {
947 async fn decide(
948 &self,
949 _: &[Message],
950 _: &ToolRegistry,
951 ) -> Result<Decision, AgentError> {
952 static COUNTER: AtomicUsize = AtomicUsize::new(0);
954 let n = COUNTER.fetch_add(1, Ordering::SeqCst);
955 Ok(Decision {
956 situation: String::new(),
957 task: vec![],
958 tool_calls: vec![ToolCall {
959 id: format!("{}", n),
960 name: format!("tool_{}", n),
961 arguments: serde_json::json!({}),
962 }],
963 completed: false,
964 })
965 }
966 }
967
968 let tools = ToolRegistry::new().register(EchoTool);
969 let mut ctx = AgentContext::new();
970 let mut messages = vec![Message::user("go")];
971 let config = LoopConfig {
972 max_steps: 5,
973 loop_abort_threshold: 100,
974 ..Default::default()
975 };
976
977 let result = run_loop(
978 &NeverDoneAgent,
979 &tools,
980 &mut ctx,
981 &mut messages,
982 &config,
983 |_| {},
984 )
985 .await;
986 assert!(matches!(result, Err(AgentError::MaxSteps(5))));
987 }
988
989 #[test]
990 fn loop_detector_exact_sig() {
991 let mut d = LoopDetector::new(3);
992 let sig = vec!["bash".to_string()];
993 assert_eq!(d.check(&sig), LoopCheckResult::Ok);
994 assert_eq!(d.check(&sig), LoopCheckResult::Ok);
995 assert_eq!(d.check(&sig), LoopCheckResult::Abort); }
997
998 #[test]
999 fn loop_detector_different_sigs_reset() {
1000 let mut d = LoopDetector::new(3);
1001 assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1002 assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1003 assert_eq!(d.check(&["read".into()]), LoopCheckResult::Ok); assert_eq!(d.check(&["bash".into()]), LoopCheckResult::Ok);
1005 }
1006
1007 #[test]
1008 fn loop_detector_tier2_warning_then_abort() {
1009 let mut d = LoopDetector::new(3);
1012 assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Ok); assert_eq!(
1018 d.check(&["edit_file".into(), "read_file".into()]),
1019 LoopCheckResult::Tier2Warning("edit_file".into())
1020 );
1021 assert_eq!(d.check(&["edit_file".into()]), LoopCheckResult::Abort);
1023 }
1024
1025 #[test]
1026 fn loop_config_default() {
1027 let c = LoopConfig::default();
1028 assert_eq!(c.max_steps, 50);
1029 assert_eq!(c.loop_abort_threshold, 6);
1030 }
1031
1032 #[test]
1033 fn loop_detector_output_stagnation() {
1034 let mut d = LoopDetector::new(3);
1035 let outputs = vec!["same result".to_string()];
1036 assert!(!d.check_outputs(&outputs));
1037 assert!(!d.check_outputs(&outputs));
1038 assert!(d.check_outputs(&outputs)); }
1040
1041 #[test]
1042 fn completion_detector_keyword() {
1043 let mut cd = CompletionDetector::new(3);
1044 let d = Decision {
1045 situation: "The task is complete, all files written.".into(),
1046 task: vec![],
1047 tool_calls: vec![ToolCall {
1048 id: "1".into(),
1049 name: "echo".into(),
1050 arguments: serde_json::json!({}),
1051 }],
1052 completed: false,
1053 };
1054 assert!(cd.check(&d));
1055 }
1056
1057 #[test]
1058 fn completion_detector_repeated_situation() {
1059 let mut cd = CompletionDetector::new(3);
1060 let d = Decision {
1061 situation: "working on it".into(),
1062 task: vec![],
1063 tool_calls: vec![ToolCall {
1064 id: "1".into(),
1065 name: "echo".into(),
1066 arguments: serde_json::json!({}),
1067 }],
1068 completed: false,
1069 };
1070 assert!(!cd.check(&d));
1071 assert!(!cd.check(&d));
1072 assert!(cd.check(&d)); }
1074
1075 #[test]
1076 fn completion_detector_ignores_explicit_completion() {
1077 let mut cd = CompletionDetector::new(2);
1078 let d = Decision {
1079 situation: "task is complete".into(),
1080 task: vec![],
1081 tool_calls: vec![],
1082 completed: true,
1083 };
1084 assert!(!cd.check(&d));
1086 }
1087
1088 #[test]
1089 fn trim_messages_basic() {
1090 let mut msgs: Vec<Message> = (0..10).map(|i| Message::user(format!("msg {i}"))).collect();
1091 trim_messages(&mut msgs, 6);
1092 assert_eq!(msgs.len(), 6);
1094 assert!(msgs[2].content.contains("trimmed"));
1095 }
1096
1097 #[test]
1098 fn trim_messages_no_op_when_under_limit() {
1099 let mut msgs = vec![Message::user("a"), Message::user("b")];
1100 trim_messages(&mut msgs, 10);
1101 assert_eq!(msgs.len(), 2);
1102 }
1103
1104 #[test]
1105 fn trim_messages_preserves_assistant_tool_call_pair() {
1106 use crate::types::Role;
1107 let mut msgs = vec![
1109 Message::system("sys"),
1110 Message::user("prompt"),
1111 Message::assistant_with_tool_calls(
1112 "calling",
1113 vec![
1114 ToolCall {
1115 id: "c1".into(),
1116 name: "read".into(),
1117 arguments: serde_json::json!({}),
1118 },
1119 ToolCall {
1120 id: "c2".into(),
1121 name: "read".into(),
1122 arguments: serde_json::json!({}),
1123 },
1124 ],
1125 ),
1126 Message::tool("c1", "result1"),
1127 Message::tool("c2", "result2"),
1128 Message::user("next"),
1129 Message::assistant("done"),
1130 ];
1131 trim_messages(&mut msgs, 5);
1133 for (i, msg) in msgs.iter().enumerate() {
1135 if msg.role == Role::Tool {
1136 assert!(i > 0, "Tool message at start");
1138 assert!(
1139 msgs[i - 1].role == Role::Assistant && !msgs[i - 1].tool_calls.is_empty()
1140 || msgs[i - 1].role == Role::Tool,
1141 "Orphaned Tool at position {i}"
1142 );
1143 }
1144 }
1145 }
1146
1147 #[test]
1148 fn loop_detector_output_stagnation_resets_on_change() {
1149 let mut d = LoopDetector::new(3);
1150 let a = vec!["result A".to_string()];
1151 let b = vec!["result B".to_string()];
1152 assert!(!d.check_outputs(&a));
1153 assert!(!d.check_outputs(&a));
1154 assert!(!d.check_outputs(&b)); assert!(!d.check_outputs(&a));
1156 }
1157
1158 #[tokio::test]
1159 async fn loop_handles_non_recoverable_llm_error() {
1160 struct FailingAgent;
1161 #[async_trait::async_trait]
1162 impl Agent for FailingAgent {
1163 async fn decide(
1164 &self,
1165 _: &[Message],
1166 _: &ToolRegistry,
1167 ) -> Result<Decision, AgentError> {
1168 Err(AgentError::Llm(SgrError::Api {
1169 status: 500,
1170 body: "internal server error".into(),
1171 }))
1172 }
1173 }
1174
1175 let tools = ToolRegistry::new().register(EchoTool);
1176 let mut ctx = AgentContext::new();
1177 let mut messages = vec![Message::user("go")];
1178 let config = LoopConfig::default();
1179
1180 let result = run_loop(
1181 &FailingAgent,
1182 &tools,
1183 &mut ctx,
1184 &mut messages,
1185 &config,
1186 |_| {},
1187 )
1188 .await;
1189 assert!(result.is_err());
1191 assert_eq!(messages.len(), 1); }
1193
1194 #[tokio::test]
1195 async fn loop_recovers_from_parse_error() {
1196 struct ParseRetryAgent {
1198 call_count: Arc<AtomicUsize>,
1199 }
1200 #[async_trait::async_trait]
1201 impl Agent for ParseRetryAgent {
1202 async fn decide(
1203 &self,
1204 msgs: &[Message],
1205 _: &ToolRegistry,
1206 ) -> Result<Decision, AgentError> {
1207 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1208 if n == 0 {
1209 Err(AgentError::Llm(SgrError::Schema(
1211 "Missing required field: situation".into(),
1212 )))
1213 } else {
1214 let last = msgs.last().unwrap();
1216 assert!(
1217 last.content.contains("Parse error"),
1218 "expected parse error feedback, got: {}",
1219 last.content
1220 );
1221 Ok(Decision {
1222 situation: "recovered from parse error".into(),
1223 task: vec![],
1224 tool_calls: vec![],
1225 completed: true,
1226 })
1227 }
1228 }
1229 }
1230
1231 let tools = ToolRegistry::new().register(EchoTool);
1232 let mut ctx = AgentContext::new();
1233 let mut messages = vec![Message::user("go")];
1234 let config = LoopConfig::default();
1235 let agent = ParseRetryAgent {
1236 call_count: Arc::new(AtomicUsize::new(0)),
1237 };
1238
1239 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1240 .await
1241 .unwrap();
1242 assert_eq!(steps, 2); assert_eq!(ctx.state, AgentState::Completed);
1244 }
1245
1246 #[tokio::test]
1247 async fn loop_aborts_after_max_parse_retries() {
1248 struct AlwaysFailParseAgent;
1249 #[async_trait::async_trait]
1250 impl Agent for AlwaysFailParseAgent {
1251 async fn decide(
1252 &self,
1253 _: &[Message],
1254 _: &ToolRegistry,
1255 ) -> Result<Decision, AgentError> {
1256 Err(AgentError::Llm(SgrError::Schema("bad json".into())))
1257 }
1258 }
1259
1260 let tools = ToolRegistry::new().register(EchoTool);
1261 let mut ctx = AgentContext::new();
1262 let mut messages = vec![Message::user("go")];
1263 let config = LoopConfig::default();
1264
1265 let result = run_loop(
1266 &AlwaysFailParseAgent,
1267 &tools,
1268 &mut ctx,
1269 &mut messages,
1270 &config,
1271 |_| {},
1272 )
1273 .await;
1274 assert!(result.is_err());
1275 let feedback_count = messages
1277 .iter()
1278 .filter(|m| m.content.contains("Parse error"))
1279 .count();
1280 assert_eq!(feedback_count, MAX_PARSE_RETRIES);
1281 }
1282
1283 #[tokio::test]
1284 async fn loop_feeds_tool_errors_back() {
1285 struct ErrorRecoveryAgent {
1287 call_count: Arc<AtomicUsize>,
1288 }
1289 #[async_trait::async_trait]
1290 impl Agent for ErrorRecoveryAgent {
1291 async fn decide(
1292 &self,
1293 msgs: &[Message],
1294 _: &ToolRegistry,
1295 ) -> Result<Decision, AgentError> {
1296 let n = self.call_count.fetch_add(1, Ordering::SeqCst);
1297 if n == 0 {
1298 Ok(Decision {
1300 situation: "trying".into(),
1301 task: vec![],
1302 tool_calls: vec![ToolCall {
1303 id: "1".into(),
1304 name: "nonexistent_tool".into(),
1305 arguments: serde_json::json!({}),
1306 }],
1307 completed: false,
1308 })
1309 } else {
1310 let last = msgs.last().unwrap();
1312 assert!(last.content.contains("Unknown tool"));
1313 Ok(Decision {
1314 situation: "recovered".into(),
1315 task: vec![],
1316 tool_calls: vec![],
1317 completed: true,
1318 })
1319 }
1320 }
1321 }
1322
1323 let tools = ToolRegistry::new().register(EchoTool);
1324 let mut ctx = AgentContext::new();
1325 let mut messages = vec![Message::user("go")];
1326 let config = LoopConfig::default();
1327 let agent = ErrorRecoveryAgent {
1328 call_count: Arc::new(AtomicUsize::new(0)),
1329 };
1330
1331 let steps = run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1332 .await
1333 .unwrap();
1334 assert_eq!(steps, 2);
1335 assert_eq!(ctx.state, AgentState::Completed);
1336 }
1337
1338 #[tokio::test]
1339 async fn parallel_readonly_tools() {
1340 struct ReadOnlyTool {
1341 name: &'static str,
1342 }
1343
1344 #[async_trait::async_trait]
1345 impl Tool for ReadOnlyTool {
1346 fn name(&self) -> &str {
1347 self.name
1348 }
1349 fn description(&self) -> &str {
1350 "read-only tool"
1351 }
1352 fn is_read_only(&self) -> bool {
1353 true
1354 }
1355 fn parameters_schema(&self) -> Value {
1356 serde_json::json!({"type": "object"})
1357 }
1358 async fn execute(
1359 &self,
1360 _: Value,
1361 _: &mut AgentContext,
1362 ) -> Result<ToolOutput, ToolError> {
1363 Ok(ToolOutput::text(format!("{} result", self.name)))
1364 }
1365 async fn execute_readonly(
1366 &self,
1367 _: Value,
1368 _ctx: &crate::context::AgentContext,
1369 ) -> Result<ToolOutput, ToolError> {
1370 Ok(ToolOutput::text(format!("{} result", self.name)))
1371 }
1372 }
1373
1374 struct ParallelAgent;
1375 #[async_trait::async_trait]
1376 impl Agent for ParallelAgent {
1377 async fn decide(
1378 &self,
1379 msgs: &[Message],
1380 _: &ToolRegistry,
1381 ) -> Result<Decision, AgentError> {
1382 if msgs.len() > 3 {
1383 return Ok(Decision {
1384 situation: "done".into(),
1385 task: vec![],
1386 tool_calls: vec![],
1387 completed: true,
1388 });
1389 }
1390 Ok(Decision {
1391 situation: "reading".into(),
1392 task: vec![],
1393 tool_calls: vec![
1394 ToolCall {
1395 id: "1".into(),
1396 name: "reader_a".into(),
1397 arguments: serde_json::json!({}),
1398 },
1399 ToolCall {
1400 id: "2".into(),
1401 name: "reader_b".into(),
1402 arguments: serde_json::json!({}),
1403 },
1404 ],
1405 completed: false,
1406 })
1407 }
1408 }
1409
1410 let tools = ToolRegistry::new()
1411 .register(ReadOnlyTool { name: "reader_a" })
1412 .register(ReadOnlyTool { name: "reader_b" });
1413 let mut ctx = AgentContext::new();
1414 let mut messages = vec![Message::user("read stuff")];
1415 let config = LoopConfig::default();
1416
1417 let steps = run_loop(
1418 &ParallelAgent,
1419 &tools,
1420 &mut ctx,
1421 &mut messages,
1422 &config,
1423 |_| {},
1424 )
1425 .await
1426 .unwrap();
1427 assert!(steps > 0);
1428 assert_eq!(ctx.state, AgentState::Completed);
1429 }
1430
1431 #[tokio::test]
1432 async fn loop_events_are_emitted() {
1433 let agent = CountingAgent {
1434 max_calls: 1,
1435 call_count: Arc::new(AtomicUsize::new(0)),
1436 };
1437 let tools = ToolRegistry::new().register(EchoTool);
1438 let mut ctx = AgentContext::new();
1439 let mut messages = vec![Message::user("go")];
1440 let config = LoopConfig::default();
1441
1442 let mut events = Vec::new();
1443 run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |e| {
1444 events.push(format!("{:?}", std::mem::discriminant(&e)));
1445 })
1446 .await
1447 .unwrap();
1448
1449 assert!(events.len() >= 4);
1451 }
1452
1453 #[tokio::test]
1454 async fn tool_output_done_stops_loop() {
1455 struct DoneTool;
1457 #[async_trait::async_trait]
1458 impl Tool for DoneTool {
1459 fn name(&self) -> &str {
1460 "done_tool"
1461 }
1462 fn description(&self) -> &str {
1463 "returns done"
1464 }
1465 fn parameters_schema(&self) -> Value {
1466 serde_json::json!({"type": "object"})
1467 }
1468 async fn execute(
1469 &self,
1470 _: Value,
1471 _: &mut AgentContext,
1472 ) -> Result<ToolOutput, ToolError> {
1473 Ok(ToolOutput::done("final answer"))
1474 }
1475 }
1476
1477 struct OneShotAgent;
1478 #[async_trait::async_trait]
1479 impl Agent for OneShotAgent {
1480 async fn decide(
1481 &self,
1482 _: &[Message],
1483 _: &ToolRegistry,
1484 ) -> Result<Decision, AgentError> {
1485 Ok(Decision {
1486 situation: "calling done tool".into(),
1487 task: vec![],
1488 tool_calls: vec![ToolCall {
1489 id: "1".into(),
1490 name: "done_tool".into(),
1491 arguments: serde_json::json!({}),
1492 }],
1493 completed: false,
1494 })
1495 }
1496 }
1497
1498 let tools = ToolRegistry::new().register(DoneTool);
1499 let mut ctx = AgentContext::new();
1500 let mut messages = vec![Message::user("go")];
1501 let config = LoopConfig::default();
1502
1503 let steps = run_loop(
1504 &OneShotAgent,
1505 &tools,
1506 &mut ctx,
1507 &mut messages,
1508 &config,
1509 |_| {},
1510 )
1511 .await
1512 .unwrap();
1513 assert_eq!(
1514 steps, 1,
1515 "Loop should stop on first step when tool returns done"
1516 );
1517 assert_eq!(ctx.state, AgentState::Completed);
1518 }
1519
1520 #[tokio::test]
1521 async fn tool_messages_formatted_correctly() {
1522 let agent = CountingAgent {
1525 max_calls: 1,
1526 call_count: Arc::new(AtomicUsize::new(0)),
1527 };
1528 let tools = ToolRegistry::new().register(EchoTool);
1529 let mut ctx = AgentContext::new();
1530 let mut messages = vec![Message::user("go")];
1531 let config = LoopConfig::default();
1532
1533 run_loop(&agent, &tools, &mut ctx, &mut messages, &config, |_| {})
1534 .await
1535 .unwrap();
1536
1537 assert!(messages.len() >= 4);
1540
1541 let assistant_tc = messages
1543 .iter()
1544 .find(|m| m.role == crate::types::Role::Assistant && !m.tool_calls.is_empty());
1545 assert!(
1546 assistant_tc.is_some(),
1547 "Should have an assistant message with tool_calls"
1548 );
1549 let atc = assistant_tc.unwrap();
1550 assert_eq!(atc.tool_calls[0].name, "echo");
1551 assert_eq!(atc.tool_calls[0].id, "call_0");
1552
1553 let tc_idx = messages
1555 .iter()
1556 .position(|m| m.role == crate::types::Role::Assistant && !m.tool_calls.is_empty())
1557 .unwrap();
1558 let tool_msg = &messages[tc_idx + 1];
1559 assert_eq!(tool_msg.role, crate::types::Role::Tool);
1560 assert_eq!(tool_msg.tool_call_id.as_deref(), Some("call_0"));
1561 assert_eq!(tool_msg.content, "echoed");
1562 }
1563
1564 #[test]
1567 fn pairing_adds_missing_tool_result() {
1568 let mut msgs = vec![
1569 Message::user("go"),
1570 Message::assistant_with_tool_calls(
1571 "calling",
1572 vec![
1573 ToolCall {
1574 id: "c1".into(),
1575 name: "bash".into(),
1576 arguments: serde_json::json!({}),
1577 },
1578 ToolCall {
1579 id: "c2".into(),
1580 name: "read".into(),
1581 arguments: serde_json::json!({}),
1582 },
1583 ],
1584 ),
1585 Message::tool("c1", "ok"),
1587 ];
1588 ensure_tool_result_pairing(&mut msgs);
1589
1590 let c2_result = msgs
1592 .iter()
1593 .find(|m| m.tool_call_id.as_deref() == Some("c2"));
1594 assert!(c2_result.is_some(), "Should have synthetic result for c2");
1595 assert!(c2_result.unwrap().content.contains("missing"));
1596 }
1597
1598 #[test]
1599 fn pairing_removes_duplicate_tool_result() {
1600 let mut msgs = vec![
1601 Message::user("go"),
1602 Message::assistant_with_tool_calls(
1603 "calling",
1604 vec![ToolCall {
1605 id: "c1".into(),
1606 name: "bash".into(),
1607 arguments: serde_json::json!({}),
1608 }],
1609 ),
1610 Message::tool("c1", "first"),
1611 Message::tool("c1", "duplicate"), ];
1613 ensure_tool_result_pairing(&mut msgs);
1614
1615 let c1_count = msgs
1616 .iter()
1617 .filter(|m| m.tool_call_id.as_deref() == Some("c1"))
1618 .count();
1619 assert_eq!(c1_count, 1, "Should remove duplicate tool_result");
1620 }
1621
1622 #[test]
1623 fn pairing_removes_orphaned_tool_result() {
1624 let mut msgs = vec![
1625 Message::user("go"),
1626 Message::tool("orphan_id", "orphaned result"), Message::assistant("done"),
1628 ];
1629 ensure_tool_result_pairing(&mut msgs);
1630
1631 let orphan = msgs
1632 .iter()
1633 .find(|m| m.tool_call_id.as_deref() == Some("orphan_id"));
1634 assert!(orphan.is_none(), "Should remove orphaned tool_result");
1635 }
1636
1637 #[test]
1638 fn pairing_noop_for_valid_transcript() {
1639 let mut msgs = vec![
1640 Message::user("go"),
1641 Message::assistant_with_tool_calls(
1642 "calling",
1643 vec![ToolCall {
1644 id: "c1".into(),
1645 name: "bash".into(),
1646 arguments: serde_json::json!({}),
1647 }],
1648 ),
1649 Message::tool("c1", "result"),
1650 Message::assistant("done"),
1651 ];
1652 let len_before = msgs.len();
1653 ensure_tool_result_pairing(&mut msgs);
1654 assert_eq!(msgs.len(), len_before, "Valid transcript should not change");
1655 }
1656
1657 #[test]
1660 fn context_modifier_system_injection() {
1661 use crate::agent_tool::ContextModifier;
1662
1663 let modifier = ContextModifier::system("Extra instructions for next step");
1664 let mut ctx = AgentContext::new();
1665 let mut messages = vec![Message::user("go")];
1666 let mut max_steps = 50;
1667
1668 apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1669
1670 assert_eq!(messages.len(), 2);
1671 assert_eq!(messages[1].role, Role::User); assert!(messages[1].content.contains("Extra instructions"));
1673 }
1674
1675 #[test]
1676 fn context_modifier_extra_steps() {
1677 use crate::agent_tool::ContextModifier;
1678
1679 let mut ctx = AgentContext::new();
1680 let mut messages = vec![];
1681 let mut max_steps = 50;
1682
1683 let modifier = ContextModifier::extra_steps(20);
1684 apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1685 assert_eq!(max_steps, 70);
1686
1687 let modifier = ContextModifier::extra_steps(-10);
1688 apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1689 assert_eq!(max_steps, 60);
1690 }
1691
1692 #[test]
1693 fn context_modifier_custom_context() {
1694 use crate::agent_tool::ContextModifier;
1695
1696 let modifier = ContextModifier::custom("my_key", serde_json::json!("my_value"));
1697 let mut ctx = AgentContext::new();
1698 let mut messages = vec![];
1699 let mut max_steps = 50;
1700
1701 apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1702
1703 assert_eq!(ctx.get("my_key").unwrap(), "my_value");
1704 }
1705
1706 #[test]
1707 fn context_modifier_is_empty() {
1708 use crate::agent_tool::ContextModifier;
1709
1710 assert!(ContextModifier::default().is_empty());
1711 assert!(!ContextModifier::system("hi").is_empty());
1712 assert!(!ContextModifier::max_tokens(100).is_empty());
1713 assert!(!ContextModifier::extra_steps(5).is_empty());
1714 assert!(!ContextModifier::custom("k", serde_json::json!("v")).is_empty());
1715 }
1716
1717 #[test]
1718 fn context_modifier_max_tokens_stored_in_context() {
1719 use crate::agent_tool::ContextModifier;
1720
1721 let modifier = ContextModifier::max_tokens(4096);
1722 let mut ctx = AgentContext::new();
1723 let mut messages = vec![];
1724 let mut max_steps = 50;
1725
1726 apply_context_modifier(&modifier, &mut ctx, &mut messages, &mut max_steps);
1727
1728 assert_eq!(ctx.max_tokens_override(), Some(4096));
1729 }
1730}