1use crate::common::privacy_helpers::{is_network_tool, resolve_boundary};
2use crate::loop_detection::{LoopDetectionConfig, ToolLoopDetector};
3use crate::security::redaction::redact_text;
4use crate::types::{
5 AgentConfig, AgentError, AssistantMessage, AuditEvent, AuditSink, ConversationMessage,
6 HookEvent, HookFailureMode, HookRiskTier, HookSink, LoopAction, MemoryEntry, MemoryStore,
7 MetricsSink, Provider, ResearchTrigger, StopReason, StreamSink, Tool, ToolContext,
8 ToolDefinition, ToolResultMessage, ToolSelector, ToolSummary, ToolUseRequest, UserMessage,
9};
10use crate::validation::validate_json;
11use serde_json::json;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::{Instant, SystemTime, UNIX_EPOCH};
14use tokio::time::{timeout, Duration};
15use tracing::{info, info_span, instrument, warn, Instrument};
16
17static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(0);
18
19fn cap_prompt(input: &str, max_chars: usize) -> (String, bool) {
20 let len = input.chars().count();
21 if len <= max_chars {
22 return (input.to_string(), false);
23 }
24 let truncated = input
25 .chars()
26 .rev()
27 .take(max_chars)
28 .collect::<Vec<_>>()
29 .into_iter()
30 .rev()
31 .collect::<String>();
32 (truncated, true)
33}
34
35fn build_provider_prompt(
36 current_prompt: &str,
37 recent_memory: &[MemoryEntry],
38 max_prompt_chars: usize,
39 context_summary: Option<&str>,
40) -> (String, bool) {
41 if recent_memory.is_empty() && context_summary.is_none() {
42 return cap_prompt(current_prompt, max_prompt_chars);
43 }
44
45 let mut prompt = String::new();
46
47 if let Some(summary) = context_summary {
48 prompt.push_str("Context summary:\n");
49 prompt.push_str(summary);
50 prompt.push_str("\n\n");
51 }
52
53 if !recent_memory.is_empty() {
54 prompt.push_str("Recent conversation:\n");
55 for entry in recent_memory.iter().rev() {
56 prompt.push_str("- ");
57 prompt.push_str(&entry.role);
58 prompt.push_str(": ");
59 prompt.push_str(&entry.content);
60 prompt.push('\n');
61 }
62 prompt.push('\n');
63 }
64
65 prompt.push_str("Current input:\n");
66 prompt.push_str(current_prompt);
67
68 cap_prompt(&prompt, max_prompt_chars)
69}
70
71fn parse_tool_calls(prompt: &str) -> Vec<(&str, &str)> {
73 prompt
74 .lines()
75 .filter_map(|line| {
76 line.strip_prefix("tool:").map(|rest| {
77 let rest = rest.trim();
78 let mut parts = rest.splitn(2, ' ');
79 let name = parts.next().unwrap_or_default();
80 let input = parts.next().unwrap_or_default();
81 (name, input)
82 })
83 })
84 .collect()
85}
86
87fn extract_tool_call_from_text(
100 text: &str,
101 known_tools: &[ToolDefinition],
102) -> Option<ToolUseRequest> {
103 let json_str = extract_json_block(text).or_else(|| extract_bare_json(text))?;
105 let obj: serde_json::Value = serde_json::from_str(json_str).ok()?;
106 let name = obj.get("name")?.as_str()?;
107
108 if !known_tools.iter().any(|t| t.name == name) {
110 return None;
111 }
112
113 let args = obj
114 .get("arguments")
115 .or_else(|| obj.get("parameters"))
116 .cloned()
117 .unwrap_or(serde_json::Value::Object(Default::default()));
118
119 Some(ToolUseRequest {
120 id: format!("text_extracted_{}", name),
121 name: name.to_string(),
122 input: args,
123 })
124}
125
126fn extract_json_block(text: &str) -> Option<&str> {
128 let start = text.find("```")?;
129 let after_fence = &text[start + 3..];
130 let content_start = after_fence.find('\n')? + 1;
132 let content = &after_fence[content_start..];
133 let end = content.find("```")?;
134 let block = content[..end].trim();
135 if block.starts_with('{') {
136 Some(block)
137 } else {
138 None
139 }
140}
141
142fn extract_bare_json(text: &str) -> Option<&str> {
144 let start = text.find('{')?;
145 let candidate = &text[start..];
146 let mut depth = 0i32;
148 let mut in_string = false;
149 let mut escape_next = false;
150 for (i, ch) in candidate.char_indices() {
151 if escape_next {
152 escape_next = false;
153 continue;
154 }
155 match ch {
156 '\\' if in_string => escape_next = true,
157 '"' => in_string = !in_string,
158 '{' if !in_string => depth += 1,
159 '}' if !in_string => {
160 depth -= 1;
161 if depth == 0 {
162 return Some(&candidate[..=i]);
163 }
164 }
165 _ => {}
166 }
167 }
168 None
169}
170
171fn single_required_string_field(schema: &serde_json::Value) -> Option<String> {
173 let required = schema.get("required")?.as_array()?;
174 if required.len() != 1 {
175 return None;
176 }
177 let field_name = required[0].as_str()?;
178 let properties = schema.get("properties")?.as_object()?;
179 let field_schema = properties.get(field_name)?;
180 if field_schema.get("type")?.as_str()? == "string" {
181 Some(field_name.to_string())
182 } else {
183 None
184 }
185}
186
187fn prepare_tool_input(tool: &dyn Tool, raw_input: &serde_json::Value) -> Result<String, String> {
196 if let Some(s) = raw_input.as_str() {
198 return Ok(s.to_string());
199 }
200
201 if let Some(schema) = tool.input_schema() {
203 if let Err(errors) = validate_json(raw_input, &schema) {
204 return Err(format!(
205 "Invalid input for tool '{}': {}",
206 tool.name(),
207 errors.join("; ")
208 ));
209 }
210
211 if let Some(field) = single_required_string_field(&schema) {
212 if let Some(val) = raw_input.get(&field).and_then(|v| v.as_str()) {
213 return Ok(val.to_string());
214 }
215 }
216 }
217
218 Ok(serde_json::to_string(raw_input).unwrap_or_default())
219}
220
221fn truncate_messages(messages: &mut Vec<ConversationMessage>, max_chars: usize) {
230 let total_chars: usize = messages.iter().map(|m| m.char_count()).sum();
231 if total_chars <= max_chars || messages.len() <= 2 {
232 return;
233 }
234
235 let mut prefix_end = 1; for (i, msg) in messages.iter().enumerate() {
242 if matches!(msg, ConversationMessage::User { .. }) {
243 prefix_end = i + 1;
244 break;
245 }
246 }
247
248 let prefix_cost: usize = messages[..prefix_end].iter().map(|m| m.char_count()).sum();
249 let mut budget = max_chars.saturating_sub(prefix_cost);
250
251 let mut keep_from_end = 0;
255 let mut in_tool_result_run = false;
256 for msg in messages[prefix_end..].iter().rev() {
257 let cost = msg.char_count();
258 let is_tool_result = matches!(msg, ConversationMessage::ToolResult(_));
259 let is_assistant_with_tools = matches!(
260 msg,
261 ConversationMessage::Assistant { tool_calls, .. } if !tool_calls.is_empty()
262 );
263
264 if is_tool_result {
265 in_tool_result_run = true;
266 }
267
268 if in_tool_result_run && is_assistant_with_tools {
272 budget = budget.saturating_sub(cost);
273 keep_from_end += 1;
274 in_tool_result_run = false;
275 continue;
276 }
277
278 if !is_tool_result {
279 in_tool_result_run = false;
280 }
281
282 if cost > budget {
283 break;
284 }
285 budget -= cost;
286 keep_from_end += 1;
287 }
288
289 if keep_from_end == 0 {
290 messages.truncate(prefix_end);
292 return;
293 }
294
295 let split_point = messages.len() - keep_from_end;
296 if split_point <= prefix_end {
297 return;
298 }
299
300 messages.drain(prefix_end..split_point);
301
302 while messages.len() > prefix_end {
307 if matches!(&messages[prefix_end], ConversationMessage::ToolResult(_)) {
308 messages.remove(prefix_end);
309 } else {
310 break;
311 }
312 }
313}
314
315fn memory_to_messages(entries: &[MemoryEntry]) -> Vec<ConversationMessage> {
318 entries
319 .iter()
320 .rev()
321 .map(|entry| {
322 if entry.role == "assistant" {
323 ConversationMessage::Assistant {
324 content: Some(entry.content.clone()),
325 tool_calls: vec![],
326 }
327 } else {
328 ConversationMessage::user(entry.content.clone())
329 }
330 })
331 .collect()
332}
333
334pub struct Agent {
335 config: AgentConfig,
336 provider: Box<dyn Provider>,
337 memory: Box<dyn MemoryStore>,
338 tools: Vec<Box<dyn Tool>>,
339 audit: Option<Box<dyn AuditSink>>,
340 hooks: Option<Box<dyn HookSink>>,
341 metrics: Option<Box<dyn MetricsSink>>,
342 loop_detection_config: Option<LoopDetectionConfig>,
343 tool_selector: Option<Box<dyn ToolSelector>>,
344}
345
346impl Agent {
347 pub fn new(
348 config: AgentConfig,
349 provider: Box<dyn Provider>,
350 memory: Box<dyn MemoryStore>,
351 tools: Vec<Box<dyn Tool>>,
352 ) -> Self {
353 Self {
354 config,
355 provider,
356 memory,
357 tools,
358 audit: None,
359 hooks: None,
360 metrics: None,
361 loop_detection_config: None,
362 tool_selector: None,
363 }
364 }
365
366 pub fn with_loop_detection(mut self, config: LoopDetectionConfig) -> Self {
368 self.loop_detection_config = Some(config);
369 self
370 }
371
372 pub fn with_audit(mut self, audit: Box<dyn AuditSink>) -> Self {
373 self.audit = Some(audit);
374 self
375 }
376
377 pub fn with_hooks(mut self, hooks: Box<dyn HookSink>) -> Self {
378 self.hooks = Some(hooks);
379 self
380 }
381
382 pub fn with_metrics(mut self, metrics: Box<dyn MetricsSink>) -> Self {
383 self.metrics = Some(metrics);
384 self
385 }
386
387 pub fn with_tool_selector(mut self, selector: Box<dyn ToolSelector>) -> Self {
388 self.tool_selector = Some(selector);
389 self
390 }
391
392 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
394 self.tools.push(tool);
395 }
396
397 fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
399 self.tools
400 .iter()
401 .filter_map(|tool| ToolDefinition::from_tool(&**tool))
402 .collect()
403 }
404
405 fn has_tool_definitions(&self) -> bool {
407 self.tools.iter().any(|t| t.input_schema().is_some())
408 }
409
410 async fn audit(&self, stage: &str, detail: serde_json::Value) {
411 if let Some(sink) = &self.audit {
412 let _ = sink
413 .record(AuditEvent {
414 stage: stage.to_string(),
415 detail,
416 })
417 .await;
418 }
419 }
420
421 fn next_request_id() -> String {
422 let ts_ms = SystemTime::now()
423 .duration_since(UNIX_EPOCH)
424 .unwrap_or_default()
425 .as_millis();
426 let seq = REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed);
427 format!("req-{ts_ms}-{seq}")
428 }
429
430 fn hook_risk_tier(stage: &str) -> HookRiskTier {
431 if matches!(
432 stage,
433 "before_tool_call" | "after_tool_call" | "before_plugin_call" | "after_plugin_call"
434 ) {
435 HookRiskTier::High
436 } else if matches!(
437 stage,
438 "before_provider_call"
439 | "after_provider_call"
440 | "before_memory_write"
441 | "after_memory_write"
442 | "before_run"
443 | "after_run"
444 ) {
445 HookRiskTier::Medium
446 } else {
447 HookRiskTier::Low
448 }
449 }
450
451 fn hook_failure_mode_for_stage(&self, stage: &str) -> HookFailureMode {
452 if self.config.hooks.fail_closed {
453 return HookFailureMode::Block;
454 }
455
456 match Self::hook_risk_tier(stage) {
457 HookRiskTier::Low => self.config.hooks.low_tier_mode,
458 HookRiskTier::Medium => self.config.hooks.medium_tier_mode,
459 HookRiskTier::High => self.config.hooks.high_tier_mode,
460 }
461 }
462
463 async fn hook(&self, stage: &str, detail: serde_json::Value) -> Result<(), AgentError> {
464 if !self.config.hooks.enabled {
465 return Ok(());
466 }
467 let Some(sink) = &self.hooks else {
468 return Ok(());
469 };
470
471 let event = HookEvent {
472 stage: stage.to_string(),
473 detail,
474 };
475 let hook_call = sink.record(event);
476 let mode = self.hook_failure_mode_for_stage(stage);
477 let tier = Self::hook_risk_tier(stage);
478 match timeout(
479 Duration::from_millis(self.config.hooks.timeout_ms),
480 hook_call,
481 )
482 .await
483 {
484 Ok(Ok(())) => Ok(()),
485 Ok(Err(err)) => {
486 let redacted = redact_text(&err.to_string());
487 match mode {
488 HookFailureMode::Block => Err(AgentError::Hook {
489 stage: stage.to_string(),
490 source: err,
491 }),
492 HookFailureMode::Warn => {
493 warn!(
494 stage = stage,
495 tier = ?tier,
496 mode = "warn",
497 "hook error (continuing): {redacted}"
498 );
499 self.audit(
500 "hook_error_warn",
501 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
502 )
503 .await;
504 Ok(())
505 }
506 HookFailureMode::Ignore => {
507 self.audit(
508 "hook_error_ignored",
509 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
510 )
511 .await;
512 Ok(())
513 }
514 }
515 }
516 Err(_) => match mode {
517 HookFailureMode::Block => Err(AgentError::Hook {
518 stage: stage.to_string(),
519 source: anyhow::anyhow!(
520 "hook execution timed out after {} ms",
521 self.config.hooks.timeout_ms
522 ),
523 }),
524 HookFailureMode::Warn => {
525 warn!(
526 stage = stage,
527 tier = ?tier,
528 mode = "warn",
529 timeout_ms = self.config.hooks.timeout_ms,
530 "hook timeout (continuing)"
531 );
532 self.audit(
533 "hook_timeout_warn",
534 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
535 )
536 .await;
537 Ok(())
538 }
539 HookFailureMode::Ignore => {
540 self.audit(
541 "hook_timeout_ignored",
542 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
543 )
544 .await;
545 Ok(())
546 }
547 },
548 }
549 }
550
551 fn increment_counter(&self, name: &'static str) {
552 if let Some(metrics) = &self.metrics {
553 metrics.increment_counter(name, 1);
554 }
555 }
556
557 fn observe_histogram(&self, name: &'static str, value: f64) {
558 if let Some(metrics) = &self.metrics {
559 metrics.observe_histogram(name, value);
560 }
561 }
562
563 #[instrument(skip(self, tool, tool_input, ctx), fields(tool = tool_name, request_id, iteration))]
564 async fn execute_tool(
565 &self,
566 tool: &dyn Tool,
567 tool_name: &str,
568 tool_input: &str,
569 ctx: &ToolContext,
570 request_id: &str,
571 iteration: usize,
572 ) -> Result<crate::types::ToolResult, AgentError> {
573 let tool_specific = self
576 .config
577 .tool_boundaries
578 .get(tool_name)
579 .map(|s| s.as_str())
580 .unwrap_or("");
581 let resolved = resolve_boundary(tool_specific, &self.config.privacy_boundary);
582 if resolved == "local_only" && is_network_tool(tool_name) {
583 return Err(AgentError::Tool {
584 tool: tool_name.to_string(),
585 source: anyhow::anyhow!(
586 "tool '{}' requires network access but privacy boundary is 'local_only'",
587 tool_name
588 ),
589 });
590 }
591
592 let is_plugin_call = tool_name.starts_with("plugin:");
593 self.hook(
594 "before_tool_call",
595 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
596 )
597 .await?;
598 if is_plugin_call {
599 self.hook(
600 "before_plugin_call",
601 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name}),
602 )
603 .await?;
604 }
605 self.audit(
606 "tool_execute_start",
607 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
608 )
609 .await;
610 let tool_started = Instant::now();
611 let tool_timeout_ms = self.config.tool_timeout_ms;
612 let make_tool_span = || {
613 info_span!(
614 "tool_execute",
615 tool_name = %tool_name,
616 request_id = %request_id,
617 iteration = iteration,
618 )
619 };
620 let result = if tool_timeout_ms > 0 {
621 match timeout(
622 Duration::from_millis(tool_timeout_ms),
623 tool.execute(tool_input, ctx).instrument(make_tool_span()),
624 )
625 .await
626 {
627 Ok(Ok(result)) => result,
628 Ok(Err(source)) => {
629 self.observe_histogram(
630 "tool_latency_ms",
631 tool_started.elapsed().as_secs_f64() * 1000.0,
632 );
633 self.increment_counter("tool_errors_total");
634 return Err(AgentError::Tool {
635 tool: tool_name.to_string(),
636 source,
637 });
638 }
639 Err(_elapsed) => {
640 self.observe_histogram(
641 "tool_latency_ms",
642 tool_started.elapsed().as_secs_f64() * 1000.0,
643 );
644 self.increment_counter("tool_errors_total");
645 self.increment_counter("tool_timeouts_total");
646 warn!(
647 tool = %tool_name,
648 timeout_ms = tool_timeout_ms,
649 "tool execution timed out"
650 );
651 return Err(AgentError::Tool {
652 tool: tool_name.to_string(),
653 source: anyhow::anyhow!(
654 "tool '{}' timed out after {}ms",
655 tool_name,
656 tool_timeout_ms
657 ),
658 });
659 }
660 }
661 } else {
662 match tool
663 .execute(tool_input, ctx)
664 .instrument(make_tool_span())
665 .await
666 {
667 Ok(result) => result,
668 Err(source) => {
669 self.observe_histogram(
670 "tool_latency_ms",
671 tool_started.elapsed().as_secs_f64() * 1000.0,
672 );
673 self.increment_counter("tool_errors_total");
674 return Err(AgentError::Tool {
675 tool: tool_name.to_string(),
676 source,
677 });
678 }
679 }
680 };
681 self.observe_histogram(
682 "tool_latency_ms",
683 tool_started.elapsed().as_secs_f64() * 1000.0,
684 );
685 self.audit(
686 "tool_execute_success",
687 json!({
688 "request_id": request_id,
689 "iteration": iteration,
690 "tool_name": tool_name,
691 "tool_output_len": result.output.len(),
692 "duration_ms": tool_started.elapsed().as_millis(),
693 }),
694 )
695 .await;
696 info!(
697 request_id = %request_id,
698 stage = "tool",
699 tool_name = %tool_name,
700 duration_ms = %tool_started.elapsed().as_millis(),
701 "tool execution finished"
702 );
703 self.hook(
704 "after_tool_call",
705 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name, "status": "ok"}),
706 )
707 .await?;
708 if is_plugin_call {
709 self.hook(
710 "after_plugin_call",
711 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name, "status": "ok"}),
712 )
713 .await?;
714 }
715 Ok(result)
716 }
717
718 async fn maybe_summarize_context(&self, entries: &[MemoryEntry]) -> Option<String> {
722 let cfg = &self.config.summarization;
723 if !cfg.enabled || entries.len() < cfg.min_entries_for_summarization {
724 return None;
725 }
726
727 let keep = cfg.keep_recent.min(entries.len());
728 let older = &entries[keep..];
729 if older.is_empty() {
730 return None;
731 }
732
733 let mut text = String::new();
735 for entry in older.iter().rev() {
736 text.push_str(&entry.role);
737 text.push_str(": ");
738 text.push_str(&entry.content);
739 text.push('\n');
740 }
741
742 let summarization_prompt = format!(
743 "Summarize the following conversation history in {} characters or less. \
744 Preserve key facts, decisions, and context that would be important for \
745 continuing the conversation. Be concise.\n\n{}",
746 cfg.max_summary_chars, text
747 );
748
749 let summary_timeout = Duration::from_secs(10);
751 match timeout(
752 summary_timeout,
753 self.provider.complete(&summarization_prompt),
754 )
755 .await
756 {
757 Ok(Ok(chat_result)) => {
758 let text = chat_result.output_text;
759 let summary = if text.chars().count() > cfg.max_summary_chars {
760 text.chars().take(cfg.max_summary_chars).collect::<String>()
761 } else {
762 text
763 };
764 Some(summary)
765 }
766 Ok(Err(e)) => {
767 warn!("context summarization failed, falling back to truncation: {e}");
768 None
769 }
770 Err(_) => {
771 warn!("context summarization timed out, falling back to truncation");
772 None
773 }
774 }
775 }
776
777 async fn call_provider_with_context(
778 &self,
779 prompt: &str,
780 request_id: &str,
781 stream_sink: Option<StreamSink>,
782 source_channel: Option<&str>,
783 ) -> Result<String, AgentError> {
784 let recent_memory = self
785 .memory
786 .recent_for_boundary(
787 self.config.memory_window_size,
788 &self.config.privacy_boundary,
789 source_channel,
790 )
791 .await
792 .map_err(|source| AgentError::Memory { source })?;
793 self.audit(
794 "memory_recent_loaded",
795 json!({"request_id": request_id, "items": recent_memory.len()}),
796 )
797 .await;
798 let context_summary = self.maybe_summarize_context(&recent_memory).await;
800 let (effective_memory, summary_ref) = if let Some(ref summary) = context_summary {
801 let keep = self
803 .config
804 .summarization
805 .keep_recent
806 .min(recent_memory.len());
807 (&recent_memory[..keep], Some(summary.as_str()))
808 } else {
809 (recent_memory.as_slice(), None)
810 };
811
812 let (provider_prompt, prompt_truncated) = build_provider_prompt(
813 prompt,
814 effective_memory,
815 self.config.max_prompt_chars,
816 summary_ref,
817 );
818 if prompt_truncated {
819 self.audit(
820 "provider_prompt_truncated",
821 json!({
822 "request_id": request_id,
823 "max_prompt_chars": self.config.max_prompt_chars,
824 }),
825 )
826 .await;
827 }
828 self.hook(
829 "before_provider_call",
830 json!({
831 "request_id": request_id,
832 "prompt_len": provider_prompt.len(),
833 "memory_items": recent_memory.len(),
834 "prompt_truncated": prompt_truncated
835 }),
836 )
837 .await?;
838 self.audit(
839 "provider_call_start",
840 json!({
841 "request_id": request_id,
842 "prompt_len": provider_prompt.len(),
843 "memory_items": recent_memory.len(),
844 "prompt_truncated": prompt_truncated
845 }),
846 )
847 .await;
848 let provider_started = Instant::now();
849 let provider_result = if let Some(sink) = stream_sink {
850 self.provider
851 .complete_streaming(&provider_prompt, sink)
852 .await
853 } else {
854 self.provider
855 .complete_with_reasoning(&provider_prompt, &self.config.reasoning)
856 .await
857 };
858 let completion = match provider_result {
859 Ok(result) => result,
860 Err(source) => {
861 self.observe_histogram(
862 "provider_latency_ms",
863 provider_started.elapsed().as_secs_f64() * 1000.0,
864 );
865 self.increment_counter("provider_errors_total");
866 return Err(AgentError::Provider { source });
867 }
868 };
869 self.observe_histogram(
870 "provider_latency_ms",
871 provider_started.elapsed().as_secs_f64() * 1000.0,
872 );
873 self.audit(
874 "provider_call_success",
875 json!({
876 "request_id": request_id,
877 "response_len": completion.output_text.len(),
878 "duration_ms": provider_started.elapsed().as_millis(),
879 }),
880 )
881 .await;
882 info!(
883 request_id = %request_id,
884 stage = "provider",
885 duration_ms = %provider_started.elapsed().as_millis(),
886 "provider call finished"
887 );
888 self.hook(
889 "after_provider_call",
890 json!({"request_id": request_id, "response_len": completion.output_text.len(), "status": "ok"}),
891 )
892 .await?;
893 Ok(completion.output_text)
894 }
895
896 async fn write_to_memory(
897 &self,
898 role: &str,
899 content: &str,
900 request_id: &str,
901 source_channel: Option<&str>,
902 conversation_id: &str,
903 agent_id: Option<&str>,
904 ) -> Result<(), AgentError> {
905 self.hook(
906 "before_memory_write",
907 json!({"request_id": request_id, "role": role}),
908 )
909 .await?;
910 self.memory
911 .append(MemoryEntry {
912 role: role.to_string(),
913 content: content.to_string(),
914 privacy_boundary: self.config.privacy_boundary.clone(),
915 source_channel: source_channel.map(String::from),
916 conversation_id: conversation_id.to_string(),
917 created_at: None,
918 expires_at: None,
919 org_id: String::new(),
920 agent_id: agent_id.unwrap_or("").to_string(),
921 embedding: None,
922 })
923 .await
924 .map_err(|source| AgentError::Memory { source })?;
925 self.hook(
926 "after_memory_write",
927 json!({"request_id": request_id, "role": role}),
928 )
929 .await?;
930 self.audit(
931 &format!("memory_append_{role}"),
932 json!({"request_id": request_id}),
933 )
934 .await;
935 Ok(())
936 }
937
938 fn should_research(&self, user_text: &str) -> bool {
939 if !self.config.research.enabled {
940 return false;
941 }
942 match self.config.research.trigger {
943 ResearchTrigger::Never => false,
944 ResearchTrigger::Always => true,
945 ResearchTrigger::Keywords => {
946 let lower = user_text.to_lowercase();
947 self.config
948 .research
949 .keywords
950 .iter()
951 .any(|kw| lower.contains(&kw.to_lowercase()))
952 }
953 ResearchTrigger::Length => user_text.len() >= self.config.research.min_message_length,
954 ResearchTrigger::Question => user_text.trim_end().ends_with('?'),
955 }
956 }
957
958 async fn run_research_phase(
959 &self,
960 user_text: &str,
961 ctx: &ToolContext,
962 request_id: &str,
963 ) -> Result<String, AgentError> {
964 self.audit(
965 "research_phase_start",
966 json!({
967 "request_id": request_id,
968 "max_iterations": self.config.research.max_iterations,
969 }),
970 )
971 .await;
972 self.increment_counter("research_phase_started");
973
974 let research_prompt = format!(
975 "You are in RESEARCH mode. The user asked: \"{user_text}\"\n\
976 Gather relevant information using available tools. \
977 Respond with tool: calls to collect data. \
978 When done gathering, summarize your findings without a tool: prefix."
979 );
980 let mut prompt = self
981 .call_provider_with_context(
982 &research_prompt,
983 request_id,
984 None,
985 ctx.source_channel.as_deref(),
986 )
987 .await?;
988 let mut findings: Vec<String> = Vec::new();
989
990 for iteration in 0..self.config.research.max_iterations {
991 if !prompt.starts_with("tool:") {
992 findings.push(prompt.clone());
993 break;
994 }
995
996 let calls = parse_tool_calls(&prompt);
997 if calls.is_empty() {
998 break;
999 }
1000
1001 let (name, input) = calls[0];
1002 if let Some(tool) = self.tools.iter().find(|t| t.name() == name) {
1003 let result = self
1004 .execute_tool(&**tool, name, input, ctx, request_id, iteration)
1005 .await?;
1006 findings.push(format!("{name}: {}", result.output));
1007
1008 if self.config.research.show_progress {
1009 info!(iteration, tool = name, "research phase: tool executed");
1010 self.audit(
1011 "research_phase_iteration",
1012 json!({
1013 "request_id": request_id,
1014 "iteration": iteration,
1015 "tool_name": name,
1016 }),
1017 )
1018 .await;
1019 }
1020
1021 let next_prompt = format!(
1022 "Research iteration {iteration}: tool `{name}` returned: {}\n\
1023 Continue researching or summarize findings.",
1024 result.output
1025 );
1026 prompt = self
1027 .call_provider_with_context(
1028 &next_prompt,
1029 request_id,
1030 None,
1031 ctx.source_channel.as_deref(),
1032 )
1033 .await?;
1034 } else {
1035 break;
1036 }
1037 }
1038
1039 self.audit(
1040 "research_phase_complete",
1041 json!({
1042 "request_id": request_id,
1043 "findings_count": findings.len(),
1044 }),
1045 )
1046 .await;
1047 self.increment_counter("research_phase_completed");
1048
1049 Ok(findings.join("\n"))
1050 }
1051
1052 #[instrument(
1055 skip(self, user_text, research_context, ctx, stream_sink),
1056 fields(request_id)
1057 )]
1058 async fn respond_with_tools(
1059 &self,
1060 request_id: &str,
1061 user_text: &str,
1062 research_context: &str,
1063 ctx: &ToolContext,
1064 stream_sink: Option<StreamSink>,
1065 ) -> Result<AssistantMessage, AgentError> {
1066 self.audit(
1067 "structured_tool_use_start",
1068 json!({
1069 "request_id": request_id,
1070 "max_tool_iterations": self.config.max_tool_iterations,
1071 }),
1072 )
1073 .await;
1074
1075 let all_tool_definitions = self.build_tool_definitions();
1076
1077 let tool_definitions = if let Some(ref selector) = self.tool_selector {
1079 let summaries: Vec<ToolSummary> = all_tool_definitions
1080 .iter()
1081 .map(|td| ToolSummary {
1082 name: td.name.clone(),
1083 description: td.description.clone(),
1084 })
1085 .collect();
1086 match selector.select(user_text, &summaries).await {
1087 Ok(selected_names) => {
1088 let selected: Vec<ToolDefinition> = all_tool_definitions
1089 .iter()
1090 .filter(|td| selected_names.contains(&td.name))
1091 .cloned()
1092 .collect();
1093 info!(
1094 total = all_tool_definitions.len(),
1095 selected = selected.len(),
1096 mode = %self.config.tool_selection,
1097 "tool selection applied"
1098 );
1099 selected
1100 }
1101 Err(e) => {
1102 warn!(error = %e, "tool selection failed, falling back to all tools");
1103 all_tool_definitions
1104 }
1105 }
1106 } else {
1107 all_tool_definitions
1108 };
1109
1110 let recent_memory = if let Some(ref cid) = ctx.conversation_id {
1112 self.memory
1113 .recent_for_conversation(cid, self.config.memory_window_size)
1114 .await
1115 .map_err(|source| AgentError::Memory { source })?
1116 } else {
1117 self.memory
1118 .recent_for_boundary(
1119 self.config.memory_window_size,
1120 &self.config.privacy_boundary,
1121 ctx.source_channel.as_deref(),
1122 )
1123 .await
1124 .map_err(|source| AgentError::Memory { source })?
1125 };
1126 self.audit(
1127 "memory_recent_loaded",
1128 json!({"request_id": request_id, "items": recent_memory.len()}),
1129 )
1130 .await;
1131
1132 let mut messages: Vec<ConversationMessage> = Vec::new();
1133
1134 if let Some(ref sp) = self.config.system_prompt {
1136 messages.push(ConversationMessage::System {
1137 content: sp.clone(),
1138 });
1139 }
1140
1141 messages.extend(memory_to_messages(&recent_memory));
1142
1143 if !research_context.is_empty() {
1144 messages.push(ConversationMessage::user(format!(
1145 "Research findings:\n{research_context}",
1146 )));
1147 }
1148
1149 messages.push(ConversationMessage::user(user_text.to_string()));
1150
1151 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1152 let mut failure_streak: usize = 0;
1153 let mut loop_detector = self
1154 .loop_detection_config
1155 .as_ref()
1156 .map(|cfg| ToolLoopDetector::new(cfg.clone()));
1157 let mut restricted_tools: Vec<String> = Vec::new();
1158
1159 for iteration in 0..self.config.max_tool_iterations {
1160 if ctx.is_cancelled() {
1162 warn!(request_id = %request_id, "agent execution cancelled");
1163 self.audit(
1164 "execution_cancelled",
1165 json!({"request_id": request_id, "iteration": iteration}),
1166 )
1167 .await;
1168 return Ok(AssistantMessage {
1169 text: "[Execution cancelled]".to_string(),
1170 });
1171 }
1172
1173 let last_tool_result_chars = messages.last().and_then(|m| {
1178 if matches!(m, ConversationMessage::ToolResult(_)) {
1179 Some(m.char_count())
1180 } else {
1181 None
1182 }
1183 });
1184 let pre_truncate_len = messages.len();
1185 truncate_messages(&mut messages, self.config.max_prompt_chars);
1186 if let Some(result_chars) = last_tool_result_chars {
1187 let still_present = messages
1188 .last()
1189 .is_some_and(|m| matches!(m, ConversationMessage::ToolResult(_)));
1190 if !still_present || messages.len() < pre_truncate_len {
1191 warn!(
1192 request_id = %request_id,
1193 iteration = iteration,
1194 tool_result_chars = result_chars,
1195 max_prompt_chars = self.config.max_prompt_chars,
1196 "tool result truncated from context: tool result ({} chars) exceeds \
1197 available budget within max_prompt_chars={}; the model will likely \
1198 call the same tool again — raise max_prompt_chars in your config",
1199 result_chars,
1200 self.config.max_prompt_chars,
1201 );
1202 }
1203 }
1204
1205 self.hook(
1206 "before_provider_call",
1207 json!({
1208 "request_id": request_id,
1209 "iteration": iteration,
1210 "message_count": messages.len(),
1211 "tool_count": tool_definitions.len(),
1212 }),
1213 )
1214 .await?;
1215 self.audit(
1216 "provider_call_start",
1217 json!({
1218 "request_id": request_id,
1219 "iteration": iteration,
1220 "message_count": messages.len(),
1221 }),
1222 )
1223 .await;
1224
1225 let effective_tools: Vec<ToolDefinition> = if restricted_tools.is_empty() {
1227 tool_definitions.clone()
1228 } else {
1229 tool_definitions
1230 .iter()
1231 .filter(|td| !restricted_tools.contains(&td.name))
1232 .cloned()
1233 .collect()
1234 };
1235
1236 let provider_span = info_span!(
1237 "provider_call",
1238 request_id = %request_id,
1239 iteration = iteration,
1240 tool_count = effective_tools.len(),
1241 );
1242 let _provider_guard = provider_span.enter();
1243 let provider_started = Instant::now();
1244 let provider_result = if let Some(ref sink) = stream_sink {
1245 self.provider
1246 .complete_streaming_with_tools(
1247 &messages,
1248 &effective_tools,
1249 &self.config.reasoning,
1250 sink.clone(),
1251 )
1252 .await
1253 } else {
1254 self.provider
1255 .complete_with_tools(&messages, &effective_tools, &self.config.reasoning)
1256 .await
1257 };
1258 let chat_result = match provider_result {
1259 Ok(result) => result,
1260 Err(source) => {
1261 self.observe_histogram(
1262 "provider_latency_ms",
1263 provider_started.elapsed().as_secs_f64() * 1000.0,
1264 );
1265 self.increment_counter("provider_errors_total");
1266 return Err(AgentError::Provider { source });
1267 }
1268 };
1269 self.observe_histogram(
1270 "provider_latency_ms",
1271 provider_started.elapsed().as_secs_f64() * 1000.0,
1272 );
1273
1274 let iter_tokens = chat_result.input_tokens + chat_result.output_tokens;
1276 if iter_tokens > 0 {
1277 ctx.add_tokens(iter_tokens);
1278 }
1279
1280 if let Some(ref calc) = self.config.cost_calculator {
1282 let cost = calc(chat_result.input_tokens, chat_result.output_tokens);
1283 if cost > 0 {
1284 ctx.add_cost(cost);
1285 }
1286 }
1287
1288 if let Some(reason) = ctx.budget_exceeded() {
1290 warn!(
1291 request_id = %request_id,
1292 iteration = iteration,
1293 reason = %reason,
1294 "budget exceeded — force-completing run"
1295 );
1296 return Err(AgentError::BudgetExceeded { reason });
1297 }
1298
1299 info!(
1300 request_id = %request_id,
1301 iteration = iteration,
1302 stop_reason = ?chat_result.stop_reason,
1303 tool_calls = chat_result.tool_calls.len(),
1304 tokens_this_call = iter_tokens,
1305 total_tokens = ctx.current_tokens(),
1306 cost_microdollars = ctx.current_cost(),
1307 "structured provider call finished"
1308 );
1309 self.hook(
1310 "after_provider_call",
1311 json!({
1312 "request_id": request_id,
1313 "iteration": iteration,
1314 "response_len": chat_result.output_text.len(),
1315 "tool_calls": chat_result.tool_calls.len(),
1316 "status": "ok",
1317 }),
1318 )
1319 .await?;
1320
1321 let mut chat_result = chat_result;
1324 if chat_result.tool_calls.is_empty()
1325 || chat_result.stop_reason == Some(StopReason::EndTurn)
1326 {
1327 if chat_result.tool_calls.is_empty() && !chat_result.output_text.is_empty() {
1331 if let Some(extracted) =
1332 extract_tool_call_from_text(&chat_result.output_text, &effective_tools)
1333 {
1334 info!(
1335 request_id = %request_id,
1336 tool = %extracted.name,
1337 "extracted tool call from text output (local model fallback)"
1338 );
1339 chat_result.tool_calls = vec![extracted];
1340 chat_result.stop_reason = Some(StopReason::ToolUse);
1341 }
1343 }
1344
1345 if chat_result.tool_calls.is_empty() {
1347 let response_text = chat_result.output_text;
1348 self.write_to_memory(
1349 "assistant",
1350 &response_text,
1351 request_id,
1352 ctx.source_channel.as_deref(),
1353 ctx.conversation_id.as_deref().unwrap_or(""),
1354 ctx.agent_id.as_deref(),
1355 )
1356 .await?;
1357 self.audit("respond_success", json!({"request_id": request_id}))
1358 .await;
1359 self.hook(
1360 "before_response_emit",
1361 json!({"request_id": request_id, "response_len": response_text.len()}),
1362 )
1363 .await?;
1364 let response = AssistantMessage {
1365 text: response_text,
1366 };
1367 self.hook(
1368 "after_response_emit",
1369 json!({"request_id": request_id, "response_len": response.text.len()}),
1370 )
1371 .await?;
1372 return Ok(response);
1373 }
1374 }
1375
1376 messages.push(ConversationMessage::Assistant {
1378 content: if chat_result.output_text.is_empty() {
1379 None
1380 } else {
1381 Some(chat_result.output_text.clone())
1382 },
1383 tool_calls: chat_result.tool_calls.clone(),
1384 });
1385
1386 let tool_calls = &chat_result.tool_calls;
1388 let has_gated = tool_calls
1389 .iter()
1390 .any(|tc| self.config.gated_tools.contains(&tc.name));
1391 let use_parallel = self.config.parallel_tools && tool_calls.len() > 1 && !has_gated;
1392
1393 let mut tool_results: Vec<ToolResultMessage> = Vec::new();
1394
1395 if use_parallel {
1396 let futs: Vec<_> = tool_calls
1397 .iter()
1398 .map(|tc| {
1399 let tool = self.tools.iter().find(|t| t.name() == tc.name);
1400 async move {
1401 match tool {
1402 Some(tool) => {
1403 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1404 Ok(s) => s,
1405 Err(validation_err) => {
1406 return (
1407 tc.name.clone(),
1408 String::new(),
1409 ToolResultMessage {
1410 tool_use_id: tc.id.clone(),
1411 content: validation_err,
1412 is_error: true,
1413 },
1414 );
1415 }
1416 };
1417 match tool.execute(&input_str, ctx).await {
1418 Ok(result) => (
1419 tc.name.clone(),
1420 input_str,
1421 ToolResultMessage {
1422 tool_use_id: tc.id.clone(),
1423 content: result.output,
1424 is_error: false,
1425 },
1426 ),
1427 Err(e) => (
1428 tc.name.clone(),
1429 input_str,
1430 ToolResultMessage {
1431 tool_use_id: tc.id.clone(),
1432 content: format!("Error: {e}"),
1433 is_error: true,
1434 },
1435 ),
1436 }
1437 }
1438 None => (
1439 tc.name.clone(),
1440 String::new(),
1441 ToolResultMessage {
1442 tool_use_id: tc.id.clone(),
1443 content: format!("Tool '{}' not found", tc.name),
1444 is_error: true,
1445 },
1446 ),
1447 }
1448 }
1449 })
1450 .collect();
1451 let results = futures_util::future::join_all(futs).await;
1452 for (name, input_str, result_msg) in results {
1453 if result_msg.is_error {
1454 failure_streak += 1;
1455 } else {
1456 failure_streak = 0;
1457 tool_history.push((name, input_str, result_msg.content.clone()));
1458 }
1459 tool_results.push(result_msg);
1460 }
1461 } else {
1462 for tc in tool_calls {
1464 let result_msg = match self.tools.iter().find(|t| t.name() == tc.name) {
1465 Some(tool) => {
1466 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1467 Ok(s) => s,
1468 Err(validation_err) => {
1469 failure_streak += 1;
1470 tool_results.push(ToolResultMessage {
1471 tool_use_id: tc.id.clone(),
1472 content: validation_err,
1473 is_error: true,
1474 });
1475 continue;
1476 }
1477 };
1478 self.audit(
1479 "tool_requested",
1480 json!({
1481 "request_id": request_id,
1482 "iteration": iteration,
1483 "tool_name": tc.name,
1484 "tool_input_len": input_str.len(),
1485 }),
1486 )
1487 .await;
1488
1489 match self
1490 .execute_tool(
1491 &**tool, &tc.name, &input_str, ctx, request_id, iteration,
1492 )
1493 .await
1494 {
1495 Ok(result) => {
1496 failure_streak = 0;
1497 tool_history.push((
1498 tc.name.clone(),
1499 input_str,
1500 result.output.clone(),
1501 ));
1502 ToolResultMessage {
1503 tool_use_id: tc.id.clone(),
1504 content: result.output,
1505 is_error: false,
1506 }
1507 }
1508 Err(e) => {
1509 failure_streak += 1;
1510 ToolResultMessage {
1511 tool_use_id: tc.id.clone(),
1512 content: format!("Error: {e}"),
1513 is_error: true,
1514 }
1515 }
1516 }
1517 }
1518 None => {
1519 self.audit(
1520 "tool_not_found",
1521 json!({
1522 "request_id": request_id,
1523 "iteration": iteration,
1524 "tool_name": tc.name,
1525 }),
1526 )
1527 .await;
1528 failure_streak += 1;
1529 ToolResultMessage {
1530 tool_use_id: tc.id.clone(),
1531 content: format!(
1532 "Tool '{}' not found. Available tools: {}",
1533 tc.name,
1534 tool_definitions
1535 .iter()
1536 .map(|d| d.name.as_str())
1537 .collect::<Vec<_>>()
1538 .join(", ")
1539 ),
1540 is_error: true,
1541 }
1542 }
1543 };
1544 tool_results.push(result_msg);
1545 }
1546 }
1547
1548 for result in &tool_results {
1550 messages.push(ConversationMessage::ToolResult(result.clone()));
1551 }
1552
1553 if let Some(ref mut detector) = loop_detector {
1555 let tool_calls_this_iter = &chat_result.tool_calls;
1557 let mut worst_action = LoopAction::Continue;
1558 for tc in tool_calls_this_iter {
1559 let action = detector.check(
1560 &tc.name,
1561 &tc.input,
1562 ctx.current_tokens(),
1563 ctx.current_cost(),
1564 );
1565 if crate::loop_detection::severity(&action)
1566 > crate::loop_detection::severity(&worst_action)
1567 {
1568 worst_action = action;
1569 }
1570 }
1571
1572 match worst_action {
1573 LoopAction::Continue => {}
1574 LoopAction::InjectMessage(ref msg) => {
1575 warn!(
1576 request_id = %request_id,
1577 "tiered loop detection: injecting message"
1578 );
1579 self.audit(
1580 "loop_detection_inject",
1581 json!({
1582 "request_id": request_id,
1583 "iteration": iteration,
1584 "message": msg,
1585 }),
1586 )
1587 .await;
1588 messages.push(ConversationMessage::user(format!("SYSTEM NOTICE: {msg}")));
1589 }
1590 LoopAction::RestrictTools(ref tools) => {
1591 warn!(
1592 request_id = %request_id,
1593 tools = ?tools,
1594 "tiered loop detection: restricting tools"
1595 );
1596 self.audit(
1597 "loop_detection_restrict",
1598 json!({
1599 "request_id": request_id,
1600 "iteration": iteration,
1601 "restricted_tools": tools,
1602 }),
1603 )
1604 .await;
1605 restricted_tools.extend(tools.iter().cloned());
1606 messages.push(ConversationMessage::user(format!(
1607 "SYSTEM NOTICE: The following tools have been temporarily \
1608 restricted due to repetitive usage: {}. Try a different approach.",
1609 tools.join(", ")
1610 )));
1611 }
1612 LoopAction::ForceComplete(ref reason) => {
1613 warn!(
1614 request_id = %request_id,
1615 reason = %reason,
1616 "tiered loop detection: force completing"
1617 );
1618 self.audit(
1619 "loop_detection_force_complete",
1620 json!({
1621 "request_id": request_id,
1622 "iteration": iteration,
1623 "reason": reason,
1624 }),
1625 )
1626 .await;
1627 messages.push(ConversationMessage::user(format!(
1629 "SYSTEM NOTICE: {reason}. Provide your best answer now \
1630 without calling any more tools."
1631 )));
1632 truncate_messages(&mut messages, self.config.max_prompt_chars);
1633 let final_result = if let Some(ref sink) = stream_sink {
1634 self.provider
1635 .complete_streaming_with_tools(
1636 &messages,
1637 &[],
1638 &self.config.reasoning,
1639 sink.clone(),
1640 )
1641 .await
1642 .map_err(|source| AgentError::Provider { source })?
1643 } else {
1644 self.provider
1645 .complete_with_tools(&messages, &[], &self.config.reasoning)
1646 .await
1647 .map_err(|source| AgentError::Provider { source })?
1648 };
1649 let response_text = final_result.output_text;
1650 self.write_to_memory(
1651 "assistant",
1652 &response_text,
1653 request_id,
1654 ctx.source_channel.as_deref(),
1655 ctx.conversation_id.as_deref().unwrap_or(""),
1656 ctx.agent_id.as_deref(),
1657 )
1658 .await?;
1659 return Ok(AssistantMessage {
1660 text: response_text,
1661 });
1662 }
1663 }
1664 }
1665
1666 if self.config.loop_detection_no_progress_threshold > 0 {
1668 let threshold = self.config.loop_detection_no_progress_threshold;
1669 if tool_history.len() >= threshold {
1670 let recent = &tool_history[tool_history.len() - threshold..];
1671 let all_same = recent.iter().all(|entry| {
1672 entry.0 == recent[0].0 && entry.1 == recent[0].1 && entry.2 == recent[0].2
1673 });
1674 if all_same {
1675 warn!(
1676 tool_name = recent[0].0,
1677 threshold, "structured loop detection: no-progress threshold reached"
1678 );
1679 self.audit(
1680 "loop_detection_no_progress",
1681 json!({
1682 "request_id": request_id,
1683 "tool_name": recent[0].0,
1684 "threshold": threshold,
1685 }),
1686 )
1687 .await;
1688 messages.push(ConversationMessage::user(format!(
1689 "SYSTEM NOTICE: Loop detected — the tool `{}` was called \
1690 {} times with identical arguments and output. You are stuck \
1691 in a loop. Stop calling this tool and try a different approach, \
1692 or provide your best answer with the information you already have.",
1693 recent[0].0, threshold
1694 )));
1695 }
1696 }
1697 }
1698
1699 if self.config.loop_detection_ping_pong_cycles > 0 {
1701 let cycles = self.config.loop_detection_ping_pong_cycles;
1702 let needed = cycles * 2;
1703 if tool_history.len() >= needed {
1704 let recent = &tool_history[tool_history.len() - needed..];
1705 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1706 && recent[0].0 != recent[1].0;
1707 if is_ping_pong {
1708 warn!(
1709 tool_a = recent[0].0,
1710 tool_b = recent[1].0,
1711 cycles,
1712 "structured loop detection: ping-pong pattern detected"
1713 );
1714 self.audit(
1715 "loop_detection_ping_pong",
1716 json!({
1717 "request_id": request_id,
1718 "tool_a": recent[0].0,
1719 "tool_b": recent[1].0,
1720 "cycles": cycles,
1721 }),
1722 )
1723 .await;
1724 messages.push(ConversationMessage::user(format!(
1725 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1726 alternating for {} cycles in a ping-pong pattern. Stop this \
1727 alternation and try a different approach, or provide your best \
1728 answer with the information you already have.",
1729 recent[0].0, recent[1].0, cycles
1730 )));
1731 }
1732 }
1733 }
1734
1735 if self.config.loop_detection_failure_streak > 0
1737 && failure_streak >= self.config.loop_detection_failure_streak
1738 {
1739 warn!(
1740 failure_streak,
1741 "structured loop detection: consecutive failure streak"
1742 );
1743 self.audit(
1744 "loop_detection_failure_streak",
1745 json!({
1746 "request_id": request_id,
1747 "streak": failure_streak,
1748 }),
1749 )
1750 .await;
1751 messages.push(ConversationMessage::user(format!(
1752 "SYSTEM NOTICE: {} consecutive tool calls have failed. \
1753 Stop calling tools and provide your best answer with \
1754 the information you already have.",
1755 failure_streak
1756 )));
1757 }
1758 }
1759
1760 self.audit(
1762 "structured_tool_use_max_iterations",
1763 json!({
1764 "request_id": request_id,
1765 "max_tool_iterations": self.config.max_tool_iterations,
1766 }),
1767 )
1768 .await;
1769
1770 messages.push(ConversationMessage::user(
1771 "You have reached the maximum number of tool call iterations. \
1772 Please provide your best answer now without calling any more tools."
1773 .to_string(),
1774 ));
1775 truncate_messages(&mut messages, self.config.max_prompt_chars);
1776
1777 let final_result = if let Some(ref sink) = stream_sink {
1778 self.provider
1779 .complete_streaming_with_tools(&messages, &[], &self.config.reasoning, sink.clone())
1780 .await
1781 .map_err(|source| AgentError::Provider { source })?
1782 } else {
1783 self.provider
1784 .complete_with_tools(&messages, &[], &self.config.reasoning)
1785 .await
1786 .map_err(|source| AgentError::Provider { source })?
1787 };
1788
1789 let response_text = final_result.output_text;
1790 self.write_to_memory(
1791 "assistant",
1792 &response_text,
1793 request_id,
1794 ctx.source_channel.as_deref(),
1795 ctx.conversation_id.as_deref().unwrap_or(""),
1796 ctx.agent_id.as_deref(),
1797 )
1798 .await?;
1799 self.audit("respond_success", json!({"request_id": request_id}))
1800 .await;
1801 self.hook(
1802 "before_response_emit",
1803 json!({"request_id": request_id, "response_len": response_text.len()}),
1804 )
1805 .await?;
1806 let response = AssistantMessage {
1807 text: response_text,
1808 };
1809 self.hook(
1810 "after_response_emit",
1811 json!({"request_id": request_id, "response_len": response.text.len()}),
1812 )
1813 .await?;
1814 Ok(response)
1815 }
1816
1817 pub async fn respond(
1818 &self,
1819 user: UserMessage,
1820 ctx: &ToolContext,
1821 ) -> Result<AssistantMessage, AgentError> {
1822 let request_id = Self::next_request_id();
1823 let span = info_span!(
1824 "agent_run",
1825 request_id = %request_id,
1826 depth = ctx.depth,
1827 conversation_id = ctx.conversation_id.as_deref().unwrap_or(""),
1828 );
1829 self.respond_traced(&request_id, user, ctx, None)
1830 .instrument(span)
1831 .await
1832 }
1833
1834 pub async fn respond_streaming(
1838 &self,
1839 user: UserMessage,
1840 ctx: &ToolContext,
1841 sink: StreamSink,
1842 ) -> Result<AssistantMessage, AgentError> {
1843 let request_id = Self::next_request_id();
1844 let span = info_span!(
1845 "agent_run",
1846 request_id = %request_id,
1847 depth = ctx.depth,
1848 conversation_id = ctx.conversation_id.as_deref().unwrap_or(""),
1849 streaming = true,
1850 );
1851 self.respond_traced(&request_id, user, ctx, Some(sink))
1852 .instrument(span)
1853 .await
1854 }
1855
1856 async fn respond_traced(
1858 &self,
1859 request_id: &str,
1860 user: UserMessage,
1861 ctx: &ToolContext,
1862 stream_sink: Option<StreamSink>,
1863 ) -> Result<AssistantMessage, AgentError> {
1864 self.increment_counter("requests_total");
1865 let run_started = Instant::now();
1866 self.hook("before_run", json!({"request_id": request_id}))
1867 .await?;
1868 let timed = timeout(
1869 Duration::from_millis(self.config.request_timeout_ms),
1870 self.respond_inner(request_id, user, ctx, stream_sink),
1871 )
1872 .await;
1873 let result = match timed {
1874 Ok(result) => result,
1875 Err(_) => Err(AgentError::Timeout {
1876 timeout_ms: self.config.request_timeout_ms,
1877 }),
1878 };
1879
1880 let after_detail = match &result {
1881 Ok(response) => json!({
1882 "request_id": request_id,
1883 "status": "ok",
1884 "response_len": response.text.len(),
1885 "duration_ms": run_started.elapsed().as_millis(),
1886 }),
1887 Err(err) => json!({
1888 "request_id": request_id,
1889 "status": "error",
1890 "error": redact_text(&err.to_string()),
1891 "duration_ms": run_started.elapsed().as_millis(),
1892 }),
1893 };
1894 let total_cost = ctx.current_cost();
1895 let cost_usd = total_cost as f64 / 1_000_000.0;
1896 info!(
1897 request_id = %request_id,
1898 duration_ms = %run_started.elapsed().as_millis(),
1899 total_tokens = ctx.current_tokens(),
1900 cost_microdollars = total_cost,
1901 cost_usd = format!("{:.4}", cost_usd),
1902 "agent run completed"
1903 );
1904 self.hook("after_run", after_detail).await?;
1905 result
1906 }
1907
1908 async fn respond_inner(
1909 &self,
1910 request_id: &str,
1911 user: UserMessage,
1912 ctx: &ToolContext,
1913 stream_sink: Option<StreamSink>,
1914 ) -> Result<AssistantMessage, AgentError> {
1915 self.audit(
1916 "respond_start",
1917 json!({
1918 "request_id": request_id,
1919 "user_message_len": user.text.len(),
1920 "max_tool_iterations": self.config.max_tool_iterations,
1921 "request_timeout_ms": self.config.request_timeout_ms,
1922 }),
1923 )
1924 .await;
1925 self.write_to_memory(
1926 "user",
1927 &user.text,
1928 request_id,
1929 ctx.source_channel.as_deref(),
1930 ctx.conversation_id.as_deref().unwrap_or(""),
1931 ctx.agent_id.as_deref(),
1932 )
1933 .await?;
1934
1935 let research_context = if self.should_research(&user.text) {
1936 self.run_research_phase(&user.text, ctx, request_id).await?
1937 } else {
1938 String::new()
1939 };
1940
1941 if self.config.model_supports_tool_use && self.has_tool_definitions() {
1943 return self
1944 .respond_with_tools(request_id, &user.text, &research_context, ctx, stream_sink)
1945 .await;
1946 }
1947
1948 let mut prompt = user.text;
1949 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1950
1951 for iteration in 0..self.config.max_tool_iterations {
1952 if !prompt.starts_with("tool:") {
1953 break;
1954 }
1955
1956 let calls = parse_tool_calls(&prompt);
1957 if calls.is_empty() {
1958 break;
1959 }
1960
1961 let has_gated = calls
1965 .iter()
1966 .any(|(name, _)| self.config.gated_tools.contains(*name));
1967 if calls.len() > 1 && self.config.parallel_tools && !has_gated {
1968 let mut resolved: Vec<(&str, &str, &dyn Tool)> = Vec::new();
1969 for &(name, input) in &calls {
1970 let tool = self.tools.iter().find(|t| t.name() == name);
1971 match tool {
1972 Some(t) => resolved.push((name, input, &**t)),
1973 None => {
1974 self.audit(
1975 "tool_not_found",
1976 json!({"request_id": request_id, "iteration": iteration, "tool_name": name}),
1977 )
1978 .await;
1979 }
1980 }
1981 }
1982 if resolved.is_empty() {
1983 break;
1984 }
1985
1986 let futs: Vec<_> = resolved
1987 .iter()
1988 .map(|&(name, input, tool)| async move {
1989 let r = tool.execute(input, ctx).await;
1990 (name, input, r)
1991 })
1992 .collect();
1993 let results = futures_util::future::join_all(futs).await;
1994
1995 let mut output_parts: Vec<String> = Vec::new();
1996 for (name, input, result) in results {
1997 match result {
1998 Ok(r) => {
1999 tool_history.push((
2000 name.to_string(),
2001 input.to_string(),
2002 r.output.clone(),
2003 ));
2004 output_parts.push(format!("Tool output from {name}: {}", r.output));
2005 }
2006 Err(source) => {
2007 return Err(AgentError::Tool {
2008 tool: name.to_string(),
2009 source,
2010 });
2011 }
2012 }
2013 }
2014 prompt = output_parts.join("\n");
2015 continue;
2016 }
2017
2018 let (tool_name, tool_input) = calls[0];
2020 self.audit(
2021 "tool_requested",
2022 json!({
2023 "request_id": request_id,
2024 "iteration": iteration,
2025 "tool_name": tool_name,
2026 "tool_input_len": tool_input.len(),
2027 }),
2028 )
2029 .await;
2030
2031 if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) {
2032 let result = self
2033 .execute_tool(&**tool, tool_name, tool_input, ctx, request_id, iteration)
2034 .await?;
2035
2036 tool_history.push((
2037 tool_name.to_string(),
2038 tool_input.to_string(),
2039 result.output.clone(),
2040 ));
2041
2042 if self.config.loop_detection_no_progress_threshold > 0 {
2044 let threshold = self.config.loop_detection_no_progress_threshold;
2045 if tool_history.len() >= threshold {
2046 let recent = &tool_history[tool_history.len() - threshold..];
2047 let all_same = recent.iter().all(|entry| {
2048 entry.0 == recent[0].0
2049 && entry.1 == recent[0].1
2050 && entry.2 == recent[0].2
2051 });
2052 if all_same {
2053 warn!(
2054 tool_name,
2055 threshold, "loop detection: no-progress threshold reached"
2056 );
2057 self.audit(
2058 "loop_detection_no_progress",
2059 json!({
2060 "request_id": request_id,
2061 "tool_name": tool_name,
2062 "threshold": threshold,
2063 }),
2064 )
2065 .await;
2066 prompt = format!(
2067 "SYSTEM NOTICE: Loop detected — the tool `{tool_name}` was called \
2068 {threshold} times with identical arguments and output. You are stuck \
2069 in a loop. Stop calling this tool and try a different approach, or \
2070 provide your best answer with the information you already have."
2071 );
2072 break;
2073 }
2074 }
2075 }
2076
2077 if self.config.loop_detection_ping_pong_cycles > 0 {
2079 let cycles = self.config.loop_detection_ping_pong_cycles;
2080 let needed = cycles * 2;
2081 if tool_history.len() >= needed {
2082 let recent = &tool_history[tool_history.len() - needed..];
2083 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
2084 && recent[0].0 != recent[1].0;
2085 if is_ping_pong {
2086 warn!(
2087 tool_a = recent[0].0,
2088 tool_b = recent[1].0,
2089 cycles,
2090 "loop detection: ping-pong pattern detected"
2091 );
2092 self.audit(
2093 "loop_detection_ping_pong",
2094 json!({
2095 "request_id": request_id,
2096 "tool_a": recent[0].0,
2097 "tool_b": recent[1].0,
2098 "cycles": cycles,
2099 }),
2100 )
2101 .await;
2102 prompt = format!(
2103 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
2104 alternating for {} cycles in a ping-pong pattern. Stop this \
2105 alternation and try a different approach, or provide your best \
2106 answer with the information you already have.",
2107 recent[0].0, recent[1].0, cycles
2108 );
2109 break;
2110 }
2111 }
2112 }
2113
2114 if result.output.starts_with("tool:") {
2115 prompt = result.output.clone();
2116 } else {
2117 prompt = format!("Tool output from {tool_name}: {}", result.output);
2118 }
2119 continue;
2120 }
2121
2122 self.audit(
2123 "tool_not_found",
2124 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
2125 )
2126 .await;
2127 break;
2128 }
2129
2130 if !research_context.is_empty() {
2131 prompt = format!("Research findings:\n{research_context}\n\nUser request:\n{prompt}");
2132 }
2133
2134 let response_text = self
2135 .call_provider_with_context(
2136 &prompt,
2137 request_id,
2138 stream_sink,
2139 ctx.source_channel.as_deref(),
2140 )
2141 .await?;
2142 self.write_to_memory(
2143 "assistant",
2144 &response_text,
2145 request_id,
2146 ctx.source_channel.as_deref(),
2147 ctx.conversation_id.as_deref().unwrap_or(""),
2148 ctx.agent_id.as_deref(),
2149 )
2150 .await?;
2151 self.audit("respond_success", json!({"request_id": request_id}))
2152 .await;
2153
2154 self.hook(
2155 "before_response_emit",
2156 json!({"request_id": request_id, "response_len": response_text.len()}),
2157 )
2158 .await?;
2159 let response = AssistantMessage {
2160 text: response_text,
2161 };
2162 self.hook(
2163 "after_response_emit",
2164 json!({"request_id": request_id, "response_len": response.text.len()}),
2165 )
2166 .await?;
2167
2168 Ok(response)
2169 }
2170}
2171
2172#[cfg(test)]
2173mod tests {
2174 use super::*;
2175 use crate::types::{
2176 ChatResult, HookEvent, HookFailureMode, HookPolicy, ReasoningConfig, ResearchPolicy,
2177 ResearchTrigger, ToolResult,
2178 };
2179 use async_trait::async_trait;
2180 use serde_json::Value;
2181 use std::collections::HashMap;
2182 use std::sync::atomic::{AtomicUsize, Ordering};
2183 use std::sync::{Arc, Mutex};
2184 use tokio::time::{sleep, Duration};
2185
2186 #[derive(Default)]
2187 struct TestMemory {
2188 entries: Arc<Mutex<Vec<MemoryEntry>>>,
2189 }
2190
2191 #[async_trait]
2192 impl MemoryStore for TestMemory {
2193 async fn append(&self, entry: MemoryEntry) -> anyhow::Result<()> {
2194 self.entries
2195 .lock()
2196 .expect("memory lock poisoned")
2197 .push(entry);
2198 Ok(())
2199 }
2200
2201 async fn recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
2202 let entries = self.entries.lock().expect("memory lock poisoned");
2203 Ok(entries.iter().rev().take(limit).cloned().collect())
2204 }
2205 }
2206
2207 struct TestProvider {
2208 received_prompts: Arc<Mutex<Vec<String>>>,
2209 response_text: String,
2210 }
2211
2212 #[async_trait]
2213 impl Provider for TestProvider {
2214 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
2215 self.received_prompts
2216 .lock()
2217 .expect("provider lock poisoned")
2218 .push(prompt.to_string());
2219 Ok(ChatResult {
2220 output_text: self.response_text.clone(),
2221 ..Default::default()
2222 })
2223 }
2224 }
2225
2226 struct EchoTool;
2227
2228 #[async_trait]
2229 impl Tool for EchoTool {
2230 fn name(&self) -> &'static str {
2231 "echo"
2232 }
2233
2234 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2235 Ok(ToolResult {
2236 output: format!("echoed:{input}"),
2237 })
2238 }
2239 }
2240
2241 struct PluginEchoTool;
2242
2243 #[async_trait]
2244 impl Tool for PluginEchoTool {
2245 fn name(&self) -> &'static str {
2246 "plugin:echo"
2247 }
2248
2249 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2250 Ok(ToolResult {
2251 output: format!("plugin-echoed:{input}"),
2252 })
2253 }
2254 }
2255
2256 struct FailingTool;
2257
2258 #[async_trait]
2259 impl Tool for FailingTool {
2260 fn name(&self) -> &'static str {
2261 "boom"
2262 }
2263
2264 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2265 Err(anyhow::anyhow!("tool exploded"))
2266 }
2267 }
2268
2269 struct SlowTool;
2270
2271 #[async_trait]
2272 impl Tool for SlowTool {
2273 fn name(&self) -> &'static str {
2274 "slow"
2275 }
2276
2277 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2278 sleep(Duration::from_millis(500)).await;
2279 Ok(ToolResult {
2280 output: "finally done".to_string(),
2281 })
2282 }
2283 }
2284
2285 struct FailingProvider;
2286
2287 #[async_trait]
2288 impl Provider for FailingProvider {
2289 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
2290 Err(anyhow::anyhow!("provider boom"))
2291 }
2292 }
2293
2294 struct SlowProvider;
2295
2296 #[async_trait]
2297 impl Provider for SlowProvider {
2298 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
2299 sleep(Duration::from_millis(500)).await;
2300 Ok(ChatResult {
2301 output_text: "late".to_string(),
2302 ..Default::default()
2303 })
2304 }
2305 }
2306
2307 struct ScriptedProvider {
2308 responses: Vec<String>,
2309 call_count: AtomicUsize,
2310 received_prompts: Arc<Mutex<Vec<String>>>,
2311 }
2312
2313 impl ScriptedProvider {
2314 fn new(responses: Vec<&str>) -> Self {
2315 Self {
2316 responses: responses.into_iter().map(|s| s.to_string()).collect(),
2317 call_count: AtomicUsize::new(0),
2318 received_prompts: Arc::new(Mutex::new(Vec::new())),
2319 }
2320 }
2321 }
2322
2323 #[async_trait]
2324 impl Provider for ScriptedProvider {
2325 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
2326 self.received_prompts
2327 .lock()
2328 .expect("provider lock poisoned")
2329 .push(prompt.to_string());
2330 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
2331 let response = if idx < self.responses.len() {
2332 self.responses[idx].clone()
2333 } else {
2334 self.responses.last().cloned().unwrap_or_default()
2335 };
2336 Ok(ChatResult {
2337 output_text: response,
2338 ..Default::default()
2339 })
2340 }
2341 }
2342
2343 struct UpperTool;
2344
2345 #[async_trait]
2346 impl Tool for UpperTool {
2347 fn name(&self) -> &'static str {
2348 "upper"
2349 }
2350
2351 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2352 Ok(ToolResult {
2353 output: input.to_uppercase(),
2354 })
2355 }
2356 }
2357
2358 struct LoopTool;
2361
2362 #[async_trait]
2363 impl Tool for LoopTool {
2364 fn name(&self) -> &'static str {
2365 "loop_tool"
2366 }
2367
2368 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2369 Ok(ToolResult {
2370 output: "tool:loop_tool x".to_string(),
2371 })
2372 }
2373 }
2374
2375 struct PingTool;
2377
2378 #[async_trait]
2379 impl Tool for PingTool {
2380 fn name(&self) -> &'static str {
2381 "ping"
2382 }
2383
2384 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2385 Ok(ToolResult {
2386 output: "tool:pong x".to_string(),
2387 })
2388 }
2389 }
2390
2391 struct PongTool;
2393
2394 #[async_trait]
2395 impl Tool for PongTool {
2396 fn name(&self) -> &'static str {
2397 "pong"
2398 }
2399
2400 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
2401 Ok(ToolResult {
2402 output: "tool:ping x".to_string(),
2403 })
2404 }
2405 }
2406
2407 struct RecordingHookSink {
2408 events: Arc<Mutex<Vec<String>>>,
2409 }
2410
2411 #[async_trait]
2412 impl HookSink for RecordingHookSink {
2413 async fn record(&self, event: HookEvent) -> anyhow::Result<()> {
2414 self.events
2415 .lock()
2416 .expect("hook lock poisoned")
2417 .push(event.stage);
2418 Ok(())
2419 }
2420 }
2421
2422 struct FailingHookSink;
2423
2424 #[async_trait]
2425 impl HookSink for FailingHookSink {
2426 async fn record(&self, _event: HookEvent) -> anyhow::Result<()> {
2427 Err(anyhow::anyhow!("hook sink failure"))
2428 }
2429 }
2430
2431 struct RecordingAuditSink {
2432 events: Arc<Mutex<Vec<AuditEvent>>>,
2433 }
2434
2435 #[async_trait]
2436 impl AuditSink for RecordingAuditSink {
2437 async fn record(&self, event: AuditEvent) -> anyhow::Result<()> {
2438 self.events.lock().expect("audit lock poisoned").push(event);
2439 Ok(())
2440 }
2441 }
2442
2443 struct RecordingMetricsSink {
2444 counters: Arc<Mutex<HashMap<&'static str, u64>>>,
2445 histograms: Arc<Mutex<HashMap<&'static str, usize>>>,
2446 }
2447
2448 impl MetricsSink for RecordingMetricsSink {
2449 fn increment_counter(&self, name: &'static str, value: u64) {
2450 let mut counters = self.counters.lock().expect("metrics lock poisoned");
2451 *counters.entry(name).or_insert(0) += value;
2452 }
2453
2454 fn observe_histogram(&self, name: &'static str, _value: f64) {
2455 let mut histograms = self.histograms.lock().expect("metrics lock poisoned");
2456 *histograms.entry(name).or_insert(0) += 1;
2457 }
2458 }
2459
2460 fn test_ctx() -> ToolContext {
2461 ToolContext::new(".".to_string())
2462 }
2463
2464 fn counter(counters: &Arc<Mutex<HashMap<&'static str, u64>>>, name: &'static str) -> u64 {
2465 counters
2466 .lock()
2467 .expect("metrics lock poisoned")
2468 .get(name)
2469 .copied()
2470 .unwrap_or(0)
2471 }
2472
2473 fn histogram_count(
2474 histograms: &Arc<Mutex<HashMap<&'static str, usize>>>,
2475 name: &'static str,
2476 ) -> usize {
2477 histograms
2478 .lock()
2479 .expect("metrics lock poisoned")
2480 .get(name)
2481 .copied()
2482 .unwrap_or(0)
2483 }
2484
2485 #[tokio::test]
2486 async fn respond_appends_user_then_assistant_memory_entries() {
2487 let entries = Arc::new(Mutex::new(Vec::new()));
2488 let prompts = Arc::new(Mutex::new(Vec::new()));
2489 let memory = TestMemory {
2490 entries: entries.clone(),
2491 };
2492 let provider = TestProvider {
2493 received_prompts: prompts,
2494 response_text: "assistant-output".to_string(),
2495 };
2496 let agent = Agent::new(
2497 AgentConfig::default(),
2498 Box::new(provider),
2499 Box::new(memory),
2500 vec![],
2501 );
2502
2503 let response = agent
2504 .respond(
2505 UserMessage {
2506 text: "hello world".to_string(),
2507 },
2508 &test_ctx(),
2509 )
2510 .await
2511 .expect("agent respond should succeed");
2512
2513 assert_eq!(response.text, "assistant-output");
2514 let stored = entries.lock().expect("memory lock poisoned");
2515 assert_eq!(stored.len(), 2);
2516 assert_eq!(stored[0].role, "user");
2517 assert_eq!(stored[0].content, "hello world");
2518 assert_eq!(stored[1].role, "assistant");
2519 assert_eq!(stored[1].content, "assistant-output");
2520 }
2521
2522 #[tokio::test]
2523 async fn respond_invokes_tool_for_tool_prefixed_prompt() {
2524 let entries = Arc::new(Mutex::new(Vec::new()));
2525 let prompts = Arc::new(Mutex::new(Vec::new()));
2526 let memory = TestMemory {
2527 entries: entries.clone(),
2528 };
2529 let provider = TestProvider {
2530 received_prompts: prompts.clone(),
2531 response_text: "assistant-after-tool".to_string(),
2532 };
2533 let agent = Agent::new(
2534 AgentConfig::default(),
2535 Box::new(provider),
2536 Box::new(memory),
2537 vec![Box::new(EchoTool)],
2538 );
2539
2540 let response = agent
2541 .respond(
2542 UserMessage {
2543 text: "tool:echo ping".to_string(),
2544 },
2545 &test_ctx(),
2546 )
2547 .await
2548 .expect("agent respond should succeed");
2549
2550 assert_eq!(response.text, "assistant-after-tool");
2551 let prompts = prompts.lock().expect("provider lock poisoned");
2552 assert_eq!(prompts.len(), 1);
2553 assert!(prompts[0].contains("Recent conversation:"));
2554 assert!(prompts[0].contains("Current input:\nTool output from echo: echoed:ping"));
2555 }
2556
2557 #[tokio::test]
2558 async fn respond_with_unknown_tool_falls_back_to_provider_prompt() {
2559 let entries = Arc::new(Mutex::new(Vec::new()));
2560 let prompts = Arc::new(Mutex::new(Vec::new()));
2561 let memory = TestMemory {
2562 entries: entries.clone(),
2563 };
2564 let provider = TestProvider {
2565 received_prompts: prompts.clone(),
2566 response_text: "assistant-without-tool".to_string(),
2567 };
2568 let agent = Agent::new(
2569 AgentConfig::default(),
2570 Box::new(provider),
2571 Box::new(memory),
2572 vec![Box::new(EchoTool)],
2573 );
2574
2575 let response = agent
2576 .respond(
2577 UserMessage {
2578 text: "tool:unknown payload".to_string(),
2579 },
2580 &test_ctx(),
2581 )
2582 .await
2583 .expect("agent respond should succeed");
2584
2585 assert_eq!(response.text, "assistant-without-tool");
2586 let prompts = prompts.lock().expect("provider lock poisoned");
2587 assert_eq!(prompts.len(), 1);
2588 assert!(prompts[0].contains("Recent conversation:"));
2589 assert!(prompts[0].contains("Current input:\ntool:unknown payload"));
2590 }
2591
2592 #[tokio::test]
2593 async fn respond_includes_bounded_recent_memory_in_provider_prompt() {
2594 let entries = Arc::new(Mutex::new(vec![
2595 MemoryEntry {
2596 role: "assistant".to_string(),
2597 content: "very-old".to_string(),
2598 ..Default::default()
2599 },
2600 MemoryEntry {
2601 role: "user".to_string(),
2602 content: "recent-before-request".to_string(),
2603 ..Default::default()
2604 },
2605 ]));
2606 let prompts = Arc::new(Mutex::new(Vec::new()));
2607 let memory = TestMemory {
2608 entries: entries.clone(),
2609 };
2610 let provider = TestProvider {
2611 received_prompts: prompts.clone(),
2612 response_text: "ok".to_string(),
2613 };
2614 let agent = Agent::new(
2615 AgentConfig {
2616 memory_window_size: 2,
2617 ..AgentConfig::default()
2618 },
2619 Box::new(provider),
2620 Box::new(memory),
2621 vec![],
2622 );
2623
2624 agent
2625 .respond(
2626 UserMessage {
2627 text: "latest-user".to_string(),
2628 },
2629 &test_ctx(),
2630 )
2631 .await
2632 .expect("respond should succeed");
2633
2634 let prompts = prompts.lock().expect("provider lock poisoned");
2635 let provider_prompt = prompts.first().expect("provider prompt should exist");
2636 assert!(provider_prompt.contains("- user: recent-before-request"));
2637 assert!(provider_prompt.contains("- user: latest-user"));
2638 assert!(!provider_prompt.contains("very-old"));
2639 }
2640
2641 #[tokio::test]
2642 async fn respond_caps_provider_prompt_to_configured_max_chars() {
2643 let entries = Arc::new(Mutex::new(vec![MemoryEntry {
2644 role: "assistant".to_string(),
2645 content: "historic context ".repeat(16),
2646 ..Default::default()
2647 }]));
2648 let prompts = Arc::new(Mutex::new(Vec::new()));
2649 let memory = TestMemory {
2650 entries: entries.clone(),
2651 };
2652 let provider = TestProvider {
2653 received_prompts: prompts.clone(),
2654 response_text: "ok".to_string(),
2655 };
2656 let agent = Agent::new(
2657 AgentConfig {
2658 max_prompt_chars: 64,
2659 ..AgentConfig::default()
2660 },
2661 Box::new(provider),
2662 Box::new(memory),
2663 vec![],
2664 );
2665
2666 agent
2667 .respond(
2668 UserMessage {
2669 text: "final-tail-marker".to_string(),
2670 },
2671 &test_ctx(),
2672 )
2673 .await
2674 .expect("respond should succeed");
2675
2676 let prompts = prompts.lock().expect("provider lock poisoned");
2677 let provider_prompt = prompts.first().expect("provider prompt should exist");
2678 assert!(provider_prompt.chars().count() <= 64);
2679 assert!(provider_prompt.contains("final-tail-marker"));
2680 }
2681
2682 #[tokio::test]
2683 async fn respond_returns_provider_typed_error() {
2684 let memory = TestMemory::default();
2685 let counters = Arc::new(Mutex::new(HashMap::new()));
2686 let histograms = Arc::new(Mutex::new(HashMap::new()));
2687 let agent = Agent::new(
2688 AgentConfig::default(),
2689 Box::new(FailingProvider),
2690 Box::new(memory),
2691 vec![],
2692 )
2693 .with_metrics(Box::new(RecordingMetricsSink {
2694 counters: counters.clone(),
2695 histograms: histograms.clone(),
2696 }));
2697
2698 let result = agent
2699 .respond(
2700 UserMessage {
2701 text: "hello".to_string(),
2702 },
2703 &test_ctx(),
2704 )
2705 .await;
2706
2707 match result {
2708 Err(AgentError::Provider { source }) => {
2709 assert!(source.to_string().contains("provider boom"));
2710 }
2711 other => panic!("expected provider error, got {other:?}"),
2712 }
2713 assert_eq!(counter(&counters, "requests_total"), 1);
2714 assert_eq!(counter(&counters, "provider_errors_total"), 1);
2715 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2716 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2717 }
2718
2719 #[tokio::test]
2720 async fn respond_increments_tool_error_counter_on_tool_failure() {
2721 let memory = TestMemory::default();
2722 let prompts = Arc::new(Mutex::new(Vec::new()));
2723 let provider = TestProvider {
2724 received_prompts: prompts,
2725 response_text: "unused".to_string(),
2726 };
2727 let counters = Arc::new(Mutex::new(HashMap::new()));
2728 let histograms = Arc::new(Mutex::new(HashMap::new()));
2729 let agent = Agent::new(
2730 AgentConfig::default(),
2731 Box::new(provider),
2732 Box::new(memory),
2733 vec![Box::new(FailingTool)],
2734 )
2735 .with_metrics(Box::new(RecordingMetricsSink {
2736 counters: counters.clone(),
2737 histograms: histograms.clone(),
2738 }));
2739
2740 let result = agent
2741 .respond(
2742 UserMessage {
2743 text: "tool:boom ping".to_string(),
2744 },
2745 &test_ctx(),
2746 )
2747 .await;
2748
2749 match result {
2750 Err(AgentError::Tool { tool, source }) => {
2751 assert_eq!(tool, "boom");
2752 assert!(source.to_string().contains("tool exploded"));
2753 }
2754 other => panic!("expected tool error, got {other:?}"),
2755 }
2756 assert_eq!(counter(&counters, "requests_total"), 1);
2757 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2758 assert_eq!(counter(&counters, "tool_errors_total"), 1);
2759 assert_eq!(histogram_count(&histograms, "tool_latency_ms"), 1);
2760 }
2761
2762 #[tokio::test]
2763 async fn respond_increments_requests_counter_on_success() {
2764 let memory = TestMemory::default();
2765 let prompts = Arc::new(Mutex::new(Vec::new()));
2766 let provider = TestProvider {
2767 received_prompts: prompts,
2768 response_text: "ok".to_string(),
2769 };
2770 let counters = Arc::new(Mutex::new(HashMap::new()));
2771 let histograms = Arc::new(Mutex::new(HashMap::new()));
2772 let agent = Agent::new(
2773 AgentConfig::default(),
2774 Box::new(provider),
2775 Box::new(memory),
2776 vec![],
2777 )
2778 .with_metrics(Box::new(RecordingMetricsSink {
2779 counters: counters.clone(),
2780 histograms: histograms.clone(),
2781 }));
2782
2783 let response = agent
2784 .respond(
2785 UserMessage {
2786 text: "hello".to_string(),
2787 },
2788 &test_ctx(),
2789 )
2790 .await
2791 .expect("response should succeed");
2792
2793 assert_eq!(response.text, "ok");
2794 assert_eq!(counter(&counters, "requests_total"), 1);
2795 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2796 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2797 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2798 }
2799
2800 #[tokio::test]
2801 async fn respond_returns_timeout_typed_error() {
2802 let memory = TestMemory::default();
2803 let agent = Agent::new(
2804 AgentConfig {
2805 max_tool_iterations: 1,
2806 request_timeout_ms: 10,
2807 ..AgentConfig::default()
2808 },
2809 Box::new(SlowProvider),
2810 Box::new(memory),
2811 vec![],
2812 );
2813
2814 let result = agent
2815 .respond(
2816 UserMessage {
2817 text: "hello".to_string(),
2818 },
2819 &test_ctx(),
2820 )
2821 .await;
2822
2823 match result {
2824 Err(AgentError::Timeout { timeout_ms }) => assert_eq!(timeout_ms, 10),
2825 other => panic!("expected timeout error, got {other:?}"),
2826 }
2827 }
2828
2829 #[tokio::test]
2830 async fn respond_emits_before_after_hook_events_when_enabled() {
2831 let events = Arc::new(Mutex::new(Vec::new()));
2832 let memory = TestMemory::default();
2833 let provider = TestProvider {
2834 received_prompts: Arc::new(Mutex::new(Vec::new())),
2835 response_text: "ok".to_string(),
2836 };
2837 let agent = Agent::new(
2838 AgentConfig {
2839 hooks: HookPolicy {
2840 enabled: true,
2841 timeout_ms: 50,
2842 fail_closed: true,
2843 ..HookPolicy::default()
2844 },
2845 ..AgentConfig::default()
2846 },
2847 Box::new(provider),
2848 Box::new(memory),
2849 vec![Box::new(EchoTool), Box::new(PluginEchoTool)],
2850 )
2851 .with_hooks(Box::new(RecordingHookSink {
2852 events: events.clone(),
2853 }));
2854
2855 agent
2856 .respond(
2857 UserMessage {
2858 text: "tool:plugin:echo ping".to_string(),
2859 },
2860 &test_ctx(),
2861 )
2862 .await
2863 .expect("respond should succeed");
2864
2865 let stages = events.lock().expect("hook lock poisoned");
2866 assert!(stages.contains(&"before_run".to_string()));
2867 assert!(stages.contains(&"after_run".to_string()));
2868 assert!(stages.contains(&"before_tool_call".to_string()));
2869 assert!(stages.contains(&"after_tool_call".to_string()));
2870 assert!(stages.contains(&"before_plugin_call".to_string()));
2871 assert!(stages.contains(&"after_plugin_call".to_string()));
2872 assert!(stages.contains(&"before_provider_call".to_string()));
2873 assert!(stages.contains(&"after_provider_call".to_string()));
2874 assert!(stages.contains(&"before_memory_write".to_string()));
2875 assert!(stages.contains(&"after_memory_write".to_string()));
2876 assert!(stages.contains(&"before_response_emit".to_string()));
2877 assert!(stages.contains(&"after_response_emit".to_string()));
2878 }
2879
2880 #[tokio::test]
2881 async fn respond_audit_events_include_request_id_and_durations() {
2882 let audit_events = Arc::new(Mutex::new(Vec::new()));
2883 let memory = TestMemory::default();
2884 let provider = TestProvider {
2885 received_prompts: Arc::new(Mutex::new(Vec::new())),
2886 response_text: "ok".to_string(),
2887 };
2888 let agent = Agent::new(
2889 AgentConfig::default(),
2890 Box::new(provider),
2891 Box::new(memory),
2892 vec![Box::new(EchoTool)],
2893 )
2894 .with_audit(Box::new(RecordingAuditSink {
2895 events: audit_events.clone(),
2896 }));
2897
2898 agent
2899 .respond(
2900 UserMessage {
2901 text: "tool:echo ping".to_string(),
2902 },
2903 &test_ctx(),
2904 )
2905 .await
2906 .expect("respond should succeed");
2907
2908 let events = audit_events.lock().expect("audit lock poisoned");
2909 let request_id = events
2910 .iter()
2911 .find_map(|e| {
2912 e.detail
2913 .get("request_id")
2914 .and_then(Value::as_str)
2915 .map(ToString::to_string)
2916 })
2917 .expect("request_id should exist on audit events");
2918
2919 let provider_event = events
2920 .iter()
2921 .find(|e| e.stage == "provider_call_success")
2922 .expect("provider success event should exist");
2923 assert_eq!(
2924 provider_event
2925 .detail
2926 .get("request_id")
2927 .and_then(Value::as_str),
2928 Some(request_id.as_str())
2929 );
2930 assert!(provider_event
2931 .detail
2932 .get("duration_ms")
2933 .and_then(Value::as_u64)
2934 .is_some());
2935
2936 let tool_event = events
2937 .iter()
2938 .find(|e| e.stage == "tool_execute_success")
2939 .expect("tool success event should exist");
2940 assert_eq!(
2941 tool_event.detail.get("request_id").and_then(Value::as_str),
2942 Some(request_id.as_str())
2943 );
2944 assert!(tool_event
2945 .detail
2946 .get("duration_ms")
2947 .and_then(Value::as_u64)
2948 .is_some());
2949 }
2950
2951 #[tokio::test]
2952 async fn hook_errors_respect_block_mode_for_high_tier_negative_path() {
2953 let memory = TestMemory::default();
2954 let provider = TestProvider {
2955 received_prompts: Arc::new(Mutex::new(Vec::new())),
2956 response_text: "ok".to_string(),
2957 };
2958 let agent = Agent::new(
2959 AgentConfig {
2960 hooks: HookPolicy {
2961 enabled: true,
2962 timeout_ms: 50,
2963 fail_closed: false,
2964 default_mode: HookFailureMode::Warn,
2965 low_tier_mode: HookFailureMode::Ignore,
2966 medium_tier_mode: HookFailureMode::Warn,
2967 high_tier_mode: HookFailureMode::Block,
2968 },
2969 ..AgentConfig::default()
2970 },
2971 Box::new(provider),
2972 Box::new(memory),
2973 vec![Box::new(EchoTool)],
2974 )
2975 .with_hooks(Box::new(FailingHookSink));
2976
2977 let result = agent
2978 .respond(
2979 UserMessage {
2980 text: "tool:echo ping".to_string(),
2981 },
2982 &test_ctx(),
2983 )
2984 .await;
2985
2986 match result {
2987 Err(AgentError::Hook { stage, .. }) => assert_eq!(stage, "before_tool_call"),
2988 other => panic!("expected hook block error, got {other:?}"),
2989 }
2990 }
2991
2992 #[tokio::test]
2993 async fn hook_errors_respect_warn_mode_for_high_tier_success_path() {
2994 let audit_events = Arc::new(Mutex::new(Vec::new()));
2995 let memory = TestMemory::default();
2996 let provider = TestProvider {
2997 received_prompts: Arc::new(Mutex::new(Vec::new())),
2998 response_text: "ok".to_string(),
2999 };
3000 let agent = Agent::new(
3001 AgentConfig {
3002 hooks: HookPolicy {
3003 enabled: true,
3004 timeout_ms: 50,
3005 fail_closed: false,
3006 default_mode: HookFailureMode::Warn,
3007 low_tier_mode: HookFailureMode::Ignore,
3008 medium_tier_mode: HookFailureMode::Warn,
3009 high_tier_mode: HookFailureMode::Warn,
3010 },
3011 ..AgentConfig::default()
3012 },
3013 Box::new(provider),
3014 Box::new(memory),
3015 vec![Box::new(EchoTool)],
3016 )
3017 .with_hooks(Box::new(FailingHookSink))
3018 .with_audit(Box::new(RecordingAuditSink {
3019 events: audit_events.clone(),
3020 }));
3021
3022 let response = agent
3023 .respond(
3024 UserMessage {
3025 text: "tool:echo ping".to_string(),
3026 },
3027 &test_ctx(),
3028 )
3029 .await
3030 .expect("warn mode should continue");
3031 assert!(!response.text.is_empty());
3032
3033 let events = audit_events.lock().expect("audit lock poisoned");
3034 assert!(events.iter().any(|event| event.stage == "hook_error_warn"));
3035 }
3036
3037 #[tokio::test]
3040 async fn self_correction_injected_on_no_progress() {
3041 let prompts = Arc::new(Mutex::new(Vec::new()));
3042 let provider = TestProvider {
3043 received_prompts: prompts.clone(),
3044 response_text: "I'll try a different approach".to_string(),
3045 };
3046 let agent = Agent::new(
3049 AgentConfig {
3050 loop_detection_no_progress_threshold: 3,
3051 max_tool_iterations: 20,
3052 ..AgentConfig::default()
3053 },
3054 Box::new(provider),
3055 Box::new(TestMemory::default()),
3056 vec![Box::new(LoopTool)],
3057 );
3058
3059 let response = agent
3060 .respond(
3061 UserMessage {
3062 text: "tool:loop_tool x".to_string(),
3063 },
3064 &test_ctx(),
3065 )
3066 .await
3067 .expect("self-correction should succeed");
3068
3069 assert_eq!(response.text, "I'll try a different approach");
3070
3071 let received = prompts.lock().expect("provider lock poisoned");
3073 let last_prompt = received.last().expect("provider should have been called");
3074 assert!(
3075 last_prompt.contains("Loop detected"),
3076 "provider prompt should contain self-correction notice, got: {last_prompt}"
3077 );
3078 assert!(last_prompt.contains("loop_tool"));
3079 }
3080
3081 #[tokio::test]
3082 async fn self_correction_injected_on_ping_pong() {
3083 let prompts = Arc::new(Mutex::new(Vec::new()));
3084 let provider = TestProvider {
3085 received_prompts: prompts.clone(),
3086 response_text: "I stopped the loop".to_string(),
3087 };
3088 let agent = Agent::new(
3090 AgentConfig {
3091 loop_detection_ping_pong_cycles: 2,
3092 loop_detection_no_progress_threshold: 0, max_tool_iterations: 20,
3094 ..AgentConfig::default()
3095 },
3096 Box::new(provider),
3097 Box::new(TestMemory::default()),
3098 vec![Box::new(PingTool), Box::new(PongTool)],
3099 );
3100
3101 let response = agent
3102 .respond(
3103 UserMessage {
3104 text: "tool:ping x".to_string(),
3105 },
3106 &test_ctx(),
3107 )
3108 .await
3109 .expect("self-correction should succeed");
3110
3111 assert_eq!(response.text, "I stopped the loop");
3112
3113 let received = prompts.lock().expect("provider lock poisoned");
3114 let last_prompt = received.last().expect("provider should have been called");
3115 assert!(
3116 last_prompt.contains("ping-pong"),
3117 "provider prompt should contain ping-pong notice, got: {last_prompt}"
3118 );
3119 }
3120
3121 #[tokio::test]
3122 async fn no_progress_disabled_when_threshold_zero() {
3123 let prompts = Arc::new(Mutex::new(Vec::new()));
3124 let provider = TestProvider {
3125 received_prompts: prompts.clone(),
3126 response_text: "final answer".to_string(),
3127 };
3128 let agent = Agent::new(
3131 AgentConfig {
3132 loop_detection_no_progress_threshold: 0,
3133 loop_detection_ping_pong_cycles: 0,
3134 max_tool_iterations: 5,
3135 ..AgentConfig::default()
3136 },
3137 Box::new(provider),
3138 Box::new(TestMemory::default()),
3139 vec![Box::new(LoopTool)],
3140 );
3141
3142 let response = agent
3143 .respond(
3144 UserMessage {
3145 text: "tool:loop_tool x".to_string(),
3146 },
3147 &test_ctx(),
3148 )
3149 .await
3150 .expect("should complete without error");
3151
3152 assert_eq!(response.text, "final answer");
3153
3154 let received = prompts.lock().expect("provider lock poisoned");
3156 let last_prompt = received.last().expect("provider should have been called");
3157 assert!(
3158 !last_prompt.contains("Loop detected"),
3159 "should not contain self-correction when detection is disabled"
3160 );
3161 }
3162
3163 #[tokio::test]
3166 async fn parallel_tools_executes_multiple_calls() {
3167 let prompts = Arc::new(Mutex::new(Vec::new()));
3168 let provider = TestProvider {
3169 received_prompts: prompts.clone(),
3170 response_text: "parallel done".to_string(),
3171 };
3172 let agent = Agent::new(
3173 AgentConfig {
3174 parallel_tools: true,
3175 max_tool_iterations: 5,
3176 ..AgentConfig::default()
3177 },
3178 Box::new(provider),
3179 Box::new(TestMemory::default()),
3180 vec![Box::new(EchoTool), Box::new(UpperTool)],
3181 );
3182
3183 let response = agent
3184 .respond(
3185 UserMessage {
3186 text: "tool:echo hello\ntool:upper world".to_string(),
3187 },
3188 &test_ctx(),
3189 )
3190 .await
3191 .expect("parallel tools should succeed");
3192
3193 assert_eq!(response.text, "parallel done");
3194
3195 let received = prompts.lock().expect("provider lock poisoned");
3196 let last_prompt = received.last().expect("provider should have been called");
3197 assert!(
3198 last_prompt.contains("echoed:hello"),
3199 "should contain echo output, got: {last_prompt}"
3200 );
3201 assert!(
3202 last_prompt.contains("WORLD"),
3203 "should contain upper output, got: {last_prompt}"
3204 );
3205 }
3206
3207 #[tokio::test]
3208 async fn parallel_tools_maintains_stable_ordering() {
3209 let prompts = Arc::new(Mutex::new(Vec::new()));
3210 let provider = TestProvider {
3211 received_prompts: prompts.clone(),
3212 response_text: "ordered".to_string(),
3213 };
3214 let agent = Agent::new(
3215 AgentConfig {
3216 parallel_tools: true,
3217 max_tool_iterations: 5,
3218 ..AgentConfig::default()
3219 },
3220 Box::new(provider),
3221 Box::new(TestMemory::default()),
3222 vec![Box::new(EchoTool), Box::new(UpperTool)],
3223 );
3224
3225 let response = agent
3227 .respond(
3228 UserMessage {
3229 text: "tool:echo aaa\ntool:upper bbb".to_string(),
3230 },
3231 &test_ctx(),
3232 )
3233 .await
3234 .expect("parallel tools should succeed");
3235
3236 assert_eq!(response.text, "ordered");
3237
3238 let received = prompts.lock().expect("provider lock poisoned");
3239 let last_prompt = received.last().expect("provider should have been called");
3240 let echo_pos = last_prompt
3241 .find("echoed:aaa")
3242 .expect("echo output should exist");
3243 let upper_pos = last_prompt.find("BBB").expect("upper output should exist");
3244 assert!(
3245 echo_pos < upper_pos,
3246 "echo output should come before upper output in the prompt"
3247 );
3248 }
3249
3250 #[tokio::test]
3251 async fn parallel_disabled_runs_first_call_only() {
3252 let prompts = Arc::new(Mutex::new(Vec::new()));
3253 let provider = TestProvider {
3254 received_prompts: prompts.clone(),
3255 response_text: "sequential".to_string(),
3256 };
3257 let agent = Agent::new(
3258 AgentConfig {
3259 parallel_tools: false,
3260 max_tool_iterations: 5,
3261 ..AgentConfig::default()
3262 },
3263 Box::new(provider),
3264 Box::new(TestMemory::default()),
3265 vec![Box::new(EchoTool), Box::new(UpperTool)],
3266 );
3267
3268 let response = agent
3269 .respond(
3270 UserMessage {
3271 text: "tool:echo hello\ntool:upper world".to_string(),
3272 },
3273 &test_ctx(),
3274 )
3275 .await
3276 .expect("sequential tools should succeed");
3277
3278 assert_eq!(response.text, "sequential");
3279
3280 let received = prompts.lock().expect("provider lock poisoned");
3282 let last_prompt = received.last().expect("provider should have been called");
3283 assert!(
3284 last_prompt.contains("echoed:hello"),
3285 "should contain first tool output, got: {last_prompt}"
3286 );
3287 assert!(
3288 !last_prompt.contains("WORLD"),
3289 "should NOT contain second tool output when parallel is disabled"
3290 );
3291 }
3292
3293 #[tokio::test]
3294 async fn single_call_parallel_matches_sequential() {
3295 let prompts = Arc::new(Mutex::new(Vec::new()));
3296 let provider = TestProvider {
3297 received_prompts: prompts.clone(),
3298 response_text: "single".to_string(),
3299 };
3300 let agent = Agent::new(
3301 AgentConfig {
3302 parallel_tools: true,
3303 max_tool_iterations: 5,
3304 ..AgentConfig::default()
3305 },
3306 Box::new(provider),
3307 Box::new(TestMemory::default()),
3308 vec![Box::new(EchoTool)],
3309 );
3310
3311 let response = agent
3312 .respond(
3313 UserMessage {
3314 text: "tool:echo ping".to_string(),
3315 },
3316 &test_ctx(),
3317 )
3318 .await
3319 .expect("single tool with parallel enabled should succeed");
3320
3321 assert_eq!(response.text, "single");
3322
3323 let received = prompts.lock().expect("provider lock poisoned");
3324 let last_prompt = received.last().expect("provider should have been called");
3325 assert!(
3326 last_prompt.contains("echoed:ping"),
3327 "single tool call should work same as sequential, got: {last_prompt}"
3328 );
3329 }
3330
3331 #[tokio::test]
3332 async fn parallel_falls_back_to_sequential_for_gated_tools() {
3333 let prompts = Arc::new(Mutex::new(Vec::new()));
3334 let provider = TestProvider {
3335 received_prompts: prompts.clone(),
3336 response_text: "gated sequential".to_string(),
3337 };
3338 let mut gated = std::collections::HashSet::new();
3339 gated.insert("upper".to_string());
3340 let agent = Agent::new(
3341 AgentConfig {
3342 parallel_tools: true,
3343 gated_tools: gated,
3344 max_tool_iterations: 5,
3345 ..AgentConfig::default()
3346 },
3347 Box::new(provider),
3348 Box::new(TestMemory::default()),
3349 vec![Box::new(EchoTool), Box::new(UpperTool)],
3350 );
3351
3352 let response = agent
3355 .respond(
3356 UserMessage {
3357 text: "tool:echo hello\ntool:upper world".to_string(),
3358 },
3359 &test_ctx(),
3360 )
3361 .await
3362 .expect("gated fallback should succeed");
3363
3364 assert_eq!(response.text, "gated sequential");
3365
3366 let received = prompts.lock().expect("provider lock poisoned");
3367 let last_prompt = received.last().expect("provider should have been called");
3368 assert!(
3369 last_prompt.contains("echoed:hello"),
3370 "first tool should execute, got: {last_prompt}"
3371 );
3372 assert!(
3373 !last_prompt.contains("WORLD"),
3374 "gated tool should NOT execute in parallel, got: {last_prompt}"
3375 );
3376 }
3377
3378 fn research_config(trigger: ResearchTrigger) -> ResearchPolicy {
3381 ResearchPolicy {
3382 enabled: true,
3383 trigger,
3384 keywords: vec!["search".to_string(), "find".to_string()],
3385 min_message_length: 10,
3386 max_iterations: 5,
3387 show_progress: true,
3388 }
3389 }
3390
3391 fn config_with_research(trigger: ResearchTrigger) -> AgentConfig {
3392 AgentConfig {
3393 research: research_config(trigger),
3394 ..Default::default()
3395 }
3396 }
3397
3398 #[tokio::test]
3399 async fn research_trigger_never() {
3400 let prompts = Arc::new(Mutex::new(Vec::new()));
3401 let provider = TestProvider {
3402 received_prompts: prompts.clone(),
3403 response_text: "final answer".to_string(),
3404 };
3405 let agent = Agent::new(
3406 config_with_research(ResearchTrigger::Never),
3407 Box::new(provider),
3408 Box::new(TestMemory::default()),
3409 vec![Box::new(EchoTool)],
3410 );
3411
3412 let response = agent
3413 .respond(
3414 UserMessage {
3415 text: "search for something".to_string(),
3416 },
3417 &test_ctx(),
3418 )
3419 .await
3420 .expect("should succeed");
3421
3422 assert_eq!(response.text, "final answer");
3423 let received = prompts.lock().expect("lock");
3424 assert_eq!(
3425 received.len(),
3426 1,
3427 "should have exactly 1 provider call (no research)"
3428 );
3429 }
3430
3431 #[tokio::test]
3432 async fn research_trigger_always() {
3433 let provider = ScriptedProvider::new(vec![
3434 "tool:echo gathering data",
3435 "Found relevant information",
3436 "Final answer with research context",
3437 ]);
3438 let prompts = provider.received_prompts.clone();
3439 let agent = Agent::new(
3440 config_with_research(ResearchTrigger::Always),
3441 Box::new(provider),
3442 Box::new(TestMemory::default()),
3443 vec![Box::new(EchoTool)],
3444 );
3445
3446 let response = agent
3447 .respond(
3448 UserMessage {
3449 text: "hello world".to_string(),
3450 },
3451 &test_ctx(),
3452 )
3453 .await
3454 .expect("should succeed");
3455
3456 assert_eq!(response.text, "Final answer with research context");
3457 let received = prompts.lock().expect("lock");
3458 let last = received.last().expect("should have provider calls");
3459 assert!(
3460 last.contains("Research findings:"),
3461 "final prompt should contain research findings, got: {last}"
3462 );
3463 }
3464
3465 #[tokio::test]
3466 async fn research_trigger_keywords_match() {
3467 let provider = ScriptedProvider::new(vec!["Summary of search", "Answer based on research"]);
3468 let prompts = provider.received_prompts.clone();
3469 let agent = Agent::new(
3470 config_with_research(ResearchTrigger::Keywords),
3471 Box::new(provider),
3472 Box::new(TestMemory::default()),
3473 vec![],
3474 );
3475
3476 let response = agent
3477 .respond(
3478 UserMessage {
3479 text: "please search for the config".to_string(),
3480 },
3481 &test_ctx(),
3482 )
3483 .await
3484 .expect("should succeed");
3485
3486 assert_eq!(response.text, "Answer based on research");
3487 let received = prompts.lock().expect("lock");
3488 assert!(received.len() >= 2, "should have at least 2 provider calls");
3489 assert!(
3490 received[0].contains("RESEARCH mode"),
3491 "first call should be research prompt"
3492 );
3493 }
3494
3495 #[tokio::test]
3496 async fn research_trigger_keywords_no_match() {
3497 let prompts = Arc::new(Mutex::new(Vec::new()));
3498 let provider = TestProvider {
3499 received_prompts: prompts.clone(),
3500 response_text: "direct answer".to_string(),
3501 };
3502 let agent = Agent::new(
3503 config_with_research(ResearchTrigger::Keywords),
3504 Box::new(provider),
3505 Box::new(TestMemory::default()),
3506 vec![],
3507 );
3508
3509 let response = agent
3510 .respond(
3511 UserMessage {
3512 text: "hello world".to_string(),
3513 },
3514 &test_ctx(),
3515 )
3516 .await
3517 .expect("should succeed");
3518
3519 assert_eq!(response.text, "direct answer");
3520 let received = prompts.lock().expect("lock");
3521 assert_eq!(received.len(), 1, "no research phase should fire");
3522 }
3523
3524 #[tokio::test]
3525 async fn research_trigger_length_short_skips() {
3526 let prompts = Arc::new(Mutex::new(Vec::new()));
3527 let provider = TestProvider {
3528 received_prompts: prompts.clone(),
3529 response_text: "short".to_string(),
3530 };
3531 let config = AgentConfig {
3532 research: ResearchPolicy {
3533 min_message_length: 20,
3534 ..research_config(ResearchTrigger::Length)
3535 },
3536 ..Default::default()
3537 };
3538 let agent = Agent::new(
3539 config,
3540 Box::new(provider),
3541 Box::new(TestMemory::default()),
3542 vec![],
3543 );
3544 agent
3545 .respond(
3546 UserMessage {
3547 text: "hi".to_string(),
3548 },
3549 &test_ctx(),
3550 )
3551 .await
3552 .expect("should succeed");
3553 let received = prompts.lock().expect("lock");
3554 assert_eq!(received.len(), 1, "short message should skip research");
3555 }
3556
3557 #[tokio::test]
3558 async fn research_trigger_length_long_triggers() {
3559 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
3560 let prompts = provider.received_prompts.clone();
3561 let config = AgentConfig {
3562 research: ResearchPolicy {
3563 min_message_length: 20,
3564 ..research_config(ResearchTrigger::Length)
3565 },
3566 ..Default::default()
3567 };
3568 let agent = Agent::new(
3569 config,
3570 Box::new(provider),
3571 Box::new(TestMemory::default()),
3572 vec![],
3573 );
3574 agent
3575 .respond(
3576 UserMessage {
3577 text: "this is a longer message that exceeds the threshold".to_string(),
3578 },
3579 &test_ctx(),
3580 )
3581 .await
3582 .expect("should succeed");
3583 let received = prompts.lock().expect("lock");
3584 assert!(received.len() >= 2, "long message should trigger research");
3585 }
3586
3587 #[tokio::test]
3588 async fn research_trigger_question_with_mark() {
3589 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
3590 let prompts = provider.received_prompts.clone();
3591 let agent = Agent::new(
3592 config_with_research(ResearchTrigger::Question),
3593 Box::new(provider),
3594 Box::new(TestMemory::default()),
3595 vec![],
3596 );
3597 agent
3598 .respond(
3599 UserMessage {
3600 text: "what is the meaning of life?".to_string(),
3601 },
3602 &test_ctx(),
3603 )
3604 .await
3605 .expect("should succeed");
3606 let received = prompts.lock().expect("lock");
3607 assert!(received.len() >= 2, "question should trigger research");
3608 }
3609
3610 #[tokio::test]
3611 async fn research_trigger_question_without_mark() {
3612 let prompts = Arc::new(Mutex::new(Vec::new()));
3613 let provider = TestProvider {
3614 received_prompts: prompts.clone(),
3615 response_text: "no research".to_string(),
3616 };
3617 let agent = Agent::new(
3618 config_with_research(ResearchTrigger::Question),
3619 Box::new(provider),
3620 Box::new(TestMemory::default()),
3621 vec![],
3622 );
3623 agent
3624 .respond(
3625 UserMessage {
3626 text: "do this thing".to_string(),
3627 },
3628 &test_ctx(),
3629 )
3630 .await
3631 .expect("should succeed");
3632 let received = prompts.lock().expect("lock");
3633 assert_eq!(received.len(), 1, "non-question should skip research");
3634 }
3635
3636 #[tokio::test]
3637 async fn research_respects_max_iterations() {
3638 let provider = ScriptedProvider::new(vec![
3639 "tool:echo step1",
3640 "tool:echo step2",
3641 "tool:echo step3",
3642 "tool:echo step4",
3643 "tool:echo step5",
3644 "tool:echo step6",
3645 "tool:echo step7",
3646 "answer after research",
3647 ]);
3648 let prompts = provider.received_prompts.clone();
3649 let config = AgentConfig {
3650 research: ResearchPolicy {
3651 max_iterations: 3,
3652 ..research_config(ResearchTrigger::Always)
3653 },
3654 ..Default::default()
3655 };
3656 let agent = Agent::new(
3657 config,
3658 Box::new(provider),
3659 Box::new(TestMemory::default()),
3660 vec![Box::new(EchoTool)],
3661 );
3662
3663 let response = agent
3664 .respond(
3665 UserMessage {
3666 text: "test".to_string(),
3667 },
3668 &test_ctx(),
3669 )
3670 .await
3671 .expect("should succeed");
3672
3673 assert!(!response.text.is_empty(), "should get a response");
3674 let received = prompts.lock().expect("lock");
3675 let research_calls = received
3676 .iter()
3677 .filter(|p| p.contains("RESEARCH mode") || p.contains("Research iteration"))
3678 .count();
3679 assert_eq!(
3681 research_calls, 4,
3682 "should have 4 research provider calls (1 initial + 3 iterations)"
3683 );
3684 }
3685
3686 #[tokio::test]
3687 async fn research_disabled_skips_phase() {
3688 let prompts = Arc::new(Mutex::new(Vec::new()));
3689 let provider = TestProvider {
3690 received_prompts: prompts.clone(),
3691 response_text: "direct answer".to_string(),
3692 };
3693 let config = AgentConfig {
3694 research: ResearchPolicy {
3695 enabled: false,
3696 trigger: ResearchTrigger::Always,
3697 ..Default::default()
3698 },
3699 ..Default::default()
3700 };
3701 let agent = Agent::new(
3702 config,
3703 Box::new(provider),
3704 Box::new(TestMemory::default()),
3705 vec![Box::new(EchoTool)],
3706 );
3707
3708 let response = agent
3709 .respond(
3710 UserMessage {
3711 text: "search for something".to_string(),
3712 },
3713 &test_ctx(),
3714 )
3715 .await
3716 .expect("should succeed");
3717
3718 assert_eq!(response.text, "direct answer");
3719 let received = prompts.lock().expect("lock");
3720 assert_eq!(
3721 received.len(),
3722 1,
3723 "disabled research should not fire even with Always trigger"
3724 );
3725 }
3726
3727 struct ReasoningCapturingProvider {
3730 captured_reasoning: Arc<Mutex<Vec<ReasoningConfig>>>,
3731 response_text: String,
3732 }
3733
3734 #[async_trait]
3735 impl Provider for ReasoningCapturingProvider {
3736 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3737 self.captured_reasoning
3738 .lock()
3739 .expect("lock")
3740 .push(ReasoningConfig::default());
3741 Ok(ChatResult {
3742 output_text: self.response_text.clone(),
3743 ..Default::default()
3744 })
3745 }
3746
3747 async fn complete_with_reasoning(
3748 &self,
3749 _prompt: &str,
3750 reasoning: &ReasoningConfig,
3751 ) -> anyhow::Result<ChatResult> {
3752 self.captured_reasoning
3753 .lock()
3754 .expect("lock")
3755 .push(reasoning.clone());
3756 Ok(ChatResult {
3757 output_text: self.response_text.clone(),
3758 ..Default::default()
3759 })
3760 }
3761 }
3762
3763 #[tokio::test]
3764 async fn reasoning_config_passed_to_provider() {
3765 let captured = Arc::new(Mutex::new(Vec::new()));
3766 let provider = ReasoningCapturingProvider {
3767 captured_reasoning: captured.clone(),
3768 response_text: "ok".to_string(),
3769 };
3770 let config = AgentConfig {
3771 reasoning: ReasoningConfig {
3772 enabled: Some(true),
3773 level: Some("high".to_string()),
3774 },
3775 ..Default::default()
3776 };
3777 let agent = Agent::new(
3778 config,
3779 Box::new(provider),
3780 Box::new(TestMemory::default()),
3781 vec![],
3782 );
3783
3784 agent
3785 .respond(
3786 UserMessage {
3787 text: "test".to_string(),
3788 },
3789 &test_ctx(),
3790 )
3791 .await
3792 .expect("should succeed");
3793
3794 let configs = captured.lock().expect("lock");
3795 assert_eq!(configs.len(), 1, "provider should be called once");
3796 assert_eq!(configs[0].enabled, Some(true));
3797 assert_eq!(configs[0].level.as_deref(), Some("high"));
3798 }
3799
3800 #[tokio::test]
3801 async fn reasoning_disabled_passes_config_through() {
3802 let captured = Arc::new(Mutex::new(Vec::new()));
3803 let provider = ReasoningCapturingProvider {
3804 captured_reasoning: captured.clone(),
3805 response_text: "ok".to_string(),
3806 };
3807 let config = AgentConfig {
3808 reasoning: ReasoningConfig {
3809 enabled: Some(false),
3810 level: None,
3811 },
3812 ..Default::default()
3813 };
3814 let agent = Agent::new(
3815 config,
3816 Box::new(provider),
3817 Box::new(TestMemory::default()),
3818 vec![],
3819 );
3820
3821 agent
3822 .respond(
3823 UserMessage {
3824 text: "test".to_string(),
3825 },
3826 &test_ctx(),
3827 )
3828 .await
3829 .expect("should succeed");
3830
3831 let configs = captured.lock().expect("lock");
3832 assert_eq!(configs.len(), 1);
3833 assert_eq!(
3834 configs[0].enabled,
3835 Some(false),
3836 "disabled reasoning should still be passed to provider"
3837 );
3838 assert_eq!(configs[0].level, None);
3839 }
3840
3841 struct StructuredProvider {
3844 responses: Vec<ChatResult>,
3845 call_count: AtomicUsize,
3846 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
3847 }
3848
3849 impl StructuredProvider {
3850 fn new(responses: Vec<ChatResult>) -> Self {
3851 Self {
3852 responses,
3853 call_count: AtomicUsize::new(0),
3854 received_messages: Arc::new(Mutex::new(Vec::new())),
3855 }
3856 }
3857 }
3858
3859 #[async_trait]
3860 impl Provider for StructuredProvider {
3861 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3862 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3863 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3864 }
3865
3866 async fn complete_with_tools(
3867 &self,
3868 messages: &[ConversationMessage],
3869 _tools: &[ToolDefinition],
3870 _reasoning: &ReasoningConfig,
3871 ) -> anyhow::Result<ChatResult> {
3872 self.received_messages
3873 .lock()
3874 .expect("lock")
3875 .push(messages.to_vec());
3876 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3877 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3878 }
3879 }
3880
3881 struct StructuredEchoTool;
3882
3883 #[async_trait]
3884 impl Tool for StructuredEchoTool {
3885 fn name(&self) -> &'static str {
3886 "echo"
3887 }
3888 fn description(&self) -> &'static str {
3889 "Echoes input back"
3890 }
3891 fn input_schema(&self) -> Option<serde_json::Value> {
3892 Some(json!({
3893 "type": "object",
3894 "properties": {
3895 "text": { "type": "string", "description": "Text to echo" }
3896 },
3897 "required": ["text"]
3898 }))
3899 }
3900 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3901 Ok(ToolResult {
3902 output: format!("echoed:{input}"),
3903 })
3904 }
3905 }
3906
3907 struct StructuredFailingTool;
3908
3909 #[async_trait]
3910 impl Tool for StructuredFailingTool {
3911 fn name(&self) -> &'static str {
3912 "boom"
3913 }
3914 fn description(&self) -> &'static str {
3915 "Always fails"
3916 }
3917 fn input_schema(&self) -> Option<serde_json::Value> {
3918 Some(json!({
3919 "type": "object",
3920 "properties": {},
3921 }))
3922 }
3923 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3924 Err(anyhow::anyhow!("tool exploded"))
3925 }
3926 }
3927
3928 struct StructuredUpperTool;
3929
3930 #[async_trait]
3931 impl Tool for StructuredUpperTool {
3932 fn name(&self) -> &'static str {
3933 "upper"
3934 }
3935 fn description(&self) -> &'static str {
3936 "Uppercases input"
3937 }
3938 fn input_schema(&self) -> Option<serde_json::Value> {
3939 Some(json!({
3940 "type": "object",
3941 "properties": {
3942 "text": { "type": "string", "description": "Text to uppercase" }
3943 },
3944 "required": ["text"]
3945 }))
3946 }
3947 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3948 Ok(ToolResult {
3949 output: input.to_uppercase(),
3950 })
3951 }
3952 }
3953
3954 use crate::types::ToolUseRequest;
3955
3956 #[tokio::test]
3959 async fn structured_basic_tool_call_then_end_turn() {
3960 let provider = StructuredProvider::new(vec![
3961 ChatResult {
3962 output_text: "Let me echo that.".to_string(),
3963 tool_calls: vec![ToolUseRequest {
3964 id: "call_1".to_string(),
3965 name: "echo".to_string(),
3966 input: json!({"text": "hello"}),
3967 }],
3968 stop_reason: Some(StopReason::ToolUse),
3969 ..Default::default()
3970 },
3971 ChatResult {
3972 output_text: "The echo returned: hello".to_string(),
3973 tool_calls: vec![],
3974 stop_reason: Some(StopReason::EndTurn),
3975 ..Default::default()
3976 },
3977 ]);
3978 let received = provider.received_messages.clone();
3979
3980 let agent = Agent::new(
3981 AgentConfig::default(),
3982 Box::new(provider),
3983 Box::new(TestMemory::default()),
3984 vec![Box::new(StructuredEchoTool)],
3985 );
3986
3987 let response = agent
3988 .respond(
3989 UserMessage {
3990 text: "echo hello".to_string(),
3991 },
3992 &test_ctx(),
3993 )
3994 .await
3995 .expect("should succeed");
3996
3997 assert_eq!(response.text, "The echo returned: hello");
3998
3999 let msgs = received.lock().expect("lock");
4001 assert_eq!(msgs.len(), 2, "provider should be called twice");
4002 assert!(
4004 msgs[1].len() >= 3,
4005 "second call should have user + assistant + tool result"
4006 );
4007 }
4008
4009 #[tokio::test]
4010 async fn structured_no_tool_calls_returns_immediately() {
4011 let provider = StructuredProvider::new(vec![ChatResult {
4012 output_text: "Hello! No tools needed.".to_string(),
4013 tool_calls: vec![],
4014 stop_reason: Some(StopReason::EndTurn),
4015 ..Default::default()
4016 }]);
4017
4018 let agent = Agent::new(
4019 AgentConfig::default(),
4020 Box::new(provider),
4021 Box::new(TestMemory::default()),
4022 vec![Box::new(StructuredEchoTool)],
4023 );
4024
4025 let response = agent
4026 .respond(
4027 UserMessage {
4028 text: "hi".to_string(),
4029 },
4030 &test_ctx(),
4031 )
4032 .await
4033 .expect("should succeed");
4034
4035 assert_eq!(response.text, "Hello! No tools needed.");
4036 }
4037
4038 #[tokio::test]
4039 async fn structured_tool_not_found_sends_error_result() {
4040 let provider = StructuredProvider::new(vec![
4041 ChatResult {
4042 output_text: String::new(),
4043 tool_calls: vec![ToolUseRequest {
4044 id: "call_1".to_string(),
4045 name: "nonexistent".to_string(),
4046 input: json!({}),
4047 }],
4048 stop_reason: Some(StopReason::ToolUse),
4049 ..Default::default()
4050 },
4051 ChatResult {
4052 output_text: "I see the tool was not found.".to_string(),
4053 tool_calls: vec![],
4054 stop_reason: Some(StopReason::EndTurn),
4055 ..Default::default()
4056 },
4057 ]);
4058 let received = provider.received_messages.clone();
4059
4060 let agent = Agent::new(
4061 AgentConfig::default(),
4062 Box::new(provider),
4063 Box::new(TestMemory::default()),
4064 vec![Box::new(StructuredEchoTool)],
4065 );
4066
4067 let response = agent
4068 .respond(
4069 UserMessage {
4070 text: "use nonexistent".to_string(),
4071 },
4072 &test_ctx(),
4073 )
4074 .await
4075 .expect("should succeed, not abort");
4076
4077 assert_eq!(response.text, "I see the tool was not found.");
4078
4079 let msgs = received.lock().expect("lock");
4081 let last_call = &msgs[1];
4082 let has_error_result = last_call.iter().any(|m| {
4083 matches!(m, ConversationMessage::ToolResult(r) if r.is_error && r.content.contains("not found"))
4084 });
4085 assert!(has_error_result, "should include error ToolResult");
4086 }
4087
4088 #[tokio::test]
4089 async fn structured_tool_error_does_not_abort() {
4090 let provider = StructuredProvider::new(vec![
4091 ChatResult {
4092 output_text: String::new(),
4093 tool_calls: vec![ToolUseRequest {
4094 id: "call_1".to_string(),
4095 name: "boom".to_string(),
4096 input: json!({}),
4097 }],
4098 stop_reason: Some(StopReason::ToolUse),
4099 ..Default::default()
4100 },
4101 ChatResult {
4102 output_text: "I handled the error gracefully.".to_string(),
4103 tool_calls: vec![],
4104 stop_reason: Some(StopReason::EndTurn),
4105 ..Default::default()
4106 },
4107 ]);
4108
4109 let agent = Agent::new(
4110 AgentConfig::default(),
4111 Box::new(provider),
4112 Box::new(TestMemory::default()),
4113 vec![Box::new(StructuredFailingTool)],
4114 );
4115
4116 let response = agent
4117 .respond(
4118 UserMessage {
4119 text: "boom".to_string(),
4120 },
4121 &test_ctx(),
4122 )
4123 .await
4124 .expect("should succeed, error becomes ToolResultMessage");
4125
4126 assert_eq!(response.text, "I handled the error gracefully.");
4127 }
4128
4129 #[tokio::test]
4130 async fn structured_multi_iteration_tool_calls() {
4131 let provider = StructuredProvider::new(vec![
4132 ChatResult {
4134 output_text: String::new(),
4135 tool_calls: vec![ToolUseRequest {
4136 id: "call_1".to_string(),
4137 name: "echo".to_string(),
4138 input: json!({"text": "first"}),
4139 }],
4140 stop_reason: Some(StopReason::ToolUse),
4141 ..Default::default()
4142 },
4143 ChatResult {
4145 output_text: String::new(),
4146 tool_calls: vec![ToolUseRequest {
4147 id: "call_2".to_string(),
4148 name: "echo".to_string(),
4149 input: json!({"text": "second"}),
4150 }],
4151 stop_reason: Some(StopReason::ToolUse),
4152 ..Default::default()
4153 },
4154 ChatResult {
4156 output_text: "Done with two tool calls.".to_string(),
4157 tool_calls: vec![],
4158 stop_reason: Some(StopReason::EndTurn),
4159 ..Default::default()
4160 },
4161 ]);
4162 let received = provider.received_messages.clone();
4163
4164 let agent = Agent::new(
4165 AgentConfig::default(),
4166 Box::new(provider),
4167 Box::new(TestMemory::default()),
4168 vec![Box::new(StructuredEchoTool)],
4169 );
4170
4171 let response = agent
4172 .respond(
4173 UserMessage {
4174 text: "echo twice".to_string(),
4175 },
4176 &test_ctx(),
4177 )
4178 .await
4179 .expect("should succeed");
4180
4181 assert_eq!(response.text, "Done with two tool calls.");
4182 let msgs = received.lock().expect("lock");
4183 assert_eq!(msgs.len(), 3, "three provider calls");
4184 }
4185
4186 #[tokio::test]
4187 async fn structured_max_iterations_forces_final_answer() {
4188 let responses = vec![
4192 ChatResult {
4193 output_text: String::new(),
4194 tool_calls: vec![ToolUseRequest {
4195 id: "call_0".to_string(),
4196 name: "echo".to_string(),
4197 input: json!({"text": "loop"}),
4198 }],
4199 stop_reason: Some(StopReason::ToolUse),
4200 ..Default::default()
4201 },
4202 ChatResult {
4203 output_text: String::new(),
4204 tool_calls: vec![ToolUseRequest {
4205 id: "call_1".to_string(),
4206 name: "echo".to_string(),
4207 input: json!({"text": "loop"}),
4208 }],
4209 stop_reason: Some(StopReason::ToolUse),
4210 ..Default::default()
4211 },
4212 ChatResult {
4213 output_text: String::new(),
4214 tool_calls: vec![ToolUseRequest {
4215 id: "call_2".to_string(),
4216 name: "echo".to_string(),
4217 input: json!({"text": "loop"}),
4218 }],
4219 stop_reason: Some(StopReason::ToolUse),
4220 ..Default::default()
4221 },
4222 ChatResult {
4224 output_text: "Forced final answer.".to_string(),
4225 tool_calls: vec![],
4226 stop_reason: Some(StopReason::EndTurn),
4227 ..Default::default()
4228 },
4229 ];
4230
4231 let agent = Agent::new(
4232 AgentConfig {
4233 max_tool_iterations: 3,
4234 loop_detection_no_progress_threshold: 0, ..AgentConfig::default()
4236 },
4237 Box::new(StructuredProvider::new(responses)),
4238 Box::new(TestMemory::default()),
4239 vec![Box::new(StructuredEchoTool)],
4240 );
4241
4242 let response = agent
4243 .respond(
4244 UserMessage {
4245 text: "loop forever".to_string(),
4246 },
4247 &test_ctx(),
4248 )
4249 .await
4250 .expect("should succeed with forced answer");
4251
4252 assert_eq!(response.text, "Forced final answer.");
4253 }
4254
4255 #[tokio::test]
4256 async fn structured_fallback_to_text_when_disabled() {
4257 let prompts = Arc::new(Mutex::new(Vec::new()));
4258 let provider = TestProvider {
4259 received_prompts: prompts.clone(),
4260 response_text: "text path used".to_string(),
4261 };
4262
4263 let agent = Agent::new(
4264 AgentConfig {
4265 model_supports_tool_use: false,
4266 ..AgentConfig::default()
4267 },
4268 Box::new(provider),
4269 Box::new(TestMemory::default()),
4270 vec![Box::new(StructuredEchoTool)],
4271 );
4272
4273 let response = agent
4274 .respond(
4275 UserMessage {
4276 text: "hello".to_string(),
4277 },
4278 &test_ctx(),
4279 )
4280 .await
4281 .expect("should succeed");
4282
4283 assert_eq!(response.text, "text path used");
4284 }
4285
4286 #[tokio::test]
4287 async fn structured_fallback_when_no_tool_schemas() {
4288 let prompts = Arc::new(Mutex::new(Vec::new()));
4289 let provider = TestProvider {
4290 received_prompts: prompts.clone(),
4291 response_text: "text path used".to_string(),
4292 };
4293
4294 let agent = Agent::new(
4296 AgentConfig::default(),
4297 Box::new(provider),
4298 Box::new(TestMemory::default()),
4299 vec![Box::new(EchoTool)],
4300 );
4301
4302 let response = agent
4303 .respond(
4304 UserMessage {
4305 text: "hello".to_string(),
4306 },
4307 &test_ctx(),
4308 )
4309 .await
4310 .expect("should succeed");
4311
4312 assert_eq!(response.text, "text path used");
4313 }
4314
4315 #[tokio::test]
4316 async fn structured_parallel_tool_calls() {
4317 let provider = StructuredProvider::new(vec![
4318 ChatResult {
4319 output_text: String::new(),
4320 tool_calls: vec![
4321 ToolUseRequest {
4322 id: "call_1".to_string(),
4323 name: "echo".to_string(),
4324 input: json!({"text": "a"}),
4325 },
4326 ToolUseRequest {
4327 id: "call_2".to_string(),
4328 name: "upper".to_string(),
4329 input: json!({"text": "b"}),
4330 },
4331 ],
4332 stop_reason: Some(StopReason::ToolUse),
4333 ..Default::default()
4334 },
4335 ChatResult {
4336 output_text: "Both tools ran.".to_string(),
4337 tool_calls: vec![],
4338 stop_reason: Some(StopReason::EndTurn),
4339 ..Default::default()
4340 },
4341 ]);
4342 let received = provider.received_messages.clone();
4343
4344 let agent = Agent::new(
4345 AgentConfig {
4346 parallel_tools: true,
4347 ..AgentConfig::default()
4348 },
4349 Box::new(provider),
4350 Box::new(TestMemory::default()),
4351 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4352 );
4353
4354 let response = agent
4355 .respond(
4356 UserMessage {
4357 text: "do both".to_string(),
4358 },
4359 &test_ctx(),
4360 )
4361 .await
4362 .expect("should succeed");
4363
4364 assert_eq!(response.text, "Both tools ran.");
4365
4366 let msgs = received.lock().expect("lock");
4368 let tool_results: Vec<_> = msgs[1]
4369 .iter()
4370 .filter(|m| matches!(m, ConversationMessage::ToolResult(_)))
4371 .collect();
4372 assert_eq!(tool_results.len(), 2, "should have two tool results");
4373 }
4374
4375 #[tokio::test]
4376 async fn structured_memory_integration() {
4377 let memory = TestMemory::default();
4378 memory
4380 .append(MemoryEntry {
4381 role: "user".to_string(),
4382 content: "old question".to_string(),
4383 ..Default::default()
4384 })
4385 .await
4386 .unwrap();
4387 memory
4388 .append(MemoryEntry {
4389 role: "assistant".to_string(),
4390 content: "old answer".to_string(),
4391 ..Default::default()
4392 })
4393 .await
4394 .unwrap();
4395
4396 let provider = StructuredProvider::new(vec![ChatResult {
4397 output_text: "I see your history.".to_string(),
4398 tool_calls: vec![],
4399 stop_reason: Some(StopReason::EndTurn),
4400 ..Default::default()
4401 }]);
4402 let received = provider.received_messages.clone();
4403
4404 let agent = Agent::new(
4405 AgentConfig::default(),
4406 Box::new(provider),
4407 Box::new(memory),
4408 vec![Box::new(StructuredEchoTool)],
4409 );
4410
4411 agent
4412 .respond(
4413 UserMessage {
4414 text: "new question".to_string(),
4415 },
4416 &test_ctx(),
4417 )
4418 .await
4419 .expect("should succeed");
4420
4421 let msgs = received.lock().expect("lock");
4422 assert!(msgs[0].len() >= 3, "should include memory messages");
4425 assert!(matches!(
4427 &msgs[0][0],
4428 ConversationMessage::User { content, .. } if content == "old question"
4429 ));
4430 }
4431
4432 #[tokio::test]
4433 async fn prepare_tool_input_extracts_single_string_field() {
4434 let tool = StructuredEchoTool;
4435 let input = json!({"text": "hello world"});
4436 let result = prepare_tool_input(&tool, &input).expect("valid input");
4437 assert_eq!(result, "hello world");
4438 }
4439
4440 #[tokio::test]
4441 async fn prepare_tool_input_serializes_multi_field_json() {
4442 struct MultiFieldTool;
4444 #[async_trait]
4445 impl Tool for MultiFieldTool {
4446 fn name(&self) -> &'static str {
4447 "multi"
4448 }
4449 fn input_schema(&self) -> Option<serde_json::Value> {
4450 Some(json!({
4451 "type": "object",
4452 "properties": {
4453 "path": { "type": "string" },
4454 "content": { "type": "string" }
4455 },
4456 "required": ["path", "content"]
4457 }))
4458 }
4459 async fn execute(
4460 &self,
4461 _input: &str,
4462 _ctx: &ToolContext,
4463 ) -> anyhow::Result<ToolResult> {
4464 unreachable!()
4465 }
4466 }
4467
4468 let tool = MultiFieldTool;
4469 let input = json!({"path": "a.txt", "content": "hello"});
4470 let result = prepare_tool_input(&tool, &input).expect("valid input");
4471 let parsed: serde_json::Value = serde_json::from_str(&result).expect("valid JSON");
4473 assert_eq!(parsed["path"], "a.txt");
4474 assert_eq!(parsed["content"], "hello");
4475 }
4476
4477 #[tokio::test]
4478 async fn prepare_tool_input_unwraps_bare_string() {
4479 let tool = StructuredEchoTool;
4480 let input = json!("plain text");
4481 let result = prepare_tool_input(&tool, &input).expect("valid input");
4482 assert_eq!(result, "plain text");
4483 }
4484
4485 #[tokio::test]
4486 async fn prepare_tool_input_rejects_schema_violating_input() {
4487 struct StrictTool;
4488 #[async_trait]
4489 impl Tool for StrictTool {
4490 fn name(&self) -> &'static str {
4491 "strict"
4492 }
4493 fn input_schema(&self) -> Option<serde_json::Value> {
4494 Some(json!({
4495 "type": "object",
4496 "required": ["path"],
4497 "properties": {
4498 "path": { "type": "string" }
4499 }
4500 }))
4501 }
4502 async fn execute(
4503 &self,
4504 _input: &str,
4505 _ctx: &ToolContext,
4506 ) -> anyhow::Result<ToolResult> {
4507 unreachable!("should not be called with invalid input")
4508 }
4509 }
4510
4511 let result = prepare_tool_input(&StrictTool, &json!({}));
4513 assert!(result.is_err());
4514 let err = result.unwrap_err();
4515 assert!(err.contains("Invalid input for tool 'strict'"));
4516 assert!(err.contains("missing required field"));
4517 }
4518
4519 #[tokio::test]
4520 async fn prepare_tool_input_rejects_wrong_type() {
4521 struct TypedTool;
4522 #[async_trait]
4523 impl Tool for TypedTool {
4524 fn name(&self) -> &'static str {
4525 "typed"
4526 }
4527 fn input_schema(&self) -> Option<serde_json::Value> {
4528 Some(json!({
4529 "type": "object",
4530 "required": ["count"],
4531 "properties": {
4532 "count": { "type": "integer" }
4533 }
4534 }))
4535 }
4536 async fn execute(
4537 &self,
4538 _input: &str,
4539 _ctx: &ToolContext,
4540 ) -> anyhow::Result<ToolResult> {
4541 unreachable!("should not be called with invalid input")
4542 }
4543 }
4544
4545 let result = prepare_tool_input(&TypedTool, &json!({"count": "not a number"}));
4547 assert!(result.is_err());
4548 let err = result.unwrap_err();
4549 assert!(err.contains("expected type \"integer\""));
4550 }
4551
4552 #[test]
4553 fn truncate_messages_preserves_small_conversation() {
4554 let mut msgs = vec![
4555 ConversationMessage::user("hello".to_string()),
4556 ConversationMessage::Assistant {
4557 content: Some("world".to_string()),
4558 tool_calls: vec![],
4559 },
4560 ];
4561 truncate_messages(&mut msgs, 1000);
4562 assert_eq!(msgs.len(), 2, "should not truncate small conversation");
4563 }
4564
4565 #[test]
4566 fn truncate_messages_drops_middle() {
4567 let mut msgs = vec![
4568 ConversationMessage::user("A".repeat(100)),
4569 ConversationMessage::Assistant {
4570 content: Some("B".repeat(100)),
4571 tool_calls: vec![],
4572 },
4573 ConversationMessage::user("C".repeat(100)),
4574 ConversationMessage::Assistant {
4575 content: Some("D".repeat(100)),
4576 tool_calls: vec![],
4577 },
4578 ];
4579 truncate_messages(&mut msgs, 250);
4582 assert!(msgs.len() < 4, "should have dropped some messages");
4583 assert!(matches!(
4585 &msgs[0],
4586 ConversationMessage::User { content, .. } if content.starts_with("AAA")
4587 ));
4588 }
4589
4590 #[test]
4591 fn memory_to_messages_reverses_order() {
4592 let entries = vec![
4593 MemoryEntry {
4594 role: "assistant".to_string(),
4595 content: "newest".to_string(),
4596 ..Default::default()
4597 },
4598 MemoryEntry {
4599 role: "user".to_string(),
4600 content: "oldest".to_string(),
4601 ..Default::default()
4602 },
4603 ];
4604 let msgs = memory_to_messages(&entries);
4605 assert_eq!(msgs.len(), 2);
4606 assert!(matches!(
4607 &msgs[0],
4608 ConversationMessage::User { content, .. } if content == "oldest"
4609 ));
4610 assert!(matches!(
4611 &msgs[1],
4612 ConversationMessage::Assistant { content: Some(c), .. } if c == "newest"
4613 ));
4614 }
4615
4616 #[test]
4617 fn single_required_string_field_detects_single() {
4618 let schema = json!({
4619 "type": "object",
4620 "properties": {
4621 "path": { "type": "string" }
4622 },
4623 "required": ["path"]
4624 });
4625 assert_eq!(
4626 single_required_string_field(&schema),
4627 Some("path".to_string())
4628 );
4629 }
4630
4631 #[test]
4632 fn single_required_string_field_none_for_multiple() {
4633 let schema = json!({
4634 "type": "object",
4635 "properties": {
4636 "path": { "type": "string" },
4637 "mode": { "type": "string" }
4638 },
4639 "required": ["path", "mode"]
4640 });
4641 assert_eq!(single_required_string_field(&schema), None);
4642 }
4643
4644 #[test]
4645 fn single_required_string_field_none_for_non_string() {
4646 let schema = json!({
4647 "type": "object",
4648 "properties": {
4649 "count": { "type": "integer" }
4650 },
4651 "required": ["count"]
4652 });
4653 assert_eq!(single_required_string_field(&schema), None);
4654 }
4655
4656 use crate::types::StreamChunk;
4659
4660 struct StreamingProvider {
4662 responses: Vec<ChatResult>,
4663 call_count: AtomicUsize,
4664 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
4665 }
4666
4667 impl StreamingProvider {
4668 fn new(responses: Vec<ChatResult>) -> Self {
4669 Self {
4670 responses,
4671 call_count: AtomicUsize::new(0),
4672 received_messages: Arc::new(Mutex::new(Vec::new())),
4673 }
4674 }
4675 }
4676
4677 #[async_trait]
4678 impl Provider for StreamingProvider {
4679 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4680 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4681 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4682 }
4683
4684 async fn complete_streaming(
4685 &self,
4686 _prompt: &str,
4687 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4688 ) -> anyhow::Result<ChatResult> {
4689 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4690 let result = self.responses.get(idx).cloned().unwrap_or_default();
4691 for word in result.output_text.split_whitespace() {
4693 let _ = sender.send(StreamChunk {
4694 delta: format!("{word} "),
4695 done: false,
4696 tool_call_delta: None,
4697 });
4698 }
4699 let _ = sender.send(StreamChunk {
4700 delta: String::new(),
4701 done: true,
4702 tool_call_delta: None,
4703 });
4704 Ok(result)
4705 }
4706
4707 async fn complete_with_tools(
4708 &self,
4709 messages: &[ConversationMessage],
4710 _tools: &[ToolDefinition],
4711 _reasoning: &ReasoningConfig,
4712 ) -> anyhow::Result<ChatResult> {
4713 self.received_messages
4714 .lock()
4715 .expect("lock")
4716 .push(messages.to_vec());
4717 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4718 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4719 }
4720
4721 async fn complete_streaming_with_tools(
4722 &self,
4723 messages: &[ConversationMessage],
4724 _tools: &[ToolDefinition],
4725 _reasoning: &ReasoningConfig,
4726 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4727 ) -> anyhow::Result<ChatResult> {
4728 self.received_messages
4729 .lock()
4730 .expect("lock")
4731 .push(messages.to_vec());
4732 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4733 let result = self.responses.get(idx).cloned().unwrap_or_default();
4734 for ch in result.output_text.chars() {
4736 let _ = sender.send(StreamChunk {
4737 delta: ch.to_string(),
4738 done: false,
4739 tool_call_delta: None,
4740 });
4741 }
4742 let _ = sender.send(StreamChunk {
4743 delta: String::new(),
4744 done: true,
4745 tool_call_delta: None,
4746 });
4747 Ok(result)
4748 }
4749 }
4750
4751 #[tokio::test]
4752 async fn streaming_text_only_sends_chunks() {
4753 let provider = StreamingProvider::new(vec![ChatResult {
4754 output_text: "Hello world".to_string(),
4755 ..Default::default()
4756 }]);
4757 let agent = Agent::new(
4758 AgentConfig {
4759 model_supports_tool_use: false,
4760 ..Default::default()
4761 },
4762 Box::new(provider),
4763 Box::new(TestMemory::default()),
4764 vec![],
4765 );
4766
4767 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4768 let response = agent
4769 .respond_streaming(
4770 UserMessage {
4771 text: "hi".to_string(),
4772 },
4773 &test_ctx(),
4774 tx,
4775 )
4776 .await
4777 .expect("should succeed");
4778
4779 assert_eq!(response.text, "Hello world");
4780
4781 let mut chunks = Vec::new();
4783 while let Ok(chunk) = rx.try_recv() {
4784 chunks.push(chunk);
4785 }
4786 assert!(chunks.len() >= 2, "should have at least text + done chunks");
4787 assert!(chunks.last().unwrap().done, "last chunk should be done");
4788 }
4789
4790 #[tokio::test]
4791 async fn streaming_single_tool_call_round_trip() {
4792 let provider = StreamingProvider::new(vec![
4793 ChatResult {
4794 output_text: "I'll echo that.".to_string(),
4795 tool_calls: vec![ToolUseRequest {
4796 id: "call_1".to_string(),
4797 name: "echo".to_string(),
4798 input: json!({"text": "hello"}),
4799 }],
4800 stop_reason: Some(StopReason::ToolUse),
4801 ..Default::default()
4802 },
4803 ChatResult {
4804 output_text: "Done echoing.".to_string(),
4805 tool_calls: vec![],
4806 stop_reason: Some(StopReason::EndTurn),
4807 ..Default::default()
4808 },
4809 ]);
4810
4811 let agent = Agent::new(
4812 AgentConfig::default(),
4813 Box::new(provider),
4814 Box::new(TestMemory::default()),
4815 vec![Box::new(StructuredEchoTool)],
4816 );
4817
4818 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4819 let response = agent
4820 .respond_streaming(
4821 UserMessage {
4822 text: "echo hello".to_string(),
4823 },
4824 &test_ctx(),
4825 tx,
4826 )
4827 .await
4828 .expect("should succeed");
4829
4830 assert_eq!(response.text, "Done echoing.");
4831
4832 let mut chunks = Vec::new();
4833 while let Ok(chunk) = rx.try_recv() {
4834 chunks.push(chunk);
4835 }
4836 assert!(chunks.len() >= 2, "should have streaming chunks");
4838 }
4839
4840 #[tokio::test]
4841 async fn streaming_multi_iteration_tool_calls() {
4842 let provider = StreamingProvider::new(vec![
4843 ChatResult {
4844 output_text: String::new(),
4845 tool_calls: vec![ToolUseRequest {
4846 id: "call_1".to_string(),
4847 name: "echo".to_string(),
4848 input: json!({"text": "first"}),
4849 }],
4850 stop_reason: Some(StopReason::ToolUse),
4851 ..Default::default()
4852 },
4853 ChatResult {
4854 output_text: String::new(),
4855 tool_calls: vec![ToolUseRequest {
4856 id: "call_2".to_string(),
4857 name: "upper".to_string(),
4858 input: json!({"text": "second"}),
4859 }],
4860 stop_reason: Some(StopReason::ToolUse),
4861 ..Default::default()
4862 },
4863 ChatResult {
4864 output_text: "All done.".to_string(),
4865 tool_calls: vec![],
4866 stop_reason: Some(StopReason::EndTurn),
4867 ..Default::default()
4868 },
4869 ]);
4870
4871 let agent = Agent::new(
4872 AgentConfig::default(),
4873 Box::new(provider),
4874 Box::new(TestMemory::default()),
4875 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4876 );
4877
4878 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4879 let response = agent
4880 .respond_streaming(
4881 UserMessage {
4882 text: "do two things".to_string(),
4883 },
4884 &test_ctx(),
4885 tx,
4886 )
4887 .await
4888 .expect("should succeed");
4889
4890 assert_eq!(response.text, "All done.");
4891
4892 let mut chunks = Vec::new();
4893 while let Ok(chunk) = rx.try_recv() {
4894 chunks.push(chunk);
4895 }
4896 let done_count = chunks.iter().filter(|c| c.done).count();
4898 assert!(
4899 done_count >= 3,
4900 "should have done chunks from 3 provider calls, got {}",
4901 done_count
4902 );
4903 }
4904
4905 #[tokio::test]
4906 async fn streaming_timeout_returns_error() {
4907 struct StreamingSlowProvider;
4908
4909 #[async_trait]
4910 impl Provider for StreamingSlowProvider {
4911 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4912 sleep(Duration::from_millis(500)).await;
4913 Ok(ChatResult::default())
4914 }
4915
4916 async fn complete_streaming(
4917 &self,
4918 _prompt: &str,
4919 _sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4920 ) -> anyhow::Result<ChatResult> {
4921 sleep(Duration::from_millis(500)).await;
4922 Ok(ChatResult::default())
4923 }
4924 }
4925
4926 let agent = Agent::new(
4927 AgentConfig {
4928 request_timeout_ms: 50,
4929 model_supports_tool_use: false,
4930 ..Default::default()
4931 },
4932 Box::new(StreamingSlowProvider),
4933 Box::new(TestMemory::default()),
4934 vec![],
4935 );
4936
4937 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4938 let result = agent
4939 .respond_streaming(
4940 UserMessage {
4941 text: "hi".to_string(),
4942 },
4943 &test_ctx(),
4944 tx,
4945 )
4946 .await;
4947
4948 assert!(result.is_err());
4949 match result.unwrap_err() {
4950 AgentError::Timeout { timeout_ms } => assert_eq!(timeout_ms, 50),
4951 other => panic!("expected Timeout, got {other:?}"),
4952 }
4953 }
4954
4955 #[tokio::test]
4956 async fn streaming_no_schema_fallback_sends_chunks() {
4957 let provider = StreamingProvider::new(vec![ChatResult {
4959 output_text: "Fallback response".to_string(),
4960 ..Default::default()
4961 }]);
4962 let agent = Agent::new(
4963 AgentConfig::default(), Box::new(provider),
4965 Box::new(TestMemory::default()),
4966 vec![Box::new(EchoTool)], );
4968
4969 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4970 let response = agent
4971 .respond_streaming(
4972 UserMessage {
4973 text: "hi".to_string(),
4974 },
4975 &test_ctx(),
4976 tx,
4977 )
4978 .await
4979 .expect("should succeed");
4980
4981 assert_eq!(response.text, "Fallback response");
4982
4983 let mut chunks = Vec::new();
4984 while let Ok(chunk) = rx.try_recv() {
4985 chunks.push(chunk);
4986 }
4987 assert!(!chunks.is_empty(), "should have streaming chunks");
4988 assert!(chunks.last().unwrap().done, "last chunk should be done");
4989 }
4990
4991 #[tokio::test]
4992 async fn streaming_tool_error_does_not_abort() {
4993 let provider = StreamingProvider::new(vec![
4994 ChatResult {
4995 output_text: String::new(),
4996 tool_calls: vec![ToolUseRequest {
4997 id: "call_1".to_string(),
4998 name: "boom".to_string(),
4999 input: json!({}),
5000 }],
5001 stop_reason: Some(StopReason::ToolUse),
5002 ..Default::default()
5003 },
5004 ChatResult {
5005 output_text: "Recovered from error.".to_string(),
5006 tool_calls: vec![],
5007 stop_reason: Some(StopReason::EndTurn),
5008 ..Default::default()
5009 },
5010 ]);
5011
5012 let agent = Agent::new(
5013 AgentConfig::default(),
5014 Box::new(provider),
5015 Box::new(TestMemory::default()),
5016 vec![Box::new(StructuredFailingTool)],
5017 );
5018
5019 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
5020 let response = agent
5021 .respond_streaming(
5022 UserMessage {
5023 text: "boom".to_string(),
5024 },
5025 &test_ctx(),
5026 tx,
5027 )
5028 .await
5029 .expect("should succeed despite tool error");
5030
5031 assert_eq!(response.text, "Recovered from error.");
5032 }
5033
5034 #[tokio::test]
5035 async fn streaming_parallel_tools() {
5036 let provider = StreamingProvider::new(vec![
5037 ChatResult {
5038 output_text: String::new(),
5039 tool_calls: vec![
5040 ToolUseRequest {
5041 id: "call_1".to_string(),
5042 name: "echo".to_string(),
5043 input: json!({"text": "a"}),
5044 },
5045 ToolUseRequest {
5046 id: "call_2".to_string(),
5047 name: "upper".to_string(),
5048 input: json!({"text": "b"}),
5049 },
5050 ],
5051 stop_reason: Some(StopReason::ToolUse),
5052 ..Default::default()
5053 },
5054 ChatResult {
5055 output_text: "Parallel done.".to_string(),
5056 tool_calls: vec![],
5057 stop_reason: Some(StopReason::EndTurn),
5058 ..Default::default()
5059 },
5060 ]);
5061
5062 let agent = Agent::new(
5063 AgentConfig {
5064 parallel_tools: true,
5065 ..Default::default()
5066 },
5067 Box::new(provider),
5068 Box::new(TestMemory::default()),
5069 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
5070 );
5071
5072 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
5073 let response = agent
5074 .respond_streaming(
5075 UserMessage {
5076 text: "parallel test".to_string(),
5077 },
5078 &test_ctx(),
5079 tx,
5080 )
5081 .await
5082 .expect("should succeed");
5083
5084 assert_eq!(response.text, "Parallel done.");
5085 }
5086
5087 #[tokio::test]
5088 async fn streaming_done_chunk_sentinel() {
5089 let provider = StreamingProvider::new(vec![ChatResult {
5090 output_text: "abc".to_string(),
5091 tool_calls: vec![],
5092 stop_reason: Some(StopReason::EndTurn),
5093 ..Default::default()
5094 }]);
5095
5096 let agent = Agent::new(
5097 AgentConfig::default(),
5098 Box::new(provider),
5099 Box::new(TestMemory::default()),
5100 vec![Box::new(StructuredEchoTool)],
5101 );
5102
5103 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
5104 let _response = agent
5105 .respond_streaming(
5106 UserMessage {
5107 text: "test".to_string(),
5108 },
5109 &test_ctx(),
5110 tx,
5111 )
5112 .await
5113 .expect("should succeed");
5114
5115 let mut chunks = Vec::new();
5116 while let Ok(chunk) = rx.try_recv() {
5117 chunks.push(chunk);
5118 }
5119
5120 let done_chunks: Vec<_> = chunks.iter().filter(|c| c.done).collect();
5122 assert!(
5123 !done_chunks.is_empty(),
5124 "must have at least one done sentinel chunk"
5125 );
5126 assert!(chunks.last().unwrap().done, "final chunk must be done=true");
5128 let content_chunks: Vec<_> = chunks
5130 .iter()
5131 .filter(|c| !c.done && !c.delta.is_empty())
5132 .collect();
5133 assert!(!content_chunks.is_empty(), "should have content chunks");
5134 }
5135
5136 #[tokio::test]
5139 async fn system_prompt_prepended_in_structured_path() {
5140 let provider = StructuredProvider::new(vec![ChatResult {
5141 output_text: "I understand.".to_string(),
5142 stop_reason: Some(StopReason::EndTurn),
5143 ..Default::default()
5144 }]);
5145 let received = provider.received_messages.clone();
5146 let memory = TestMemory::default();
5147
5148 let agent = Agent::new(
5149 AgentConfig {
5150 model_supports_tool_use: true,
5151 system_prompt: Some("You are a math tutor.".to_string()),
5152 ..AgentConfig::default()
5153 },
5154 Box::new(provider),
5155 Box::new(memory),
5156 vec![Box::new(StructuredEchoTool)],
5157 );
5158
5159 let result = agent
5160 .respond(
5161 UserMessage {
5162 text: "What is 2+2?".to_string(),
5163 },
5164 &test_ctx(),
5165 )
5166 .await
5167 .expect("respond should succeed");
5168
5169 assert_eq!(result.text, "I understand.");
5170
5171 let msgs = received.lock().expect("lock");
5172 assert!(!msgs.is_empty(), "provider should have been called");
5173 let first_call = &msgs[0];
5174 match &first_call[0] {
5176 ConversationMessage::System { content } => {
5177 assert_eq!(content, "You are a math tutor.");
5178 }
5179 other => panic!("expected System message first, got {other:?}"),
5180 }
5181 match &first_call[1] {
5183 ConversationMessage::User { content, .. } => {
5184 assert_eq!(content, "What is 2+2?");
5185 }
5186 other => panic!("expected User message second, got {other:?}"),
5187 }
5188 }
5189
5190 #[tokio::test]
5191 async fn no_system_prompt_omits_system_message() {
5192 let provider = StructuredProvider::new(vec![ChatResult {
5193 output_text: "ok".to_string(),
5194 stop_reason: Some(StopReason::EndTurn),
5195 ..Default::default()
5196 }]);
5197 let received = provider.received_messages.clone();
5198 let memory = TestMemory::default();
5199
5200 let agent = Agent::new(
5201 AgentConfig {
5202 model_supports_tool_use: true,
5203 system_prompt: None,
5204 ..AgentConfig::default()
5205 },
5206 Box::new(provider),
5207 Box::new(memory),
5208 vec![Box::new(StructuredEchoTool)],
5209 );
5210
5211 agent
5212 .respond(
5213 UserMessage {
5214 text: "hello".to_string(),
5215 },
5216 &test_ctx(),
5217 )
5218 .await
5219 .expect("respond should succeed");
5220
5221 let msgs = received.lock().expect("lock");
5222 let first_call = &msgs[0];
5223 match &first_call[0] {
5225 ConversationMessage::User { content, .. } => {
5226 assert_eq!(content, "hello");
5227 }
5228 other => panic!("expected User message first (no system prompt), got {other:?}"),
5229 }
5230 }
5231
5232 #[tokio::test]
5233 async fn system_prompt_persists_across_tool_iterations() {
5234 let provider = StructuredProvider::new(vec![
5236 ChatResult {
5237 output_text: String::new(),
5238 tool_calls: vec![ToolUseRequest {
5239 id: "call_1".to_string(),
5240 name: "echo".to_string(),
5241 input: serde_json::json!({"text": "ping"}),
5242 }],
5243 stop_reason: Some(StopReason::ToolUse),
5244 ..Default::default()
5245 },
5246 ChatResult {
5247 output_text: "pong".to_string(),
5248 stop_reason: Some(StopReason::EndTurn),
5249 ..Default::default()
5250 },
5251 ]);
5252 let received = provider.received_messages.clone();
5253 let memory = TestMemory::default();
5254
5255 let agent = Agent::new(
5256 AgentConfig {
5257 model_supports_tool_use: true,
5258 system_prompt: Some("Always be concise.".to_string()),
5259 ..AgentConfig::default()
5260 },
5261 Box::new(provider),
5262 Box::new(memory),
5263 vec![Box::new(StructuredEchoTool)],
5264 );
5265
5266 let result = agent
5267 .respond(
5268 UserMessage {
5269 text: "test".to_string(),
5270 },
5271 &test_ctx(),
5272 )
5273 .await
5274 .expect("respond should succeed");
5275 assert_eq!(result.text, "pong");
5276
5277 let msgs = received.lock().expect("lock");
5278 for (i, call_msgs) in msgs.iter().enumerate() {
5280 match &call_msgs[0] {
5281 ConversationMessage::System { content } => {
5282 assert_eq!(content, "Always be concise.", "call {i} system prompt");
5283 }
5284 other => panic!("call {i}: expected System first, got {other:?}"),
5285 }
5286 }
5287 }
5288
5289 #[test]
5290 fn agent_config_system_prompt_defaults_to_none() {
5291 let config = AgentConfig::default();
5292 assert!(config.system_prompt.is_none());
5293 }
5294
5295 #[tokio::test]
5296 async fn memory_write_propagates_source_channel_from_context() {
5297 let memory = TestMemory::default();
5298 let entries = memory.entries.clone();
5299 let provider = StructuredProvider::new(vec![ChatResult {
5300 output_text: "noted".to_string(),
5301 tool_calls: vec![],
5302 stop_reason: None,
5303 ..Default::default()
5304 }]);
5305 let config = AgentConfig {
5306 privacy_boundary: "encrypted_only".to_string(),
5307 ..AgentConfig::default()
5308 };
5309 let agent = Agent::new(config, Box::new(provider), Box::new(memory), vec![]);
5310
5311 let mut ctx = ToolContext::new(".".to_string());
5312 ctx.source_channel = Some("telegram".to_string());
5313 ctx.privacy_boundary = "encrypted_only".to_string();
5314
5315 agent
5316 .respond(
5317 UserMessage {
5318 text: "hello".to_string(),
5319 },
5320 &ctx,
5321 )
5322 .await
5323 .expect("respond should succeed");
5324
5325 let stored = entries.lock().expect("memory lock poisoned");
5326 assert_eq!(stored[0].source_channel.as_deref(), Some("telegram"));
5328 assert_eq!(stored[0].privacy_boundary, "encrypted_only");
5329 assert_eq!(stored[1].source_channel.as_deref(), Some("telegram"));
5330 assert_eq!(stored[1].privacy_boundary, "encrypted_only");
5331 }
5332
5333 #[tokio::test]
5334 async fn memory_write_source_channel_none_when_ctx_empty() {
5335 let memory = TestMemory::default();
5336 let entries = memory.entries.clone();
5337 let provider = StructuredProvider::new(vec![ChatResult {
5338 output_text: "ok".to_string(),
5339 tool_calls: vec![],
5340 stop_reason: None,
5341 ..Default::default()
5342 }]);
5343 let agent = Agent::new(
5344 AgentConfig::default(),
5345 Box::new(provider),
5346 Box::new(memory),
5347 vec![],
5348 );
5349
5350 agent
5351 .respond(
5352 UserMessage {
5353 text: "hi".to_string(),
5354 },
5355 &test_ctx(),
5356 )
5357 .await
5358 .expect("respond should succeed");
5359
5360 let stored = entries.lock().expect("memory lock poisoned");
5361 assert!(stored[0].source_channel.is_none());
5362 assert!(stored[1].source_channel.is_none());
5363 }
5364
5365 fn sample_tools() -> Vec<ToolDefinition> {
5370 vec![
5371 ToolDefinition {
5372 name: "web_search".to_string(),
5373 description: "Search the web".to_string(),
5374 input_schema: serde_json::json!({"type": "object"}),
5375 },
5376 ToolDefinition {
5377 name: "read_file".to_string(),
5378 description: "Read a file".to_string(),
5379 input_schema: serde_json::json!({"type": "object"}),
5380 },
5381 ]
5382 }
5383
5384 #[test]
5385 fn extract_tool_call_from_json_code_block() {
5386 let text = "I'll search for that.\n```json\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"AI regulation EU\"}}\n```";
5387 let tools = sample_tools();
5388 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5389 assert_eq!(result.name, "web_search");
5390 assert_eq!(result.input["query"], "AI regulation EU");
5391 }
5392
5393 #[test]
5394 fn extract_tool_call_from_bare_code_block() {
5395 let text = "```\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"test\"}}\n```";
5396 let tools = sample_tools();
5397 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5398 assert_eq!(result.name, "web_search");
5399 }
5400
5401 #[test]
5402 fn extract_tool_call_from_bare_json() {
5403 let text = "{\"name\": \"web_search\", \"arguments\": {\"query\": \"test\"}}";
5404 let tools = sample_tools();
5405 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5406 assert_eq!(result.name, "web_search");
5407 assert_eq!(result.input["query"], "test");
5408 }
5409
5410 #[test]
5411 fn extract_tool_call_with_parameters_key() {
5412 let text = "{\"name\": \"web_search\", \"parameters\": {\"query\": \"test\"}}";
5413 let tools = sample_tools();
5414 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5415 assert_eq!(result.name, "web_search");
5416 assert_eq!(result.input["query"], "test");
5417 }
5418
5419 #[test]
5420 fn extract_tool_call_ignores_unknown_tool() {
5421 let text = "{\"name\": \"unknown_tool\", \"arguments\": {}}";
5422 let tools = sample_tools();
5423 assert!(extract_tool_call_from_text(text, &tools).is_none());
5424 }
5425
5426 #[test]
5427 fn extract_tool_call_returns_none_for_plain_text() {
5428 let text = "I don't know how to help with that.";
5429 let tools = sample_tools();
5430 assert!(extract_tool_call_from_text(text, &tools).is_none());
5431 }
5432
5433 #[test]
5434 fn extract_tool_call_returns_none_for_non_tool_json() {
5435 let text = "{\"message\": \"hello\", \"status\": \"ok\"}";
5436 let tools = sample_tools();
5437 assert!(extract_tool_call_from_text(text, &tools).is_none());
5438 }
5439
5440 #[test]
5441 fn extract_tool_call_with_surrounding_text() {
5442 let text = "Sure, let me search for that.\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"AI regulation\"}}\nI'll get back to you.";
5443 let tools = sample_tools();
5444 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5445 assert_eq!(result.name, "web_search");
5446 }
5447
5448 #[test]
5449 fn extract_tool_call_no_arguments_field() {
5450 let text = "{\"name\": \"web_search\"}";
5451 let tools = sample_tools();
5452 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5453 assert_eq!(result.name, "web_search");
5454 assert!(result.input.is_object());
5455 }
5456
5457 #[test]
5458 fn extract_json_block_handles_nested_braces() {
5459 let text =
5460 "```json\n{\"name\": \"read_file\", \"arguments\": {\"path\": \"/tmp/{test}\"}}\n```";
5461 let tools = sample_tools();
5462 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5463 assert_eq!(result.name, "read_file");
5464 assert_eq!(result.input["path"], "/tmp/{test}");
5465 }
5466
5467 #[test]
5468 fn extract_bare_json_handles_escaped_quotes() {
5469 let text =
5470 "{\"name\": \"web_search\", \"arguments\": {\"query\": \"test \\\"quoted\\\" word\"}}";
5471 let tools = sample_tools();
5472 let result = extract_tool_call_from_text(text, &tools).expect("should extract");
5473 assert_eq!(result.name, "web_search");
5474 }
5475
5476 #[test]
5477 fn extract_tool_call_empty_string() {
5478 let tools = sample_tools();
5479 assert!(extract_tool_call_from_text("", &tools).is_none());
5480 }
5481
5482 struct TextToolCallProvider {
5489 call_count: AtomicUsize,
5490 }
5491
5492 impl TextToolCallProvider {
5493 fn new() -> Self {
5494 Self {
5495 call_count: AtomicUsize::new(0),
5496 }
5497 }
5498 }
5499
5500 #[async_trait]
5501 impl Provider for TextToolCallProvider {
5502 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
5503 Ok(ChatResult::default())
5504 }
5505
5506 async fn complete_with_tools(
5507 &self,
5508 messages: &[ConversationMessage],
5509 _tools: &[ToolDefinition],
5510 _reasoning: &ReasoningConfig,
5511 ) -> anyhow::Result<ChatResult> {
5512 let n = self.call_count.fetch_add(1, Ordering::Relaxed);
5513 if n == 0 {
5514 Ok(ChatResult {
5516 output_text: "```json\n{\"name\": \"echo\", \"arguments\": {\"message\": \"hello from text\"}}\n```".to_string(),
5517 tool_calls: vec![],
5518 stop_reason: Some(StopReason::EndTurn),
5519 ..Default::default()
5520 })
5521 } else {
5522 let tool_output = messages
5524 .iter()
5525 .rev()
5526 .find_map(|m| match m {
5527 ConversationMessage::ToolResult(r) => Some(r.content.as_str()),
5528 _ => None,
5529 })
5530 .unwrap_or("no tool result");
5531 Ok(ChatResult {
5532 output_text: format!("Got it: {tool_output}"),
5533 tool_calls: vec![],
5534 stop_reason: Some(StopReason::EndTurn),
5535 ..Default::default()
5536 })
5537 }
5538 }
5539 }
5540
5541 struct TestEchoTool;
5543
5544 #[async_trait]
5545 impl crate::types::Tool for TestEchoTool {
5546 fn name(&self) -> &'static str {
5547 "echo"
5548 }
5549 fn description(&self) -> &'static str {
5550 "Echo back the message"
5551 }
5552 fn input_schema(&self) -> Option<serde_json::Value> {
5553 Some(serde_json::json!({
5554 "type": "object",
5555 "required": ["message"],
5556 "properties": {
5557 "message": {"type": "string"}
5558 }
5559 }))
5560 }
5561 async fn execute(
5562 &self,
5563 input: &str,
5564 _ctx: &ToolContext,
5565 ) -> anyhow::Result<crate::types::ToolResult> {
5566 let v: serde_json::Value =
5567 serde_json::from_str(input).unwrap_or(Value::String(input.to_string()));
5568 let msg = v.get("message").and_then(|m| m.as_str()).unwrap_or(input);
5569 Ok(crate::types::ToolResult {
5570 output: format!("echoed:{msg}"),
5571 })
5572 }
5573 }
5574
5575 #[tokio::test]
5576 async fn text_tool_extraction_dispatches_through_agent_loop() {
5577 let agent = Agent::new(
5578 AgentConfig {
5579 model_supports_tool_use: true,
5580 max_tool_iterations: 5,
5581 request_timeout_ms: 10_000,
5582 ..AgentConfig::default()
5583 },
5584 Box::new(TextToolCallProvider::new()),
5585 Box::new(TestMemory::default()),
5586 vec![Box::new(TestEchoTool)],
5587 );
5588
5589 let response = agent
5590 .respond(
5591 UserMessage {
5592 text: "echo hello".to_string(),
5593 },
5594 &test_ctx(),
5595 )
5596 .await
5597 .expect("should succeed with text-extracted tool call");
5598
5599 assert!(
5600 response.text.contains("echoed:hello from text"),
5601 "expected tool result in response, got: {}",
5602 response.text
5603 );
5604 }
5605
5606 #[tokio::test]
5607 async fn tool_execution_timeout_fires() {
5608 let memory = TestMemory::default();
5609 let agent = Agent::new(
5610 AgentConfig {
5611 max_tool_iterations: 3,
5612 tool_timeout_ms: 50, ..AgentConfig::default()
5614 },
5615 Box::new(ScriptedProvider::new(vec!["done"])),
5616 Box::new(memory),
5617 vec![Box::new(SlowTool)],
5618 );
5619
5620 let result = agent
5621 .respond(
5622 UserMessage {
5623 text: "tool:slow go".to_string(),
5624 },
5625 &test_ctx(),
5626 )
5627 .await;
5628
5629 let err = result.expect_err("should time out");
5630 let msg = format!("{err}");
5631 assert!(
5632 msg.contains("timed out"),
5633 "expected timeout error, got: {msg}"
5634 );
5635 }
5636
5637 #[tokio::test]
5638 async fn tool_execution_no_timeout_when_disabled() {
5639 let memory = TestMemory::default();
5640 let agent = Agent::new(
5641 AgentConfig {
5642 max_tool_iterations: 3,
5643 tool_timeout_ms: 0, ..AgentConfig::default()
5645 },
5646 Box::new(ScriptedProvider::new(vec!["done"])),
5647 Box::new(memory),
5648 vec![Box::new(EchoTool)],
5649 );
5650
5651 let response = agent
5652 .respond(
5653 UserMessage {
5654 text: "tool:echo hello".to_string(),
5655 },
5656 &test_ctx(),
5657 )
5658 .await
5659 .expect("should succeed with timeout disabled");
5660
5661 assert_eq!(response.text, "done");
5664 }
5665
5666 #[tokio::test]
5669 async fn tool_execute_span_is_created_for_tool_call() {
5670 let subscriber = tracing_subscriber::fmt()
5673 .with_max_level(tracing::Level::TRACE)
5674 .with_writer(std::io::sink)
5675 .finish();
5676 let _guard = tracing::subscriber::set_default(subscriber);
5677
5678 let agent = Agent::new(
5679 AgentConfig {
5680 tool_timeout_ms: 5000,
5681 ..AgentConfig::default()
5682 },
5683 Box::new(ScriptedProvider::new(vec!["done"])),
5684 Box::new(TestMemory::default()),
5685 vec![Box::new(EchoTool)],
5686 );
5687
5688 let response = agent
5690 .respond(
5691 UserMessage {
5692 text: "tool:echo test-span".to_string(),
5693 },
5694 &test_ctx(),
5695 )
5696 .await
5697 .expect("should succeed");
5698
5699 assert_eq!(response.text, "done");
5702 }
5703
5704 #[tokio::test]
5705 async fn tool_execute_span_works_without_timeout() {
5706 let agent = Agent::new(
5708 AgentConfig {
5709 tool_timeout_ms: 0, ..AgentConfig::default()
5711 },
5712 Box::new(ScriptedProvider::new(vec!["no-timeout-done"])),
5713 Box::new(TestMemory::default()),
5714 vec![Box::new(EchoTool)],
5715 );
5716
5717 let response = agent
5718 .respond(
5719 UserMessage {
5720 text: "tool:echo span-test".to_string(),
5721 },
5722 &test_ctx(),
5723 )
5724 .await
5725 .expect("should succeed without timeout");
5726
5727 assert_eq!(response.text, "no-timeout-done");
5728 }
5729}