1use std::sync::Arc;
2
3use crate::common::privacy_helpers::{is_network_tool, resolve_boundary};
4use crate::loop_detection::{LoopDetectionConfig, ToolLoopDetector};
5use crate::security::redaction::redact_text;
6use crate::types::{
7 AgentConfig, AgentError, AssistantMessage, AuditEvent, AuditSink, ConversationMessage,
8 HookEvent, HookFailureMode, HookRiskTier, HookSink, LoopAction, MemoryEntry, MemoryStore,
9 MetricsSink, Provider, ResearchTrigger, StopReason, StreamSink, Tool, ToolContext,
10 ToolDefinition, ToolResultMessage, ToolSelector, ToolSummary, ToolUseRequest, UserMessage,
11};
12use crate::validation::validate_json;
13use serde_json::json;
14use std::sync::atomic::{AtomicU64, Ordering};
15use std::time::{Instant, SystemTime, UNIX_EPOCH};
16use tokio::time::{timeout, Duration};
17use tracing::{info, info_span, instrument, warn, Instrument};
18
19static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(0);
20
21fn cap_prompt(input: &str, max_chars: usize) -> (String, bool) {
22 let len = input.chars().count();
23 if len <= max_chars {
24 return (input.to_string(), false);
25 }
26 let truncated = input
27 .chars()
28 .rev()
29 .take(max_chars)
30 .collect::<Vec<_>>()
31 .into_iter()
32 .rev()
33 .collect::<String>();
34 (truncated, true)
35}
36
37fn build_provider_prompt(
38 current_prompt: &str,
39 recent_memory: &[MemoryEntry],
40 max_prompt_chars: usize,
41 context_summary: Option<&str>,
42) -> (String, bool) {
43 if recent_memory.is_empty() && context_summary.is_none() {
44 return cap_prompt(current_prompt, max_prompt_chars);
45 }
46
47 let mut prompt = String::new();
48
49 if let Some(summary) = context_summary {
50 prompt.push_str("Context summary:\n");
51 prompt.push_str(summary);
52 prompt.push_str("\n\n");
53 }
54
55 if !recent_memory.is_empty() {
56 prompt.push_str("Recent conversation:\n");
57 for entry in recent_memory.iter().rev() {
58 prompt.push_str("- ");
59 prompt.push_str(&entry.role);
60 prompt.push_str(": ");
61 prompt.push_str(&entry.content);
62 prompt.push('\n');
63 }
64 prompt.push('\n');
65 }
66
67 prompt.push_str("Current input:\n");
68 prompt.push_str(current_prompt);
69
70 cap_prompt(&prompt, max_prompt_chars)
71}
72
73fn parse_tool_calls(prompt: &str) -> Vec<(&str, &str)> {
75 prompt
76 .lines()
77 .filter_map(|line| {
78 line.strip_prefix("tool:").map(|rest| {
79 let rest = rest.trim();
80 let mut parts = rest.splitn(2, ' ');
81 let name = parts.next().unwrap_or_default();
82 let input = parts.next().unwrap_or_default();
83 (name, input)
84 })
85 })
86 .collect()
87}
88
89fn extract_tool_call_from_text(
102 text: &str,
103 known_tools: &[ToolDefinition],
104) -> Option<ToolUseRequest> {
105 let json_str = extract_json_block(text).or_else(|| extract_bare_json(text))?;
107 let obj: serde_json::Value = serde_json::from_str(json_str).ok()?;
108 let name = obj.get("name")?.as_str()?;
109
110 if !known_tools.iter().any(|t| t.name == name) {
112 return None;
113 }
114
115 let args = obj
116 .get("arguments")
117 .or_else(|| obj.get("parameters"))
118 .cloned()
119 .unwrap_or(serde_json::Value::Object(Default::default()));
120
121 Some(ToolUseRequest {
122 id: format!("text_extracted_{}", name),
123 name: name.to_string(),
124 input: args,
125 })
126}
127
128fn extract_json_block(text: &str) -> Option<&str> {
130 let start = text.find("```")?;
131 let after_fence = &text[start + 3..];
132 let content_start = after_fence.find('\n')? + 1;
134 let content = &after_fence[content_start..];
135 let end = content.find("```")?;
136 let block = content[..end].trim();
137 if block.starts_with('{') {
138 Some(block)
139 } else {
140 None
141 }
142}
143
144fn extract_bare_json(text: &str) -> Option<&str> {
146 let start = text.find('{')?;
147 let candidate = &text[start..];
148 let mut depth = 0i32;
150 let mut in_string = false;
151 let mut escape_next = false;
152 for (i, ch) in candidate.char_indices() {
153 if escape_next {
154 escape_next = false;
155 continue;
156 }
157 match ch {
158 '\\' if in_string => escape_next = true,
159 '"' => in_string = !in_string,
160 '{' if !in_string => depth += 1,
161 '}' if !in_string => {
162 depth -= 1;
163 if depth == 0 {
164 return Some(&candidate[..=i]);
165 }
166 }
167 _ => {}
168 }
169 }
170 None
171}
172
173fn single_required_string_field(schema: &serde_json::Value) -> Option<String> {
175 let required = schema.get("required")?.as_array()?;
176 if required.len() != 1 {
177 return None;
178 }
179 let field_name = required[0].as_str()?;
180 let properties = schema.get("properties")?.as_object()?;
181 let field_schema = properties.get(field_name)?;
182 if field_schema.get("type")?.as_str()? == "string" {
183 Some(field_name.to_string())
184 } else {
185 None
186 }
187}
188
189fn prepare_tool_input(tool: &dyn Tool, raw_input: &serde_json::Value) -> Result<String, String> {
198 if let Some(s) = raw_input.as_str() {
200 return Ok(s.to_string());
201 }
202
203 if let Some(schema) = tool.input_schema() {
205 if let Err(errors) = validate_json(raw_input, &schema) {
206 return Err(format!(
207 "Invalid input for tool '{}': {}",
208 tool.name(),
209 errors.join("; ")
210 ));
211 }
212
213 if let Some(field) = single_required_string_field(&schema) {
214 if let Some(val) = raw_input.get(&field).and_then(|v| v.as_str()) {
215 return Ok(val.to_string());
216 }
217 }
218 }
219
220 Ok(serde_json::to_string(raw_input).unwrap_or_default())
221}
222
223fn truncate_messages(messages: &mut Vec<ConversationMessage>, max_chars: usize) {
232 let total_chars: usize = messages.iter().map(|m| m.char_count()).sum();
233 if total_chars <= max_chars || messages.len() <= 2 {
234 return;
235 }
236
237 let mut prefix_end = 1; for (i, msg) in messages.iter().enumerate() {
244 if matches!(msg, ConversationMessage::User { .. }) {
245 prefix_end = i + 1;
246 break;
247 }
248 }
249
250 let prefix_cost: usize = messages[..prefix_end].iter().map(|m| m.char_count()).sum();
251 let mut budget = max_chars.saturating_sub(prefix_cost);
252
253 let mut keep_from_end = 0;
257 let mut in_tool_result_run = false;
258 for msg in messages[prefix_end..].iter().rev() {
259 let cost = msg.char_count();
260 let is_tool_result = matches!(msg, ConversationMessage::ToolResult(_));
261 let is_assistant_with_tools = matches!(
262 msg,
263 ConversationMessage::Assistant { tool_calls, .. } if !tool_calls.is_empty()
264 );
265
266 if is_tool_result {
267 in_tool_result_run = true;
268 }
269
270 if in_tool_result_run && is_assistant_with_tools {
274 budget = budget.saturating_sub(cost);
275 keep_from_end += 1;
276 in_tool_result_run = false;
277 continue;
278 }
279
280 if !is_tool_result {
281 in_tool_result_run = false;
282 }
283
284 if cost > budget {
285 break;
286 }
287 budget -= cost;
288 keep_from_end += 1;
289 }
290
291 if keep_from_end == 0 {
292 messages.truncate(prefix_end);
294 return;
295 }
296
297 let split_point = messages.len() - keep_from_end;
298 if split_point <= prefix_end {
299 return;
300 }
301
302 messages.drain(prefix_end..split_point);
303
304 while messages.len() > prefix_end {
309 if matches!(&messages[prefix_end], ConversationMessage::ToolResult(_)) {
310 messages.remove(prefix_end);
311 } else {
312 break;
313 }
314 }
315}
316
317fn memory_to_messages(entries: &[MemoryEntry]) -> Vec<ConversationMessage> {
320 entries
321 .iter()
322 .rev()
323 .map(|entry| {
324 if entry.role == "assistant" {
325 ConversationMessage::Assistant {
326 content: Some(entry.content.clone()),
327 tool_calls: vec![],
328 }
329 } else {
330 ConversationMessage::user(entry.content.clone())
331 }
332 })
333 .collect()
334}
335
336pub trait ToolSource: Send + Sync {
342 fn additional_tools(&self) -> Vec<Box<dyn Tool>>;
344}
345
346pub struct Agent {
347 config: AgentConfig,
348 provider: Box<dyn Provider>,
349 memory: Box<dyn MemoryStore>,
350 tools: Vec<Box<dyn Tool>>,
351 audit: Option<Box<dyn AuditSink>>,
352 hooks: Option<Box<dyn HookSink>>,
353 metrics: Option<Box<dyn MetricsSink>>,
354 loop_detection_config: Option<LoopDetectionConfig>,
355 tool_selector: Option<Box<dyn ToolSelector>>,
356 extra_tool_source: Option<Arc<dyn ToolSource>>,
357}
358
359impl Agent {
360 pub fn new(
361 config: AgentConfig,
362 provider: Box<dyn Provider>,
363 memory: Box<dyn MemoryStore>,
364 tools: Vec<Box<dyn Tool>>,
365 ) -> Self {
366 Self {
367 config,
368 provider,
369 memory,
370 tools,
371 audit: None,
372 hooks: None,
373 metrics: None,
374 loop_detection_config: None,
375 tool_selector: None,
376 extra_tool_source: None,
377 }
378 }
379
380 pub fn with_loop_detection(mut self, config: LoopDetectionConfig) -> Self {
382 self.loop_detection_config = Some(config);
383 self
384 }
385
386 pub fn with_audit(mut self, audit: Box<dyn AuditSink>) -> Self {
387 self.audit = Some(audit);
388 self
389 }
390
391 pub fn with_hooks(mut self, hooks: Box<dyn HookSink>) -> Self {
392 self.hooks = Some(hooks);
393 self
394 }
395
396 pub fn with_metrics(mut self, metrics: Box<dyn MetricsSink>) -> Self {
397 self.metrics = Some(metrics);
398 self
399 }
400
401 pub fn with_tool_selector(mut self, selector: Box<dyn ToolSelector>) -> Self {
402 self.tool_selector = Some(selector);
403 self
404 }
405
406 pub fn with_tool_source(mut self, source: Arc<dyn ToolSource>) -> Self {
411 self.extra_tool_source = Some(source);
412 self
413 }
414
415 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
417 self.tools.push(tool);
418 }
419
420 fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
423 let mut defs: Vec<ToolDefinition> = self
424 .tools
425 .iter()
426 .filter_map(|tool| ToolDefinition::from_tool(&**tool))
427 .collect();
428
429 if let Some(ref source) = self.extra_tool_source {
431 let extra = source.additional_tools();
432 for tool in &extra {
433 if let Some(def) = ToolDefinition::from_tool(&**tool) {
434 if !defs.iter().any(|d| d.name == def.name) {
436 defs.push(def);
437 }
438 }
439 }
440 }
441
442 defs
443 }
444
445 fn has_tool_definitions(&self) -> bool {
447 self.tools.iter().any(|t| t.input_schema().is_some())
448 }
449
450 async fn audit(&self, stage: &str, detail: serde_json::Value) {
451 if let Some(sink) = &self.audit {
452 let _ = sink
453 .record(AuditEvent {
454 stage: stage.to_string(),
455 detail,
456 })
457 .await;
458 }
459 }
460
461 fn next_request_id() -> String {
462 let ts_ms = SystemTime::now()
463 .duration_since(UNIX_EPOCH)
464 .unwrap_or_default()
465 .as_millis();
466 let seq = REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed);
467 format!("req-{ts_ms}-{seq}")
468 }
469
470 fn hook_risk_tier(stage: &str) -> HookRiskTier {
471 if matches!(
472 stage,
473 "before_tool_call" | "after_tool_call" | "before_plugin_call" | "after_plugin_call"
474 ) {
475 HookRiskTier::High
476 } else if matches!(
477 stage,
478 "before_provider_call"
479 | "after_provider_call"
480 | "before_memory_write"
481 | "after_memory_write"
482 | "before_run"
483 | "after_run"
484 ) {
485 HookRiskTier::Medium
486 } else {
487 HookRiskTier::Low
488 }
489 }
490
491 fn hook_failure_mode_for_stage(&self, stage: &str) -> HookFailureMode {
492 if self.config.hooks.fail_closed {
493 return HookFailureMode::Block;
494 }
495
496 match Self::hook_risk_tier(stage) {
497 HookRiskTier::Low => self.config.hooks.low_tier_mode,
498 HookRiskTier::Medium => self.config.hooks.medium_tier_mode,
499 HookRiskTier::High => self.config.hooks.high_tier_mode,
500 }
501 }
502
503 async fn hook(&self, stage: &str, detail: serde_json::Value) -> Result<(), AgentError> {
504 if !self.config.hooks.enabled {
505 return Ok(());
506 }
507 let Some(sink) = &self.hooks else {
508 return Ok(());
509 };
510
511 let event = HookEvent {
512 stage: stage.to_string(),
513 detail,
514 };
515 let hook_call = sink.record(event);
516 let mode = self.hook_failure_mode_for_stage(stage);
517 let tier = Self::hook_risk_tier(stage);
518 match timeout(
519 Duration::from_millis(self.config.hooks.timeout_ms),
520 hook_call,
521 )
522 .await
523 {
524 Ok(Ok(())) => Ok(()),
525 Ok(Err(err)) => {
526 let redacted = redact_text(&err.to_string());
527 match mode {
528 HookFailureMode::Block => Err(AgentError::Hook {
529 stage: stage.to_string(),
530 source: err,
531 }),
532 HookFailureMode::Warn => {
533 warn!(
534 stage = stage,
535 tier = ?tier,
536 mode = "warn",
537 "hook error (continuing): {redacted}"
538 );
539 self.audit(
540 "hook_error_warn",
541 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
542 )
543 .await;
544 Ok(())
545 }
546 HookFailureMode::Ignore => {
547 self.audit(
548 "hook_error_ignored",
549 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
550 )
551 .await;
552 Ok(())
553 }
554 }
555 }
556 Err(_) => match mode {
557 HookFailureMode::Block => Err(AgentError::Hook {
558 stage: stage.to_string(),
559 source: anyhow::anyhow!(
560 "hook execution timed out after {} ms",
561 self.config.hooks.timeout_ms
562 ),
563 }),
564 HookFailureMode::Warn => {
565 warn!(
566 stage = stage,
567 tier = ?tier,
568 mode = "warn",
569 timeout_ms = self.config.hooks.timeout_ms,
570 "hook timeout (continuing)"
571 );
572 self.audit(
573 "hook_timeout_warn",
574 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
575 )
576 .await;
577 Ok(())
578 }
579 HookFailureMode::Ignore => {
580 self.audit(
581 "hook_timeout_ignored",
582 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
583 )
584 .await;
585 Ok(())
586 }
587 },
588 }
589 }
590
591 fn increment_counter(&self, name: &'static str) {
592 if let Some(metrics) = &self.metrics {
593 metrics.increment_counter(name, 1);
594 }
595 }
596
597 fn observe_histogram(&self, name: &'static str, value: f64) {
598 if let Some(metrics) = &self.metrics {
599 metrics.observe_histogram(name, value);
600 }
601 }
602
603 #[instrument(skip(self, tool, tool_input, ctx), fields(tool = tool_name, request_id, iteration))]
604 async fn execute_tool(
605 &self,
606 tool: &dyn Tool,
607 tool_name: &str,
608 tool_input: &str,
609 ctx: &ToolContext,
610 request_id: &str,
611 iteration: usize,
612 ) -> Result<crate::types::ToolResult, AgentError> {
613 let tool_specific = self
616 .config
617 .tool_boundaries
618 .get(tool_name)
619 .map(|s| s.as_str())
620 .unwrap_or("");
621 let resolved = resolve_boundary(tool_specific, &self.config.privacy_boundary);
622 if resolved == "local_only" && is_network_tool(tool_name) {
623 return Err(AgentError::Tool {
624 tool: tool_name.to_string(),
625 source: anyhow::anyhow!(
626 "tool '{}' requires network access but privacy boundary is 'local_only'",
627 tool_name
628 ),
629 });
630 }
631
632 let is_plugin_call = tool_name.starts_with("plugin:");
633 self.hook(
634 "before_tool_call",
635 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
636 )
637 .await?;
638 if is_plugin_call {
639 self.hook(
640 "before_plugin_call",
641 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name}),
642 )
643 .await?;
644 }
645 self.audit(
646 "tool_execute_start",
647 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
648 )
649 .await;
650 let tool_started = Instant::now();
651 let tool_timeout_ms = self.config.tool_timeout_ms;
652 let make_tool_span = || {
653 info_span!(
654 "tool_execute",
655 tool_name = %tool_name,
656 request_id = %request_id,
657 iteration = iteration,
658 )
659 };
660 let result = if tool_timeout_ms > 0 {
661 match timeout(
662 Duration::from_millis(tool_timeout_ms),
663 tool.execute(tool_input, ctx).instrument(make_tool_span()),
664 )
665 .await
666 {
667 Ok(Ok(result)) => result,
668 Ok(Err(source)) => {
669 self.observe_histogram(
670 "tool_latency_ms",
671 tool_started.elapsed().as_secs_f64() * 1000.0,
672 );
673 self.increment_counter("tool_errors_total");
674 return Err(AgentError::Tool {
675 tool: tool_name.to_string(),
676 source,
677 });
678 }
679 Err(_elapsed) => {
680 self.observe_histogram(
681 "tool_latency_ms",
682 tool_started.elapsed().as_secs_f64() * 1000.0,
683 );
684 self.increment_counter("tool_errors_total");
685 self.increment_counter("tool_timeouts_total");
686 warn!(
687 tool = %tool_name,
688 timeout_ms = tool_timeout_ms,
689 "tool execution timed out"
690 );
691 return Err(AgentError::Tool {
692 tool: tool_name.to_string(),
693 source: anyhow::anyhow!(
694 "tool '{}' timed out after {}ms",
695 tool_name,
696 tool_timeout_ms
697 ),
698 });
699 }
700 }
701 } else {
702 match tool
703 .execute(tool_input, ctx)
704 .instrument(make_tool_span())
705 .await
706 {
707 Ok(result) => result,
708 Err(source) => {
709 self.observe_histogram(
710 "tool_latency_ms",
711 tool_started.elapsed().as_secs_f64() * 1000.0,
712 );
713 self.increment_counter("tool_errors_total");
714 return Err(AgentError::Tool {
715 tool: tool_name.to_string(),
716 source,
717 });
718 }
719 }
720 };
721 self.observe_histogram(
722 "tool_latency_ms",
723 tool_started.elapsed().as_secs_f64() * 1000.0,
724 );
725 self.audit(
726 "tool_execute_success",
727 json!({
728 "request_id": request_id,
729 "iteration": iteration,
730 "tool_name": tool_name,
731 "tool_output_len": result.output.len(),
732 "duration_ms": tool_started.elapsed().as_millis(),
733 }),
734 )
735 .await;
736 info!(
737 request_id = %request_id,
738 stage = "tool",
739 tool_name = %tool_name,
740 duration_ms = %tool_started.elapsed().as_millis(),
741 "tool execution finished"
742 );
743 self.hook(
744 "after_tool_call",
745 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name, "status": "ok"}),
746 )
747 .await?;
748 if is_plugin_call {
749 self.hook(
750 "after_plugin_call",
751 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name, "status": "ok"}),
752 )
753 .await?;
754 }
755 Ok(result)
756 }
757
758 async fn maybe_summarize_context(&self, entries: &[MemoryEntry]) -> Option<String> {
762 let cfg = &self.config.summarization;
763 if !cfg.enabled || entries.len() < cfg.min_entries_for_summarization {
764 return None;
765 }
766
767 let keep = cfg.keep_recent.min(entries.len());
768 let older = &entries[keep..];
769 if older.is_empty() {
770 return None;
771 }
772
773 let mut text = String::new();
775 for entry in older.iter().rev() {
776 text.push_str(&entry.role);
777 text.push_str(": ");
778 text.push_str(&entry.content);
779 text.push('\n');
780 }
781
782 let summarization_prompt = format!(
783 "Summarize the following conversation history in {} characters or less. \
784 Preserve key facts, decisions, and context that would be important for \
785 continuing the conversation. Be concise.\n\n{}",
786 cfg.max_summary_chars, text
787 );
788
789 let summary_timeout = Duration::from_secs(10);
791 match timeout(
792 summary_timeout,
793 self.provider.complete(&summarization_prompt),
794 )
795 .await
796 {
797 Ok(Ok(chat_result)) => {
798 let text = chat_result.output_text;
799 let summary = if text.chars().count() > cfg.max_summary_chars {
800 text.chars().take(cfg.max_summary_chars).collect::<String>()
801 } else {
802 text
803 };
804 Some(summary)
805 }
806 Ok(Err(e)) => {
807 warn!("context summarization failed, falling back to truncation: {e}");
808 None
809 }
810 Err(_) => {
811 warn!("context summarization timed out, falling back to truncation");
812 None
813 }
814 }
815 }
816
817 async fn call_provider_with_context(
818 &self,
819 prompt: &str,
820 request_id: &str,
821 stream_sink: Option<StreamSink>,
822 source_channel: Option<&str>,
823 ) -> Result<String, AgentError> {
824 let recent_memory = self
825 .memory
826 .recent_for_boundary(
827 self.config.memory_window_size,
828 &self.config.privacy_boundary,
829 source_channel,
830 )
831 .await
832 .map_err(|source| AgentError::Memory { source })?;
833 self.audit(
834 "memory_recent_loaded",
835 json!({"request_id": request_id, "items": recent_memory.len()}),
836 )
837 .await;
838 let context_summary = self.maybe_summarize_context(&recent_memory).await;
840 let (effective_memory, summary_ref) = if let Some(ref summary) = context_summary {
841 let keep = self
843 .config
844 .summarization
845 .keep_recent
846 .min(recent_memory.len());
847 (&recent_memory[..keep], Some(summary.as_str()))
848 } else {
849 (recent_memory.as_slice(), None)
850 };
851
852 let (provider_prompt, prompt_truncated) = build_provider_prompt(
853 prompt,
854 effective_memory,
855 self.config.max_prompt_chars,
856 summary_ref,
857 );
858 if prompt_truncated {
859 self.audit(
860 "provider_prompt_truncated",
861 json!({
862 "request_id": request_id,
863 "max_prompt_chars": self.config.max_prompt_chars,
864 }),
865 )
866 .await;
867 }
868 self.hook(
869 "before_provider_call",
870 json!({
871 "request_id": request_id,
872 "prompt_len": provider_prompt.len(),
873 "memory_items": recent_memory.len(),
874 "prompt_truncated": prompt_truncated
875 }),
876 )
877 .await?;
878 self.audit(
879 "provider_call_start",
880 json!({
881 "request_id": request_id,
882 "prompt_len": provider_prompt.len(),
883 "memory_items": recent_memory.len(),
884 "prompt_truncated": prompt_truncated
885 }),
886 )
887 .await;
888 let provider_started = Instant::now();
889 let provider_result = if let Some(sink) = stream_sink {
890 self.provider
891 .complete_streaming(&provider_prompt, sink)
892 .await
893 } else {
894 self.provider
895 .complete_with_reasoning(&provider_prompt, &self.config.reasoning)
896 .await
897 };
898 let completion = match provider_result {
899 Ok(result) => result,
900 Err(source) => {
901 self.observe_histogram(
902 "provider_latency_ms",
903 provider_started.elapsed().as_secs_f64() * 1000.0,
904 );
905 self.increment_counter("provider_errors_total");
906 return Err(AgentError::Provider { source });
907 }
908 };
909 self.observe_histogram(
910 "provider_latency_ms",
911 provider_started.elapsed().as_secs_f64() * 1000.0,
912 );
913 self.audit(
914 "provider_call_success",
915 json!({
916 "request_id": request_id,
917 "response_len": completion.output_text.len(),
918 "duration_ms": provider_started.elapsed().as_millis(),
919 }),
920 )
921 .await;
922 info!(
923 request_id = %request_id,
924 stage = "provider",
925 duration_ms = %provider_started.elapsed().as_millis(),
926 "provider call finished"
927 );
928 self.hook(
929 "after_provider_call",
930 json!({"request_id": request_id, "response_len": completion.output_text.len(), "status": "ok"}),
931 )
932 .await?;
933 Ok(completion.output_text)
934 }
935
936 async fn write_to_memory(
937 &self,
938 role: &str,
939 content: &str,
940 request_id: &str,
941 source_channel: Option<&str>,
942 conversation_id: &str,
943 agent_id: Option<&str>,
944 ) -> Result<(), AgentError> {
945 self.hook(
946 "before_memory_write",
947 json!({"request_id": request_id, "role": role}),
948 )
949 .await?;
950 self.memory
951 .append(MemoryEntry {
952 role: role.to_string(),
953 content: content.to_string(),
954 privacy_boundary: self.config.privacy_boundary.clone(),
955 source_channel: source_channel.map(String::from),
956 conversation_id: conversation_id.to_string(),
957 created_at: None,
958 expires_at: None,
959 org_id: String::new(),
960 agent_id: agent_id.unwrap_or("").to_string(),
961 embedding: None,
962 })
963 .await
964 .map_err(|source| AgentError::Memory { source })?;
965 self.hook(
966 "after_memory_write",
967 json!({"request_id": request_id, "role": role}),
968 )
969 .await?;
970 self.audit(
971 &format!("memory_append_{role}"),
972 json!({"request_id": request_id}),
973 )
974 .await;
975 Ok(())
976 }
977
978 fn should_research(&self, user_text: &str) -> bool {
979 if !self.config.research.enabled {
980 return false;
981 }
982 match self.config.research.trigger {
983 ResearchTrigger::Never => false,
984 ResearchTrigger::Always => true,
985 ResearchTrigger::Keywords => {
986 let lower = user_text.to_lowercase();
987 self.config
988 .research
989 .keywords
990 .iter()
991 .any(|kw| lower.contains(&kw.to_lowercase()))
992 }
993 ResearchTrigger::Length => user_text.len() >= self.config.research.min_message_length,
994 ResearchTrigger::Question => user_text.trim_end().ends_with('?'),
995 }
996 }
997
998 async fn run_research_phase(
999 &self,
1000 user_text: &str,
1001 ctx: &ToolContext,
1002 request_id: &str,
1003 ) -> Result<String, AgentError> {
1004 self.audit(
1005 "research_phase_start",
1006 json!({
1007 "request_id": request_id,
1008 "max_iterations": self.config.research.max_iterations,
1009 }),
1010 )
1011 .await;
1012 self.increment_counter("research_phase_started");
1013
1014 let research_prompt = format!(
1015 "You are in RESEARCH mode. The user asked: \"{user_text}\"\n\
1016 Gather relevant information using available tools. \
1017 Respond with tool: calls to collect data. \
1018 When done gathering, summarize your findings without a tool: prefix."
1019 );
1020 let mut prompt = self
1021 .call_provider_with_context(
1022 &research_prompt,
1023 request_id,
1024 None,
1025 ctx.source_channel.as_deref(),
1026 )
1027 .await?;
1028 let mut findings: Vec<String> = Vec::new();
1029
1030 for iteration in 0..self.config.research.max_iterations {
1031 if !prompt.starts_with("tool:") {
1032 findings.push(prompt.clone());
1033 break;
1034 }
1035
1036 let calls = parse_tool_calls(&prompt);
1037 if calls.is_empty() {
1038 break;
1039 }
1040
1041 let (name, input) = calls[0];
1042 if let Some(tool) = self.tools.iter().find(|t| t.name() == name) {
1043 let result = self
1044 .execute_tool(&**tool, name, input, ctx, request_id, iteration)
1045 .await?;
1046 findings.push(format!("{name}: {}", result.output));
1047
1048 if self.config.research.show_progress {
1049 info!(iteration, tool = name, "research phase: tool executed");
1050 self.audit(
1051 "research_phase_iteration",
1052 json!({
1053 "request_id": request_id,
1054 "iteration": iteration,
1055 "tool_name": name,
1056 }),
1057 )
1058 .await;
1059 }
1060
1061 let next_prompt = format!(
1062 "Research iteration {iteration}: tool `{name}` returned: {}\n\
1063 Continue researching or summarize findings.",
1064 result.output
1065 );
1066 prompt = self
1067 .call_provider_with_context(
1068 &next_prompt,
1069 request_id,
1070 None,
1071 ctx.source_channel.as_deref(),
1072 )
1073 .await?;
1074 } else {
1075 break;
1076 }
1077 }
1078
1079 self.audit(
1080 "research_phase_complete",
1081 json!({
1082 "request_id": request_id,
1083 "findings_count": findings.len(),
1084 }),
1085 )
1086 .await;
1087 self.increment_counter("research_phase_completed");
1088
1089 Ok(findings.join("\n"))
1090 }
1091
1092 #[instrument(
1095 skip(self, user_text, research_context, ctx, stream_sink),
1096 fields(request_id)
1097 )]
1098 async fn respond_with_tools(
1099 &self,
1100 request_id: &str,
1101 user_text: &str,
1102 research_context: &str,
1103 ctx: &ToolContext,
1104 stream_sink: Option<StreamSink>,
1105 ) -> Result<AssistantMessage, AgentError> {
1106 self.audit(
1107 "structured_tool_use_start",
1108 json!({
1109 "request_id": request_id,
1110 "max_tool_iterations": self.config.max_tool_iterations,
1111 }),
1112 )
1113 .await;
1114
1115 let all_tool_definitions = self.build_tool_definitions();
1116
1117 let tool_definitions = if let Some(ref selector) = self.tool_selector {
1119 let summaries: Vec<ToolSummary> = all_tool_definitions
1120 .iter()
1121 .map(|td| ToolSummary {
1122 name: td.name.clone(),
1123 description: td.description.clone(),
1124 })
1125 .collect();
1126 match selector.select(user_text, &summaries).await {
1127 Ok(selected_names) => {
1128 let selected: Vec<ToolDefinition> = all_tool_definitions
1129 .iter()
1130 .filter(|td| selected_names.contains(&td.name))
1131 .cloned()
1132 .collect();
1133 info!(
1134 total = all_tool_definitions.len(),
1135 selected = selected.len(),
1136 mode = %self.config.tool_selection,
1137 "tool selection applied"
1138 );
1139 selected
1140 }
1141 Err(e) => {
1142 warn!(error = %e, "tool selection failed, falling back to all tools");
1143 all_tool_definitions
1144 }
1145 }
1146 } else {
1147 all_tool_definitions
1148 };
1149
1150 let recent_memory = if let Some(ref cid) = ctx.conversation_id {
1152 self.memory
1153 .recent_for_conversation(cid, self.config.memory_window_size)
1154 .await
1155 .map_err(|source| AgentError::Memory { source })?
1156 } else {
1157 self.memory
1158 .recent_for_boundary(
1159 self.config.memory_window_size,
1160 &self.config.privacy_boundary,
1161 ctx.source_channel.as_deref(),
1162 )
1163 .await
1164 .map_err(|source| AgentError::Memory { source })?
1165 };
1166 self.audit(
1167 "memory_recent_loaded",
1168 json!({"request_id": request_id, "items": recent_memory.len()}),
1169 )
1170 .await;
1171
1172 let mut messages: Vec<ConversationMessage> = Vec::new();
1173
1174 if let Some(ref sp) = self.config.system_prompt {
1176 messages.push(ConversationMessage::System {
1177 content: sp.clone(),
1178 });
1179 }
1180
1181 messages.extend(memory_to_messages(&recent_memory));
1182
1183 if !research_context.is_empty() {
1184 messages.push(ConversationMessage::user(format!(
1185 "Research findings:\n{research_context}",
1186 )));
1187 }
1188
1189 messages.push(ConversationMessage::user(user_text.to_string()));
1190
1191 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1192 let mut failure_streak: usize = 0;
1193 let mut loop_detector = self
1194 .loop_detection_config
1195 .as_ref()
1196 .map(|cfg| ToolLoopDetector::new(cfg.clone()));
1197 let mut restricted_tools: Vec<String> = Vec::new();
1198
1199 for iteration in 0..self.config.max_tool_iterations {
1200 if ctx.is_cancelled() {
1202 warn!(request_id = %request_id, "agent execution cancelled");
1203 self.audit(
1204 "execution_cancelled",
1205 json!({"request_id": request_id, "iteration": iteration}),
1206 )
1207 .await;
1208 return Ok(AssistantMessage {
1209 text: "[Execution cancelled]".to_string(),
1210 });
1211 }
1212
1213 let last_tool_result_chars = messages.last().and_then(|m| {
1218 if matches!(m, ConversationMessage::ToolResult(_)) {
1219 Some(m.char_count())
1220 } else {
1221 None
1222 }
1223 });
1224 let pre_truncate_len = messages.len();
1225 truncate_messages(&mut messages, self.config.max_prompt_chars);
1226 if let Some(result_chars) = last_tool_result_chars {
1227 let still_present = messages
1228 .last()
1229 .is_some_and(|m| matches!(m, ConversationMessage::ToolResult(_)));
1230 if !still_present || messages.len() < pre_truncate_len {
1231 warn!(
1232 request_id = %request_id,
1233 iteration = iteration,
1234 tool_result_chars = result_chars,
1235 max_prompt_chars = self.config.max_prompt_chars,
1236 "tool result truncated from context: tool result ({} chars) exceeds \
1237 available budget within max_prompt_chars={}; the model will likely \
1238 call the same tool again — raise max_prompt_chars in your config",
1239 result_chars,
1240 self.config.max_prompt_chars,
1241 );
1242 }
1243 }
1244
1245 self.hook(
1246 "before_provider_call",
1247 json!({
1248 "request_id": request_id,
1249 "iteration": iteration,
1250 "message_count": messages.len(),
1251 "tool_count": tool_definitions.len(),
1252 }),
1253 )
1254 .await?;
1255 self.audit(
1256 "provider_call_start",
1257 json!({
1258 "request_id": request_id,
1259 "iteration": iteration,
1260 "message_count": messages.len(),
1261 }),
1262 )
1263 .await;
1264
1265 let effective_tools: Vec<ToolDefinition> = if restricted_tools.is_empty() {
1267 tool_definitions.clone()
1268 } else {
1269 tool_definitions
1270 .iter()
1271 .filter(|td| !restricted_tools.contains(&td.name))
1272 .cloned()
1273 .collect()
1274 };
1275
1276 let provider_span = info_span!(
1277 "provider_call",
1278 request_id = %request_id,
1279 iteration = iteration,
1280 tool_count = effective_tools.len(),
1281 );
1282 let _provider_guard = provider_span.enter();
1283 let provider_started = Instant::now();
1284 let provider_result = if let Some(ref sink) = stream_sink {
1285 self.provider
1286 .complete_streaming_with_tools(
1287 &messages,
1288 &effective_tools,
1289 &self.config.reasoning,
1290 sink.clone(),
1291 )
1292 .await
1293 } else {
1294 self.provider
1295 .complete_with_tools(&messages, &effective_tools, &self.config.reasoning)
1296 .await
1297 };
1298 let chat_result = match provider_result {
1299 Ok(result) => result,
1300 Err(source) => {
1301 self.observe_histogram(
1302 "provider_latency_ms",
1303 provider_started.elapsed().as_secs_f64() * 1000.0,
1304 );
1305 self.increment_counter("provider_errors_total");
1306 return Err(AgentError::Provider { source });
1307 }
1308 };
1309 self.observe_histogram(
1310 "provider_latency_ms",
1311 provider_started.elapsed().as_secs_f64() * 1000.0,
1312 );
1313
1314 let iter_tokens = chat_result.input_tokens + chat_result.output_tokens;
1316 if iter_tokens > 0 {
1317 ctx.add_tokens(iter_tokens);
1318 }
1319
1320 if let Some(ref calc) = self.config.cost_calculator {
1322 let cost = calc(chat_result.input_tokens, chat_result.output_tokens);
1323 if cost > 0 {
1324 ctx.add_cost(cost);
1325 }
1326 }
1327
1328 if let Some(reason) = ctx.budget_exceeded() {
1330 warn!(
1331 request_id = %request_id,
1332 iteration = iteration,
1333 reason = %reason,
1334 "budget exceeded — force-completing run"
1335 );
1336 return Err(AgentError::BudgetExceeded { reason });
1337 }
1338
1339 info!(
1340 request_id = %request_id,
1341 iteration = iteration,
1342 stop_reason = ?chat_result.stop_reason,
1343 tool_calls = chat_result.tool_calls.len(),
1344 tokens_this_call = iter_tokens,
1345 total_tokens = ctx.current_tokens(),
1346 cost_microdollars = ctx.current_cost(),
1347 "structured provider call finished"
1348 );
1349 self.hook(
1350 "after_provider_call",
1351 json!({
1352 "request_id": request_id,
1353 "iteration": iteration,
1354 "response_len": chat_result.output_text.len(),
1355 "tool_calls": chat_result.tool_calls.len(),
1356 "status": "ok",
1357 }),
1358 )
1359 .await?;
1360
1361 let mut chat_result = chat_result;
1364 if chat_result.tool_calls.is_empty()
1365 || chat_result.stop_reason == Some(StopReason::EndTurn)
1366 {
1367 if chat_result.tool_calls.is_empty() && !chat_result.output_text.is_empty() {
1371 if let Some(extracted) =
1372 extract_tool_call_from_text(&chat_result.output_text, &effective_tools)
1373 {
1374 info!(
1375 request_id = %request_id,
1376 tool = %extracted.name,
1377 "extracted tool call from text output (local model fallback)"
1378 );
1379 chat_result.tool_calls = vec![extracted];
1380 chat_result.stop_reason = Some(StopReason::ToolUse);
1381 }
1383 }
1384
1385 if chat_result.tool_calls.is_empty() {
1387 let response_text = chat_result.output_text;
1388 self.write_to_memory(
1389 "assistant",
1390 &response_text,
1391 request_id,
1392 ctx.source_channel.as_deref(),
1393 ctx.conversation_id.as_deref().unwrap_or(""),
1394 ctx.agent_id.as_deref(),
1395 )
1396 .await?;
1397 self.audit("respond_success", json!({"request_id": request_id}))
1398 .await;
1399 self.hook(
1400 "before_response_emit",
1401 json!({"request_id": request_id, "response_len": response_text.len()}),
1402 )
1403 .await?;
1404 let response = AssistantMessage {
1405 text: response_text,
1406 };
1407 self.hook(
1408 "after_response_emit",
1409 json!({"request_id": request_id, "response_len": response.text.len()}),
1410 )
1411 .await?;
1412 return Ok(response);
1413 }
1414 }
1415
1416 messages.push(ConversationMessage::Assistant {
1418 content: if chat_result.output_text.is_empty() {
1419 None
1420 } else {
1421 Some(chat_result.output_text.clone())
1422 },
1423 tool_calls: chat_result.tool_calls.clone(),
1424 });
1425
1426 let tool_calls = &chat_result.tool_calls;
1428 let has_gated = tool_calls
1429 .iter()
1430 .any(|tc| self.config.gated_tools.contains(&tc.name));
1431 let use_parallel = self.config.parallel_tools && tool_calls.len() > 1 && !has_gated;
1432
1433 let mut tool_results: Vec<ToolResultMessage> = Vec::new();
1434
1435 if use_parallel {
1436 let futs: Vec<_> = tool_calls
1437 .iter()
1438 .map(|tc| {
1439 let tool = self.tools.iter().find(|t| t.name() == tc.name);
1440 async move {
1441 match tool {
1442 Some(tool) => {
1443 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1444 Ok(s) => s,
1445 Err(validation_err) => {
1446 return (
1447 tc.name.clone(),
1448 String::new(),
1449 ToolResultMessage {
1450 tool_use_id: tc.id.clone(),
1451 content: validation_err,
1452 is_error: true,
1453 },
1454 );
1455 }
1456 };
1457 match tool.execute(&input_str, ctx).await {
1458 Ok(result) => (
1459 tc.name.clone(),
1460 input_str,
1461 ToolResultMessage {
1462 tool_use_id: tc.id.clone(),
1463 content: result.output,
1464 is_error: false,
1465 },
1466 ),
1467 Err(e) => (
1468 tc.name.clone(),
1469 input_str,
1470 ToolResultMessage {
1471 tool_use_id: tc.id.clone(),
1472 content: format!("Error: {e}"),
1473 is_error: true,
1474 },
1475 ),
1476 }
1477 }
1478 None => (
1479 tc.name.clone(),
1480 String::new(),
1481 ToolResultMessage {
1482 tool_use_id: tc.id.clone(),
1483 content: format!("Tool '{}' not found", tc.name),
1484 is_error: true,
1485 },
1486 ),
1487 }
1488 }
1489 })
1490 .collect();
1491 let results = futures_util::future::join_all(futs).await;
1492 for (name, input_str, result_msg) in results {
1493 if result_msg.is_error {
1494 failure_streak += 1;
1495 } else {
1496 failure_streak = 0;
1497 tool_history.push((name, input_str, result_msg.content.clone()));
1498 }
1499 tool_results.push(result_msg);
1500 }
1501 } else {
1502 for tc in tool_calls {
1504 let result_msg = match self.tools.iter().find(|t| t.name() == tc.name) {
1505 Some(tool) => {
1506 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1507 Ok(s) => s,
1508 Err(validation_err) => {
1509 failure_streak += 1;
1510 tool_results.push(ToolResultMessage {
1511 tool_use_id: tc.id.clone(),
1512 content: validation_err,
1513 is_error: true,
1514 });
1515 continue;
1516 }
1517 };
1518 self.audit(
1519 "tool_requested",
1520 json!({
1521 "request_id": request_id,
1522 "iteration": iteration,
1523 "tool_name": tc.name,
1524 "tool_input_len": input_str.len(),
1525 }),
1526 )
1527 .await;
1528
1529 match self
1530 .execute_tool(
1531 &**tool, &tc.name, &input_str, ctx, request_id, iteration,
1532 )
1533 .await
1534 {
1535 Ok(result) => {
1536 failure_streak = 0;
1537 tool_history.push((
1538 tc.name.clone(),
1539 input_str,
1540 result.output.clone(),
1541 ));
1542 ToolResultMessage {
1543 tool_use_id: tc.id.clone(),
1544 content: result.output,
1545 is_error: false,
1546 }
1547 }
1548 Err(e) => {
1549 failure_streak += 1;
1550 ToolResultMessage {
1551 tool_use_id: tc.id.clone(),
1552 content: format!("Error: {e}"),
1553 is_error: true,
1554 }
1555 }
1556 }
1557 }
1558 None => {
1559 self.audit(
1560 "tool_not_found",
1561 json!({
1562 "request_id": request_id,
1563 "iteration": iteration,
1564 "tool_name": tc.name,
1565 }),
1566 )
1567 .await;
1568 failure_streak += 1;
1569 ToolResultMessage {
1570 tool_use_id: tc.id.clone(),
1571 content: format!(
1572 "Tool '{}' not found. Available tools: {}",
1573 tc.name,
1574 tool_definitions
1575 .iter()
1576 .map(|d| d.name.as_str())
1577 .collect::<Vec<_>>()
1578 .join(", ")
1579 ),
1580 is_error: true,
1581 }
1582 }
1583 };
1584 tool_results.push(result_msg);
1585 }
1586 }
1587
1588 for result in &tool_results {
1590 messages.push(ConversationMessage::ToolResult(result.clone()));
1591 }
1592
1593 if let Some(ref mut detector) = loop_detector {
1595 let tool_calls_this_iter = &chat_result.tool_calls;
1597 let mut worst_action = LoopAction::Continue;
1598 for tc in tool_calls_this_iter {
1599 let action = detector.check(
1600 &tc.name,
1601 &tc.input,
1602 ctx.current_tokens(),
1603 ctx.current_cost(),
1604 );
1605 if crate::loop_detection::severity(&action)
1606 > crate::loop_detection::severity(&worst_action)
1607 {
1608 worst_action = action;
1609 }
1610 }
1611
1612 match worst_action {
1613 LoopAction::Continue => {}
1614 LoopAction::InjectMessage(ref msg) => {
1615 warn!(
1616 request_id = %request_id,
1617 "tiered loop detection: injecting message"
1618 );
1619 self.audit(
1620 "loop_detection_inject",
1621 json!({
1622 "request_id": request_id,
1623 "iteration": iteration,
1624 "message": msg,
1625 }),
1626 )
1627 .await;
1628 messages.push(ConversationMessage::user(format!("SYSTEM NOTICE: {msg}")));
1629 }
1630 LoopAction::RestrictTools(ref tools) => {
1631 warn!(
1632 request_id = %request_id,
1633 tools = ?tools,
1634 "tiered loop detection: restricting tools"
1635 );
1636 self.audit(
1637 "loop_detection_restrict",
1638 json!({
1639 "request_id": request_id,
1640 "iteration": iteration,
1641 "restricted_tools": tools,
1642 }),
1643 )
1644 .await;
1645 restricted_tools.extend(tools.iter().cloned());
1646 messages.push(ConversationMessage::user(format!(
1647 "SYSTEM NOTICE: The following tools have been temporarily \
1648 restricted due to repetitive usage: {}. Try a different approach.",
1649 tools.join(", ")
1650 )));
1651 }
1652 LoopAction::ForceComplete(ref reason) => {
1653 warn!(
1654 request_id = %request_id,
1655 reason = %reason,
1656 "tiered loop detection: force completing"
1657 );
1658 self.audit(
1659 "loop_detection_force_complete",
1660 json!({
1661 "request_id": request_id,
1662 "iteration": iteration,
1663 "reason": reason,
1664 }),
1665 )
1666 .await;
1667 messages.push(ConversationMessage::user(format!(
1669 "SYSTEM NOTICE: {reason}. Provide your best answer now \
1670 without calling any more tools."
1671 )));
1672 truncate_messages(&mut messages, self.config.max_prompt_chars);
1673 let final_result = if let Some(ref sink) = stream_sink {
1674 self.provider
1675 .complete_streaming_with_tools(
1676 &messages,
1677 &[],
1678 &self.config.reasoning,
1679 sink.clone(),
1680 )
1681 .await
1682 .map_err(|source| AgentError::Provider { source })?
1683 } else {
1684 self.provider
1685 .complete_with_tools(&messages, &[], &self.config.reasoning)
1686 .await
1687 .map_err(|source| AgentError::Provider { source })?
1688 };
1689 let response_text = final_result.output_text;
1690 self.write_to_memory(
1691 "assistant",
1692 &response_text,
1693 request_id,
1694 ctx.source_channel.as_deref(),
1695 ctx.conversation_id.as_deref().unwrap_or(""),
1696 ctx.agent_id.as_deref(),
1697 )
1698 .await?;
1699 return Ok(AssistantMessage {
1700 text: response_text,
1701 });
1702 }
1703 }
1704 }
1705
1706 if self.config.loop_detection_no_progress_threshold > 0 {
1708 let threshold = self.config.loop_detection_no_progress_threshold;
1709 if tool_history.len() >= threshold {
1710 let recent = &tool_history[tool_history.len() - threshold..];
1711 let all_same = recent.iter().all(|entry| {
1712 entry.0 == recent[0].0 && entry.1 == recent[0].1 && entry.2 == recent[0].2
1713 });
1714 if all_same {
1715 warn!(
1716 tool_name = recent[0].0,
1717 threshold, "structured loop detection: no-progress threshold reached"
1718 );
1719 self.audit(
1720 "loop_detection_no_progress",
1721 json!({
1722 "request_id": request_id,
1723 "tool_name": recent[0].0,
1724 "threshold": threshold,
1725 }),
1726 )
1727 .await;
1728 messages.push(ConversationMessage::user(format!(
1729 "SYSTEM NOTICE: Loop detected — the tool `{}` was called \
1730 {} times with identical arguments and output. You are stuck \
1731 in a loop. Stop calling this tool and try a different approach, \
1732 or provide your best answer with the information you already have.",
1733 recent[0].0, threshold
1734 )));
1735 }
1736 }
1737 }
1738
1739 if self.config.loop_detection_ping_pong_cycles > 0 {
1741 let cycles = self.config.loop_detection_ping_pong_cycles;
1742 let needed = cycles * 2;
1743 if tool_history.len() >= needed {
1744 let recent = &tool_history[tool_history.len() - needed..];
1745 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1746 && recent[0].0 != recent[1].0;
1747 if is_ping_pong {
1748 warn!(
1749 tool_a = recent[0].0,
1750 tool_b = recent[1].0,
1751 cycles,
1752 "structured loop detection: ping-pong pattern detected"
1753 );
1754 self.audit(
1755 "loop_detection_ping_pong",
1756 json!({
1757 "request_id": request_id,
1758 "tool_a": recent[0].0,
1759 "tool_b": recent[1].0,
1760 "cycles": cycles,
1761 }),
1762 )
1763 .await;
1764 messages.push(ConversationMessage::user(format!(
1765 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1766 alternating for {} cycles in a ping-pong pattern. Stop this \
1767 alternation and try a different approach, or provide your best \
1768 answer with the information you already have.",
1769 recent[0].0, recent[1].0, cycles
1770 )));
1771 }
1772 }
1773 }
1774
1775 if self.config.loop_detection_failure_streak > 0
1777 && failure_streak >= self.config.loop_detection_failure_streak
1778 {
1779 warn!(
1780 failure_streak,
1781 "structured loop detection: consecutive failure streak"
1782 );
1783 self.audit(
1784 "loop_detection_failure_streak",
1785 json!({
1786 "request_id": request_id,
1787 "streak": failure_streak,
1788 }),
1789 )
1790 .await;
1791 messages.push(ConversationMessage::user(format!(
1792 "SYSTEM NOTICE: {} consecutive tool calls have failed. \
1793 Stop calling tools and provide your best answer with \
1794 the information you already have.",
1795 failure_streak
1796 )));
1797 }
1798 }
1799
1800 self.audit(
1802 "structured_tool_use_max_iterations",
1803 json!({
1804 "request_id": request_id,
1805 "max_tool_iterations": self.config.max_tool_iterations,
1806 }),
1807 )
1808 .await;
1809
1810 messages.push(ConversationMessage::user(
1811 "You have reached the maximum number of tool call iterations. \
1812 Please provide your best answer now without calling any more tools."
1813 .to_string(),
1814 ));
1815 truncate_messages(&mut messages, self.config.max_prompt_chars);
1816
1817 let final_result = if let Some(ref sink) = stream_sink {
1818 self.provider
1819 .complete_streaming_with_tools(&messages, &[], &self.config.reasoning, sink.clone())
1820 .await
1821 .map_err(|source| AgentError::Provider { source })?
1822 } else {
1823 self.provider
1824 .complete_with_tools(&messages, &[], &self.config.reasoning)
1825 .await
1826 .map_err(|source| AgentError::Provider { source })?
1827 };
1828
1829 let response_text = final_result.output_text;
1830 self.write_to_memory(
1831 "assistant",
1832 &response_text,
1833 request_id,
1834 ctx.source_channel.as_deref(),
1835 ctx.conversation_id.as_deref().unwrap_or(""),
1836 ctx.agent_id.as_deref(),
1837 )
1838 .await?;
1839 self.audit("respond_success", json!({"request_id": request_id}))
1840 .await;
1841 self.hook(
1842 "before_response_emit",
1843 json!({"request_id": request_id, "response_len": response_text.len()}),
1844 )
1845 .await?;
1846 let response = AssistantMessage {
1847 text: response_text,
1848 };
1849 self.hook(
1850 "after_response_emit",
1851 json!({"request_id": request_id, "response_len": response.text.len()}),
1852 )
1853 .await?;
1854 Ok(response)
1855 }
1856
1857 pub async fn respond(
1858 &self,
1859 user: UserMessage,
1860 ctx: &ToolContext,
1861 ) -> Result<AssistantMessage, AgentError> {
1862 let request_id = Self::next_request_id();
1863 let span = info_span!(
1864 "agent_run",
1865 request_id = %request_id,
1866 depth = ctx.depth,
1867 conversation_id = ctx.conversation_id.as_deref().unwrap_or(""),
1868 );
1869 self.respond_traced(&request_id, user, ctx, None)
1870 .instrument(span)
1871 .await
1872 }
1873
1874 pub async fn respond_streaming(
1878 &self,
1879 user: UserMessage,
1880 ctx: &ToolContext,
1881 sink: StreamSink,
1882 ) -> Result<AssistantMessage, AgentError> {
1883 let request_id = Self::next_request_id();
1884 let span = info_span!(
1885 "agent_run",
1886 request_id = %request_id,
1887 depth = ctx.depth,
1888 conversation_id = ctx.conversation_id.as_deref().unwrap_or(""),
1889 streaming = true,
1890 );
1891 self.respond_traced(&request_id, user, ctx, Some(sink))
1892 .instrument(span)
1893 .await
1894 }
1895
1896 async fn respond_traced(
1898 &self,
1899 request_id: &str,
1900 user: UserMessage,
1901 ctx: &ToolContext,
1902 stream_sink: Option<StreamSink>,
1903 ) -> Result<AssistantMessage, AgentError> {
1904 self.increment_counter("requests_total");
1905 let run_started = Instant::now();
1906 self.hook("before_run", json!({"request_id": request_id}))
1907 .await?;
1908 let timed = timeout(
1909 Duration::from_millis(self.config.request_timeout_ms),
1910 self.respond_inner(request_id, user, ctx, stream_sink),
1911 )
1912 .await;
1913 let result = match timed {
1914 Ok(result) => result,
1915 Err(_) => Err(AgentError::Timeout {
1916 timeout_ms: self.config.request_timeout_ms,
1917 }),
1918 };
1919
1920 let after_detail = match &result {
1921 Ok(response) => json!({
1922 "request_id": request_id,
1923 "status": "ok",
1924 "response_len": response.text.len(),
1925 "duration_ms": run_started.elapsed().as_millis(),
1926 }),
1927 Err(err) => json!({
1928 "request_id": request_id,
1929 "status": "error",
1930 "error": redact_text(&err.to_string()),
1931 "duration_ms": run_started.elapsed().as_millis(),
1932 }),
1933 };
1934 let total_cost = ctx.current_cost();
1935 let cost_usd = total_cost as f64 / 1_000_000.0;
1936 info!(
1937 request_id = %request_id,
1938 duration_ms = %run_started.elapsed().as_millis(),
1939 total_tokens = ctx.current_tokens(),
1940 cost_microdollars = total_cost,
1941 cost_usd = format!("{:.4}", cost_usd),
1942 "agent run completed"
1943 );
1944 self.hook("after_run", after_detail).await?;
1945 result
1946 }
1947
1948 async fn respond_inner(
1949 &self,
1950 request_id: &str,
1951 user: UserMessage,
1952 ctx: &ToolContext,
1953 stream_sink: Option<StreamSink>,
1954 ) -> Result<AssistantMessage, AgentError> {
1955 self.audit(
1956 "respond_start",
1957 json!({
1958 "request_id": request_id,
1959 "user_message_len": user.text.len(),
1960 "max_tool_iterations": self.config.max_tool_iterations,
1961 "request_timeout_ms": self.config.request_timeout_ms,
1962 }),
1963 )
1964 .await;
1965 self.write_to_memory(
1966 "user",
1967 &user.text,
1968 request_id,
1969 ctx.source_channel.as_deref(),
1970 ctx.conversation_id.as_deref().unwrap_or(""),
1971 ctx.agent_id.as_deref(),
1972 )
1973 .await?;
1974
1975 let research_context = if self.should_research(&user.text) {
1976 self.run_research_phase(&user.text, ctx, request_id).await?
1977 } else {
1978 String::new()
1979 };
1980
1981 if self.config.model_supports_tool_use && self.has_tool_definitions() {
1983 return self
1984 .respond_with_tools(request_id, &user.text, &research_context, ctx, stream_sink)
1985 .await;
1986 }
1987
1988 let mut prompt = user.text;
1989 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1990
1991 for iteration in 0..self.config.max_tool_iterations {
1992 if !prompt.starts_with("tool:") {
1993 break;
1994 }
1995
1996 let calls = parse_tool_calls(&prompt);
1997 if calls.is_empty() {
1998 break;
1999 }
2000
2001 let has_gated = calls
2005 .iter()
2006 .any(|(name, _)| self.config.gated_tools.contains(*name));
2007 if calls.len() > 1 && self.config.parallel_tools && !has_gated {
2008 let mut resolved: Vec<(&str, &str, &dyn Tool)> = Vec::new();
2009 for &(name, input) in &calls {
2010 let tool = self.tools.iter().find(|t| t.name() == name);
2011 match tool {
2012 Some(t) => resolved.push((name, input, &**t)),
2013 None => {
2014 self.audit(
2015 "tool_not_found",
2016 json!({"request_id": request_id, "iteration": iteration, "tool_name": name}),
2017 )
2018 .await;
2019 }
2020 }
2021 }
2022 if resolved.is_empty() {
2023 break;
2024 }
2025
2026 let futs: Vec<_> = resolved
2027 .iter()
2028 .map(|&(name, input, tool)| async move {
2029 let r = tool.execute(input, ctx).await;
2030 (name, input, r)
2031 })
2032 .collect();
2033 let results = futures_util::future::join_all(futs).await;
2034
2035 let mut output_parts: Vec<String> = Vec::new();
2036 for (name, input, result) in results {
2037 match result {
2038 Ok(r) => {
2039 tool_history.push((
2040 name.to_string(),
2041 input.to_string(),
2042 r.output.clone(),
2043 ));
2044 output_parts.push(format!("Tool output from {name}: {}", r.output));
2045 }
2046 Err(source) => {
2047 return Err(AgentError::Tool {
2048 tool: name.to_string(),
2049 source,
2050 });
2051 }
2052 }
2053 }
2054 prompt = output_parts.join("\n");
2055 continue;
2056 }
2057
2058 let (tool_name, tool_input) = calls[0];
2060 self.audit(
2061 "tool_requested",
2062 json!({
2063 "request_id": request_id,
2064 "iteration": iteration,
2065 "tool_name": tool_name,
2066 "tool_input_len": tool_input.len(),
2067 }),
2068 )
2069 .await;
2070
2071 if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) {
2072 let result = self
2073 .execute_tool(&**tool, tool_name, tool_input, ctx, request_id, iteration)
2074 .await?;
2075
2076 tool_history.push((
2077 tool_name.to_string(),
2078 tool_input.to_string(),
2079 result.output.clone(),
2080 ));
2081
2082 if self.config.loop_detection_no_progress_threshold > 0 {
2084 let threshold = self.config.loop_detection_no_progress_threshold;
2085 if tool_history.len() >= threshold {
2086 let recent = &tool_history[tool_history.len() - threshold..];
2087 let all_same = recent.iter().all(|entry| {
2088 entry.0 == recent[0].0
2089 && entry.1 == recent[0].1
2090 && entry.2 == recent[0].2
2091 });
2092 if all_same {
2093 warn!(
2094 tool_name,
2095 threshold, "loop detection: no-progress threshold reached"
2096 );
2097 self.audit(
2098 "loop_detection_no_progress",
2099 json!({
2100 "request_id": request_id,
2101 "tool_name": tool_name,
2102 "threshold": threshold,
2103 }),
2104 )
2105 .await;
2106 prompt = format!(
2107 "SYSTEM NOTICE: Loop detected — the tool `{tool_name}` was called \
2108 {threshold} times with identical arguments and output. You are stuck \
2109 in a loop. Stop calling this tool and try a different approach, or \
2110 provide your best answer with the information you already have."
2111 );
2112 break;
2113 }
2114 }
2115 }
2116
2117 if self.config.loop_detection_ping_pong_cycles > 0 {
2119 let cycles = self.config.loop_detection_ping_pong_cycles;
2120 let needed = cycles * 2;
2121 if tool_history.len() >= needed {
2122 let recent = &tool_history[tool_history.len() - needed..];
2123 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
2124 && recent[0].0 != recent[1].0;
2125 if is_ping_pong {
2126 warn!(
2127 tool_a = recent[0].0,
2128 tool_b = recent[1].0,
2129 cycles,
2130 "loop detection: ping-pong pattern detected"
2131 );
2132 self.audit(
2133 "loop_detection_ping_pong",
2134 json!({
2135 "request_id": request_id,
2136 "tool_a": recent[0].0,
2137 "tool_b": recent[1].0,
2138 "cycles": cycles,
2139 }),
2140 )
2141 .await;
2142 prompt = format!(
2143 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
2144 alternating for {} cycles in a ping-pong pattern. Stop this \
2145 alternation and try a different approach, or provide your best \
2146 answer with the information you already have.",
2147 recent[0].0, recent[1].0, cycles
2148 );
2149 break;
2150 }
2151 }
2152 }
2153
2154 if result.output.starts_with("tool:") {
2155 prompt = result.output.clone();
2156 } else {
2157 prompt = format!("Tool output from {tool_name}: {}", result.output);
2158 }
2159 continue;
2160 }
2161
2162 self.audit(
2163 "tool_not_found",
2164 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
2165 )
2166 .await;
2167 break;
2168 }
2169
2170 if !research_context.is_empty() {
2171 prompt = format!("Research findings:\n{research_context}\n\nUser request:\n{prompt}");
2172 }
2173
2174 let response_text = self
2175 .call_provider_with_context(
2176 &prompt,
2177 request_id,
2178 stream_sink,
2179 ctx.source_channel.as_deref(),
2180 )
2181 .await?;
2182 self.write_to_memory(
2183 "assistant",
2184 &response_text,
2185 request_id,
2186 ctx.source_channel.as_deref(),
2187 ctx.conversation_id.as_deref().unwrap_or(""),
2188 ctx.agent_id.as_deref(),
2189 )
2190 .await?;
2191 self.audit("respond_success", json!({"request_id": request_id}))
2192 .await;
2193
2194 self.hook(
2195 "before_response_emit",
2196 json!({"request_id": request_id, "response_len": response_text.len()}),
2197 )
2198 .await?;
2199 let response = AssistantMessage {
2200 text: response_text,
2201 };
2202 self.hook(
2203 "after_response_emit",
2204 json!({"request_id": request_id, "response_len": response.text.len()}),
2205 )
2206 .await?;
2207
2208 Ok(response)
2209 }
2210}
2211
2212#[cfg(test)]
2213mod tests {
2214 use super::*;
2215 use crate::types::{
2216 ChatResult, HookEvent, HookFailureMode, HookPolicy, ReasoningConfig, ResearchPolicy,
2217 ResearchTrigger, ToolResult,
2218 };
2219 use async_trait::async_trait;
2220 use serde_json::Value;
2221 use std::collections::HashMap;
2222 use std::sync::atomic::{AtomicUsize, Ordering};
2223 use std::sync::{Arc, Mutex};
2224 use tokio::time::{sleep, Duration};
2225
2226 #[derive(Default)]
2227 struct TestMemory {
2228 entries: Arc<Mutex<Vec<MemoryEntry>>>,
2229 }
2230
2231 #[async_trait]
2232 impl MemoryStore for TestMemory {
2233 async fn append(&self, entry: MemoryEntry) -> anyhow::Result<()> {
2234 self.entries
2235 .lock()
2236 .expect("memory lock poisoned")
2237 .push(entry);
2238 Ok(())
2239 }
2240
2241 async fn recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
2242 let entries = self.entries.lock().expect("memory lock poisoned");
2243 Ok(entries.iter().rev().take(limit).cloned().collect())
2244 }
2245 }
2246
2247 struct TestProvider {
2248 received_prompts: Arc<Mutex<Vec<String>>>,
2249 response_text: String,
2250 }
2251
2252 #[async_trait]
2253 impl Provider for TestProvider {
2254 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
2255 self.received_prompts
2256 .lock()
2257 .expect("provider lock poisoned")
2258 .push(prompt.to_string());
2259 Ok(ChatResult {
2260 output_text: self.response_text.clone(),
2261 ..Default::default()
2262 })
2263 }
2264 }
2265
2266 struct EchoTool;
2267
2268 #[async_trait]
2269 impl Tool for EchoTool {
2270 fn name(&self) -> &'static str {
2271 "echo"
2272 }
2273
2274 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2275 Ok(ToolResult {
2276 output: format!("echoed:{input}"),
2277 })
2278 }
2279 }
2280
2281 struct PluginEchoTool;
2282
2283 #[async_trait]
2284 impl Tool for PluginEchoTool {
2285 fn name(&self) -> &'static str {
2286 "plugin:echo"
2287 }
2288
2289 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2290 Ok(ToolResult {
2291 output: format!("plugin-echoed:{input}"),
2292 })
2293 }
2294 }
2295
2296 struct FailingTool;
2297
2298 #[async_trait]
2299 impl Tool for FailingTool {
2300 fn name(&self) -> &'static str {
2301 "boom"
2302 }
2303
2304 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2305 Err(anyhow::anyhow!("tool exploded"))
2306 }
2307 }
2308
2309 struct SlowTool;
2310
2311 #[async_trait]
2312 impl Tool for SlowTool {
2313 fn name(&self) -> &'static str {
2314 "slow"
2315 }
2316
2317 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2318 sleep(Duration::from_millis(500)).await;
2319 Ok(ToolResult {
2320 output: "finally done".to_string(),
2321 })
2322 }
2323 }
2324
2325 struct FailingProvider;
2326
2327 #[async_trait]
2328 impl Provider for FailingProvider {
2329 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
2330 Err(anyhow::anyhow!("provider boom"))
2331 }
2332 }
2333
2334 struct SlowProvider;
2335
2336 #[async_trait]
2337 impl Provider for SlowProvider {
2338 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
2339 sleep(Duration::from_millis(500)).await;
2340 Ok(ChatResult {
2341 output_text: "late".to_string(),
2342 ..Default::default()
2343 })
2344 }
2345 }
2346
2347 struct ScriptedProvider {
2348 responses: Vec<String>,
2349 call_count: AtomicUsize,
2350 received_prompts: Arc<Mutex<Vec<String>>>,
2351 }
2352
2353 impl ScriptedProvider {
2354 fn new(responses: Vec<&str>) -> Self {
2355 Self {
2356 responses: responses.into_iter().map(|s| s.to_string()).collect(),
2357 call_count: AtomicUsize::new(0),
2358 received_prompts: Arc::new(Mutex::new(Vec::new())),
2359 }
2360 }
2361 }
2362
2363 #[async_trait]
2364 impl Provider for ScriptedProvider {
2365 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
2366 self.received_prompts
2367 .lock()
2368 .expect("provider lock poisoned")
2369 .push(prompt.to_string());
2370 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
2371 let response = if idx < self.responses.len() {
2372 self.responses[idx].clone()
2373 } else {
2374 self.responses.last().cloned().unwrap_or_default()
2375 };
2376 Ok(ChatResult {
2377 output_text: response,
2378 ..Default::default()
2379 })
2380 }
2381 }
2382
2383 struct UpperTool;
2384
2385 #[async_trait]
2386 impl Tool for UpperTool {
2387 fn name(&self) -> &'static str {
2388 "upper"
2389 }
2390
2391 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2392 Ok(ToolResult {
2393 output: input.to_uppercase(),
2394 })
2395 }
2396 }
2397
2398 struct LoopTool;
2401
2402 #[async_trait]
2403 impl Tool for LoopTool {
2404 fn name(&self) -> &'static str {
2405 "loop_tool"
2406 }
2407
2408 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2409 Ok(ToolResult {
2410 output: "tool:loop_tool x".to_string(),
2411 })
2412 }
2413 }
2414
2415 struct PingTool;
2417
2418 #[async_trait]
2419 impl Tool for PingTool {
2420 fn name(&self) -> &'static str {
2421 "ping"
2422 }
2423
2424 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2425 Ok(ToolResult {
2426 output: "tool:pong x".to_string(),
2427 })
2428 }
2429 }
2430
2431 struct PongTool;
2433
2434 #[async_trait]
2435 impl Tool for PongTool {
2436 fn name(&self) -> &'static str {
2437 "pong"
2438 }
2439
2440 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2441 Ok(ToolResult {
2442 output: "tool:ping x".to_string(),
2443 })
2444 }
2445 }
2446
2447 struct RecordingHookSink {
2448 events: Arc<Mutex<Vec<String>>>,
2449 }
2450
2451 #[async_trait]
2452 impl HookSink for RecordingHookSink {
2453 async fn record(&self, event: HookEvent) -> anyhow::Result<()> {
2454 self.events
2455 .lock()
2456 .expect("hook lock poisoned")
2457 .push(event.stage);
2458 Ok(())
2459 }
2460 }
2461
2462 struct FailingHookSink;
2463
2464 #[async_trait]
2465 impl HookSink for FailingHookSink {
2466 async fn record(&self, _event: HookEvent) -> anyhow::Result<()> {
2467 Err(anyhow::anyhow!("hook sink failure"))
2468 }
2469 }
2470
2471 struct RecordingAuditSink {
2472 events: Arc<Mutex<Vec<AuditEvent>>>,
2473 }
2474
2475 #[async_trait]
2476 impl AuditSink for RecordingAuditSink {
2477 async fn record(&self, event: AuditEvent) -> anyhow::Result<()> {
2478 self.events.lock().expect("audit lock poisoned").push(event);
2479 Ok(())
2480 }
2481 }
2482
2483 struct RecordingMetricsSink {
2484 counters: Arc<Mutex<HashMap<&'static str, u64>>>,
2485 histograms: Arc<Mutex<HashMap<&'static str, usize>>>,
2486 }
2487
2488 impl MetricsSink for RecordingMetricsSink {
2489 fn increment_counter(&self, name: &'static str, value: u64) {
2490 let mut counters = self.counters.lock().expect("metrics lock poisoned");
2491 *counters.entry(name).or_insert(0) += value;
2492 }
2493
2494 fn observe_histogram(&self, name: &'static str, _value: f64) {
2495 let mut histograms = self.histograms.lock().expect("metrics lock poisoned");
2496 *histograms.entry(name).or_insert(0) += 1;
2497 }
2498 }
2499
2500 fn test_ctx() -> ToolContext {
2501 ToolContext::new(".".to_string())
2502 }
2503
2504 fn counter(counters: &Arc<Mutex<HashMap<&'static str, u64>>>, name: &'static str) -> u64 {
2505 counters
2506 .lock()
2507 .expect("metrics lock poisoned")
2508 .get(name)
2509 .copied()
2510 .unwrap_or(0)
2511 }
2512
2513 fn histogram_count(
2514 histograms: &Arc<Mutex<HashMap<&'static str, usize>>>,
2515 name: &'static str,
2516 ) -> usize {
2517 histograms
2518 .lock()
2519 .expect("metrics lock poisoned")
2520 .get(name)
2521 .copied()
2522 .unwrap_or(0)
2523 }
2524
2525 #[tokio::test]
2526 async fn respond_appends_user_then_assistant_memory_entries() {
2527 let entries = Arc::new(Mutex::new(Vec::new()));
2528 let prompts = Arc::new(Mutex::new(Vec::new()));
2529 let memory = TestMemory {
2530 entries: entries.clone(),
2531 };
2532 let provider = TestProvider {
2533 received_prompts: prompts,
2534 response_text: "assistant-output".to_string(),
2535 };
2536 let agent = Agent::new(
2537 AgentConfig::default(),
2538 Box::new(provider),
2539 Box::new(memory),
2540 vec![],
2541 );
2542
2543 let response = agent
2544 .respond(
2545 UserMessage {
2546 text: "hello world".to_string(),
2547 },
2548 &test_ctx(),
2549 )
2550 .await
2551 .expect("agent respond should succeed");
2552
2553 assert_eq!(response.text, "assistant-output");
2554 let stored = entries.lock().expect("memory lock poisoned");
2555 assert_eq!(stored.len(), 2);
2556 assert_eq!(stored[0].role, "user");
2557 assert_eq!(stored[0].content, "hello world");
2558 assert_eq!(stored[1].role, "assistant");
2559 assert_eq!(stored[1].content, "assistant-output");
2560 }
2561
2562 #[tokio::test]
2563 async fn respond_invokes_tool_for_tool_prefixed_prompt() {
2564 let entries = Arc::new(Mutex::new(Vec::new()));
2565 let prompts = Arc::new(Mutex::new(Vec::new()));
2566 let memory = TestMemory {
2567 entries: entries.clone(),
2568 };
2569 let provider = TestProvider {
2570 received_prompts: prompts.clone(),
2571 response_text: "assistant-after-tool".to_string(),
2572 };
2573 let agent = Agent::new(
2574 AgentConfig::default(),
2575 Box::new(provider),
2576 Box::new(memory),
2577 vec![Box::new(EchoTool)],
2578 );
2579
2580 let response = agent
2581 .respond(
2582 UserMessage {
2583 text: "tool:echo ping".to_string(),
2584 },
2585 &test_ctx(),
2586 )
2587 .await
2588 .expect("agent respond should succeed");
2589
2590 assert_eq!(response.text, "assistant-after-tool");
2591 let prompts = prompts.lock().expect("provider lock poisoned");
2592 assert_eq!(prompts.len(), 1);
2593 assert!(prompts[0].contains("Recent conversation:"));
2594 assert!(prompts[0].contains("Current input:\nTool output from echo: echoed:ping"));
2595 }
2596
2597 #[tokio::test]
2598 async fn respond_with_unknown_tool_falls_back_to_provider_prompt() {
2599 let entries = Arc::new(Mutex::new(Vec::new()));
2600 let prompts = Arc::new(Mutex::new(Vec::new()));
2601 let memory = TestMemory {
2602 entries: entries.clone(),
2603 };
2604 let provider = TestProvider {
2605 received_prompts: prompts.clone(),
2606 response_text: "assistant-without-tool".to_string(),
2607 };
2608 let agent = Agent::new(
2609 AgentConfig::default(),
2610 Box::new(provider),
2611 Box::new(memory),
2612 vec![Box::new(EchoTool)],
2613 );
2614
2615 let response = agent
2616 .respond(
2617 UserMessage {
2618 text: "tool:unknown payload".to_string(),
2619 },
2620 &test_ctx(),
2621 )
2622 .await
2623 .expect("agent respond should succeed");
2624
2625 assert_eq!(response.text, "assistant-without-tool");
2626 let prompts = prompts.lock().expect("provider lock poisoned");
2627 assert_eq!(prompts.len(), 1);
2628 assert!(prompts[0].contains("Recent conversation:"));
2629 assert!(prompts[0].contains("Current input:\ntool:unknown payload"));
2630 }
2631
2632 #[tokio::test]
2633 async fn respond_includes_bounded_recent_memory_in_provider_prompt() {
2634 let entries = Arc::new(Mutex::new(vec![
2635 MemoryEntry {
2636 role: "assistant".to_string(),
2637 content: "very-old".to_string(),
2638 ..Default::default()
2639 },
2640 MemoryEntry {
2641 role: "user".to_string(),
2642 content: "recent-before-request".to_string(),
2643 ..Default::default()
2644 },
2645 ]));
2646 let prompts = Arc::new(Mutex::new(Vec::new()));
2647 let memory = TestMemory {
2648 entries: entries.clone(),
2649 };
2650 let provider = TestProvider {
2651 received_prompts: prompts.clone(),
2652 response_text: "ok".to_string(),
2653 };
2654 let agent = Agent::new(
2655 AgentConfig {
2656 memory_window_size: 2,
2657 ..AgentConfig::default()
2658 },
2659 Box::new(provider),
2660 Box::new(memory),
2661 vec![],
2662 );
2663
2664 agent
2665 .respond(
2666 UserMessage {
2667 text: "latest-user".to_string(),
2668 },
2669 &test_ctx(),
2670 )
2671 .await
2672 .expect("respond should succeed");
2673
2674 let prompts = prompts.lock().expect("provider lock poisoned");
2675 let provider_prompt = prompts.first().expect("provider prompt should exist");
2676 assert!(provider_prompt.contains("- user: recent-before-request"));
2677 assert!(provider_prompt.contains("- user: latest-user"));
2678 assert!(!provider_prompt.contains("very-old"));
2679 }
2680
2681 #[tokio::test]
2682 async fn respond_caps_provider_prompt_to_configured_max_chars() {
2683 let entries = Arc::new(Mutex::new(vec![MemoryEntry {
2684 role: "assistant".to_string(),
2685 content: "historic context ".repeat(16),
2686 ..Default::default()
2687 }]));
2688 let prompts = Arc::new(Mutex::new(Vec::new()));
2689 let memory = TestMemory {
2690 entries: entries.clone(),
2691 };
2692 let provider = TestProvider {
2693 received_prompts: prompts.clone(),
2694 response_text: "ok".to_string(),
2695 };
2696 let agent = Agent::new(
2697 AgentConfig {
2698 max_prompt_chars: 64,
2699 ..AgentConfig::default()
2700 },
2701 Box::new(provider),
2702 Box::new(memory),
2703 vec![],
2704 );
2705
2706 agent
2707 .respond(
2708 UserMessage {
2709 text: "final-tail-marker".to_string(),
2710 },
2711 &test_ctx(),
2712 )
2713 .await
2714 .expect("respond should succeed");
2715
2716 let prompts = prompts.lock().expect("provider lock poisoned");
2717 let provider_prompt = prompts.first().expect("provider prompt should exist");
2718 assert!(provider_prompt.chars().count() <= 64);
2719 assert!(provider_prompt.contains("final-tail-marker"));
2720 }
2721
2722 #[tokio::test]
2723 async fn respond_returns_provider_typed_error() {
2724 let memory = TestMemory::default();
2725 let counters = Arc::new(Mutex::new(HashMap::new()));
2726 let histograms = Arc::new(Mutex::new(HashMap::new()));
2727 let agent = Agent::new(
2728 AgentConfig::default(),
2729 Box::new(FailingProvider),
2730 Box::new(memory),
2731 vec![],
2732 )
2733 .with_metrics(Box::new(RecordingMetricsSink {
2734 counters: counters.clone(),
2735 histograms: histograms.clone(),
2736 }));
2737
2738 let result = agent
2739 .respond(
2740 UserMessage {
2741 text: "hello".to_string(),
2742 },
2743 &test_ctx(),
2744 )
2745 .await;
2746
2747 match result {
2748 Err(AgentError::Provider { source }) => {
2749 assert!(source.to_string().contains("provider boom"));
2750 }
2751 other => panic!("expected provider error, got {other:?}"),
2752 }
2753 assert_eq!(counter(&counters, "requests_total"), 1);
2754 assert_eq!(counter(&counters, "provider_errors_total"), 1);
2755 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2756 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2757 }
2758
2759 #[tokio::test]
2760 async fn respond_increments_tool_error_counter_on_tool_failure() {
2761 let memory = TestMemory::default();
2762 let prompts = Arc::new(Mutex::new(Vec::new()));
2763 let provider = TestProvider {
2764 received_prompts: prompts,
2765 response_text: "unused".to_string(),
2766 };
2767 let counters = Arc::new(Mutex::new(HashMap::new()));
2768 let histograms = Arc::new(Mutex::new(HashMap::new()));
2769 let agent = Agent::new(
2770 AgentConfig::default(),
2771 Box::new(provider),
2772 Box::new(memory),
2773 vec![Box::new(FailingTool)],
2774 )
2775 .with_metrics(Box::new(RecordingMetricsSink {
2776 counters: counters.clone(),
2777 histograms: histograms.clone(),
2778 }));
2779
2780 let result = agent
2781 .respond(
2782 UserMessage {
2783 text: "tool:boom ping".to_string(),
2784 },
2785 &test_ctx(),
2786 )
2787 .await;
2788
2789 match result {
2790 Err(AgentError::Tool { tool, source }) => {
2791 assert_eq!(tool, "boom");
2792 assert!(source.to_string().contains("tool exploded"));
2793 }
2794 other => panic!("expected tool error, got {other:?}"),
2795 }
2796 assert_eq!(counter(&counters, "requests_total"), 1);
2797 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2798 assert_eq!(counter(&counters, "tool_errors_total"), 1);
2799 assert_eq!(histogram_count(&histograms, "tool_latency_ms"), 1);
2800 }
2801
2802 #[tokio::test]
2803 async fn respond_increments_requests_counter_on_success() {
2804 let memory = TestMemory::default();
2805 let prompts = Arc::new(Mutex::new(Vec::new()));
2806 let provider = TestProvider {
2807 received_prompts: prompts,
2808 response_text: "ok".to_string(),
2809 };
2810 let counters = Arc::new(Mutex::new(HashMap::new()));
2811 let histograms = Arc::new(Mutex::new(HashMap::new()));
2812 let agent = Agent::new(
2813 AgentConfig::default(),
2814 Box::new(provider),
2815 Box::new(memory),
2816 vec![],
2817 )
2818 .with_metrics(Box::new(RecordingMetricsSink {
2819 counters: counters.clone(),
2820 histograms: histograms.clone(),
2821 }));
2822
2823 let response = agent
2824 .respond(
2825 UserMessage {
2826 text: "hello".to_string(),
2827 },
2828 &test_ctx(),
2829 )
2830 .await
2831 .expect("response should succeed");
2832
2833 assert_eq!(response.text, "ok");
2834 assert_eq!(counter(&counters, "requests_total"), 1);
2835 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2836 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2837 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2838 }
2839
2840 #[tokio::test]
2841 async fn respond_returns_timeout_typed_error() {
2842 let memory = TestMemory::default();
2843 let agent = Agent::new(
2844 AgentConfig {
2845 max_tool_iterations: 1,
2846 request_timeout_ms: 10,
2847 ..AgentConfig::default()
2848 },
2849 Box::new(SlowProvider),
2850 Box::new(memory),
2851 vec![],
2852 );
2853
2854 let result = agent
2855 .respond(
2856 UserMessage {
2857 text: "hello".to_string(),
2858 },
2859 &test_ctx(),
2860 )
2861 .await;
2862
2863 match result {
2864 Err(AgentError::Timeout { timeout_ms }) => assert_eq!(timeout_ms, 10),
2865 other => panic!("expected timeout error, got {other:?}"),
2866 }
2867 }
2868
2869 #[tokio::test]
2870 async fn respond_emits_before_after_hook_events_when_enabled() {
2871 let events = Arc::new(Mutex::new(Vec::new()));
2872 let memory = TestMemory::default();
2873 let provider = TestProvider {
2874 received_prompts: Arc::new(Mutex::new(Vec::new())),
2875 response_text: "ok".to_string(),
2876 };
2877 let agent = Agent::new(
2878 AgentConfig {
2879 hooks: HookPolicy {
2880 enabled: true,
2881 timeout_ms: 50,
2882 fail_closed: true,
2883 ..HookPolicy::default()
2884 },
2885 ..AgentConfig::default()
2886 },
2887 Box::new(provider),
2888 Box::new(memory),
2889 vec![Box::new(EchoTool), Box::new(PluginEchoTool)],
2890 )
2891 .with_hooks(Box::new(RecordingHookSink {
2892 events: events.clone(),
2893 }));
2894
2895 agent
2896 .respond(
2897 UserMessage {
2898 text: "tool:plugin:echo ping".to_string(),
2899 },
2900 &test_ctx(),
2901 )
2902 .await
2903 .expect("respond should succeed");
2904
2905 let stages = events.lock().expect("hook lock poisoned");
2906 assert!(stages.contains(&"before_run".to_string()));
2907 assert!(stages.contains(&"after_run".to_string()));
2908 assert!(stages.contains(&"before_tool_call".to_string()));
2909 assert!(stages.contains(&"after_tool_call".to_string()));
2910 assert!(stages.contains(&"before_plugin_call".to_string()));
2911 assert!(stages.contains(&"after_plugin_call".to_string()));
2912 assert!(stages.contains(&"before_provider_call".to_string()));
2913 assert!(stages.contains(&"after_provider_call".to_string()));
2914 assert!(stages.contains(&"before_memory_write".to_string()));
2915 assert!(stages.contains(&"after_memory_write".to_string()));
2916 assert!(stages.contains(&"before_response_emit".to_string()));
2917 assert!(stages.contains(&"after_response_emit".to_string()));
2918 }
2919
2920 #[tokio::test]
2921 async fn respond_audit_events_include_request_id_and_durations() {
2922 let audit_events = Arc::new(Mutex::new(Vec::new()));
2923 let memory = TestMemory::default();
2924 let provider = TestProvider {
2925 received_prompts: Arc::new(Mutex::new(Vec::new())),
2926 response_text: "ok".to_string(),
2927 };
2928 let agent = Agent::new(
2929 AgentConfig::default(),
2930 Box::new(provider),
2931 Box::new(memory),
2932 vec![Box::new(EchoTool)],
2933 )
2934 .with_audit(Box::new(RecordingAuditSink {
2935 events: audit_events.clone(),
2936 }));
2937
2938 agent
2939 .respond(
2940 UserMessage {
2941 text: "tool:echo ping".to_string(),
2942 },
2943 &test_ctx(),
2944 )
2945 .await
2946 .expect("respond should succeed");
2947
2948 let events = audit_events.lock().expect("audit lock poisoned");
2949 let request_id = events
2950 .iter()
2951 .find_map(|e| {
2952 e.detail
2953 .get("request_id")
2954 .and_then(Value::as_str)
2955 .map(ToString::to_string)
2956 })
2957 .expect("request_id should exist on audit events");
2958
2959 let provider_event = events
2960 .iter()
2961 .find(|e| e.stage == "provider_call_success")
2962 .expect("provider success event should exist");
2963 assert_eq!(
2964 provider_event
2965 .detail
2966 .get("request_id")
2967 .and_then(Value::as_str),
2968 Some(request_id.as_str())
2969 );
2970 assert!(provider_event
2971 .detail
2972 .get("duration_ms")
2973 .and_then(Value::as_u64)
2974 .is_some());
2975
2976 let tool_event = events
2977 .iter()
2978 .find(|e| e.stage == "tool_execute_success")
2979 .expect("tool success event should exist");
2980 assert_eq!(
2981 tool_event.detail.get("request_id").and_then(Value::as_str),
2982 Some(request_id.as_str())
2983 );
2984 assert!(tool_event
2985 .detail
2986 .get("duration_ms")
2987 .and_then(Value::as_u64)
2988 .is_some());
2989 }
2990
2991 #[tokio::test]
2992 async fn hook_errors_respect_block_mode_for_high_tier_negative_path() {
2993 let memory = TestMemory::default();
2994 let provider = TestProvider {
2995 received_prompts: Arc::new(Mutex::new(Vec::new())),
2996 response_text: "ok".to_string(),
2997 };
2998 let agent = Agent::new(
2999 AgentConfig {
3000 hooks: HookPolicy {
3001 enabled: true,
3002 timeout_ms: 50,
3003 fail_closed: false,
3004 default_mode: HookFailureMode::Warn,
3005 low_tier_mode: HookFailureMode::Ignore,
3006 medium_tier_mode: HookFailureMode::Warn,
3007 high_tier_mode: HookFailureMode::Block,
3008 },
3009 ..AgentConfig::default()
3010 },
3011 Box::new(provider),
3012 Box::new(memory),
3013 vec![Box::new(EchoTool)],
3014 )
3015 .with_hooks(Box::new(FailingHookSink));
3016
3017 let result = agent
3018 .respond(
3019 UserMessage {
3020 text: "tool:echo ping".to_string(),
3021 },
3022 &test_ctx(),
3023 )
3024 .await;
3025
3026 match result {
3027 Err(AgentError::Hook { stage, .. }) => assert_eq!(stage, "before_tool_call"),
3028 other => panic!("expected hook block error, got {other:?}"),
3029 }
3030 }
3031
3032 #[tokio::test]
3033 async fn hook_errors_respect_warn_mode_for_high_tier_success_path() {
3034 let audit_events = Arc::new(Mutex::new(Vec::new()));
3035 let memory = TestMemory::default();
3036 let provider = TestProvider {
3037 received_prompts: Arc::new(Mutex::new(Vec::new())),
3038 response_text: "ok".to_string(),
3039 };
3040 let agent = Agent::new(
3041 AgentConfig {
3042 hooks: HookPolicy {
3043 enabled: true,
3044 timeout_ms: 50,
3045 fail_closed: false,
3046 default_mode: HookFailureMode::Warn,
3047 low_tier_mode: HookFailureMode::Ignore,
3048 medium_tier_mode: HookFailureMode::Warn,
3049 high_tier_mode: HookFailureMode::Warn,
3050 },
3051 ..AgentConfig::default()
3052 },
3053 Box::new(provider),
3054 Box::new(memory),
3055 vec![Box::new(EchoTool)],
3056 )
3057 .with_hooks(Box::new(FailingHookSink))
3058 .with_audit(Box::new(RecordingAuditSink {
3059 events: audit_events.clone(),
3060 }));
3061
3062 let response = agent
3063 .respond(
3064 UserMessage {
3065 text: "tool:echo ping".to_string(),
3066 },
3067 &test_ctx(),
3068 )
3069 .await
3070 .expect("warn mode should continue");
3071 assert!(!response.text.is_empty());
3072
3073 let events = audit_events.lock().expect("audit lock poisoned");
3074 assert!(events.iter().any(|event| event.stage == "hook_error_warn"));
3075 }
3076
3077 #[tokio::test]
3080 async fn self_correction_injected_on_no_progress() {
3081 let prompts = Arc::new(Mutex::new(Vec::new()));
3082 let provider = TestProvider {
3083 received_prompts: prompts.clone(),
3084 response_text: "I'll try a different approach".to_string(),
3085 };
3086 let agent = Agent::new(
3089 AgentConfig {
3090 loop_detection_no_progress_threshold: 3,
3091 max_tool_iterations: 20,
3092 ..AgentConfig::default()
3093 },
3094 Box::new(provider),
3095 Box::new(TestMemory::default()),
3096 vec![Box::new(LoopTool)],
3097 );
3098
3099 let response = agent
3100 .respond(
3101 UserMessage {
3102 text: "tool:loop_tool x".to_string(),
3103 },
3104 &test_ctx(),
3105 )
3106 .await
3107 .expect("self-correction should succeed");
3108
3109 assert_eq!(response.text, "I'll try a different approach");
3110
3111 let received = prompts.lock().expect("provider lock poisoned");
3113 let last_prompt = received.last().expect("provider should have been called");
3114 assert!(
3115 last_prompt.contains("Loop detected"),
3116 "provider prompt should contain self-correction notice, got: {last_prompt}"
3117 );
3118 assert!(last_prompt.contains("loop_tool"));
3119 }
3120
3121 #[tokio::test]
3122 async fn self_correction_injected_on_ping_pong() {
3123 let prompts = Arc::new(Mutex::new(Vec::new()));
3124 let provider = TestProvider {
3125 received_prompts: prompts.clone(),
3126 response_text: "I stopped the loop".to_string(),
3127 };
3128 let agent = Agent::new(
3130 AgentConfig {
3131 loop_detection_ping_pong_cycles: 2,
3132 loop_detection_no_progress_threshold: 0, max_tool_iterations: 20,
3134 ..AgentConfig::default()
3135 },
3136 Box::new(provider),
3137 Box::new(TestMemory::default()),
3138 vec![Box::new(PingTool), Box::new(PongTool)],
3139 );
3140
3141 let response = agent
3142 .respond(
3143 UserMessage {
3144 text: "tool:ping x".to_string(),
3145 },
3146 &test_ctx(),
3147 )
3148 .await
3149 .expect("self-correction should succeed");
3150
3151 assert_eq!(response.text, "I stopped the loop");
3152
3153 let received = prompts.lock().expect("provider lock poisoned");
3154 let last_prompt = received.last().expect("provider should have been called");
3155 assert!(
3156 last_prompt.contains("ping-pong"),
3157 "provider prompt should contain ping-pong notice, got: {last_prompt}"
3158 );
3159 }
3160
3161 #[tokio::test]
3162 async fn no_progress_disabled_when_threshold_zero() {
3163 let prompts = Arc::new(Mutex::new(Vec::new()));
3164 let provider = TestProvider {
3165 received_prompts: prompts.clone(),
3166 response_text: "final answer".to_string(),
3167 };
3168 let agent = Agent::new(
3171 AgentConfig {
3172 loop_detection_no_progress_threshold: 0,
3173 loop_detection_ping_pong_cycles: 0,
3174 max_tool_iterations: 5,
3175 ..AgentConfig::default()
3176 },
3177 Box::new(provider),
3178 Box::new(TestMemory::default()),
3179 vec![Box::new(LoopTool)],
3180 );
3181
3182 let response = agent
3183 .respond(
3184 UserMessage {
3185 text: "tool:loop_tool x".to_string(),
3186 },
3187 &test_ctx(),
3188 )
3189 .await
3190 .expect("should complete without error");
3191
3192 assert_eq!(response.text, "final answer");
3193
3194 let received = prompts.lock().expect("provider lock poisoned");
3196 let last_prompt = received.last().expect("provider should have been called");
3197 assert!(
3198 !last_prompt.contains("Loop detected"),
3199 "should not contain self-correction when detection is disabled"
3200 );
3201 }
3202
3203 #[tokio::test]
3206 async fn parallel_tools_executes_multiple_calls() {
3207 let prompts = Arc::new(Mutex::new(Vec::new()));
3208 let provider = TestProvider {
3209 received_prompts: prompts.clone(),
3210 response_text: "parallel done".to_string(),
3211 };
3212 let agent = Agent::new(
3213 AgentConfig {
3214 parallel_tools: true,
3215 max_tool_iterations: 5,
3216 ..AgentConfig::default()
3217 },
3218 Box::new(provider),
3219 Box::new(TestMemory::default()),
3220 vec![Box::new(EchoTool), Box::new(UpperTool)],
3221 );
3222
3223 let response = agent
3224 .respond(
3225 UserMessage {
3226 text: "tool:echo hello\ntool:upper world".to_string(),
3227 },
3228 &test_ctx(),
3229 )
3230 .await
3231 .expect("parallel tools should succeed");
3232
3233 assert_eq!(response.text, "parallel done");
3234
3235 let received = prompts.lock().expect("provider lock poisoned");
3236 let last_prompt = received.last().expect("provider should have been called");
3237 assert!(
3238 last_prompt.contains("echoed:hello"),
3239 "should contain echo output, got: {last_prompt}"
3240 );
3241 assert!(
3242 last_prompt.contains("WORLD"),
3243 "should contain upper output, got: {last_prompt}"
3244 );
3245 }
3246
3247 #[tokio::test]
3248 async fn parallel_tools_maintains_stable_ordering() {
3249 let prompts = Arc::new(Mutex::new(Vec::new()));
3250 let provider = TestProvider {
3251 received_prompts: prompts.clone(),
3252 response_text: "ordered".to_string(),
3253 };
3254 let agent = Agent::new(
3255 AgentConfig {
3256 parallel_tools: true,
3257 max_tool_iterations: 5,
3258 ..AgentConfig::default()
3259 },
3260 Box::new(provider),
3261 Box::new(TestMemory::default()),
3262 vec![Box::new(EchoTool), Box::new(UpperTool)],
3263 );
3264
3265 let response = agent
3267 .respond(
3268 UserMessage {
3269 text: "tool:echo aaa\ntool:upper bbb".to_string(),
3270 },
3271 &test_ctx(),
3272 )
3273 .await
3274 .expect("parallel tools should succeed");
3275
3276 assert_eq!(response.text, "ordered");
3277
3278 let received = prompts.lock().expect("provider lock poisoned");
3279 let last_prompt = received.last().expect("provider should have been called");
3280 let echo_pos = last_prompt
3281 .find("echoed:aaa")
3282 .expect("echo output should exist");
3283 let upper_pos = last_prompt.find("BBB").expect("upper output should exist");
3284 assert!(
3285 echo_pos < upper_pos,
3286 "echo output should come before upper output in the prompt"
3287 );
3288 }
3289
3290 #[tokio::test]
3291 async fn parallel_disabled_runs_first_call_only() {
3292 let prompts = Arc::new(Mutex::new(Vec::new()));
3293 let provider = TestProvider {
3294 received_prompts: prompts.clone(),
3295 response_text: "sequential".to_string(),
3296 };
3297 let agent = Agent::new(
3298 AgentConfig {
3299 parallel_tools: false,
3300 max_tool_iterations: 5,
3301 ..AgentConfig::default()
3302 },
3303 Box::new(provider),
3304 Box::new(TestMemory::default()),
3305 vec![Box::new(EchoTool), Box::new(UpperTool)],
3306 );
3307
3308 let response = agent
3309 .respond(
3310 UserMessage {
3311 text: "tool:echo hello\ntool:upper world".to_string(),
3312 },
3313 &test_ctx(),
3314 )
3315 .await
3316 .expect("sequential tools should succeed");
3317
3318 assert_eq!(response.text, "sequential");
3319
3320 let received = prompts.lock().expect("provider lock poisoned");
3322 let last_prompt = received.last().expect("provider should have been called");
3323 assert!(
3324 last_prompt.contains("echoed:hello"),
3325 "should contain first tool output, got: {last_prompt}"
3326 );
3327 assert!(
3328 !last_prompt.contains("WORLD"),
3329 "should NOT contain second tool output when parallel is disabled"
3330 );
3331 }
3332
3333 #[tokio::test]
3334 async fn single_call_parallel_matches_sequential() {
3335 let prompts = Arc::new(Mutex::new(Vec::new()));
3336 let provider = TestProvider {
3337 received_prompts: prompts.clone(),
3338 response_text: "single".to_string(),
3339 };
3340 let agent = Agent::new(
3341 AgentConfig {
3342 parallel_tools: true,
3343 max_tool_iterations: 5,
3344 ..AgentConfig::default()
3345 },
3346 Box::new(provider),
3347 Box::new(TestMemory::default()),
3348 vec![Box::new(EchoTool)],
3349 );
3350
3351 let response = agent
3352 .respond(
3353 UserMessage {
3354 text: "tool:echo ping".to_string(),
3355 },
3356 &test_ctx(),
3357 )
3358 .await
3359 .expect("single tool with parallel enabled should succeed");
3360
3361 assert_eq!(response.text, "single");
3362
3363 let received = prompts.lock().expect("provider lock poisoned");
3364 let last_prompt = received.last().expect("provider should have been called");
3365 assert!(
3366 last_prompt.contains("echoed:ping"),
3367 "single tool call should work same as sequential, got: {last_prompt}"
3368 );
3369 }
3370
3371 #[tokio::test]
3372 async fn parallel_falls_back_to_sequential_for_gated_tools() {
3373 let prompts = Arc::new(Mutex::new(Vec::new()));
3374 let provider = TestProvider {
3375 received_prompts: prompts.clone(),
3376 response_text: "gated sequential".to_string(),
3377 };
3378 let mut gated = std::collections::HashSet::new();
3379 gated.insert("upper".to_string());
3380 let agent = Agent::new(
3381 AgentConfig {
3382 parallel_tools: true,
3383 gated_tools: gated,
3384 max_tool_iterations: 5,
3385 ..AgentConfig::default()
3386 },
3387 Box::new(provider),
3388 Box::new(TestMemory::default()),
3389 vec![Box::new(EchoTool), Box::new(UpperTool)],
3390 );
3391
3392 let response = agent
3395 .respond(
3396 UserMessage {
3397 text: "tool:echo hello\ntool:upper world".to_string(),
3398 },
3399 &test_ctx(),
3400 )
3401 .await
3402 .expect("gated fallback should succeed");
3403
3404 assert_eq!(response.text, "gated sequential");
3405
3406 let received = prompts.lock().expect("provider lock poisoned");
3407 let last_prompt = received.last().expect("provider should have been called");
3408 assert!(
3409 last_prompt.contains("echoed:hello"),
3410 "first tool should execute, got: {last_prompt}"
3411 );
3412 assert!(
3413 !last_prompt.contains("WORLD"),
3414 "gated tool should NOT execute in parallel, got: {last_prompt}"
3415 );
3416 }
3417
3418 fn research_config(trigger: ResearchTrigger) -> ResearchPolicy {
3421 ResearchPolicy {
3422 enabled: true,
3423 trigger,
3424 keywords: vec!["search".to_string(), "find".to_string()],
3425 min_message_length: 10,
3426 max_iterations: 5,
3427 show_progress: true,
3428 }
3429 }
3430
3431 fn config_with_research(trigger: ResearchTrigger) -> AgentConfig {
3432 AgentConfig {
3433 research: research_config(trigger),
3434 ..Default::default()
3435 }
3436 }
3437
3438 #[tokio::test]
3439 async fn research_trigger_never() {
3440 let prompts = Arc::new(Mutex::new(Vec::new()));
3441 let provider = TestProvider {
3442 received_prompts: prompts.clone(),
3443 response_text: "final answer".to_string(),
3444 };
3445 let agent = Agent::new(
3446 config_with_research(ResearchTrigger::Never),
3447 Box::new(provider),
3448 Box::new(TestMemory::default()),
3449 vec![Box::new(EchoTool)],
3450 );
3451
3452 let response = agent
3453 .respond(
3454 UserMessage {
3455 text: "search for something".to_string(),
3456 },
3457 &test_ctx(),
3458 )
3459 .await
3460 .expect("should succeed");
3461
3462 assert_eq!(response.text, "final answer");
3463 let received = prompts.lock().expect("lock");
3464 assert_eq!(
3465 received.len(),
3466 1,
3467 "should have exactly 1 provider call (no research)"
3468 );
3469 }
3470
3471 #[tokio::test]
3472 async fn research_trigger_always() {
3473 let provider = ScriptedProvider::new(vec![
3474 "tool:echo gathering data",
3475 "Found relevant information",
3476 "Final answer with research context",
3477 ]);
3478 let prompts = provider.received_prompts.clone();
3479 let agent = Agent::new(
3480 config_with_research(ResearchTrigger::Always),
3481 Box::new(provider),
3482 Box::new(TestMemory::default()),
3483 vec![Box::new(EchoTool)],
3484 );
3485
3486 let response = agent
3487 .respond(
3488 UserMessage {
3489 text: "hello world".to_string(),
3490 },
3491 &test_ctx(),
3492 )
3493 .await
3494 .expect("should succeed");
3495
3496 assert_eq!(response.text, "Final answer with research context");
3497 let received = prompts.lock().expect("lock");
3498 let last = received.last().expect("should have provider calls");
3499 assert!(
3500 last.contains("Research findings:"),
3501 "final prompt should contain research findings, got: {last}"
3502 );
3503 }
3504
3505 #[tokio::test]
3506 async fn research_trigger_keywords_match() {
3507 let provider = ScriptedProvider::new(vec!["Summary of search", "Answer based on research"]);
3508 let prompts = provider.received_prompts.clone();
3509 let agent = Agent::new(
3510 config_with_research(ResearchTrigger::Keywords),
3511 Box::new(provider),
3512 Box::new(TestMemory::default()),
3513 vec![],
3514 );
3515
3516 let response = agent
3517 .respond(
3518 UserMessage {
3519 text: "please search for the config".to_string(),
3520 },
3521 &test_ctx(),
3522 )
3523 .await
3524 .expect("should succeed");
3525
3526 assert_eq!(response.text, "Answer based on research");
3527 let received = prompts.lock().expect("lock");
3528 assert!(received.len() >= 2, "should have at least 2 provider calls");
3529 assert!(
3530 received[0].contains("RESEARCH mode"),
3531 "first call should be research prompt"
3532 );
3533 }
3534
3535 #[tokio::test]
3536 async fn research_trigger_keywords_no_match() {
3537 let prompts = Arc::new(Mutex::new(Vec::new()));
3538 let provider = TestProvider {
3539 received_prompts: prompts.clone(),
3540 response_text: "direct answer".to_string(),
3541 };
3542 let agent = Agent::new(
3543 config_with_research(ResearchTrigger::Keywords),
3544 Box::new(provider),
3545 Box::new(TestMemory::default()),
3546 vec![],
3547 );
3548
3549 let response = agent
3550 .respond(
3551 UserMessage {
3552 text: "hello world".to_string(),
3553 },
3554 &test_ctx(),
3555 )
3556 .await
3557 .expect("should succeed");
3558
3559 assert_eq!(response.text, "direct answer");
3560 let received = prompts.lock().expect("lock");
3561 assert_eq!(received.len(), 1, "no research phase should fire");
3562 }
3563
3564 #[tokio::test]
3565 async fn research_trigger_length_short_skips() {
3566 let prompts = Arc::new(Mutex::new(Vec::new()));
3567 let provider = TestProvider {
3568 received_prompts: prompts.clone(),
3569 response_text: "short".to_string(),
3570 };
3571 let config = AgentConfig {
3572 research: ResearchPolicy {
3573 min_message_length: 20,
3574 ..research_config(ResearchTrigger::Length)
3575 },
3576 ..Default::default()
3577 };
3578 let agent = Agent::new(
3579 config,
3580 Box::new(provider),
3581 Box::new(TestMemory::default()),
3582 vec![],
3583 );
3584 agent
3585 .respond(
3586 UserMessage {
3587 text: "hi".to_string(),
3588 },
3589 &test_ctx(),
3590 )
3591 .await
3592 .expect("should succeed");
3593 let received = prompts.lock().expect("lock");
3594 assert_eq!(received.len(), 1, "short message should skip research");
3595 }
3596
3597 #[tokio::test]
3598 async fn research_trigger_length_long_triggers() {
3599 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
3600 let prompts = provider.received_prompts.clone();
3601 let config = AgentConfig {
3602 research: ResearchPolicy {
3603 min_message_length: 20,
3604 ..research_config(ResearchTrigger::Length)
3605 },
3606 ..Default::default()
3607 };
3608 let agent = Agent::new(
3609 config,
3610 Box::new(provider),
3611 Box::new(TestMemory::default()),
3612 vec![],
3613 );
3614 agent
3615 .respond(
3616 UserMessage {
3617 text: "this is a longer message that exceeds the threshold".to_string(),
3618 },
3619 &test_ctx(),
3620 )
3621 .await
3622 .expect("should succeed");
3623 let received = prompts.lock().expect("lock");
3624 assert!(received.len() >= 2, "long message should trigger research");
3625 }
3626
3627 #[tokio::test]
3628 async fn research_trigger_question_with_mark() {
3629 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
3630 let prompts = provider.received_prompts.clone();
3631 let agent = Agent::new(
3632 config_with_research(ResearchTrigger::Question),
3633 Box::new(provider),
3634 Box::new(TestMemory::default()),
3635 vec![],
3636 );
3637 agent
3638 .respond(
3639 UserMessage {
3640 text: "what is the meaning of life?".to_string(),
3641 },
3642 &test_ctx(),
3643 )
3644 .await
3645 .expect("should succeed");
3646 let received = prompts.lock().expect("lock");
3647 assert!(received.len() >= 2, "question should trigger research");
3648 }
3649
3650 #[tokio::test]
3651 async fn research_trigger_question_without_mark() {
3652 let prompts = Arc::new(Mutex::new(Vec::new()));
3653 let provider = TestProvider {
3654 received_prompts: prompts.clone(),
3655 response_text: "no research".to_string(),
3656 };
3657 let agent = Agent::new(
3658 config_with_research(ResearchTrigger::Question),
3659 Box::new(provider),
3660 Box::new(TestMemory::default()),
3661 vec![],
3662 );
3663 agent
3664 .respond(
3665 UserMessage {
3666 text: "do this thing".to_string(),
3667 },
3668 &test_ctx(),
3669 )
3670 .await
3671 .expect("should succeed");
3672 let received = prompts.lock().expect("lock");
3673 assert_eq!(received.len(), 1, "non-question should skip research");
3674 }
3675
3676 #[tokio::test]
3677 async fn research_respects_max_iterations() {
3678 let provider = ScriptedProvider::new(vec![
3679 "tool:echo step1",
3680 "tool:echo step2",
3681 "tool:echo step3",
3682 "tool:echo step4",
3683 "tool:echo step5",
3684 "tool:echo step6",
3685 "tool:echo step7",
3686 "answer after research",
3687 ]);
3688 let prompts = provider.received_prompts.clone();
3689 let config = AgentConfig {
3690 research: ResearchPolicy {
3691 max_iterations: 3,
3692 ..research_config(ResearchTrigger::Always)
3693 },
3694 ..Default::default()
3695 };
3696 let agent = Agent::new(
3697 config,
3698 Box::new(provider),
3699 Box::new(TestMemory::default()),
3700 vec![Box::new(EchoTool)],
3701 );
3702
3703 let response = agent
3704 .respond(
3705 UserMessage {
3706 text: "test".to_string(),
3707 },
3708 &test_ctx(),
3709 )
3710 .await
3711 .expect("should succeed");
3712
3713 assert!(!response.text.is_empty(), "should get a response");
3714 let received = prompts.lock().expect("lock");
3715 let research_calls = received
3716 .iter()
3717 .filter(|p| p.contains("RESEARCH mode") || p.contains("Research iteration"))
3718 .count();
3719 assert_eq!(
3721 research_calls, 4,
3722 "should have 4 research provider calls (1 initial + 3 iterations)"
3723 );
3724 }
3725
3726 #[tokio::test]
3727 async fn research_disabled_skips_phase() {
3728 let prompts = Arc::new(Mutex::new(Vec::new()));
3729 let provider = TestProvider {
3730 received_prompts: prompts.clone(),
3731 response_text: "direct answer".to_string(),
3732 };
3733 let config = AgentConfig {
3734 research: ResearchPolicy {
3735 enabled: false,
3736 trigger: ResearchTrigger::Always,
3737 ..Default::default()
3738 },
3739 ..Default::default()
3740 };
3741 let agent = Agent::new(
3742 config,
3743 Box::new(provider),
3744 Box::new(TestMemory::default()),
3745 vec![Box::new(EchoTool)],
3746 );
3747
3748 let response = agent
3749 .respond(
3750 UserMessage {
3751 text: "search for something".to_string(),
3752 },
3753 &test_ctx(),
3754 )
3755 .await
3756 .expect("should succeed");
3757
3758 assert_eq!(response.text, "direct answer");
3759 let received = prompts.lock().expect("lock");
3760 assert_eq!(
3761 received.len(),
3762 1,
3763 "disabled research should not fire even with Always trigger"
3764 );
3765 }
3766
3767 struct ReasoningCapturingProvider {
3770 captured_reasoning: Arc<Mutex<Vec<ReasoningConfig>>>,
3771 response_text: String,
3772 }
3773
3774 #[async_trait]
3775 impl Provider for ReasoningCapturingProvider {
3776 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3777 self.captured_reasoning
3778 .lock()
3779 .expect("lock")
3780 .push(ReasoningConfig::default());
3781 Ok(ChatResult {
3782 output_text: self.response_text.clone(),
3783 ..Default::default()
3784 })
3785 }
3786
3787 async fn complete_with_reasoning(
3788 &self,
3789 _prompt: &str,
3790 reasoning: &ReasoningConfig,
3791 ) -> anyhow::Result<ChatResult> {
3792 self.captured_reasoning
3793 .lock()
3794 .expect("lock")
3795 .push(reasoning.clone());
3796 Ok(ChatResult {
3797 output_text: self.response_text.clone(),
3798 ..Default::default()
3799 })
3800 }
3801 }
3802
3803 #[tokio::test]
3804 async fn reasoning_config_passed_to_provider() {
3805 let captured = Arc::new(Mutex::new(Vec::new()));
3806 let provider = ReasoningCapturingProvider {
3807 captured_reasoning: captured.clone(),
3808 response_text: "ok".to_string(),
3809 };
3810 let config = AgentConfig {
3811 reasoning: ReasoningConfig {
3812 enabled: Some(true),
3813 level: Some("high".to_string()),
3814 },
3815 ..Default::default()
3816 };
3817 let agent = Agent::new(
3818 config,
3819 Box::new(provider),
3820 Box::new(TestMemory::default()),
3821 vec![],
3822 );
3823
3824 agent
3825 .respond(
3826 UserMessage {
3827 text: "test".to_string(),
3828 },
3829 &test_ctx(),
3830 )
3831 .await
3832 .expect("should succeed");
3833
3834 let configs = captured.lock().expect("lock");
3835 assert_eq!(configs.len(), 1, "provider should be called once");
3836 assert_eq!(configs[0].enabled, Some(true));
3837 assert_eq!(configs[0].level.as_deref(), Some("high"));
3838 }
3839
3840 #[tokio::test]
3841 async fn reasoning_disabled_passes_config_through() {
3842 let captured = Arc::new(Mutex::new(Vec::new()));
3843 let provider = ReasoningCapturingProvider {
3844 captured_reasoning: captured.clone(),
3845 response_text: "ok".to_string(),
3846 };
3847 let config = AgentConfig {
3848 reasoning: ReasoningConfig {
3849 enabled: Some(false),
3850 level: None,
3851 },
3852 ..Default::default()
3853 };
3854 let agent = Agent::new(
3855 config,
3856 Box::new(provider),
3857 Box::new(TestMemory::default()),
3858 vec![],
3859 );
3860
3861 agent
3862 .respond(
3863 UserMessage {
3864 text: "test".to_string(),
3865 },
3866 &test_ctx(),
3867 )
3868 .await
3869 .expect("should succeed");
3870
3871 let configs = captured.lock().expect("lock");
3872 assert_eq!(configs.len(), 1);
3873 assert_eq!(
3874 configs[0].enabled,
3875 Some(false),
3876 "disabled reasoning should still be passed to provider"
3877 );
3878 assert_eq!(configs[0].level, None);
3879 }
3880
3881 struct StructuredProvider {
3884 responses: Vec<ChatResult>,
3885 call_count: AtomicUsize,
3886 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
3887 }
3888
3889 impl StructuredProvider {
3890 fn new(responses: Vec<ChatResult>) -> Self {
3891 Self {
3892 responses,
3893 call_count: AtomicUsize::new(0),
3894 received_messages: Arc::new(Mutex::new(Vec::new())),
3895 }
3896 }
3897 }
3898
3899 #[async_trait]
3900 impl Provider for StructuredProvider {
3901 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3902 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3903 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3904 }
3905
3906 async fn complete_with_tools(
3907 &self,
3908 messages: &[ConversationMessage],
3909 _tools: &[ToolDefinition],
3910 _reasoning: &ReasoningConfig,
3911 ) -> anyhow::Result<ChatResult> {
3912 self.received_messages
3913 .lock()
3914 .expect("lock")
3915 .push(messages.to_vec());
3916 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3917 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3918 }
3919 }
3920
3921 struct StructuredEchoTool;
3922
3923 #[async_trait]
3924 impl Tool for StructuredEchoTool {
3925 fn name(&self) -> &'static str {
3926 "echo"
3927 }
3928 fn description(&self) -> &'static str {
3929 "Echoes input back"
3930 }
3931 fn input_schema(&self) -> Option<serde_json::Value> {
3932 Some(json!({
3933 "type": "object",
3934 "properties": {
3935 "text": { "type": "string", "description": "Text to echo" }
3936 },
3937 "required": ["text"]
3938 }))
3939 }
3940 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3941 Ok(ToolResult {
3942 output: format!("echoed:{input}"),
3943 })
3944 }
3945 }
3946
3947 struct StructuredFailingTool;
3948
3949 #[async_trait]
3950 impl Tool for StructuredFailingTool {
3951 fn name(&self) -> &'static str {
3952 "boom"
3953 }
3954 fn description(&self) -> &'static str {
3955 "Always fails"
3956 }
3957 fn input_schema(&self) -> Option<serde_json::Value> {
3958 Some(json!({
3959 "type": "object",
3960 "properties": {},
3961 }))
3962 }
3963 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3964 Err(anyhow::anyhow!("tool exploded"))
3965 }
3966 }
3967
3968 struct StructuredUpperTool;
3969
3970 #[async_trait]
3971 impl Tool for StructuredUpperTool {
3972 fn name(&self) -> &'static str {
3973 "upper"
3974 }
3975 fn description(&self) -> &'static str {
3976 "Uppercases input"
3977 }
3978 fn input_schema(&self) -> Option<serde_json::Value> {
3979 Some(json!({
3980 "type": "object",
3981 "properties": {
3982 "text": { "type": "string", "description": "Text to uppercase" }
3983 },
3984 "required": ["text"]
3985 }))
3986 }
3987 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3988 Ok(ToolResult {
3989 output: input.to_uppercase(),
3990 })
3991 }
3992 }
3993
3994 use crate::types::ToolUseRequest;
3995
3996 #[tokio::test]
3999 async fn structured_basic_tool_call_then_end_turn() {
4000 let provider = StructuredProvider::new(vec![
4001 ChatResult {
4002 output_text: "Let me echo that.".to_string(),
4003 tool_calls: vec![ToolUseRequest {
4004 id: "call_1".to_string(),
4005 name: "echo".to_string(),
4006 input: json!({"text": "hello"}),
4007 }],
4008 stop_reason: Some(StopReason::ToolUse),
4009 ..Default::default()
4010 },
4011 ChatResult {
4012 output_text: "The echo returned: hello".to_string(),
4013 tool_calls: vec![],
4014 stop_reason: Some(StopReason::EndTurn),
4015 ..Default::default()
4016 },
4017 ]);
4018 let received = provider.received_messages.clone();
4019
4020 let agent = Agent::new(
4021 AgentConfig::default(),
4022 Box::new(provider),
4023 Box::new(TestMemory::default()),
4024 vec![Box::new(StructuredEchoTool)],
4025 );
4026
4027 let response = agent
4028 .respond(
4029 UserMessage {
4030 text: "echo hello".to_string(),
4031 },
4032 &test_ctx(),
4033 )
4034 .await
4035 .expect("should succeed");
4036
4037 assert_eq!(response.text, "The echo returned: hello");
4038
4039 let msgs = received.lock().expect("lock");
4041 assert_eq!(msgs.len(), 2, "provider should be called twice");
4042 assert!(
4044 msgs[1].len() >= 3,
4045 "second call should have user + assistant + tool result"
4046 );
4047 }
4048
4049 #[tokio::test]
4050 async fn structured_no_tool_calls_returns_immediately() {
4051 let provider = StructuredProvider::new(vec![ChatResult {
4052 output_text: "Hello! No tools needed.".to_string(),
4053 tool_calls: vec![],
4054 stop_reason: Some(StopReason::EndTurn),
4055 ..Default::default()
4056 }]);
4057
4058 let agent = Agent::new(
4059 AgentConfig::default(),
4060 Box::new(provider),
4061 Box::new(TestMemory::default()),
4062 vec![Box::new(StructuredEchoTool)],
4063 );
4064
4065 let response = agent
4066 .respond(
4067 UserMessage {
4068 text: "hi".to_string(),
4069 },
4070 &test_ctx(),
4071 )
4072 .await
4073 .expect("should succeed");
4074
4075 assert_eq!(response.text, "Hello! No tools needed.");
4076 }
4077
4078 #[tokio::test]
4079 async fn structured_tool_not_found_sends_error_result() {
4080 let provider = StructuredProvider::new(vec![
4081 ChatResult {
4082 output_text: String::new(),
4083 tool_calls: vec![ToolUseRequest {
4084 id: "call_1".to_string(),
4085 name: "nonexistent".to_string(),
4086 input: json!({}),
4087 }],
4088 stop_reason: Some(StopReason::ToolUse),
4089 ..Default::default()
4090 },
4091 ChatResult {
4092 output_text: "I see the tool was not found.".to_string(),
4093 tool_calls: vec![],
4094 stop_reason: Some(StopReason::EndTurn),
4095 ..Default::default()
4096 },
4097 ]);
4098 let received = provider.received_messages.clone();
4099
4100 let agent = Agent::new(
4101 AgentConfig::default(),
4102 Box::new(provider),
4103 Box::new(TestMemory::default()),
4104 vec![Box::new(StructuredEchoTool)],
4105 );
4106
4107 let response = agent
4108 .respond(
4109 UserMessage {
4110 text: "use nonexistent".to_string(),
4111 },
4112 &test_ctx(),
4113 )
4114 .await
4115 .expect("should succeed, not abort");
4116
4117 assert_eq!(response.text, "I see the tool was not found.");
4118
4119 let msgs = received.lock().expect("lock");
4121 let last_call = &msgs[1];
4122 let has_error_result = last_call.iter().any(|m| {
4123 matches!(m, ConversationMessage::ToolResult(r) if r.is_error && r.content.contains("not found"))
4124 });
4125 assert!(has_error_result, "should include error ToolResult");
4126 }
4127
4128 #[tokio::test]
4129 async fn structured_tool_error_does_not_abort() {
4130 let provider = StructuredProvider::new(vec![
4131 ChatResult {
4132 output_text: String::new(),
4133 tool_calls: vec![ToolUseRequest {
4134 id: "call_1".to_string(),
4135 name: "boom".to_string(),
4136 input: json!({}),
4137 }],
4138 stop_reason: Some(StopReason::ToolUse),
4139 ..Default::default()
4140 },
4141 ChatResult {
4142 output_text: "I handled the error gracefully.".to_string(),
4143 tool_calls: vec![],
4144 stop_reason: Some(StopReason::EndTurn),
4145 ..Default::default()
4146 },
4147 ]);
4148
4149 let agent = Agent::new(
4150 AgentConfig::default(),
4151 Box::new(provider),
4152 Box::new(TestMemory::default()),
4153 vec![Box::new(StructuredFailingTool)],
4154 );
4155
4156 let response = agent
4157 .respond(
4158 UserMessage {
4159 text: "boom".to_string(),
4160 },
4161 &test_ctx(),
4162 )
4163 .await
4164 .expect("should succeed, error becomes ToolResultMessage");
4165
4166 assert_eq!(response.text, "I handled the error gracefully.");
4167 }
4168
4169 #[tokio::test]
4170 async fn structured_multi_iteration_tool_calls() {
4171 let provider = StructuredProvider::new(vec![
4172 ChatResult {
4174 output_text: String::new(),
4175 tool_calls: vec![ToolUseRequest {
4176 id: "call_1".to_string(),
4177 name: "echo".to_string(),
4178 input: json!({"text": "first"}),
4179 }],
4180 stop_reason: Some(StopReason::ToolUse),
4181 ..Default::default()
4182 },
4183 ChatResult {
4185 output_text: String::new(),
4186 tool_calls: vec![ToolUseRequest {
4187 id: "call_2".to_string(),
4188 name: "echo".to_string(),
4189 input: json!({"text": "second"}),
4190 }],
4191 stop_reason: Some(StopReason::ToolUse),
4192 ..Default::default()
4193 },
4194 ChatResult {
4196 output_text: "Done with two tool calls.".to_string(),
4197 tool_calls: vec![],
4198 stop_reason: Some(StopReason::EndTurn),
4199 ..Default::default()
4200 },
4201 ]);
4202 let received = provider.received_messages.clone();
4203
4204 let agent = Agent::new(
4205 AgentConfig::default(),
4206 Box::new(provider),
4207 Box::new(TestMemory::default()),
4208 vec![Box::new(StructuredEchoTool)],
4209 );
4210
4211 let response = agent
4212 .respond(
4213 UserMessage {
4214 text: "echo twice".to_string(),
4215 },
4216 &test_ctx(),
4217 )
4218 .await
4219 .expect("should succeed");
4220
4221 assert_eq!(response.text, "Done with two tool calls.");
4222 let msgs = received.lock().expect("lock");
4223 assert_eq!(msgs.len(), 3, "three provider calls");
4224 }
4225
4226 #[tokio::test]
4227 async fn structured_max_iterations_forces_final_answer() {
4228 let responses = vec![
4232 ChatResult {
4233 output_text: String::new(),
4234 tool_calls: vec![ToolUseRequest {
4235 id: "call_0".to_string(),
4236 name: "echo".to_string(),
4237 input: json!({"text": "loop"}),
4238 }],
4239 stop_reason: Some(StopReason::ToolUse),
4240 ..Default::default()
4241 },
4242 ChatResult {
4243 output_text: String::new(),
4244 tool_calls: vec![ToolUseRequest {
4245 id: "call_1".to_string(),
4246 name: "echo".to_string(),
4247 input: json!({"text": "loop"}),
4248 }],
4249 stop_reason: Some(StopReason::ToolUse),
4250 ..Default::default()
4251 },
4252 ChatResult {
4253 output_text: String::new(),
4254 tool_calls: vec![ToolUseRequest {
4255 id: "call_2".to_string(),
4256 name: "echo".to_string(),
4257 input: json!({"text": "loop"}),
4258 }],
4259 stop_reason: Some(StopReason::ToolUse),
4260 ..Default::default()
4261 },
4262 ChatResult {
4264 output_text: "Forced final answer.".to_string(),
4265 tool_calls: vec![],
4266 stop_reason: Some(StopReason::EndTurn),
4267 ..Default::default()
4268 },
4269 ];
4270
4271 let agent = Agent::new(
4272 AgentConfig {
4273 max_tool_iterations: 3,
4274 loop_detection_no_progress_threshold: 0, ..AgentConfig::default()
4276 },
4277 Box::new(StructuredProvider::new(responses)),
4278 Box::new(TestMemory::default()),
4279 vec![Box::new(StructuredEchoTool)],
4280 );
4281
4282 let response = agent
4283 .respond(
4284 UserMessage {
4285 text: "loop forever".to_string(),
4286 },
4287 &test_ctx(),
4288 )
4289 .await
4290 .expect("should succeed with forced answer");
4291
4292 assert_eq!(response.text, "Forced final answer.");
4293 }
4294
4295 #[tokio::test]
4296 async fn structured_fallback_to_text_when_disabled() {
4297 let prompts = Arc::new(Mutex::new(Vec::new()));
4298 let provider = TestProvider {
4299 received_prompts: prompts.clone(),
4300 response_text: "text path used".to_string(),
4301 };
4302
4303 let agent = Agent::new(
4304 AgentConfig {
4305 model_supports_tool_use: false,
4306 ..AgentConfig::default()
4307 },
4308 Box::new(provider),
4309 Box::new(TestMemory::default()),
4310 vec![Box::new(StructuredEchoTool)],
4311 );
4312
4313 let response = agent
4314 .respond(
4315 UserMessage {
4316 text: "hello".to_string(),
4317 },
4318 &test_ctx(),
4319 )
4320 .await
4321 .expect("should succeed");
4322
4323 assert_eq!(response.text, "text path used");
4324 }
4325
4326 #[tokio::test]
4327 async fn structured_fallback_when_no_tool_schemas() {
4328 let prompts = Arc::new(Mutex::new(Vec::new()));
4329 let provider = TestProvider {
4330 received_prompts: prompts.clone(),
4331 response_text: "text path used".to_string(),
4332 };
4333
4334 let agent = Agent::new(
4336 AgentConfig::default(),
4337 Box::new(provider),
4338 Box::new(TestMemory::default()),
4339 vec![Box::new(EchoTool)],
4340 );
4341
4342 let response = agent
4343 .respond(
4344 UserMessage {
4345 text: "hello".to_string(),
4346 },
4347 &test_ctx(),
4348 )
4349 .await
4350 .expect("should succeed");
4351
4352 assert_eq!(response.text, "text path used");
4353 }
4354
4355 #[tokio::test]
4356 async fn structured_parallel_tool_calls() {
4357 let provider = StructuredProvider::new(vec![
4358 ChatResult {
4359 output_text: String::new(),
4360 tool_calls: vec![
4361 ToolUseRequest {
4362 id: "call_1".to_string(),
4363 name: "echo".to_string(),
4364 input: json!({"text": "a"}),
4365 },
4366 ToolUseRequest {
4367 id: "call_2".to_string(),
4368 name: "upper".to_string(),
4369 input: json!({"text": "b"}),
4370 },
4371 ],
4372 stop_reason: Some(StopReason::ToolUse),
4373 ..Default::default()
4374 },
4375 ChatResult {
4376 output_text: "Both tools ran.".to_string(),
4377 tool_calls: vec![],
4378 stop_reason: Some(StopReason::EndTurn),
4379 ..Default::default()
4380 },
4381 ]);
4382 let received = provider.received_messages.clone();
4383
4384 let agent = Agent::new(
4385 AgentConfig {
4386 parallel_tools: true,
4387 ..AgentConfig::default()
4388 },
4389 Box::new(provider),
4390 Box::new(TestMemory::default()),
4391 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4392 );
4393
4394 let response = agent
4395 .respond(
4396 UserMessage {
4397 text: "do both".to_string(),
4398 },
4399 &test_ctx(),
4400 )
4401 .await
4402 .expect("should succeed");
4403
4404 assert_eq!(response.text, "Both tools ran.");
4405
4406 let msgs = received.lock().expect("lock");
4408 let tool_results: Vec<_> = msgs[1]
4409 .iter()
4410 .filter(|m| matches!(m, ConversationMessage::ToolResult(_)))
4411 .collect();
4412 assert_eq!(tool_results.len(), 2, "should have two tool results");
4413 }
4414
4415 #[tokio::test]
4416 async fn structured_memory_integration() {
4417 let memory = TestMemory::default();
4418 memory
4420 .append(MemoryEntry {
4421 role: "user".to_string(),
4422 content: "old question".to_string(),
4423 ..Default::default()
4424 })
4425 .await
4426 .unwrap();
4427 memory
4428 .append(MemoryEntry {
4429 role: "assistant".to_string(),
4430 content: "old answer".to_string(),
4431 ..Default::default()
4432 })
4433 .await
4434 .unwrap();
4435
4436 let provider = StructuredProvider::new(vec![ChatResult {
4437 output_text: "I see your history.".to_string(),
4438 tool_calls: vec![],
4439 stop_reason: Some(StopReason::EndTurn),
4440 ..Default::default()
4441 }]);
4442 let received = provider.received_messages.clone();
4443
4444 let agent = Agent::new(
4445 AgentConfig::default(),
4446 Box::new(provider),
4447 Box::new(memory),
4448 vec![Box::new(StructuredEchoTool)],
4449 );
4450
4451 agent
4452 .respond(
4453 UserMessage {
4454 text: "new question".to_string(),
4455 },
4456 &test_ctx(),
4457 )
4458 .await
4459 .expect("should succeed");
4460
4461 let msgs = received.lock().expect("lock");
4462 assert!(msgs[0].len() >= 3, "should include memory messages");
4465 assert!(matches!(
4467 &msgs[0][0],
4468 ConversationMessage::User { content, .. } if content == "old question"
4469 ));
4470 }
4471
4472 #[tokio::test]
4473 async fn prepare_tool_input_extracts_single_string_field() {
4474 let tool = StructuredEchoTool;
4475 let input = json!({"text": "hello world"});
4476 let result = prepare_tool_input(&tool, &input).expect("valid input");
4477 assert_eq!(result, "hello world");
4478 }
4479
4480 #[tokio::test]
4481 async fn prepare_tool_input_serializes_multi_field_json() {
4482 struct MultiFieldTool;
4484 #[async_trait]
4485 impl Tool for MultiFieldTool {
4486 fn name(&self) -> &'static str {
4487 "multi"
4488 }
4489 fn input_schema(&self) -> Option<serde_json::Value> {
4490 Some(json!({
4491 "type": "object",
4492 "properties": {
4493 "path": { "type": "string" },
4494 "content": { "type": "string" }
4495 },
4496 "required": ["path", "content"]
4497 }))
4498 }
4499 async fn execute(
4500 &self,
4501 _input: &str,
4502 _ctx: &ToolContext,
4503 ) -> anyhow::Result<ToolResult> {
4504 unreachable!()
4505 }
4506 }
4507
4508 let tool = MultiFieldTool;
4509 let input = json!({"path": "a.txt", "content": "hello"});
4510 let result = prepare_tool_input(&tool, &input).expect("valid input");
4511 let parsed: serde_json::Value = serde_json::from_str(&result).expect("valid JSON");
4513 assert_eq!(parsed["path"], "a.txt");
4514 assert_eq!(parsed["content"], "hello");
4515 }
4516
4517 #[tokio::test]
4518 async fn prepare_tool_input_unwraps_bare_string() {
4519 let tool = StructuredEchoTool;
4520 let input = json!("plain text");
4521 let result = prepare_tool_input(&tool, &input).expect("valid input");
4522 assert_eq!(result, "plain text");
4523 }
4524
4525 #[tokio::test]
4526 async fn prepare_tool_input_rejects_schema_violating_input() {
4527 struct StrictTool;
4528 #[async_trait]
4529 impl Tool for StrictTool {
4530 fn name(&self) -> &'static str {
4531 "strict"
4532 }
4533 fn input_schema(&self) -> Option<serde_json::Value> {
4534 Some(json!({
4535 "type": "object",
4536 "required": ["path"],
4537 "properties": {
4538 "path": { "type": "string" }
4539 }
4540 }))
4541 }
4542 async fn execute(
4543 &self,
4544 _input: &str,
4545 _ctx: &ToolContext,
4546 ) -> anyhow::Result<ToolResult> {
4547 unreachable!("should not be called with invalid input")
4548 }
4549 }
4550
4551 let result = prepare_tool_input(&StrictTool, &json!({}));
4553 assert!(result.is_err());
4554 let err = result.unwrap_err();
4555 assert!(err.contains("Invalid input for tool 'strict'"));
4556 assert!(err.contains("missing required field"));
4557 }
4558
4559 #[tokio::test]
4560 async fn prepare_tool_input_rejects_wrong_type() {
4561 struct TypedTool;
4562 #[async_trait]
4563 impl Tool for TypedTool {
4564 fn name(&self) -> &'static str {
4565 "typed"
4566 }
4567 fn input_schema(&self) -> Option<serde_json::Value> {
4568 Some(json!({
4569 "type": "object",
4570 "required": ["count"],
4571 "properties": {
4572 "count": { "type": "integer" }
4573 }
4574 }))
4575 }
4576 async fn execute(
4577 &self,
4578 _input: &str,
4579 _ctx: &ToolContext,
4580 ) -> anyhow::Result<ToolResult> {
4581 unreachable!("should not be called with invalid input")
4582 }
4583 }
4584
4585 let result = prepare_tool_input(&TypedTool, &json!({"count": "not a number"}));
4587 assert!(result.is_err());
4588 let err = result.unwrap_err();
4589 assert!(err.contains("expected type \"integer\""));
4590 }
4591
4592 #[test]
4593 fn truncate_messages_preserves_small_conversation() {
4594 let mut msgs = vec![
4595 ConversationMessage::user("hello".to_string()),
4596 ConversationMessage::Assistant {
4597 content: Some("world".to_string()),
4598 tool_calls: vec![],
4599 },
4600 ];
4601 truncate_messages(&mut msgs, 1000);
4602 assert_eq!(msgs.len(), 2, "should not truncate small conversation");
4603 }
4604
4605 #[test]
4606 fn truncate_messages_drops_middle() {
4607 let mut msgs = vec![
4608 ConversationMessage::user("A".repeat(100)),
4609 ConversationMessage::Assistant {
4610 content: Some("B".repeat(100)),
4611 tool_calls: vec![],
4612 },
4613 ConversationMessage::user("C".repeat(100)),
4614 ConversationMessage::Assistant {
4615 content: Some("D".repeat(100)),
4616 tool_calls: vec![],
4617 },
4618 ];
4619 truncate_messages(&mut msgs, 250);
4622 assert!(msgs.len() < 4, "should have dropped some messages");
4623 assert!(matches!(
4625 &msgs[0],
4626 ConversationMessage::User { content, .. } if content.starts_with("AAA")
4627 ));
4628 }
4629
4630 #[test]
4631 fn memory_to_messages_reverses_order() {
4632 let entries = vec![
4633 MemoryEntry {
4634 role: "assistant".to_string(),
4635 content: "newest".to_string(),
4636 ..Default::default()
4637 },
4638 MemoryEntry {
4639 role: "user".to_string(),
4640 content: "oldest".to_string(),
4641 ..Default::default()
4642 },
4643 ];
4644 let msgs = memory_to_messages(&entries);
4645 assert_eq!(msgs.len(), 2);
4646 assert!(matches!(
4647 &msgs[0],
4648 ConversationMessage::User { content, .. } if content == "oldest"
4649 ));
4650 assert!(matches!(
4651 &msgs[1],
4652 ConversationMessage::Assistant { content: Some(c), .. } if c == "newest"
4653 ));
4654 }
4655
4656 #[test]
4657 fn single_required_string_field_detects_single() {
4658 let schema = json!({
4659 "type": "object",
4660 "properties": {
4661 "path": { "type": "string" }
4662 },
4663 "required": ["path"]
4664 });
4665 assert_eq!(
4666 single_required_string_field(&schema),
4667 Some("path".to_string())
4668 );
4669 }
4670
4671 #[test]
4672 fn single_required_string_field_none_for_multiple() {
4673 let schema = json!({
4674 "type": "object",
4675 "properties": {
4676 "path": { "type": "string" },
4677 "mode": { "type": "string" }
4678 },
4679 "required": ["path", "mode"]
4680 });
4681 assert_eq!(single_required_string_field(&schema), None);
4682 }
4683
4684 #[test]
4685 fn single_required_string_field_none_for_non_string() {
4686 let schema = json!({
4687 "type": "object",
4688 "properties": {
4689 "count": { "type": "integer" }
4690 },
4691 "required": ["count"]
4692 });
4693 assert_eq!(single_required_string_field(&schema), None);
4694 }
4695
4696 use crate::types::StreamChunk;
4699
4700 struct StreamingProvider {
4702 responses: Vec<ChatResult>,
4703 call_count: AtomicUsize,
4704 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
4705 }
4706
4707 impl StreamingProvider {
4708 fn new(responses: Vec<ChatResult>) -> Self {
4709 Self {
4710 responses,
4711 call_count: AtomicUsize::new(0),
4712 received_messages: Arc::new(Mutex::new(Vec::new())),
4713 }
4714 }
4715 }
4716
4717 #[async_trait]
4718 impl Provider for StreamingProvider {
4719 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4720 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4721 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4722 }
4723
4724 async fn complete_streaming(
4725 &self,
4726 _prompt: &str,
4727 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4728 ) -> anyhow::Result<ChatResult> {
4729 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4730 let result = self.responses.get(idx).cloned().unwrap_or_default();
4731 for word in result.output_text.split_whitespace() {
4733 let _ = sender.send(StreamChunk {
4734 delta: format!("{word} "),
4735 done: false,
4736 tool_call_delta: None,
4737 });
4738 }
4739 let _ = sender.send(StreamChunk {
4740 delta: String::new(),
4741 done: true,
4742 tool_call_delta: None,
4743 });
4744 Ok(result)
4745 }
4746
4747 async fn complete_with_tools(
4748 &self,
4749 messages: &[ConversationMessage],
4750 _tools: &[ToolDefinition],
4751 _reasoning: &ReasoningConfig,
4752 ) -> anyhow::Result<ChatResult> {
4753 self.received_messages
4754 .lock()
4755 .expect("lock")
4756 .push(messages.to_vec());
4757 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4758 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4759 }
4760
4761 async fn complete_streaming_with_tools(
4762 &self,
4763 messages: &[ConversationMessage],
4764 _tools: &[ToolDefinition],
4765 _reasoning: &ReasoningConfig,
4766 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4767 ) -> anyhow::Result<ChatResult> {
4768 self.received_messages
4769 .lock()
4770 .expect("lock")
4771 .push(messages.to_vec());
4772 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4773 let result = self.responses.get(idx).cloned().unwrap_or_default();
4774 for ch in result.output_text.chars() {
4776 let _ = sender.send(StreamChunk {
4777 delta: ch.to_string(),
4778 done: false,
4779 tool_call_delta: None,
4780 });
4781 }
4782 let _ = sender.send(StreamChunk {
4783 delta: String::new(),
4784 done: true,
4785 tool_call_delta: None,
4786 });
4787 Ok(result)
4788 }
4789 }
4790
4791 #[tokio::test]
4792 async fn streaming_text_only_sends_chunks() {
4793 let provider = StreamingProvider::new(vec![ChatResult {
4794 output_text: "Hello world".to_string(),
4795 ..Default::default()
4796 }]);
4797 let agent = Agent::new(
4798 AgentConfig {
4799 model_supports_tool_use: false,
4800 ..Default::default()
4801 },
4802 Box::new(provider),
4803 Box::new(TestMemory::default()),
4804 vec![],
4805 );
4806
4807 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4808 let response = agent
4809 .respond_streaming(
4810 UserMessage {
4811 text: "hi".to_string(),
4812 },
4813 &test_ctx(),
4814 tx,
4815 )
4816 .await
4817 .expect("should succeed");
4818
4819 assert_eq!(response.text, "Hello world");
4820
4821 let mut chunks = Vec::new();
4823 while let Ok(chunk) = rx.try_recv() {
4824 chunks.push(chunk);
4825 }
4826 assert!(chunks.len() >= 2, "should have at least text + done chunks");
4827 assert!(chunks.last().unwrap().done, "last chunk should be done");
4828 }
4829
4830 #[tokio::test]
4831 async fn streaming_single_tool_call_round_trip() {
4832 let provider = StreamingProvider::new(vec![
4833 ChatResult {
4834 output_text: "I'll echo that.".to_string(),
4835 tool_calls: vec![ToolUseRequest {
4836 id: "call_1".to_string(),
4837 name: "echo".to_string(),
4838 input: json!({"text": "hello"}),
4839 }],
4840 stop_reason: Some(StopReason::ToolUse),
4841 ..Default::default()
4842 },
4843 ChatResult {
4844 output_text: "Done echoing.".to_string(),
4845 tool_calls: vec![],
4846 stop_reason: Some(StopReason::EndTurn),
4847 ..Default::default()
4848 },
4849 ]);
4850
4851 let agent = Agent::new(
4852 AgentConfig::default(),
4853 Box::new(provider),
4854 Box::new(TestMemory::default()),
4855 vec![Box::new(StructuredEchoTool)],
4856 );
4857
4858 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4859 let response = agent
4860 .respond_streaming(
4861 UserMessage {
4862 text: "echo hello".to_string(),
4863 },
4864 &test_ctx(),
4865 tx,
4866 )
4867 .await
4868 .expect("should succeed");
4869
4870 assert_eq!(response.text, "Done echoing.");
4871
4872 let mut chunks = Vec::new();
4873 while let Ok(chunk) = rx.try_recv() {
4874 chunks.push(chunk);
4875 }
4876 assert!(chunks.len() >= 2, "should have streaming chunks");
4878 }
4879
4880 #[tokio::test]
4881 async fn streaming_multi_iteration_tool_calls() {
4882 let provider = StreamingProvider::new(vec![
4883 ChatResult {
4884 output_text: String::new(),
4885 tool_calls: vec![ToolUseRequest {
4886 id: "call_1".to_string(),
4887 name: "echo".to_string(),
4888 input: json!({"text": "first"}),
4889 }],
4890 stop_reason: Some(StopReason::ToolUse),
4891 ..Default::default()
4892 },
4893 ChatResult {
4894 output_text: String::new(),
4895 tool_calls: vec![ToolUseRequest {
4896 id: "call_2".to_string(),
4897 name: "upper".to_string(),
4898 input: json!({"text": "second"}),
4899 }],
4900 stop_reason: Some(StopReason::ToolUse),
4901 ..Default::default()
4902 },
4903 ChatResult {
4904 output_text: "All done.".to_string(),
4905 tool_calls: vec![],
4906 stop_reason: Some(StopReason::EndTurn),
4907 ..Default::default()
4908 },
4909 ]);
4910
4911 let agent = Agent::new(
4912 AgentConfig::default(),
4913 Box::new(provider),
4914 Box::new(TestMemory::default()),
4915 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4916 );
4917
4918 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4919 let response = agent
4920 .respond_streaming(
4921 UserMessage {
4922 text: "do two things".to_string(),
4923 },
4924 &test_ctx(),
4925 tx,
4926 )
4927 .await
4928 .expect("should succeed");
4929
4930 assert_eq!(response.text, "All done.");
4931
4932 let mut chunks = Vec::new();
4933 while let Ok(chunk) = rx.try_recv() {
4934 chunks.push(chunk);
4935 }
4936 let done_count = chunks.iter().filter(|c| c.done).count();
4938 assert!(
4939 done_count >= 3,
4940 "should have done chunks from 3 provider calls, got {}",
4941 done_count
4942 );
4943 }
4944
4945 #[tokio::test]
4946 async fn streaming_timeout_returns_error() {
4947 struct StreamingSlowProvider;
4948
4949 #[async_trait]
4950 impl Provider for StreamingSlowProvider {
4951 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4952 sleep(Duration::from_millis(500)).await;
4953 Ok(ChatResult::default())
4954 }
4955
4956 async fn complete_streaming(
4957 &self,
4958 _prompt: &str,
4959 _sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4960 ) -> anyhow::Result<ChatResult> {
4961 sleep(Duration::from_millis(500)).await;
4962 Ok(ChatResult::default())
4963 }
4964 }
4965
4966 let agent = Agent::new(
4967 AgentConfig {
4968 request_timeout_ms: 50,
4969 model_supports_tool_use: false,
4970 ..Default::default()
4971 },
4972 Box::new(StreamingSlowProvider),
4973 Box::new(TestMemory::default()),
4974 vec![],
4975 );
4976
4977 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4978 let result = agent
4979 .respond_streaming(
4980 UserMessage {
4981 text: "hi".to_string(),
4982 },
4983 &test_ctx(),
4984 tx,
4985 )
4986 .await;
4987
4988 assert!(result.is_err());
4989 match result.unwrap_err() {
4990 AgentError::Timeout { timeout_ms } => assert_eq!(timeout_ms, 50),
4991 other => panic!("expected Timeout, got {other:?}"),
4992 }
4993 }
4994
4995 #[tokio::test]
4996 async fn streaming_no_schema_fallback_sends_chunks() {
4997 let provider = StreamingProvider::new(vec![ChatResult {
4999 output_text: "Fallback response".to_string(),
5000 ..Default::default()
5001 }]);
5002 let agent = Agent::new(
5003 AgentConfig::default(), Box::new(provider),
5005 Box::new(TestMemory::default()),
5006 vec![Box::new(EchoTool)], );
5008
5009 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
5010 let response = agent
5011 .respond_streaming(
5012 UserMessage {
5013 text: "hi".to_string(),
5014 },
5015 &test_ctx(),
5016 tx,
5017 )
5018 .await
5019 .expect("should succeed");
5020
5021 assert_eq!(response.text, "Fallback response");
5022
5023 let mut chunks = Vec::new();
5024 while let Ok(chunk) = rx.try_recv() {
5025 chunks.push(chunk);
5026 }
5027 assert!(!chunks.is_empty(), "should have streaming chunks");
5028 assert!(chunks.last().unwrap().done, "last chunk should be done");
5029 }
5030
5031 #[tokio::test]
5032 async fn streaming_tool_error_does_not_abort() {
5033 let provider = StreamingProvider::new(vec![
5034 ChatResult {
5035 output_text: String::new(),
5036 tool_calls: vec![ToolUseRequest {
5037 id: "call_1".to_string(),
5038 name: "boom".to_string(),
5039 input: json!({}),
5040 }],
5041 stop_reason: Some(StopReason::ToolUse),
5042 ..Default::default()
5043 },
5044 ChatResult {
5045 output_text: "Recovered from error.".to_string(),
5046 tool_calls: vec![],
5047 stop_reason: Some(StopReason::EndTurn),
5048 ..Default::default()
5049 },
5050 ]);
5051
5052 let agent = Agent::new(
5053 AgentConfig::default(),
5054 Box::new(provider),
5055 Box::new(TestMemory::default()),
5056 vec![Box::new(StructuredFailingTool)],
5057 );
5058
5059 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
5060 let response = agent
5061 .respond_streaming(
5062 UserMessage {
5063 text: "boom".to_string(),
5064 },
5065 &test_ctx(),
5066 tx,
5067 )
5068 .await
5069 .expect("should succeed despite tool error");
5070
5071 assert_eq!(response.text, "Recovered from error.");
5072 }
5073
5074 #[tokio::test]
5075 async fn streaming_parallel_tools() {
5076 let provider = StreamingProvider::new(vec![
5077 ChatResult {
5078 output_text: String::new(),
5079 tool_calls: vec![
5080 ToolUseRequest {
5081 id: "call_1".to_string(),
5082 name: "echo".to_string(),
5083 input: json!({"text": "a"}),
5084 },
5085 ToolUseRequest {
5086 id: "call_2".to_string(),
5087 name: "upper".to_string(),
5088 input: json!({"text": "b"}),
5089 },
5090 ],
5091 stop_reason: Some(StopReason::ToolUse),
5092 ..Default::default()
5093 },
5094 ChatResult {
5095 output_text: "Parallel done.".to_string(),
5096 tool_calls: vec![],
5097 stop_reason: Some(StopReason::EndTurn),
5098 ..Default::default()
5099 },
5100 ]);
5101
5102 let agent = Agent::new(
5103 AgentConfig {
5104 parallel_tools: true,
5105 ..Default::default()
5106 },
5107 Box::new(provider),
5108 Box::new(TestMemory::default()),
5109 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
5110 );
5111
5112 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
5113 let response = agent
5114 .respond_streaming(
5115 UserMessage {
5116 text: "parallel test".to_string(),
5117 },
5118 &test_ctx(),
5119 tx,
5120 )
5121 .await
5122 .expect("should succeed");
5123
5124 assert_eq!(response.text, "Parallel done.");
5125 }
5126
5127 #[tokio::test]
5128 async fn streaming_done_chunk_sentinel() {
5129 let provider = StreamingProvider::new(vec![ChatResult {
5130 output_text: "abc".to_string(),
5131 tool_calls: vec![],
5132 stop_reason: Some(StopReason::EndTurn),
5133 ..Default::default()
5134 }]);
5135
5136 let agent = Agent::new(
5137 AgentConfig::default(),
5138 Box::new(provider),
5139 Box::new(TestMemory::default()),
5140 vec![Box::new(StructuredEchoTool)],
5141 );
5142
5143 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
5144 let _response = agent
5145 .respond_streaming(
5146 UserMessage {
5147 text: "test".to_string(),
5148 },
5149 &test_ctx(),
5150 tx,
5151 )
5152 .await
5153 .expect("should succeed");
5154
5155 let mut chunks = Vec::new();
5156 while let Ok(chunk) = rx.try_recv() {
5157 chunks.push(chunk);
5158 }
5159
5160 let done_chunks: Vec<_> = chunks.iter().filter(|c| c.done).collect();
5162 assert!(
5163 !done_chunks.is_empty(),
5164 "must have at least one done sentinel chunk"
5165 );
5166 assert!(chunks.last().unwrap().done, "final chunk must be done=true");
5168 let content_chunks: Vec<_> = chunks
5170 .iter()
5171 .filter(|c| !c.done && !c.delta.is_empty())
5172 .collect();
5173 assert!(!content_chunks.is_empty(), "should have content chunks");
5174 }
5175
5176 #[tokio::test]
5179 async fn system_prompt_prepended_in_structured_path() {
5180 let provider = StructuredProvider::new(vec![ChatResult {
5181 output_text: "I understand.".to_string(),
5182 stop_reason: Some(StopReason::EndTurn),
5183 ..Default::default()
5184 }]);
5185 let received = provider.received_messages.clone();
5186 let memory = TestMemory::default();
5187
5188 let agent = Agent::new(
5189 AgentConfig {
5190 model_supports_tool_use: true,
5191 system_prompt: Some("You are a math tutor.".to_string()),
5192 ..AgentConfig::default()
5193 },
5194 Box::new(provider),
5195 Box::new(memory),
5196 vec![Box::new(StructuredEchoTool)],
5197 );
5198
5199 let result = agent
5200 .respond(
5201 UserMessage {
5202 text: "What is 2+2?".to_string(),
5203 },
5204 &test_ctx(),
5205 )
5206 .await
5207 .expect("respond should succeed");
5208
5209 assert_eq!(result.text, "I understand.");
5210
5211 let msgs = received.lock().expect("lock");
5212 assert!(!msgs.is_empty(), "provider should have been called");
5213 let first_call = &msgs[0];
5214 match &first_call[0] {
5216 ConversationMessage::System { content } => {
5217 assert_eq!(content, "You are a math tutor.");
5218 }
5219 other => panic!("expected System message first, got {other:?}"),
5220 }
5221 match &first_call[1] {
5223 ConversationMessage::User { content, .. } => {
5224 assert_eq!(content, "What is 2+2?");
5225 }
5226 other => panic!("expected User message second, got {other:?}"),
5227 }
5228 }
5229
5230 #[tokio::test]
5231 async fn no_system_prompt_omits_system_message() {
5232 let provider = StructuredProvider::new(vec![ChatResult {
5233 output_text: "ok".to_string(),
5234 stop_reason: Some(StopReason::EndTurn),
5235 ..Default::default()
5236 }]);
5237 let received = provider.received_messages.clone();
5238 let memory = TestMemory::default();
5239
5240 let agent = Agent::new(
5241 AgentConfig {
5242 model_supports_tool_use: true,
5243 system_prompt: None,
5244 ..AgentConfig::default()
5245 },
5246 Box::new(provider),
5247 Box::new(memory),
5248 vec![Box::new(StructuredEchoTool)],
5249 );
5250
5251 agent
5252 .respond(
5253 UserMessage {
5254 text: "hello".to_string(),
5255 },
5256 &test_ctx(),
5257 )
5258 .await
5259 .expect("respond should succeed");
5260
5261 let msgs = received.lock().expect("lock");
5262 let first_call = &msgs[0];
5263 match &first_call[0] {
5265 ConversationMessage::User { content, .. } => {
5266 assert_eq!(content, "hello");
5267 }
5268 other => panic!("expected User message first (no system prompt), got {other:?}"),
5269 }
5270 }
5271
5272 #[tokio::test]
5273 async fn system_prompt_persists_across_tool_iterations() {
5274 let provider = StructuredProvider::new(vec![
5276 ChatResult {
5277 output_text: String::new(),
5278 tool_calls: vec![ToolUseRequest {
5279 id: "call_1".to_string(),
5280 name: "echo".to_string(),
5281 input: serde_json::json!({"text": "ping"}),
5282 }],
5283 stop_reason: Some(StopReason::ToolUse),
5284 ..Default::default()
5285 },
5286 ChatResult {
5287 output_text: "pong".to_string(),
5288 stop_reason: Some(StopReason::EndTurn),
5289 ..Default::default()
5290 },
5291 ]);
5292 let received = provider.received_messages.clone();
5293 let memory = TestMemory::default();
5294
5295 let agent = Agent::new(
5296 AgentConfig {
5297 model_supports_tool_use: true,
5298 system_prompt: Some("Always be concise.".to_string()),
5299 ..AgentConfig::default()
5300 },
5301 Box::new(provider),
5302 Box::new(memory),
5303 vec![Box::new(StructuredEchoTool)],
5304 );
5305
5306 let result = agent
5307 .respond(
5308 UserMessage {
5309 text: "test".to_string(),
5310 },
5311 &test_ctx(),
5312 )
5313 .await
5314 .expect("respond should succeed");
5315 assert_eq!(result.text, "pong");
5316
5317 let msgs = received.lock().expect("lock");
5318 for (i, call_msgs) in msgs.iter().enumerate() {
5320 match &call_msgs[0] {
5321 ConversationMessage::System { content } => {
5322 assert_eq!(content, "Always be concise.", "call {i} system prompt");
5323 }
5324 other => panic!("call {i}: expected System first, got {other:?}"),
5325 }
5326 }
5327 }
5328
5329 #[test]
5330 fn agent_config_system_prompt_defaults_to_none() {
5331 let config = AgentConfig::default();
5332 assert!(config.system_prompt.is_none());
5333 }
5334
5335 #[tokio::test]
5336 async fn memory_write_propagates_source_channel_from_context() {
5337 let memory = TestMemory::default();
5338 let entries = memory.entries.clone();
5339 let provider = StructuredProvider::new(vec![ChatResult {
5340 output_text: "noted".to_string(),
5341 tool_calls: vec![],
5342 stop_reason: None,
5343 ..Default::default()
5344 }]);
5345 let config = AgentConfig {
5346 privacy_boundary: "encrypted_only".to_string(),
5347 ..AgentConfig::default()
5348 };
5349 let agent = Agent::new(config, Box::new(provider), Box::new(memory), vec![]);
5350
5351 let mut ctx = ToolContext::new(".".to_string());
5352 ctx.source_channel = Some("telegram".to_string());
5353 ctx.privacy_boundary = "encrypted_only".to_string();
5354
5355 agent
5356 .respond(
5357 UserMessage {
5358 text: "hello".to_string(),
5359 },
5360 &ctx,
5361 )
5362 .await
5363 .expect("respond should succeed");
5364
5365 let stored = entries.lock().expect("memory lock poisoned");
5366 assert_eq!(stored[0].source_channel.as_deref(), Some("telegram"));
5368 assert_eq!(stored[0].privacy_boundary, "encrypted_only");
5369 assert_eq!(stored[1].source_channel.as_deref(), Some("telegram"));
5370 assert_eq!(stored[1].privacy_boundary, "encrypted_only");
5371 }
5372
5373 #[tokio::test]
5374 async fn memory_write_source_channel_none_when_ctx_empty() {
5375 let memory = TestMemory::default();
5376 let entries = memory.entries.clone();
5377 let provider = StructuredProvider::new(vec![ChatResult {
5378 output_text: "ok".to_string(),
5379 tool_calls: vec![],
5380 stop_reason: None,
5381 ..Default::default()
5382 }]);
5383 let agent = Agent::new(
5384 AgentConfig::default(),
5385 Box::new(provider),
5386 Box::new(memory),
5387 vec![],
5388 );
5389
5390 agent
5391 .respond(
5392 UserMessage {
5393 text: "hi".to_string(),
5394 },
5395 &test_ctx(),
5396 )
5397 .await
5398 .expect("respond should succeed");
5399
5400 let stored = entries.lock().expect("memory lock poisoned");
5401 assert!(stored[0].source_channel.is_none());
5402 assert!(stored[1].source_channel.is_none());
5403 }
5404
5405 fn sample_tools() -> Vec<ToolDefinition> {
5410 vec![
5411 ToolDefinition {
5412 name: "web_search".to_string(),
5413 description: "Search the web".to_string(),
5414 input_schema: serde_json::json!({"type": "object"}),
5415 },
5416 ToolDefinition {
5417 name: "read_file".to_string(),
5418 description: "Read a file".to_string(),
5419 input_schema: serde_json::json!({"type": "object"}),
5420 },
5421 ]
5422 }
5423
5424 #[test]
5425 fn extract_tool_call_from_json_code_block() {
5426 let text = "I'll search for that.\n```json\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"AI regulation EU\"}}\n```";
5427 let tools = sample_tools();
5428 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5429 assert_eq!(result.name, "web_search");
5430 assert_eq!(result.input["query"], "AI regulation EU");
5431 }
5432
5433 #[test]
5434 fn extract_tool_call_from_bare_code_block() {
5435 let text = "```\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"test\"}}\n```";
5436 let tools = sample_tools();
5437 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5438 assert_eq!(result.name, "web_search");
5439 }
5440
5441 #[test]
5442 fn extract_tool_call_from_bare_json() {
5443 let text = "{\"name\": \"web_search\", \"arguments\": {\"query\": \"test\"}}";
5444 let tools = sample_tools();
5445 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5446 assert_eq!(result.name, "web_search");
5447 assert_eq!(result.input["query"], "test");
5448 }
5449
5450 #[test]
5451 fn extract_tool_call_with_parameters_key() {
5452 let text = "{\"name\": \"web_search\", \"parameters\": {\"query\": \"test\"}}";
5453 let tools = sample_tools();
5454 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5455 assert_eq!(result.name, "web_search");
5456 assert_eq!(result.input["query"], "test");
5457 }
5458
5459 #[test]
5460 fn extract_tool_call_ignores_unknown_tool() {
5461 let text = "{\"name\": \"unknown_tool\", \"arguments\": {}}";
5462 let tools = sample_tools();
5463 assert!(extract_tool_call_from_text(text, &tools).is_none());
5464 }
5465
5466 #[test]
5467 fn extract_tool_call_returns_none_for_plain_text() {
5468 let text = "I don't know how to help with that.";
5469 let tools = sample_tools();
5470 assert!(extract_tool_call_from_text(text, &tools).is_none());
5471 }
5472
5473 #[test]
5474 fn extract_tool_call_returns_none_for_non_tool_json() {
5475 let text = "{\"message\": \"hello\", \"status\": \"ok\"}";
5476 let tools = sample_tools();
5477 assert!(extract_tool_call_from_text(text, &tools).is_none());
5478 }
5479
5480 #[test]
5481 fn extract_tool_call_with_surrounding_text() {
5482 let text = "Sure, let me search for that.\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"AI regulation\"}}\nI'll get back to you.";
5483 let tools = sample_tools();
5484 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5485 assert_eq!(result.name, "web_search");
5486 }
5487
5488 #[test]
5489 fn extract_tool_call_no_arguments_field() {
5490 let text = "{\"name\": \"web_search\"}";
5491 let tools = sample_tools();
5492 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5493 assert_eq!(result.name, "web_search");
5494 assert!(result.input.is_object());
5495 }
5496
5497 #[test]
5498 fn extract_json_block_handles_nested_braces() {
5499 let text =
5500 "```json\n{\"name\": \"read_file\", \"arguments\": {\"path\": \"/tmp/{test}\"}}\n```";
5501 let tools = sample_tools();
5502 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5503 assert_eq!(result.name, "read_file");
5504 assert_eq!(result.input["path"], "/tmp/{test}");
5505 }
5506
5507 #[test]
5508 fn extract_bare_json_handles_escaped_quotes() {
5509 let text =
5510 "{\"name\": \"web_search\", \"arguments\": {\"query\": \"test \\\"quoted\\\" word\"}}";
5511 let tools = sample_tools();
5512 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5513 assert_eq!(result.name, "web_search");
5514 }
5515
5516 #[test]
5517 fn extract_tool_call_empty_string() {
5518 let tools = sample_tools();
5519 assert!(extract_tool_call_from_text("", &tools).is_none());
5520 }
5521
5522 struct TextToolCallProvider {
5529 call_count: AtomicUsize,
5530 }
5531
5532 impl TextToolCallProvider {
5533 fn new() -> Self {
5534 Self {
5535 call_count: AtomicUsize::new(0),
5536 }
5537 }
5538 }
5539
5540 #[async_trait]
5541 impl Provider for TextToolCallProvider {
5542 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
5543 Ok(ChatResult::default())
5544 }
5545
5546 async fn complete_with_tools(
5547 &self,
5548 messages: &[ConversationMessage],
5549 _tools: &[ToolDefinition],
5550 _reasoning: &ReasoningConfig,
5551 ) -> anyhow::Result<ChatResult> {
5552 let n = self.call_count.fetch_add(1, Ordering::Relaxed);
5553 if n == 0 {
5554 Ok(ChatResult {
5556 output_text: "```json\n{\"name\": \"echo\", \"arguments\": {\"message\": \"hello from text\"}}\n```".to_string(),
5557 tool_calls: vec![],
5558 stop_reason: Some(StopReason::EndTurn),
5559 ..Default::default()
5560 })
5561 } else {
5562 let tool_output = messages
5564 .iter()
5565 .rev()
5566 .find_map(|m| match m {
5567 ConversationMessage::ToolResult(r) => Some(r.content.as_str()),
5568 _ => None,
5569 })
5570 .unwrap_or("no tool result");
5571 Ok(ChatResult {
5572 output_text: format!("Got it: {tool_output}"),
5573 tool_calls: vec![],
5574 stop_reason: Some(StopReason::EndTurn),
5575 ..Default::default()
5576 })
5577 }
5578 }
5579 }
5580
5581 struct TestEchoTool;
5583
5584 #[async_trait]
5585 impl crate::types::Tool for TestEchoTool {
5586 fn name(&self) -> &'static str {
5587 "echo"
5588 }
5589 fn description(&self) -> &'static str {
5590 "Echo back the message"
5591 }
5592 fn input_schema(&self) -> Option<serde_json::Value> {
5593 Some(serde_json::json!({
5594 "type": "object",
5595 "required": ["message"],
5596 "properties": {
5597 "message": {"type": "string"}
5598 }
5599 }))
5600 }
5601 async fn execute(
5602 &self,
5603 input: &str,
5604 _ctx: &ToolContext,
5605 ) -> anyhow::Result<crate::types::ToolResult> {
5606 let v: serde_json::Value =
5607 serde_json::from_str(input).unwrap_or(Value::String(input.to_string()));
5608 let msg = v.get("message").and_then(|m| m.as_str()).unwrap_or(input);
5609 Ok(crate::types::ToolResult {
5610 output: format!("echoed:{msg}"),
5611 })
5612 }
5613 }
5614
5615 #[tokio::test]
5616 async fn text_tool_extraction_dispatches_through_agent_loop() {
5617 let agent = Agent::new(
5618 AgentConfig {
5619 model_supports_tool_use: true,
5620 max_tool_iterations: 5,
5621 request_timeout_ms: 10_000,
5622 ..AgentConfig::default()
5623 },
5624 Box::new(TextToolCallProvider::new()),
5625 Box::new(TestMemory::default()),
5626 vec![Box::new(TestEchoTool)],
5627 );
5628
5629 let response = agent
5630 .respond(
5631 UserMessage {
5632 text: "echo hello".to_string(),
5633 },
5634 &test_ctx(),
5635 )
5636 .await
5637 .expect("should succeed with text-extracted tool call");
5638
5639 assert!(
5640 response.text.contains("echoed:hello from text"),
5641 "expected tool result in response, got: {}",
5642 response.text
5643 );
5644 }
5645
5646 #[tokio::test]
5647 async fn tool_execution_timeout_fires() {
5648 let memory = TestMemory::default();
5649 let agent = Agent::new(
5650 AgentConfig {
5651 max_tool_iterations: 3,
5652 tool_timeout_ms: 50, ..AgentConfig::default()
5654 },
5655 Box::new(ScriptedProvider::new(vec!["done"])),
5656 Box::new(memory),
5657 vec![Box::new(SlowTool)],
5658 );
5659
5660 let result = agent
5661 .respond(
5662 UserMessage {
5663 text: "tool:slow go".to_string(),
5664 },
5665 &test_ctx(),
5666 )
5667 .await;
5668
5669 let err = result.expect_err("should time out");
5670 let msg = format!("{err}");
5671 assert!(
5672 msg.contains("timed out"),
5673 "expected timeout error, got: {msg}"
5674 );
5675 }
5676
5677 #[tokio::test]
5678 async fn tool_execution_no_timeout_when_disabled() {
5679 let memory = TestMemory::default();
5680 let agent = Agent::new(
5681 AgentConfig {
5682 max_tool_iterations: 3,
5683 tool_timeout_ms: 0, ..AgentConfig::default()
5685 },
5686 Box::new(ScriptedProvider::new(vec!["done"])),
5687 Box::new(memory),
5688 vec![Box::new(EchoTool)],
5689 );
5690
5691 let response = agent
5692 .respond(
5693 UserMessage {
5694 text: "tool:echo hello".to_string(),
5695 },
5696 &test_ctx(),
5697 )
5698 .await
5699 .expect("should succeed with timeout disabled");
5700
5701 assert_eq!(response.text, "done");
5704 }
5705
5706 #[tokio::test]
5709 async fn tool_execute_span_is_created_for_tool_call() {
5710 let subscriber = tracing_subscriber::fmt()
5713 .with_max_level(tracing::Level::TRACE)
5714 .with_writer(std::io::sink)
5715 .finish();
5716 let _guard = tracing::subscriber::set_default(subscriber);
5717
5718 let agent = Agent::new(
5719 AgentConfig {
5720 tool_timeout_ms: 5000,
5721 ..AgentConfig::default()
5722 },
5723 Box::new(ScriptedProvider::new(vec!["done"])),
5724 Box::new(TestMemory::default()),
5725 vec![Box::new(EchoTool)],
5726 );
5727
5728 let response = agent
5730 .respond(
5731 UserMessage {
5732 text: "tool:echo test-span".to_string(),
5733 },
5734 &test_ctx(),
5735 )
5736 .await
5737 .expect("should succeed");
5738
5739 assert_eq!(response.text, "done");
5742 }
5743
5744 #[tokio::test]
5745 async fn tool_execute_span_works_without_timeout() {
5746 let agent = Agent::new(
5748 AgentConfig {
5749 tool_timeout_ms: 0, ..AgentConfig::default()
5751 },
5752 Box::new(ScriptedProvider::new(vec!["no-timeout-done"])),
5753 Box::new(TestMemory::default()),
5754 vec![Box::new(EchoTool)],
5755 );
5756
5757 let response = agent
5758 .respond(
5759 UserMessage {
5760 text: "tool:echo span-test".to_string(),
5761 },
5762 &test_ctx(),
5763 )
5764 .await
5765 .expect("should succeed without timeout");
5766
5767 assert_eq!(response.text, "no-timeout-done");
5768 }
5769}