1use crate::common::privacy_helpers::{is_network_tool, resolve_boundary};
2use crate::loop_detection::{LoopDetectionConfig, ToolLoopDetector};
3use crate::security::redaction::redact_text;
4use crate::types::{
5 AgentConfig, AgentError, AssistantMessage, AuditEvent, AuditSink, ConversationMessage,
6 HookEvent, HookFailureMode, HookRiskTier, HookSink, LoopAction, MemoryEntry, MemoryStore,
7 MetricsSink, Provider, ResearchTrigger, StopReason, StreamSink, Tool, ToolContext,
8 ToolDefinition, ToolResultMessage, UserMessage,
9};
10use crate::validation::validate_json;
11use serde_json::json;
12use std::sync::atomic::{AtomicU64, Ordering};
13use std::time::{Instant, SystemTime, UNIX_EPOCH};
14use tokio::time::{timeout, Duration};
15use tracing::{info, info_span, warn, Instrument};
16
17static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(0);
18
19fn cap_prompt(input: &str, max_chars: usize) -> (String, bool) {
20 let len = input.chars().count();
21 if len <= max_chars {
22 return (input.to_string(), false);
23 }
24 let truncated = input
25 .chars()
26 .rev()
27 .take(max_chars)
28 .collect::<Vec<_>>()
29 .into_iter()
30 .rev()
31 .collect::<String>();
32 (truncated, true)
33}
34
35fn build_provider_prompt(
36 current_prompt: &str,
37 recent_memory: &[MemoryEntry],
38 max_prompt_chars: usize,
39) -> (String, bool) {
40 if recent_memory.is_empty() {
41 return cap_prompt(current_prompt, max_prompt_chars);
42 }
43
44 let mut prompt = String::from("Recent memory:\n");
45 for entry in recent_memory.iter().rev() {
46 prompt.push_str("- ");
47 prompt.push_str(&entry.role);
48 prompt.push_str(": ");
49 prompt.push_str(&entry.content);
50 prompt.push('\n');
51 }
52 prompt.push_str("\nCurrent input:\n");
53 prompt.push_str(current_prompt);
54
55 cap_prompt(&prompt, max_prompt_chars)
56}
57
58fn parse_tool_calls(prompt: &str) -> Vec<(&str, &str)> {
60 prompt
61 .lines()
62 .filter_map(|line| {
63 line.strip_prefix("tool:").map(|rest| {
64 let rest = rest.trim();
65 let mut parts = rest.splitn(2, ' ');
66 let name = parts.next().unwrap_or_default();
67 let input = parts.next().unwrap_or_default();
68 (name, input)
69 })
70 })
71 .collect()
72}
73
74fn single_required_string_field(schema: &serde_json::Value) -> Option<String> {
76 let required = schema.get("required")?.as_array()?;
77 if required.len() != 1 {
78 return None;
79 }
80 let field_name = required[0].as_str()?;
81 let properties = schema.get("properties")?.as_object()?;
82 let field_schema = properties.get(field_name)?;
83 if field_schema.get("type")?.as_str()? == "string" {
84 Some(field_name.to_string())
85 } else {
86 None
87 }
88}
89
90fn prepare_tool_input(tool: &dyn Tool, raw_input: &serde_json::Value) -> Result<String, String> {
99 if let Some(s) = raw_input.as_str() {
101 return Ok(s.to_string());
102 }
103
104 if let Some(schema) = tool.input_schema() {
106 if let Err(errors) = validate_json(raw_input, &schema) {
107 return Err(format!(
108 "Invalid input for tool '{}': {}",
109 tool.name(),
110 errors.join("; ")
111 ));
112 }
113
114 if let Some(field) = single_required_string_field(&schema) {
115 if let Some(val) = raw_input.get(&field).and_then(|v| v.as_str()) {
116 return Ok(val.to_string());
117 }
118 }
119 }
120
121 Ok(serde_json::to_string(raw_input).unwrap_or_default())
122}
123
124fn truncate_messages(messages: &mut Vec<ConversationMessage>, max_chars: usize) {
128 let total_chars: usize = messages.iter().map(|m| m.char_count()).sum();
129 if total_chars <= max_chars || messages.len() <= 2 {
130 return;
131 }
132
133 let first_cost = messages[0].char_count();
134 let mut budget = max_chars.saturating_sub(first_cost);
135
136 let mut keep_from_end = 0;
138 for msg in messages.iter().rev() {
139 let cost = msg.char_count();
140 if cost > budget {
141 break;
142 }
143 budget -= cost;
144 keep_from_end += 1;
145 }
146
147 if keep_from_end == 0 {
148 let last = messages
150 .pop()
151 .expect("messages must have at least 2 entries");
152 messages.truncate(1);
153 messages.push(last);
154 return;
155 }
156
157 let split_point = messages.len() - keep_from_end;
158 if split_point <= 1 {
159 return;
160 }
161
162 messages.drain(1..split_point);
163}
164
165fn memory_to_messages(entries: &[MemoryEntry]) -> Vec<ConversationMessage> {
168 entries
169 .iter()
170 .rev()
171 .map(|entry| {
172 if entry.role == "assistant" {
173 ConversationMessage::Assistant {
174 content: Some(entry.content.clone()),
175 tool_calls: vec![],
176 }
177 } else {
178 ConversationMessage::user(entry.content.clone())
179 }
180 })
181 .collect()
182}
183
184pub struct Agent {
185 config: AgentConfig,
186 provider: Box<dyn Provider>,
187 memory: Box<dyn MemoryStore>,
188 tools: Vec<Box<dyn Tool>>,
189 audit: Option<Box<dyn AuditSink>>,
190 hooks: Option<Box<dyn HookSink>>,
191 metrics: Option<Box<dyn MetricsSink>>,
192 loop_detection_config: Option<LoopDetectionConfig>,
193}
194
195impl Agent {
196 pub fn new(
197 config: AgentConfig,
198 provider: Box<dyn Provider>,
199 memory: Box<dyn MemoryStore>,
200 tools: Vec<Box<dyn Tool>>,
201 ) -> Self {
202 Self {
203 config,
204 provider,
205 memory,
206 tools,
207 audit: None,
208 hooks: None,
209 metrics: None,
210 loop_detection_config: None,
211 }
212 }
213
214 pub fn with_loop_detection(mut self, config: LoopDetectionConfig) -> Self {
216 self.loop_detection_config = Some(config);
217 self
218 }
219
220 pub fn with_audit(mut self, audit: Box<dyn AuditSink>) -> Self {
221 self.audit = Some(audit);
222 self
223 }
224
225 pub fn with_hooks(mut self, hooks: Box<dyn HookSink>) -> Self {
226 self.hooks = Some(hooks);
227 self
228 }
229
230 pub fn with_metrics(mut self, metrics: Box<dyn MetricsSink>) -> Self {
231 self.metrics = Some(metrics);
232 self
233 }
234
235 pub fn add_tool(&mut self, tool: Box<dyn Tool>) {
237 self.tools.push(tool);
238 }
239
240 fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
242 self.tools
243 .iter()
244 .filter_map(|tool| ToolDefinition::from_tool(&**tool))
245 .collect()
246 }
247
248 fn has_tool_definitions(&self) -> bool {
250 self.tools.iter().any(|t| t.input_schema().is_some())
251 }
252
253 async fn audit(&self, stage: &str, detail: serde_json::Value) {
254 if let Some(sink) = &self.audit {
255 let _ = sink
256 .record(AuditEvent {
257 stage: stage.to_string(),
258 detail,
259 })
260 .await;
261 }
262 }
263
264 fn next_request_id() -> String {
265 let ts_ms = SystemTime::now()
266 .duration_since(UNIX_EPOCH)
267 .unwrap_or_default()
268 .as_millis();
269 let seq = REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed);
270 format!("req-{ts_ms}-{seq}")
271 }
272
273 fn hook_risk_tier(stage: &str) -> HookRiskTier {
274 if matches!(
275 stage,
276 "before_tool_call" | "after_tool_call" | "before_plugin_call" | "after_plugin_call"
277 ) {
278 HookRiskTier::High
279 } else if matches!(
280 stage,
281 "before_provider_call"
282 | "after_provider_call"
283 | "before_memory_write"
284 | "after_memory_write"
285 | "before_run"
286 | "after_run"
287 ) {
288 HookRiskTier::Medium
289 } else {
290 HookRiskTier::Low
291 }
292 }
293
294 fn hook_failure_mode_for_stage(&self, stage: &str) -> HookFailureMode {
295 if self.config.hooks.fail_closed {
296 return HookFailureMode::Block;
297 }
298
299 match Self::hook_risk_tier(stage) {
300 HookRiskTier::Low => self.config.hooks.low_tier_mode,
301 HookRiskTier::Medium => self.config.hooks.medium_tier_mode,
302 HookRiskTier::High => self.config.hooks.high_tier_mode,
303 }
304 }
305
306 async fn hook(&self, stage: &str, detail: serde_json::Value) -> Result<(), AgentError> {
307 if !self.config.hooks.enabled {
308 return Ok(());
309 }
310 let Some(sink) = &self.hooks else {
311 return Ok(());
312 };
313
314 let event = HookEvent {
315 stage: stage.to_string(),
316 detail,
317 };
318 let hook_call = sink.record(event);
319 let mode = self.hook_failure_mode_for_stage(stage);
320 let tier = Self::hook_risk_tier(stage);
321 match timeout(
322 Duration::from_millis(self.config.hooks.timeout_ms),
323 hook_call,
324 )
325 .await
326 {
327 Ok(Ok(())) => Ok(()),
328 Ok(Err(err)) => {
329 let redacted = redact_text(&err.to_string());
330 match mode {
331 HookFailureMode::Block => Err(AgentError::Hook {
332 stage: stage.to_string(),
333 source: err,
334 }),
335 HookFailureMode::Warn => {
336 warn!(
337 stage = stage,
338 tier = ?tier,
339 mode = "warn",
340 "hook error (continuing): {redacted}"
341 );
342 self.audit(
343 "hook_error_warn",
344 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
345 )
346 .await;
347 Ok(())
348 }
349 HookFailureMode::Ignore => {
350 self.audit(
351 "hook_error_ignored",
352 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
353 )
354 .await;
355 Ok(())
356 }
357 }
358 }
359 Err(_) => match mode {
360 HookFailureMode::Block => Err(AgentError::Hook {
361 stage: stage.to_string(),
362 source: anyhow::anyhow!(
363 "hook execution timed out after {} ms",
364 self.config.hooks.timeout_ms
365 ),
366 }),
367 HookFailureMode::Warn => {
368 warn!(
369 stage = stage,
370 tier = ?tier,
371 mode = "warn",
372 timeout_ms = self.config.hooks.timeout_ms,
373 "hook timeout (continuing)"
374 );
375 self.audit(
376 "hook_timeout_warn",
377 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
378 )
379 .await;
380 Ok(())
381 }
382 HookFailureMode::Ignore => {
383 self.audit(
384 "hook_timeout_ignored",
385 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
386 )
387 .await;
388 Ok(())
389 }
390 },
391 }
392 }
393
394 fn increment_counter(&self, name: &'static str) {
395 if let Some(metrics) = &self.metrics {
396 metrics.increment_counter(name, 1);
397 }
398 }
399
400 fn observe_histogram(&self, name: &'static str, value: f64) {
401 if let Some(metrics) = &self.metrics {
402 metrics.observe_histogram(name, value);
403 }
404 }
405
406 async fn execute_tool(
407 &self,
408 tool: &dyn Tool,
409 tool_name: &str,
410 tool_input: &str,
411 ctx: &ToolContext,
412 request_id: &str,
413 iteration: usize,
414 ) -> Result<crate::types::ToolResult, AgentError> {
415 let tool_specific = self
418 .config
419 .tool_boundaries
420 .get(tool_name)
421 .map(|s| s.as_str())
422 .unwrap_or("");
423 let resolved = resolve_boundary(tool_specific, &self.config.privacy_boundary);
424 if resolved == "local_only" && is_network_tool(tool_name) {
425 return Err(AgentError::Tool {
426 tool: tool_name.to_string(),
427 source: anyhow::anyhow!(
428 "tool '{}' requires network access but privacy boundary is 'local_only'",
429 tool_name
430 ),
431 });
432 }
433
434 let is_plugin_call = tool_name.starts_with("plugin:");
435 self.hook(
436 "before_tool_call",
437 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
438 )
439 .await?;
440 if is_plugin_call {
441 self.hook(
442 "before_plugin_call",
443 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name}),
444 )
445 .await?;
446 }
447 self.audit(
448 "tool_execute_start",
449 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
450 )
451 .await;
452 let tool_span = info_span!(
453 "tool_call",
454 tool = %tool_name,
455 request_id = %request_id,
456 iteration = iteration,
457 );
458 let _tool_guard = tool_span.enter();
459 let tool_started = Instant::now();
460 let result = match tool.execute(tool_input, ctx).await {
461 Ok(result) => result,
462 Err(source) => {
463 self.observe_histogram(
464 "tool_latency_ms",
465 tool_started.elapsed().as_secs_f64() * 1000.0,
466 );
467 self.increment_counter("tool_errors_total");
468 return Err(AgentError::Tool {
469 tool: tool_name.to_string(),
470 source,
471 });
472 }
473 };
474 self.observe_histogram(
475 "tool_latency_ms",
476 tool_started.elapsed().as_secs_f64() * 1000.0,
477 );
478 self.audit(
479 "tool_execute_success",
480 json!({
481 "request_id": request_id,
482 "iteration": iteration,
483 "tool_name": tool_name,
484 "tool_output_len": result.output.len(),
485 "duration_ms": tool_started.elapsed().as_millis(),
486 }),
487 )
488 .await;
489 info!(
490 request_id = %request_id,
491 stage = "tool",
492 tool_name = %tool_name,
493 duration_ms = %tool_started.elapsed().as_millis(),
494 "tool execution finished"
495 );
496 self.hook(
497 "after_tool_call",
498 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name, "status": "ok"}),
499 )
500 .await?;
501 if is_plugin_call {
502 self.hook(
503 "after_plugin_call",
504 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name, "status": "ok"}),
505 )
506 .await?;
507 }
508 Ok(result)
509 }
510
511 async fn call_provider_with_context(
512 &self,
513 prompt: &str,
514 request_id: &str,
515 stream_sink: Option<StreamSink>,
516 source_channel: Option<&str>,
517 ) -> Result<String, AgentError> {
518 let recent_memory = self
519 .memory
520 .recent_for_boundary(
521 self.config.memory_window_size,
522 &self.config.privacy_boundary,
523 source_channel,
524 )
525 .await
526 .map_err(|source| AgentError::Memory { source })?;
527 self.audit(
528 "memory_recent_loaded",
529 json!({"request_id": request_id, "items": recent_memory.len()}),
530 )
531 .await;
532 let (provider_prompt, prompt_truncated) =
533 build_provider_prompt(prompt, &recent_memory, self.config.max_prompt_chars);
534 if prompt_truncated {
535 self.audit(
536 "provider_prompt_truncated",
537 json!({
538 "request_id": request_id,
539 "max_prompt_chars": self.config.max_prompt_chars,
540 }),
541 )
542 .await;
543 }
544 self.hook(
545 "before_provider_call",
546 json!({
547 "request_id": request_id,
548 "prompt_len": provider_prompt.len(),
549 "memory_items": recent_memory.len(),
550 "prompt_truncated": prompt_truncated
551 }),
552 )
553 .await?;
554 self.audit(
555 "provider_call_start",
556 json!({
557 "request_id": request_id,
558 "prompt_len": provider_prompt.len(),
559 "memory_items": recent_memory.len(),
560 "prompt_truncated": prompt_truncated
561 }),
562 )
563 .await;
564 let provider_started = Instant::now();
565 let provider_result = if let Some(sink) = stream_sink {
566 self.provider
567 .complete_streaming(&provider_prompt, sink)
568 .await
569 } else {
570 self.provider
571 .complete_with_reasoning(&provider_prompt, &self.config.reasoning)
572 .await
573 };
574 let completion = match provider_result {
575 Ok(result) => result,
576 Err(source) => {
577 self.observe_histogram(
578 "provider_latency_ms",
579 provider_started.elapsed().as_secs_f64() * 1000.0,
580 );
581 self.increment_counter("provider_errors_total");
582 return Err(AgentError::Provider { source });
583 }
584 };
585 self.observe_histogram(
586 "provider_latency_ms",
587 provider_started.elapsed().as_secs_f64() * 1000.0,
588 );
589 self.audit(
590 "provider_call_success",
591 json!({
592 "request_id": request_id,
593 "response_len": completion.output_text.len(),
594 "duration_ms": provider_started.elapsed().as_millis(),
595 }),
596 )
597 .await;
598 info!(
599 request_id = %request_id,
600 stage = "provider",
601 duration_ms = %provider_started.elapsed().as_millis(),
602 "provider call finished"
603 );
604 self.hook(
605 "after_provider_call",
606 json!({"request_id": request_id, "response_len": completion.output_text.len(), "status": "ok"}),
607 )
608 .await?;
609 Ok(completion.output_text)
610 }
611
612 async fn write_to_memory(
613 &self,
614 role: &str,
615 content: &str,
616 request_id: &str,
617 source_channel: Option<&str>,
618 conversation_id: &str,
619 ) -> Result<(), AgentError> {
620 self.hook(
621 "before_memory_write",
622 json!({"request_id": request_id, "role": role}),
623 )
624 .await?;
625 self.memory
626 .append(MemoryEntry {
627 role: role.to_string(),
628 content: content.to_string(),
629 privacy_boundary: self.config.privacy_boundary.clone(),
630 source_channel: source_channel.map(String::from),
631 conversation_id: conversation_id.to_string(),
632 created_at: None,
633 expires_at: None,
634 })
635 .await
636 .map_err(|source| AgentError::Memory { source })?;
637 self.hook(
638 "after_memory_write",
639 json!({"request_id": request_id, "role": role}),
640 )
641 .await?;
642 self.audit(
643 &format!("memory_append_{role}"),
644 json!({"request_id": request_id}),
645 )
646 .await;
647 Ok(())
648 }
649
650 fn should_research(&self, user_text: &str) -> bool {
651 if !self.config.research.enabled {
652 return false;
653 }
654 match self.config.research.trigger {
655 ResearchTrigger::Never => false,
656 ResearchTrigger::Always => true,
657 ResearchTrigger::Keywords => {
658 let lower = user_text.to_lowercase();
659 self.config
660 .research
661 .keywords
662 .iter()
663 .any(|kw| lower.contains(&kw.to_lowercase()))
664 }
665 ResearchTrigger::Length => user_text.len() >= self.config.research.min_message_length,
666 ResearchTrigger::Question => user_text.trim_end().ends_with('?'),
667 }
668 }
669
670 async fn run_research_phase(
671 &self,
672 user_text: &str,
673 ctx: &ToolContext,
674 request_id: &str,
675 ) -> Result<String, AgentError> {
676 self.audit(
677 "research_phase_start",
678 json!({
679 "request_id": request_id,
680 "max_iterations": self.config.research.max_iterations,
681 }),
682 )
683 .await;
684 self.increment_counter("research_phase_started");
685
686 let research_prompt = format!(
687 "You are in RESEARCH mode. The user asked: \"{user_text}\"\n\
688 Gather relevant information using available tools. \
689 Respond with tool: calls to collect data. \
690 When done gathering, summarize your findings without a tool: prefix."
691 );
692 let mut prompt = self
693 .call_provider_with_context(
694 &research_prompt,
695 request_id,
696 None,
697 ctx.source_channel.as_deref(),
698 )
699 .await?;
700 let mut findings: Vec<String> = Vec::new();
701
702 for iteration in 0..self.config.research.max_iterations {
703 if !prompt.starts_with("tool:") {
704 findings.push(prompt.clone());
705 break;
706 }
707
708 let calls = parse_tool_calls(&prompt);
709 if calls.is_empty() {
710 break;
711 }
712
713 let (name, input) = calls[0];
714 if let Some(tool) = self.tools.iter().find(|t| t.name() == name) {
715 let result = self
716 .execute_tool(&**tool, name, input, ctx, request_id, iteration)
717 .await?;
718 findings.push(format!("{name}: {}", result.output));
719
720 if self.config.research.show_progress {
721 info!(iteration, tool = name, "research phase: tool executed");
722 self.audit(
723 "research_phase_iteration",
724 json!({
725 "request_id": request_id,
726 "iteration": iteration,
727 "tool_name": name,
728 }),
729 )
730 .await;
731 }
732
733 let next_prompt = format!(
734 "Research iteration {iteration}: tool `{name}` returned: {}\n\
735 Continue researching or summarize findings.",
736 result.output
737 );
738 prompt = self
739 .call_provider_with_context(
740 &next_prompt,
741 request_id,
742 None,
743 ctx.source_channel.as_deref(),
744 )
745 .await?;
746 } else {
747 break;
748 }
749 }
750
751 self.audit(
752 "research_phase_complete",
753 json!({
754 "request_id": request_id,
755 "findings_count": findings.len(),
756 }),
757 )
758 .await;
759 self.increment_counter("research_phase_completed");
760
761 Ok(findings.join("\n"))
762 }
763
764 async fn respond_with_tools(
767 &self,
768 request_id: &str,
769 user_text: &str,
770 research_context: &str,
771 ctx: &ToolContext,
772 stream_sink: Option<StreamSink>,
773 ) -> Result<AssistantMessage, AgentError> {
774 self.audit(
775 "structured_tool_use_start",
776 json!({
777 "request_id": request_id,
778 "max_tool_iterations": self.config.max_tool_iterations,
779 }),
780 )
781 .await;
782
783 let tool_definitions = self.build_tool_definitions();
784
785 let recent_memory = if let Some(ref cid) = ctx.conversation_id {
787 self.memory
788 .recent_for_conversation(cid, self.config.memory_window_size)
789 .await
790 .map_err(|source| AgentError::Memory { source })?
791 } else {
792 self.memory
793 .recent_for_boundary(
794 self.config.memory_window_size,
795 &self.config.privacy_boundary,
796 ctx.source_channel.as_deref(),
797 )
798 .await
799 .map_err(|source| AgentError::Memory { source })?
800 };
801 self.audit(
802 "memory_recent_loaded",
803 json!({"request_id": request_id, "items": recent_memory.len()}),
804 )
805 .await;
806
807 let mut messages: Vec<ConversationMessage> = Vec::new();
808
809 if let Some(ref sp) = self.config.system_prompt {
811 messages.push(ConversationMessage::System {
812 content: sp.clone(),
813 });
814 }
815
816 messages.extend(memory_to_messages(&recent_memory));
817
818 if !research_context.is_empty() {
819 messages.push(ConversationMessage::user(format!(
820 "Research findings:\n{research_context}",
821 )));
822 }
823
824 messages.push(ConversationMessage::user(user_text.to_string()));
825
826 let mut tool_history: Vec<(String, String, String)> = Vec::new();
827 let mut failure_streak: usize = 0;
828 let mut loop_detector = self
829 .loop_detection_config
830 .as_ref()
831 .map(|cfg| ToolLoopDetector::new(cfg.clone()));
832 let mut restricted_tools: Vec<String> = Vec::new();
833
834 for iteration in 0..self.config.max_tool_iterations {
835 if ctx.is_cancelled() {
837 warn!(request_id = %request_id, "agent execution cancelled");
838 self.audit(
839 "execution_cancelled",
840 json!({"request_id": request_id, "iteration": iteration}),
841 )
842 .await;
843 return Ok(AssistantMessage {
844 text: "[Execution cancelled]".to_string(),
845 });
846 }
847
848 truncate_messages(&mut messages, self.config.max_prompt_chars);
849
850 self.hook(
851 "before_provider_call",
852 json!({
853 "request_id": request_id,
854 "iteration": iteration,
855 "message_count": messages.len(),
856 "tool_count": tool_definitions.len(),
857 }),
858 )
859 .await?;
860 self.audit(
861 "provider_call_start",
862 json!({
863 "request_id": request_id,
864 "iteration": iteration,
865 "message_count": messages.len(),
866 }),
867 )
868 .await;
869
870 let effective_tools: Vec<ToolDefinition> = if restricted_tools.is_empty() {
872 tool_definitions.clone()
873 } else {
874 tool_definitions
875 .iter()
876 .filter(|td| !restricted_tools.contains(&td.name))
877 .cloned()
878 .collect()
879 };
880
881 let provider_span = info_span!(
882 "provider_call",
883 request_id = %request_id,
884 iteration = iteration,
885 tool_count = effective_tools.len(),
886 );
887 let _provider_guard = provider_span.enter();
888 let provider_started = Instant::now();
889 let provider_result = if let Some(ref sink) = stream_sink {
890 self.provider
891 .complete_streaming_with_tools(
892 &messages,
893 &effective_tools,
894 &self.config.reasoning,
895 sink.clone(),
896 )
897 .await
898 } else {
899 self.provider
900 .complete_with_tools(&messages, &effective_tools, &self.config.reasoning)
901 .await
902 };
903 let chat_result = match provider_result {
904 Ok(result) => result,
905 Err(source) => {
906 self.observe_histogram(
907 "provider_latency_ms",
908 provider_started.elapsed().as_secs_f64() * 1000.0,
909 );
910 self.increment_counter("provider_errors_total");
911 return Err(AgentError::Provider { source });
912 }
913 };
914 self.observe_histogram(
915 "provider_latency_ms",
916 provider_started.elapsed().as_secs_f64() * 1000.0,
917 );
918
919 let iter_tokens = chat_result.input_tokens + chat_result.output_tokens;
921 if iter_tokens > 0 {
922 ctx.add_tokens(iter_tokens);
923 }
924
925 if let Some(reason) = ctx.budget_exceeded() {
927 warn!(
928 request_id = %request_id,
929 iteration = iteration,
930 reason = %reason,
931 "budget exceeded — force-completing run"
932 );
933 return Err(AgentError::BudgetExceeded { reason });
934 }
935
936 info!(
937 request_id = %request_id,
938 iteration = iteration,
939 stop_reason = ?chat_result.stop_reason,
940 tool_calls = chat_result.tool_calls.len(),
941 tokens_this_call = iter_tokens,
942 total_tokens = ctx.current_tokens(),
943 "structured provider call finished"
944 );
945 self.hook(
946 "after_provider_call",
947 json!({
948 "request_id": request_id,
949 "iteration": iteration,
950 "response_len": chat_result.output_text.len(),
951 "tool_calls": chat_result.tool_calls.len(),
952 "status": "ok",
953 }),
954 )
955 .await?;
956
957 if chat_result.tool_calls.is_empty()
959 || chat_result.stop_reason == Some(StopReason::EndTurn)
960 {
961 let response_text = chat_result.output_text;
962 self.write_to_memory(
963 "assistant",
964 &response_text,
965 request_id,
966 ctx.source_channel.as_deref(),
967 ctx.conversation_id.as_deref().unwrap_or(""),
968 )
969 .await?;
970 self.audit("respond_success", json!({"request_id": request_id}))
971 .await;
972 self.hook(
973 "before_response_emit",
974 json!({"request_id": request_id, "response_len": response_text.len()}),
975 )
976 .await?;
977 let response = AssistantMessage {
978 text: response_text,
979 };
980 self.hook(
981 "after_response_emit",
982 json!({"request_id": request_id, "response_len": response.text.len()}),
983 )
984 .await?;
985 return Ok(response);
986 }
987
988 messages.push(ConversationMessage::Assistant {
990 content: if chat_result.output_text.is_empty() {
991 None
992 } else {
993 Some(chat_result.output_text.clone())
994 },
995 tool_calls: chat_result.tool_calls.clone(),
996 });
997
998 let tool_calls = &chat_result.tool_calls;
1000 let has_gated = tool_calls
1001 .iter()
1002 .any(|tc| self.config.gated_tools.contains(&tc.name));
1003 let use_parallel = self.config.parallel_tools && tool_calls.len() > 1 && !has_gated;
1004
1005 let mut tool_results: Vec<ToolResultMessage> = Vec::new();
1006
1007 if use_parallel {
1008 let futs: Vec<_> = tool_calls
1009 .iter()
1010 .map(|tc| {
1011 let tool = self.tools.iter().find(|t| t.name() == tc.name);
1012 async move {
1013 match tool {
1014 Some(tool) => {
1015 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1016 Ok(s) => s,
1017 Err(validation_err) => {
1018 return (
1019 tc.name.clone(),
1020 String::new(),
1021 ToolResultMessage {
1022 tool_use_id: tc.id.clone(),
1023 content: validation_err,
1024 is_error: true,
1025 },
1026 );
1027 }
1028 };
1029 match tool.execute(&input_str, ctx).await {
1030 Ok(result) => (
1031 tc.name.clone(),
1032 input_str,
1033 ToolResultMessage {
1034 tool_use_id: tc.id.clone(),
1035 content: result.output,
1036 is_error: false,
1037 },
1038 ),
1039 Err(e) => (
1040 tc.name.clone(),
1041 input_str,
1042 ToolResultMessage {
1043 tool_use_id: tc.id.clone(),
1044 content: format!("Error: {e}"),
1045 is_error: true,
1046 },
1047 ),
1048 }
1049 }
1050 None => (
1051 tc.name.clone(),
1052 String::new(),
1053 ToolResultMessage {
1054 tool_use_id: tc.id.clone(),
1055 content: format!("Tool '{}' not found", tc.name),
1056 is_error: true,
1057 },
1058 ),
1059 }
1060 }
1061 })
1062 .collect();
1063 let results = futures_util::future::join_all(futs).await;
1064 for (name, input_str, result_msg) in results {
1065 if result_msg.is_error {
1066 failure_streak += 1;
1067 } else {
1068 failure_streak = 0;
1069 tool_history.push((name, input_str, result_msg.content.clone()));
1070 }
1071 tool_results.push(result_msg);
1072 }
1073 } else {
1074 for tc in tool_calls {
1076 let result_msg = match self.tools.iter().find(|t| t.name() == tc.name) {
1077 Some(tool) => {
1078 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1079 Ok(s) => s,
1080 Err(validation_err) => {
1081 failure_streak += 1;
1082 tool_results.push(ToolResultMessage {
1083 tool_use_id: tc.id.clone(),
1084 content: validation_err,
1085 is_error: true,
1086 });
1087 continue;
1088 }
1089 };
1090 self.audit(
1091 "tool_requested",
1092 json!({
1093 "request_id": request_id,
1094 "iteration": iteration,
1095 "tool_name": tc.name,
1096 "tool_input_len": input_str.len(),
1097 }),
1098 )
1099 .await;
1100
1101 match self
1102 .execute_tool(
1103 &**tool, &tc.name, &input_str, ctx, request_id, iteration,
1104 )
1105 .await
1106 {
1107 Ok(result) => {
1108 failure_streak = 0;
1109 tool_history.push((
1110 tc.name.clone(),
1111 input_str,
1112 result.output.clone(),
1113 ));
1114 ToolResultMessage {
1115 tool_use_id: tc.id.clone(),
1116 content: result.output,
1117 is_error: false,
1118 }
1119 }
1120 Err(e) => {
1121 failure_streak += 1;
1122 ToolResultMessage {
1123 tool_use_id: tc.id.clone(),
1124 content: format!("Error: {e}"),
1125 is_error: true,
1126 }
1127 }
1128 }
1129 }
1130 None => {
1131 self.audit(
1132 "tool_not_found",
1133 json!({
1134 "request_id": request_id,
1135 "iteration": iteration,
1136 "tool_name": tc.name,
1137 }),
1138 )
1139 .await;
1140 failure_streak += 1;
1141 ToolResultMessage {
1142 tool_use_id: tc.id.clone(),
1143 content: format!(
1144 "Tool '{}' not found. Available tools: {}",
1145 tc.name,
1146 tool_definitions
1147 .iter()
1148 .map(|d| d.name.as_str())
1149 .collect::<Vec<_>>()
1150 .join(", ")
1151 ),
1152 is_error: true,
1153 }
1154 }
1155 };
1156 tool_results.push(result_msg);
1157 }
1158 }
1159
1160 for result in &tool_results {
1162 messages.push(ConversationMessage::ToolResult(result.clone()));
1163 }
1164
1165 if let Some(ref mut detector) = loop_detector {
1167 let tool_calls_this_iter = &chat_result.tool_calls;
1169 let mut worst_action = LoopAction::Continue;
1170 for tc in tool_calls_this_iter {
1171 let action = detector.check(
1172 &tc.name,
1173 &tc.input,
1174 ctx.current_tokens(),
1175 ctx.current_cost(),
1176 );
1177 if crate::loop_detection::severity(&action)
1178 > crate::loop_detection::severity(&worst_action)
1179 {
1180 worst_action = action;
1181 }
1182 }
1183
1184 match worst_action {
1185 LoopAction::Continue => {}
1186 LoopAction::InjectMessage(ref msg) => {
1187 warn!(
1188 request_id = %request_id,
1189 "tiered loop detection: injecting message"
1190 );
1191 self.audit(
1192 "loop_detection_inject",
1193 json!({
1194 "request_id": request_id,
1195 "iteration": iteration,
1196 "message": msg,
1197 }),
1198 )
1199 .await;
1200 messages.push(ConversationMessage::user(format!("SYSTEM NOTICE: {msg}")));
1201 }
1202 LoopAction::RestrictTools(ref tools) => {
1203 warn!(
1204 request_id = %request_id,
1205 tools = ?tools,
1206 "tiered loop detection: restricting tools"
1207 );
1208 self.audit(
1209 "loop_detection_restrict",
1210 json!({
1211 "request_id": request_id,
1212 "iteration": iteration,
1213 "restricted_tools": tools,
1214 }),
1215 )
1216 .await;
1217 restricted_tools.extend(tools.iter().cloned());
1218 messages.push(ConversationMessage::user(format!(
1219 "SYSTEM NOTICE: The following tools have been temporarily \
1220 restricted due to repetitive usage: {}. Try a different approach.",
1221 tools.join(", ")
1222 )));
1223 }
1224 LoopAction::ForceComplete(ref reason) => {
1225 warn!(
1226 request_id = %request_id,
1227 reason = %reason,
1228 "tiered loop detection: force completing"
1229 );
1230 self.audit(
1231 "loop_detection_force_complete",
1232 json!({
1233 "request_id": request_id,
1234 "iteration": iteration,
1235 "reason": reason,
1236 }),
1237 )
1238 .await;
1239 messages.push(ConversationMessage::user(format!(
1241 "SYSTEM NOTICE: {reason}. Provide your best answer now \
1242 without calling any more tools."
1243 )));
1244 truncate_messages(&mut messages, self.config.max_prompt_chars);
1245 let final_result = if let Some(ref sink) = stream_sink {
1246 self.provider
1247 .complete_streaming_with_tools(
1248 &messages,
1249 &[],
1250 &self.config.reasoning,
1251 sink.clone(),
1252 )
1253 .await
1254 .map_err(|source| AgentError::Provider { source })?
1255 } else {
1256 self.provider
1257 .complete_with_tools(&messages, &[], &self.config.reasoning)
1258 .await
1259 .map_err(|source| AgentError::Provider { source })?
1260 };
1261 let response_text = final_result.output_text;
1262 self.write_to_memory(
1263 "assistant",
1264 &response_text,
1265 request_id,
1266 ctx.source_channel.as_deref(),
1267 ctx.conversation_id.as_deref().unwrap_or(""),
1268 )
1269 .await?;
1270 return Ok(AssistantMessage {
1271 text: response_text,
1272 });
1273 }
1274 }
1275 }
1276
1277 if self.config.loop_detection_no_progress_threshold > 0 {
1279 let threshold = self.config.loop_detection_no_progress_threshold;
1280 if tool_history.len() >= threshold {
1281 let recent = &tool_history[tool_history.len() - threshold..];
1282 let all_same = recent.iter().all(|entry| {
1283 entry.0 == recent[0].0 && entry.1 == recent[0].1 && entry.2 == recent[0].2
1284 });
1285 if all_same {
1286 warn!(
1287 tool_name = recent[0].0,
1288 threshold, "structured loop detection: no-progress threshold reached"
1289 );
1290 self.audit(
1291 "loop_detection_no_progress",
1292 json!({
1293 "request_id": request_id,
1294 "tool_name": recent[0].0,
1295 "threshold": threshold,
1296 }),
1297 )
1298 .await;
1299 messages.push(ConversationMessage::user(format!(
1300 "SYSTEM NOTICE: Loop detected — the tool `{}` was called \
1301 {} times with identical arguments and output. You are stuck \
1302 in a loop. Stop calling this tool and try a different approach, \
1303 or provide your best answer with the information you already have.",
1304 recent[0].0, threshold
1305 )));
1306 }
1307 }
1308 }
1309
1310 if self.config.loop_detection_ping_pong_cycles > 0 {
1312 let cycles = self.config.loop_detection_ping_pong_cycles;
1313 let needed = cycles * 2;
1314 if tool_history.len() >= needed {
1315 let recent = &tool_history[tool_history.len() - needed..];
1316 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1317 && recent[0].0 != recent[1].0;
1318 if is_ping_pong {
1319 warn!(
1320 tool_a = recent[0].0,
1321 tool_b = recent[1].0,
1322 cycles,
1323 "structured loop detection: ping-pong pattern detected"
1324 );
1325 self.audit(
1326 "loop_detection_ping_pong",
1327 json!({
1328 "request_id": request_id,
1329 "tool_a": recent[0].0,
1330 "tool_b": recent[1].0,
1331 "cycles": cycles,
1332 }),
1333 )
1334 .await;
1335 messages.push(ConversationMessage::user(format!(
1336 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1337 alternating for {} cycles in a ping-pong pattern. Stop this \
1338 alternation and try a different approach, or provide your best \
1339 answer with the information you already have.",
1340 recent[0].0, recent[1].0, cycles
1341 )));
1342 }
1343 }
1344 }
1345
1346 if self.config.loop_detection_failure_streak > 0
1348 && failure_streak >= self.config.loop_detection_failure_streak
1349 {
1350 warn!(
1351 failure_streak,
1352 "structured loop detection: consecutive failure streak"
1353 );
1354 self.audit(
1355 "loop_detection_failure_streak",
1356 json!({
1357 "request_id": request_id,
1358 "streak": failure_streak,
1359 }),
1360 )
1361 .await;
1362 messages.push(ConversationMessage::user(format!(
1363 "SYSTEM NOTICE: {} consecutive tool calls have failed. \
1364 Stop calling tools and provide your best answer with \
1365 the information you already have.",
1366 failure_streak
1367 )));
1368 }
1369 }
1370
1371 self.audit(
1373 "structured_tool_use_max_iterations",
1374 json!({
1375 "request_id": request_id,
1376 "max_tool_iterations": self.config.max_tool_iterations,
1377 }),
1378 )
1379 .await;
1380
1381 messages.push(ConversationMessage::user(
1382 "You have reached the maximum number of tool call iterations. \
1383 Please provide your best answer now without calling any more tools."
1384 .to_string(),
1385 ));
1386 truncate_messages(&mut messages, self.config.max_prompt_chars);
1387
1388 let final_result = if let Some(ref sink) = stream_sink {
1389 self.provider
1390 .complete_streaming_with_tools(&messages, &[], &self.config.reasoning, sink.clone())
1391 .await
1392 .map_err(|source| AgentError::Provider { source })?
1393 } else {
1394 self.provider
1395 .complete_with_tools(&messages, &[], &self.config.reasoning)
1396 .await
1397 .map_err(|source| AgentError::Provider { source })?
1398 };
1399
1400 let response_text = final_result.output_text;
1401 self.write_to_memory(
1402 "assistant",
1403 &response_text,
1404 request_id,
1405 ctx.source_channel.as_deref(),
1406 ctx.conversation_id.as_deref().unwrap_or(""),
1407 )
1408 .await?;
1409 self.audit("respond_success", json!({"request_id": request_id}))
1410 .await;
1411 self.hook(
1412 "before_response_emit",
1413 json!({"request_id": request_id, "response_len": response_text.len()}),
1414 )
1415 .await?;
1416 let response = AssistantMessage {
1417 text: response_text,
1418 };
1419 self.hook(
1420 "after_response_emit",
1421 json!({"request_id": request_id, "response_len": response.text.len()}),
1422 )
1423 .await?;
1424 Ok(response)
1425 }
1426
1427 pub async fn respond(
1428 &self,
1429 user: UserMessage,
1430 ctx: &ToolContext,
1431 ) -> Result<AssistantMessage, AgentError> {
1432 let request_id = Self::next_request_id();
1433 let span = info_span!(
1434 "agent_run",
1435 request_id = %request_id,
1436 depth = ctx.depth,
1437 conversation_id = ctx.conversation_id.as_deref().unwrap_or(""),
1438 );
1439 self.respond_traced(&request_id, user, ctx, None)
1440 .instrument(span)
1441 .await
1442 }
1443
1444 pub async fn respond_streaming(
1448 &self,
1449 user: UserMessage,
1450 ctx: &ToolContext,
1451 sink: StreamSink,
1452 ) -> Result<AssistantMessage, AgentError> {
1453 let request_id = Self::next_request_id();
1454 let span = info_span!(
1455 "agent_run",
1456 request_id = %request_id,
1457 depth = ctx.depth,
1458 conversation_id = ctx.conversation_id.as_deref().unwrap_or(""),
1459 streaming = true,
1460 );
1461 self.respond_traced(&request_id, user, ctx, Some(sink))
1462 .instrument(span)
1463 .await
1464 }
1465
1466 async fn respond_traced(
1468 &self,
1469 request_id: &str,
1470 user: UserMessage,
1471 ctx: &ToolContext,
1472 stream_sink: Option<StreamSink>,
1473 ) -> Result<AssistantMessage, AgentError> {
1474 self.increment_counter("requests_total");
1475 let run_started = Instant::now();
1476 self.hook("before_run", json!({"request_id": request_id}))
1477 .await?;
1478 let timed = timeout(
1479 Duration::from_millis(self.config.request_timeout_ms),
1480 self.respond_inner(request_id, user, ctx, stream_sink),
1481 )
1482 .await;
1483 let result = match timed {
1484 Ok(result) => result,
1485 Err(_) => Err(AgentError::Timeout {
1486 timeout_ms: self.config.request_timeout_ms,
1487 }),
1488 };
1489
1490 let after_detail = match &result {
1491 Ok(response) => json!({
1492 "request_id": request_id,
1493 "status": "ok",
1494 "response_len": response.text.len(),
1495 "duration_ms": run_started.elapsed().as_millis(),
1496 }),
1497 Err(err) => json!({
1498 "request_id": request_id,
1499 "status": "error",
1500 "error": redact_text(&err.to_string()),
1501 "duration_ms": run_started.elapsed().as_millis(),
1502 }),
1503 };
1504 info!(
1505 request_id = %request_id,
1506 duration_ms = %run_started.elapsed().as_millis(),
1507 "agent run completed"
1508 );
1509 self.hook("after_run", after_detail).await?;
1510 result
1511 }
1512
1513 async fn respond_inner(
1514 &self,
1515 request_id: &str,
1516 user: UserMessage,
1517 ctx: &ToolContext,
1518 stream_sink: Option<StreamSink>,
1519 ) -> Result<AssistantMessage, AgentError> {
1520 self.audit(
1521 "respond_start",
1522 json!({
1523 "request_id": request_id,
1524 "user_message_len": user.text.len(),
1525 "max_tool_iterations": self.config.max_tool_iterations,
1526 "request_timeout_ms": self.config.request_timeout_ms,
1527 }),
1528 )
1529 .await;
1530 self.write_to_memory(
1531 "user",
1532 &user.text,
1533 request_id,
1534 ctx.source_channel.as_deref(),
1535 ctx.conversation_id.as_deref().unwrap_or(""),
1536 )
1537 .await?;
1538
1539 let research_context = if self.should_research(&user.text) {
1540 self.run_research_phase(&user.text, ctx, request_id).await?
1541 } else {
1542 String::new()
1543 };
1544
1545 if self.config.model_supports_tool_use && self.has_tool_definitions() {
1547 return self
1548 .respond_with_tools(request_id, &user.text, &research_context, ctx, stream_sink)
1549 .await;
1550 }
1551
1552 let mut prompt = user.text;
1553 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1554
1555 for iteration in 0..self.config.max_tool_iterations {
1556 if !prompt.starts_with("tool:") {
1557 break;
1558 }
1559
1560 let calls = parse_tool_calls(&prompt);
1561 if calls.is_empty() {
1562 break;
1563 }
1564
1565 let has_gated = calls
1569 .iter()
1570 .any(|(name, _)| self.config.gated_tools.contains(*name));
1571 if calls.len() > 1 && self.config.parallel_tools && !has_gated {
1572 let mut resolved: Vec<(&str, &str, &dyn Tool)> = Vec::new();
1573 for &(name, input) in &calls {
1574 let tool = self.tools.iter().find(|t| t.name() == name);
1575 match tool {
1576 Some(t) => resolved.push((name, input, &**t)),
1577 None => {
1578 self.audit(
1579 "tool_not_found",
1580 json!({"request_id": request_id, "iteration": iteration, "tool_name": name}),
1581 )
1582 .await;
1583 }
1584 }
1585 }
1586 if resolved.is_empty() {
1587 break;
1588 }
1589
1590 let futs: Vec<_> = resolved
1591 .iter()
1592 .map(|&(name, input, tool)| async move {
1593 let r = tool.execute(input, ctx).await;
1594 (name, input, r)
1595 })
1596 .collect();
1597 let results = futures_util::future::join_all(futs).await;
1598
1599 let mut output_parts: Vec<String> = Vec::new();
1600 for (name, input, result) in results {
1601 match result {
1602 Ok(r) => {
1603 tool_history.push((
1604 name.to_string(),
1605 input.to_string(),
1606 r.output.clone(),
1607 ));
1608 output_parts.push(format!("Tool output from {name}: {}", r.output));
1609 }
1610 Err(source) => {
1611 return Err(AgentError::Tool {
1612 tool: name.to_string(),
1613 source,
1614 });
1615 }
1616 }
1617 }
1618 prompt = output_parts.join("\n");
1619 continue;
1620 }
1621
1622 let (tool_name, tool_input) = calls[0];
1624 self.audit(
1625 "tool_requested",
1626 json!({
1627 "request_id": request_id,
1628 "iteration": iteration,
1629 "tool_name": tool_name,
1630 "tool_input_len": tool_input.len(),
1631 }),
1632 )
1633 .await;
1634
1635 if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) {
1636 let result = self
1637 .execute_tool(&**tool, tool_name, tool_input, ctx, request_id, iteration)
1638 .await?;
1639
1640 tool_history.push((
1641 tool_name.to_string(),
1642 tool_input.to_string(),
1643 result.output.clone(),
1644 ));
1645
1646 if self.config.loop_detection_no_progress_threshold > 0 {
1648 let threshold = self.config.loop_detection_no_progress_threshold;
1649 if tool_history.len() >= threshold {
1650 let recent = &tool_history[tool_history.len() - threshold..];
1651 let all_same = recent.iter().all(|entry| {
1652 entry.0 == recent[0].0
1653 && entry.1 == recent[0].1
1654 && entry.2 == recent[0].2
1655 });
1656 if all_same {
1657 warn!(
1658 tool_name,
1659 threshold, "loop detection: no-progress threshold reached"
1660 );
1661 self.audit(
1662 "loop_detection_no_progress",
1663 json!({
1664 "request_id": request_id,
1665 "tool_name": tool_name,
1666 "threshold": threshold,
1667 }),
1668 )
1669 .await;
1670 prompt = format!(
1671 "SYSTEM NOTICE: Loop detected — the tool `{tool_name}` was called \
1672 {threshold} times with identical arguments and output. You are stuck \
1673 in a loop. Stop calling this tool and try a different approach, or \
1674 provide your best answer with the information you already have."
1675 );
1676 break;
1677 }
1678 }
1679 }
1680
1681 if self.config.loop_detection_ping_pong_cycles > 0 {
1683 let cycles = self.config.loop_detection_ping_pong_cycles;
1684 let needed = cycles * 2;
1685 if tool_history.len() >= needed {
1686 let recent = &tool_history[tool_history.len() - needed..];
1687 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1688 && recent[0].0 != recent[1].0;
1689 if is_ping_pong {
1690 warn!(
1691 tool_a = recent[0].0,
1692 tool_b = recent[1].0,
1693 cycles,
1694 "loop detection: ping-pong pattern detected"
1695 );
1696 self.audit(
1697 "loop_detection_ping_pong",
1698 json!({
1699 "request_id": request_id,
1700 "tool_a": recent[0].0,
1701 "tool_b": recent[1].0,
1702 "cycles": cycles,
1703 }),
1704 )
1705 .await;
1706 prompt = format!(
1707 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1708 alternating for {} cycles in a ping-pong pattern. Stop this \
1709 alternation and try a different approach, or provide your best \
1710 answer with the information you already have.",
1711 recent[0].0, recent[1].0, cycles
1712 );
1713 break;
1714 }
1715 }
1716 }
1717
1718 if result.output.starts_with("tool:") {
1719 prompt = result.output.clone();
1720 } else {
1721 prompt = format!("Tool output from {tool_name}: {}", result.output);
1722 }
1723 continue;
1724 }
1725
1726 self.audit(
1727 "tool_not_found",
1728 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
1729 )
1730 .await;
1731 break;
1732 }
1733
1734 if !research_context.is_empty() {
1735 prompt = format!("Research findings:\n{research_context}\n\nUser request:\n{prompt}");
1736 }
1737
1738 let response_text = self
1739 .call_provider_with_context(
1740 &prompt,
1741 request_id,
1742 stream_sink,
1743 ctx.source_channel.as_deref(),
1744 )
1745 .await?;
1746 self.write_to_memory(
1747 "assistant",
1748 &response_text,
1749 request_id,
1750 ctx.source_channel.as_deref(),
1751 ctx.conversation_id.as_deref().unwrap_or(""),
1752 )
1753 .await?;
1754 self.audit("respond_success", json!({"request_id": request_id}))
1755 .await;
1756
1757 self.hook(
1758 "before_response_emit",
1759 json!({"request_id": request_id, "response_len": response_text.len()}),
1760 )
1761 .await?;
1762 let response = AssistantMessage {
1763 text: response_text,
1764 };
1765 self.hook(
1766 "after_response_emit",
1767 json!({"request_id": request_id, "response_len": response.text.len()}),
1768 )
1769 .await?;
1770
1771 Ok(response)
1772 }
1773}
1774
1775#[cfg(test)]
1776mod tests {
1777 use super::*;
1778 use crate::types::{
1779 ChatResult, HookEvent, HookFailureMode, HookPolicy, ReasoningConfig, ResearchPolicy,
1780 ResearchTrigger, ToolResult,
1781 };
1782 use async_trait::async_trait;
1783 use serde_json::Value;
1784 use std::collections::HashMap;
1785 use std::sync::atomic::{AtomicUsize, Ordering};
1786 use std::sync::{Arc, Mutex};
1787 use tokio::time::{sleep, Duration};
1788
1789 #[derive(Default)]
1790 struct TestMemory {
1791 entries: Arc<Mutex<Vec<MemoryEntry>>>,
1792 }
1793
1794 #[async_trait]
1795 impl MemoryStore for TestMemory {
1796 async fn append(&self, entry: MemoryEntry) -> anyhow::Result<()> {
1797 self.entries
1798 .lock()
1799 .expect("memory lock poisoned")
1800 .push(entry);
1801 Ok(())
1802 }
1803
1804 async fn recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
1805 let entries = self.entries.lock().expect("memory lock poisoned");
1806 Ok(entries.iter().rev().take(limit).cloned().collect())
1807 }
1808 }
1809
1810 struct TestProvider {
1811 received_prompts: Arc<Mutex<Vec<String>>>,
1812 response_text: String,
1813 }
1814
1815 #[async_trait]
1816 impl Provider for TestProvider {
1817 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
1818 self.received_prompts
1819 .lock()
1820 .expect("provider lock poisoned")
1821 .push(prompt.to_string());
1822 Ok(ChatResult {
1823 output_text: self.response_text.clone(),
1824 ..Default::default()
1825 })
1826 }
1827 }
1828
1829 struct EchoTool;
1830
1831 #[async_trait]
1832 impl Tool for EchoTool {
1833 fn name(&self) -> &'static str {
1834 "echo"
1835 }
1836
1837 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1838 Ok(ToolResult {
1839 output: format!("echoed:{input}"),
1840 })
1841 }
1842 }
1843
1844 struct PluginEchoTool;
1845
1846 #[async_trait]
1847 impl Tool for PluginEchoTool {
1848 fn name(&self) -> &'static str {
1849 "plugin:echo"
1850 }
1851
1852 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1853 Ok(ToolResult {
1854 output: format!("plugin-echoed:{input}"),
1855 })
1856 }
1857 }
1858
1859 struct FailingTool;
1860
1861 #[async_trait]
1862 impl Tool for FailingTool {
1863 fn name(&self) -> &'static str {
1864 "boom"
1865 }
1866
1867 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1868 Err(anyhow::anyhow!("tool exploded"))
1869 }
1870 }
1871
1872 struct FailingProvider;
1873
1874 #[async_trait]
1875 impl Provider for FailingProvider {
1876 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
1877 Err(anyhow::anyhow!("provider boom"))
1878 }
1879 }
1880
1881 struct SlowProvider;
1882
1883 #[async_trait]
1884 impl Provider for SlowProvider {
1885 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
1886 sleep(Duration::from_millis(500)).await;
1887 Ok(ChatResult {
1888 output_text: "late".to_string(),
1889 ..Default::default()
1890 })
1891 }
1892 }
1893
1894 struct ScriptedProvider {
1895 responses: Vec<String>,
1896 call_count: AtomicUsize,
1897 received_prompts: Arc<Mutex<Vec<String>>>,
1898 }
1899
1900 impl ScriptedProvider {
1901 fn new(responses: Vec<&str>) -> Self {
1902 Self {
1903 responses: responses.into_iter().map(|s| s.to_string()).collect(),
1904 call_count: AtomicUsize::new(0),
1905 received_prompts: Arc::new(Mutex::new(Vec::new())),
1906 }
1907 }
1908 }
1909
1910 #[async_trait]
1911 impl Provider for ScriptedProvider {
1912 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
1913 self.received_prompts
1914 .lock()
1915 .expect("provider lock poisoned")
1916 .push(prompt.to_string());
1917 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
1918 let response = if idx < self.responses.len() {
1919 self.responses[idx].clone()
1920 } else {
1921 self.responses.last().cloned().unwrap_or_default()
1922 };
1923 Ok(ChatResult {
1924 output_text: response,
1925 ..Default::default()
1926 })
1927 }
1928 }
1929
1930 struct UpperTool;
1931
1932 #[async_trait]
1933 impl Tool for UpperTool {
1934 fn name(&self) -> &'static str {
1935 "upper"
1936 }
1937
1938 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1939 Ok(ToolResult {
1940 output: input.to_uppercase(),
1941 })
1942 }
1943 }
1944
1945 struct LoopTool;
1948
1949 #[async_trait]
1950 impl Tool for LoopTool {
1951 fn name(&self) -> &'static str {
1952 "loop_tool"
1953 }
1954
1955 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1956 Ok(ToolResult {
1957 output: "tool:loop_tool x".to_string(),
1958 })
1959 }
1960 }
1961
1962 struct PingTool;
1964
1965 #[async_trait]
1966 impl Tool for PingTool {
1967 fn name(&self) -> &'static str {
1968 "ping"
1969 }
1970
1971 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1972 Ok(ToolResult {
1973 output: "tool:pong x".to_string(),
1974 })
1975 }
1976 }
1977
1978 struct PongTool;
1980
1981 #[async_trait]
1982 impl Tool for PongTool {
1983 fn name(&self) -> &'static str {
1984 "pong"
1985 }
1986
1987 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1988 Ok(ToolResult {
1989 output: "tool:ping x".to_string(),
1990 })
1991 }
1992 }
1993
1994 struct RecordingHookSink {
1995 events: Arc<Mutex<Vec<String>>>,
1996 }
1997
1998 #[async_trait]
1999 impl HookSink for RecordingHookSink {
2000 async fn record(&self, event: HookEvent) -> anyhow::Result<()> {
2001 self.events
2002 .lock()
2003 .expect("hook lock poisoned")
2004 .push(event.stage);
2005 Ok(())
2006 }
2007 }
2008
2009 struct FailingHookSink;
2010
2011 #[async_trait]
2012 impl HookSink for FailingHookSink {
2013 async fn record(&self, _event: HookEvent) -> anyhow::Result<()> {
2014 Err(anyhow::anyhow!("hook sink failure"))
2015 }
2016 }
2017
2018 struct RecordingAuditSink {
2019 events: Arc<Mutex<Vec<AuditEvent>>>,
2020 }
2021
2022 #[async_trait]
2023 impl AuditSink for RecordingAuditSink {
2024 async fn record(&self, event: AuditEvent) -> anyhow::Result<()> {
2025 self.events.lock().expect("audit lock poisoned").push(event);
2026 Ok(())
2027 }
2028 }
2029
2030 struct RecordingMetricsSink {
2031 counters: Arc<Mutex<HashMap<&'static str, u64>>>,
2032 histograms: Arc<Mutex<HashMap<&'static str, usize>>>,
2033 }
2034
2035 impl MetricsSink for RecordingMetricsSink {
2036 fn increment_counter(&self, name: &'static str, value: u64) {
2037 let mut counters = self.counters.lock().expect("metrics lock poisoned");
2038 *counters.entry(name).or_insert(0) += value;
2039 }
2040
2041 fn observe_histogram(&self, name: &'static str, _value: f64) {
2042 let mut histograms = self.histograms.lock().expect("metrics lock poisoned");
2043 *histograms.entry(name).or_insert(0) += 1;
2044 }
2045 }
2046
2047 fn test_ctx() -> ToolContext {
2048 ToolContext::new(".".to_string())
2049 }
2050
2051 fn counter(counters: &Arc<Mutex<HashMap<&'static str, u64>>>, name: &'static str) -> u64 {
2052 counters
2053 .lock()
2054 .expect("metrics lock poisoned")
2055 .get(name)
2056 .copied()
2057 .unwrap_or(0)
2058 }
2059
2060 fn histogram_count(
2061 histograms: &Arc<Mutex<HashMap<&'static str, usize>>>,
2062 name: &'static str,
2063 ) -> usize {
2064 histograms
2065 .lock()
2066 .expect("metrics lock poisoned")
2067 .get(name)
2068 .copied()
2069 .unwrap_or(0)
2070 }
2071
2072 #[tokio::test]
2073 async fn respond_appends_user_then_assistant_memory_entries() {
2074 let entries = Arc::new(Mutex::new(Vec::new()));
2075 let prompts = Arc::new(Mutex::new(Vec::new()));
2076 let memory = TestMemory {
2077 entries: entries.clone(),
2078 };
2079 let provider = TestProvider {
2080 received_prompts: prompts,
2081 response_text: "assistant-output".to_string(),
2082 };
2083 let agent = Agent::new(
2084 AgentConfig::default(),
2085 Box::new(provider),
2086 Box::new(memory),
2087 vec![],
2088 );
2089
2090 let response = agent
2091 .respond(
2092 UserMessage {
2093 text: "hello world".to_string(),
2094 },
2095 &test_ctx(),
2096 )
2097 .await
2098 .expect("agent respond should succeed");
2099
2100 assert_eq!(response.text, "assistant-output");
2101 let stored = entries.lock().expect("memory lock poisoned");
2102 assert_eq!(stored.len(), 2);
2103 assert_eq!(stored[0].role, "user");
2104 assert_eq!(stored[0].content, "hello world");
2105 assert_eq!(stored[1].role, "assistant");
2106 assert_eq!(stored[1].content, "assistant-output");
2107 }
2108
2109 #[tokio::test]
2110 async fn respond_invokes_tool_for_tool_prefixed_prompt() {
2111 let entries = Arc::new(Mutex::new(Vec::new()));
2112 let prompts = Arc::new(Mutex::new(Vec::new()));
2113 let memory = TestMemory {
2114 entries: entries.clone(),
2115 };
2116 let provider = TestProvider {
2117 received_prompts: prompts.clone(),
2118 response_text: "assistant-after-tool".to_string(),
2119 };
2120 let agent = Agent::new(
2121 AgentConfig::default(),
2122 Box::new(provider),
2123 Box::new(memory),
2124 vec![Box::new(EchoTool)],
2125 );
2126
2127 let response = agent
2128 .respond(
2129 UserMessage {
2130 text: "tool:echo ping".to_string(),
2131 },
2132 &test_ctx(),
2133 )
2134 .await
2135 .expect("agent respond should succeed");
2136
2137 assert_eq!(response.text, "assistant-after-tool");
2138 let prompts = prompts.lock().expect("provider lock poisoned");
2139 assert_eq!(prompts.len(), 1);
2140 assert!(prompts[0].contains("Recent memory:"));
2141 assert!(prompts[0].contains("Current input:\nTool output from echo: echoed:ping"));
2142 }
2143
2144 #[tokio::test]
2145 async fn respond_with_unknown_tool_falls_back_to_provider_prompt() {
2146 let entries = Arc::new(Mutex::new(Vec::new()));
2147 let prompts = Arc::new(Mutex::new(Vec::new()));
2148 let memory = TestMemory {
2149 entries: entries.clone(),
2150 };
2151 let provider = TestProvider {
2152 received_prompts: prompts.clone(),
2153 response_text: "assistant-without-tool".to_string(),
2154 };
2155 let agent = Agent::new(
2156 AgentConfig::default(),
2157 Box::new(provider),
2158 Box::new(memory),
2159 vec![Box::new(EchoTool)],
2160 );
2161
2162 let response = agent
2163 .respond(
2164 UserMessage {
2165 text: "tool:unknown payload".to_string(),
2166 },
2167 &test_ctx(),
2168 )
2169 .await
2170 .expect("agent respond should succeed");
2171
2172 assert_eq!(response.text, "assistant-without-tool");
2173 let prompts = prompts.lock().expect("provider lock poisoned");
2174 assert_eq!(prompts.len(), 1);
2175 assert!(prompts[0].contains("Recent memory:"));
2176 assert!(prompts[0].contains("Current input:\ntool:unknown payload"));
2177 }
2178
2179 #[tokio::test]
2180 async fn respond_includes_bounded_recent_memory_in_provider_prompt() {
2181 let entries = Arc::new(Mutex::new(vec![
2182 MemoryEntry {
2183 role: "assistant".to_string(),
2184 content: "very-old".to_string(),
2185 ..Default::default()
2186 },
2187 MemoryEntry {
2188 role: "user".to_string(),
2189 content: "recent-before-request".to_string(),
2190 ..Default::default()
2191 },
2192 ]));
2193 let prompts = Arc::new(Mutex::new(Vec::new()));
2194 let memory = TestMemory {
2195 entries: entries.clone(),
2196 };
2197 let provider = TestProvider {
2198 received_prompts: prompts.clone(),
2199 response_text: "ok".to_string(),
2200 };
2201 let agent = Agent::new(
2202 AgentConfig {
2203 memory_window_size: 2,
2204 ..AgentConfig::default()
2205 },
2206 Box::new(provider),
2207 Box::new(memory),
2208 vec![],
2209 );
2210
2211 agent
2212 .respond(
2213 UserMessage {
2214 text: "latest-user".to_string(),
2215 },
2216 &test_ctx(),
2217 )
2218 .await
2219 .expect("respond should succeed");
2220
2221 let prompts = prompts.lock().expect("provider lock poisoned");
2222 let provider_prompt = prompts.first().expect("provider prompt should exist");
2223 assert!(provider_prompt.contains("- user: recent-before-request"));
2224 assert!(provider_prompt.contains("- user: latest-user"));
2225 assert!(!provider_prompt.contains("very-old"));
2226 }
2227
2228 #[tokio::test]
2229 async fn respond_caps_provider_prompt_to_configured_max_chars() {
2230 let entries = Arc::new(Mutex::new(vec![MemoryEntry {
2231 role: "assistant".to_string(),
2232 content: "historic context ".repeat(16),
2233 ..Default::default()
2234 }]));
2235 let prompts = Arc::new(Mutex::new(Vec::new()));
2236 let memory = TestMemory {
2237 entries: entries.clone(),
2238 };
2239 let provider = TestProvider {
2240 received_prompts: prompts.clone(),
2241 response_text: "ok".to_string(),
2242 };
2243 let agent = Agent::new(
2244 AgentConfig {
2245 max_prompt_chars: 64,
2246 ..AgentConfig::default()
2247 },
2248 Box::new(provider),
2249 Box::new(memory),
2250 vec![],
2251 );
2252
2253 agent
2254 .respond(
2255 UserMessage {
2256 text: "final-tail-marker".to_string(),
2257 },
2258 &test_ctx(),
2259 )
2260 .await
2261 .expect("respond should succeed");
2262
2263 let prompts = prompts.lock().expect("provider lock poisoned");
2264 let provider_prompt = prompts.first().expect("provider prompt should exist");
2265 assert!(provider_prompt.chars().count() <= 64);
2266 assert!(provider_prompt.contains("final-tail-marker"));
2267 }
2268
2269 #[tokio::test]
2270 async fn respond_returns_provider_typed_error() {
2271 let memory = TestMemory::default();
2272 let counters = Arc::new(Mutex::new(HashMap::new()));
2273 let histograms = Arc::new(Mutex::new(HashMap::new()));
2274 let agent = Agent::new(
2275 AgentConfig::default(),
2276 Box::new(FailingProvider),
2277 Box::new(memory),
2278 vec![],
2279 )
2280 .with_metrics(Box::new(RecordingMetricsSink {
2281 counters: counters.clone(),
2282 histograms: histograms.clone(),
2283 }));
2284
2285 let result = agent
2286 .respond(
2287 UserMessage {
2288 text: "hello".to_string(),
2289 },
2290 &test_ctx(),
2291 )
2292 .await;
2293
2294 match result {
2295 Err(AgentError::Provider { source }) => {
2296 assert!(source.to_string().contains("provider boom"));
2297 }
2298 other => panic!("expected provider error, got {other:?}"),
2299 }
2300 assert_eq!(counter(&counters, "requests_total"), 1);
2301 assert_eq!(counter(&counters, "provider_errors_total"), 1);
2302 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2303 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2304 }
2305
2306 #[tokio::test]
2307 async fn respond_increments_tool_error_counter_on_tool_failure() {
2308 let memory = TestMemory::default();
2309 let prompts = Arc::new(Mutex::new(Vec::new()));
2310 let provider = TestProvider {
2311 received_prompts: prompts,
2312 response_text: "unused".to_string(),
2313 };
2314 let counters = Arc::new(Mutex::new(HashMap::new()));
2315 let histograms = Arc::new(Mutex::new(HashMap::new()));
2316 let agent = Agent::new(
2317 AgentConfig::default(),
2318 Box::new(provider),
2319 Box::new(memory),
2320 vec![Box::new(FailingTool)],
2321 )
2322 .with_metrics(Box::new(RecordingMetricsSink {
2323 counters: counters.clone(),
2324 histograms: histograms.clone(),
2325 }));
2326
2327 let result = agent
2328 .respond(
2329 UserMessage {
2330 text: "tool:boom ping".to_string(),
2331 },
2332 &test_ctx(),
2333 )
2334 .await;
2335
2336 match result {
2337 Err(AgentError::Tool { tool, source }) => {
2338 assert_eq!(tool, "boom");
2339 assert!(source.to_string().contains("tool exploded"));
2340 }
2341 other => panic!("expected tool error, got {other:?}"),
2342 }
2343 assert_eq!(counter(&counters, "requests_total"), 1);
2344 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2345 assert_eq!(counter(&counters, "tool_errors_total"), 1);
2346 assert_eq!(histogram_count(&histograms, "tool_latency_ms"), 1);
2347 }
2348
2349 #[tokio::test]
2350 async fn respond_increments_requests_counter_on_success() {
2351 let memory = TestMemory::default();
2352 let prompts = Arc::new(Mutex::new(Vec::new()));
2353 let provider = TestProvider {
2354 received_prompts: prompts,
2355 response_text: "ok".to_string(),
2356 };
2357 let counters = Arc::new(Mutex::new(HashMap::new()));
2358 let histograms = Arc::new(Mutex::new(HashMap::new()));
2359 let agent = Agent::new(
2360 AgentConfig::default(),
2361 Box::new(provider),
2362 Box::new(memory),
2363 vec![],
2364 )
2365 .with_metrics(Box::new(RecordingMetricsSink {
2366 counters: counters.clone(),
2367 histograms: histograms.clone(),
2368 }));
2369
2370 let response = agent
2371 .respond(
2372 UserMessage {
2373 text: "hello".to_string(),
2374 },
2375 &test_ctx(),
2376 )
2377 .await
2378 .expect("response should succeed");
2379
2380 assert_eq!(response.text, "ok");
2381 assert_eq!(counter(&counters, "requests_total"), 1);
2382 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2383 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2384 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2385 }
2386
2387 #[tokio::test]
2388 async fn respond_returns_timeout_typed_error() {
2389 let memory = TestMemory::default();
2390 let agent = Agent::new(
2391 AgentConfig {
2392 max_tool_iterations: 1,
2393 request_timeout_ms: 10,
2394 ..AgentConfig::default()
2395 },
2396 Box::new(SlowProvider),
2397 Box::new(memory),
2398 vec![],
2399 );
2400
2401 let result = agent
2402 .respond(
2403 UserMessage {
2404 text: "hello".to_string(),
2405 },
2406 &test_ctx(),
2407 )
2408 .await;
2409
2410 match result {
2411 Err(AgentError::Timeout { timeout_ms }) => assert_eq!(timeout_ms, 10),
2412 other => panic!("expected timeout error, got {other:?}"),
2413 }
2414 }
2415
2416 #[tokio::test]
2417 async fn respond_emits_before_after_hook_events_when_enabled() {
2418 let events = Arc::new(Mutex::new(Vec::new()));
2419 let memory = TestMemory::default();
2420 let provider = TestProvider {
2421 received_prompts: Arc::new(Mutex::new(Vec::new())),
2422 response_text: "ok".to_string(),
2423 };
2424 let agent = Agent::new(
2425 AgentConfig {
2426 hooks: HookPolicy {
2427 enabled: true,
2428 timeout_ms: 50,
2429 fail_closed: true,
2430 ..HookPolicy::default()
2431 },
2432 ..AgentConfig::default()
2433 },
2434 Box::new(provider),
2435 Box::new(memory),
2436 vec![Box::new(EchoTool), Box::new(PluginEchoTool)],
2437 )
2438 .with_hooks(Box::new(RecordingHookSink {
2439 events: events.clone(),
2440 }));
2441
2442 agent
2443 .respond(
2444 UserMessage {
2445 text: "tool:plugin:echo ping".to_string(),
2446 },
2447 &test_ctx(),
2448 )
2449 .await
2450 .expect("respond should succeed");
2451
2452 let stages = events.lock().expect("hook lock poisoned");
2453 assert!(stages.contains(&"before_run".to_string()));
2454 assert!(stages.contains(&"after_run".to_string()));
2455 assert!(stages.contains(&"before_tool_call".to_string()));
2456 assert!(stages.contains(&"after_tool_call".to_string()));
2457 assert!(stages.contains(&"before_plugin_call".to_string()));
2458 assert!(stages.contains(&"after_plugin_call".to_string()));
2459 assert!(stages.contains(&"before_provider_call".to_string()));
2460 assert!(stages.contains(&"after_provider_call".to_string()));
2461 assert!(stages.contains(&"before_memory_write".to_string()));
2462 assert!(stages.contains(&"after_memory_write".to_string()));
2463 assert!(stages.contains(&"before_response_emit".to_string()));
2464 assert!(stages.contains(&"after_response_emit".to_string()));
2465 }
2466
2467 #[tokio::test]
2468 async fn respond_audit_events_include_request_id_and_durations() {
2469 let audit_events = Arc::new(Mutex::new(Vec::new()));
2470 let memory = TestMemory::default();
2471 let provider = TestProvider {
2472 received_prompts: Arc::new(Mutex::new(Vec::new())),
2473 response_text: "ok".to_string(),
2474 };
2475 let agent = Agent::new(
2476 AgentConfig::default(),
2477 Box::new(provider),
2478 Box::new(memory),
2479 vec![Box::new(EchoTool)],
2480 )
2481 .with_audit(Box::new(RecordingAuditSink {
2482 events: audit_events.clone(),
2483 }));
2484
2485 agent
2486 .respond(
2487 UserMessage {
2488 text: "tool:echo ping".to_string(),
2489 },
2490 &test_ctx(),
2491 )
2492 .await
2493 .expect("respond should succeed");
2494
2495 let events = audit_events.lock().expect("audit lock poisoned");
2496 let request_id = events
2497 .iter()
2498 .find_map(|e| {
2499 e.detail
2500 .get("request_id")
2501 .and_then(Value::as_str)
2502 .map(ToString::to_string)
2503 })
2504 .expect("request_id should exist on audit events");
2505
2506 let provider_event = events
2507 .iter()
2508 .find(|e| e.stage == "provider_call_success")
2509 .expect("provider success event should exist");
2510 assert_eq!(
2511 provider_event
2512 .detail
2513 .get("request_id")
2514 .and_then(Value::as_str),
2515 Some(request_id.as_str())
2516 );
2517 assert!(provider_event
2518 .detail
2519 .get("duration_ms")
2520 .and_then(Value::as_u64)
2521 .is_some());
2522
2523 let tool_event = events
2524 .iter()
2525 .find(|e| e.stage == "tool_execute_success")
2526 .expect("tool success event should exist");
2527 assert_eq!(
2528 tool_event.detail.get("request_id").and_then(Value::as_str),
2529 Some(request_id.as_str())
2530 );
2531 assert!(tool_event
2532 .detail
2533 .get("duration_ms")
2534 .and_then(Value::as_u64)
2535 .is_some());
2536 }
2537
2538 #[tokio::test]
2539 async fn hook_errors_respect_block_mode_for_high_tier_negative_path() {
2540 let memory = TestMemory::default();
2541 let provider = TestProvider {
2542 received_prompts: Arc::new(Mutex::new(Vec::new())),
2543 response_text: "ok".to_string(),
2544 };
2545 let agent = Agent::new(
2546 AgentConfig {
2547 hooks: HookPolicy {
2548 enabled: true,
2549 timeout_ms: 50,
2550 fail_closed: false,
2551 default_mode: HookFailureMode::Warn,
2552 low_tier_mode: HookFailureMode::Ignore,
2553 medium_tier_mode: HookFailureMode::Warn,
2554 high_tier_mode: HookFailureMode::Block,
2555 },
2556 ..AgentConfig::default()
2557 },
2558 Box::new(provider),
2559 Box::new(memory),
2560 vec![Box::new(EchoTool)],
2561 )
2562 .with_hooks(Box::new(FailingHookSink));
2563
2564 let result = agent
2565 .respond(
2566 UserMessage {
2567 text: "tool:echo ping".to_string(),
2568 },
2569 &test_ctx(),
2570 )
2571 .await;
2572
2573 match result {
2574 Err(AgentError::Hook { stage, .. }) => assert_eq!(stage, "before_tool_call"),
2575 other => panic!("expected hook block error, got {other:?}"),
2576 }
2577 }
2578
2579 #[tokio::test]
2580 async fn hook_errors_respect_warn_mode_for_high_tier_success_path() {
2581 let audit_events = Arc::new(Mutex::new(Vec::new()));
2582 let memory = TestMemory::default();
2583 let provider = TestProvider {
2584 received_prompts: Arc::new(Mutex::new(Vec::new())),
2585 response_text: "ok".to_string(),
2586 };
2587 let agent = Agent::new(
2588 AgentConfig {
2589 hooks: HookPolicy {
2590 enabled: true,
2591 timeout_ms: 50,
2592 fail_closed: false,
2593 default_mode: HookFailureMode::Warn,
2594 low_tier_mode: HookFailureMode::Ignore,
2595 medium_tier_mode: HookFailureMode::Warn,
2596 high_tier_mode: HookFailureMode::Warn,
2597 },
2598 ..AgentConfig::default()
2599 },
2600 Box::new(provider),
2601 Box::new(memory),
2602 vec![Box::new(EchoTool)],
2603 )
2604 .with_hooks(Box::new(FailingHookSink))
2605 .with_audit(Box::new(RecordingAuditSink {
2606 events: audit_events.clone(),
2607 }));
2608
2609 let response = agent
2610 .respond(
2611 UserMessage {
2612 text: "tool:echo ping".to_string(),
2613 },
2614 &test_ctx(),
2615 )
2616 .await
2617 .expect("warn mode should continue");
2618 assert!(!response.text.is_empty());
2619
2620 let events = audit_events.lock().expect("audit lock poisoned");
2621 assert!(events.iter().any(|event| event.stage == "hook_error_warn"));
2622 }
2623
2624 #[tokio::test]
2627 async fn self_correction_injected_on_no_progress() {
2628 let prompts = Arc::new(Mutex::new(Vec::new()));
2629 let provider = TestProvider {
2630 received_prompts: prompts.clone(),
2631 response_text: "I'll try a different approach".to_string(),
2632 };
2633 let agent = Agent::new(
2636 AgentConfig {
2637 loop_detection_no_progress_threshold: 3,
2638 max_tool_iterations: 20,
2639 ..AgentConfig::default()
2640 },
2641 Box::new(provider),
2642 Box::new(TestMemory::default()),
2643 vec![Box::new(LoopTool)],
2644 );
2645
2646 let response = agent
2647 .respond(
2648 UserMessage {
2649 text: "tool:loop_tool x".to_string(),
2650 },
2651 &test_ctx(),
2652 )
2653 .await
2654 .expect("self-correction should succeed");
2655
2656 assert_eq!(response.text, "I'll try a different approach");
2657
2658 let received = prompts.lock().expect("provider lock poisoned");
2660 let last_prompt = received.last().expect("provider should have been called");
2661 assert!(
2662 last_prompt.contains("Loop detected"),
2663 "provider prompt should contain self-correction notice, got: {last_prompt}"
2664 );
2665 assert!(last_prompt.contains("loop_tool"));
2666 }
2667
2668 #[tokio::test]
2669 async fn self_correction_injected_on_ping_pong() {
2670 let prompts = Arc::new(Mutex::new(Vec::new()));
2671 let provider = TestProvider {
2672 received_prompts: prompts.clone(),
2673 response_text: "I stopped the loop".to_string(),
2674 };
2675 let agent = Agent::new(
2677 AgentConfig {
2678 loop_detection_ping_pong_cycles: 2,
2679 loop_detection_no_progress_threshold: 0, max_tool_iterations: 20,
2681 ..AgentConfig::default()
2682 },
2683 Box::new(provider),
2684 Box::new(TestMemory::default()),
2685 vec![Box::new(PingTool), Box::new(PongTool)],
2686 );
2687
2688 let response = agent
2689 .respond(
2690 UserMessage {
2691 text: "tool:ping x".to_string(),
2692 },
2693 &test_ctx(),
2694 )
2695 .await
2696 .expect("self-correction should succeed");
2697
2698 assert_eq!(response.text, "I stopped the loop");
2699
2700 let received = prompts.lock().expect("provider lock poisoned");
2701 let last_prompt = received.last().expect("provider should have been called");
2702 assert!(
2703 last_prompt.contains("ping-pong"),
2704 "provider prompt should contain ping-pong notice, got: {last_prompt}"
2705 );
2706 }
2707
2708 #[tokio::test]
2709 async fn no_progress_disabled_when_threshold_zero() {
2710 let prompts = Arc::new(Mutex::new(Vec::new()));
2711 let provider = TestProvider {
2712 received_prompts: prompts.clone(),
2713 response_text: "final answer".to_string(),
2714 };
2715 let agent = Agent::new(
2718 AgentConfig {
2719 loop_detection_no_progress_threshold: 0,
2720 loop_detection_ping_pong_cycles: 0,
2721 max_tool_iterations: 5,
2722 ..AgentConfig::default()
2723 },
2724 Box::new(provider),
2725 Box::new(TestMemory::default()),
2726 vec![Box::new(LoopTool)],
2727 );
2728
2729 let response = agent
2730 .respond(
2731 UserMessage {
2732 text: "tool:loop_tool x".to_string(),
2733 },
2734 &test_ctx(),
2735 )
2736 .await
2737 .expect("should complete without error");
2738
2739 assert_eq!(response.text, "final answer");
2740
2741 let received = prompts.lock().expect("provider lock poisoned");
2743 let last_prompt = received.last().expect("provider should have been called");
2744 assert!(
2745 !last_prompt.contains("Loop detected"),
2746 "should not contain self-correction when detection is disabled"
2747 );
2748 }
2749
2750 #[tokio::test]
2753 async fn parallel_tools_executes_multiple_calls() {
2754 let prompts = Arc::new(Mutex::new(Vec::new()));
2755 let provider = TestProvider {
2756 received_prompts: prompts.clone(),
2757 response_text: "parallel done".to_string(),
2758 };
2759 let agent = Agent::new(
2760 AgentConfig {
2761 parallel_tools: true,
2762 max_tool_iterations: 5,
2763 ..AgentConfig::default()
2764 },
2765 Box::new(provider),
2766 Box::new(TestMemory::default()),
2767 vec![Box::new(EchoTool), Box::new(UpperTool)],
2768 );
2769
2770 let response = agent
2771 .respond(
2772 UserMessage {
2773 text: "tool:echo hello\ntool:upper world".to_string(),
2774 },
2775 &test_ctx(),
2776 )
2777 .await
2778 .expect("parallel tools should succeed");
2779
2780 assert_eq!(response.text, "parallel done");
2781
2782 let received = prompts.lock().expect("provider lock poisoned");
2783 let last_prompt = received.last().expect("provider should have been called");
2784 assert!(
2785 last_prompt.contains("echoed:hello"),
2786 "should contain echo output, got: {last_prompt}"
2787 );
2788 assert!(
2789 last_prompt.contains("WORLD"),
2790 "should contain upper output, got: {last_prompt}"
2791 );
2792 }
2793
2794 #[tokio::test]
2795 async fn parallel_tools_maintains_stable_ordering() {
2796 let prompts = Arc::new(Mutex::new(Vec::new()));
2797 let provider = TestProvider {
2798 received_prompts: prompts.clone(),
2799 response_text: "ordered".to_string(),
2800 };
2801 let agent = Agent::new(
2802 AgentConfig {
2803 parallel_tools: true,
2804 max_tool_iterations: 5,
2805 ..AgentConfig::default()
2806 },
2807 Box::new(provider),
2808 Box::new(TestMemory::default()),
2809 vec![Box::new(EchoTool), Box::new(UpperTool)],
2810 );
2811
2812 let response = agent
2814 .respond(
2815 UserMessage {
2816 text: "tool:echo aaa\ntool:upper bbb".to_string(),
2817 },
2818 &test_ctx(),
2819 )
2820 .await
2821 .expect("parallel tools should succeed");
2822
2823 assert_eq!(response.text, "ordered");
2824
2825 let received = prompts.lock().expect("provider lock poisoned");
2826 let last_prompt = received.last().expect("provider should have been called");
2827 let echo_pos = last_prompt
2828 .find("echoed:aaa")
2829 .expect("echo output should exist");
2830 let upper_pos = last_prompt.find("BBB").expect("upper output should exist");
2831 assert!(
2832 echo_pos < upper_pos,
2833 "echo output should come before upper output in the prompt"
2834 );
2835 }
2836
2837 #[tokio::test]
2838 async fn parallel_disabled_runs_first_call_only() {
2839 let prompts = Arc::new(Mutex::new(Vec::new()));
2840 let provider = TestProvider {
2841 received_prompts: prompts.clone(),
2842 response_text: "sequential".to_string(),
2843 };
2844 let agent = Agent::new(
2845 AgentConfig {
2846 parallel_tools: false,
2847 max_tool_iterations: 5,
2848 ..AgentConfig::default()
2849 },
2850 Box::new(provider),
2851 Box::new(TestMemory::default()),
2852 vec![Box::new(EchoTool), Box::new(UpperTool)],
2853 );
2854
2855 let response = agent
2856 .respond(
2857 UserMessage {
2858 text: "tool:echo hello\ntool:upper world".to_string(),
2859 },
2860 &test_ctx(),
2861 )
2862 .await
2863 .expect("sequential tools should succeed");
2864
2865 assert_eq!(response.text, "sequential");
2866
2867 let received = prompts.lock().expect("provider lock poisoned");
2869 let last_prompt = received.last().expect("provider should have been called");
2870 assert!(
2871 last_prompt.contains("echoed:hello"),
2872 "should contain first tool output, got: {last_prompt}"
2873 );
2874 assert!(
2875 !last_prompt.contains("WORLD"),
2876 "should NOT contain second tool output when parallel is disabled"
2877 );
2878 }
2879
2880 #[tokio::test]
2881 async fn single_call_parallel_matches_sequential() {
2882 let prompts = Arc::new(Mutex::new(Vec::new()));
2883 let provider = TestProvider {
2884 received_prompts: prompts.clone(),
2885 response_text: "single".to_string(),
2886 };
2887 let agent = Agent::new(
2888 AgentConfig {
2889 parallel_tools: true,
2890 max_tool_iterations: 5,
2891 ..AgentConfig::default()
2892 },
2893 Box::new(provider),
2894 Box::new(TestMemory::default()),
2895 vec![Box::new(EchoTool)],
2896 );
2897
2898 let response = agent
2899 .respond(
2900 UserMessage {
2901 text: "tool:echo ping".to_string(),
2902 },
2903 &test_ctx(),
2904 )
2905 .await
2906 .expect("single tool with parallel enabled should succeed");
2907
2908 assert_eq!(response.text, "single");
2909
2910 let received = prompts.lock().expect("provider lock poisoned");
2911 let last_prompt = received.last().expect("provider should have been called");
2912 assert!(
2913 last_prompt.contains("echoed:ping"),
2914 "single tool call should work same as sequential, got: {last_prompt}"
2915 );
2916 }
2917
2918 #[tokio::test]
2919 async fn parallel_falls_back_to_sequential_for_gated_tools() {
2920 let prompts = Arc::new(Mutex::new(Vec::new()));
2921 let provider = TestProvider {
2922 received_prompts: prompts.clone(),
2923 response_text: "gated sequential".to_string(),
2924 };
2925 let mut gated = std::collections::HashSet::new();
2926 gated.insert("upper".to_string());
2927 let agent = Agent::new(
2928 AgentConfig {
2929 parallel_tools: true,
2930 gated_tools: gated,
2931 max_tool_iterations: 5,
2932 ..AgentConfig::default()
2933 },
2934 Box::new(provider),
2935 Box::new(TestMemory::default()),
2936 vec![Box::new(EchoTool), Box::new(UpperTool)],
2937 );
2938
2939 let response = agent
2942 .respond(
2943 UserMessage {
2944 text: "tool:echo hello\ntool:upper world".to_string(),
2945 },
2946 &test_ctx(),
2947 )
2948 .await
2949 .expect("gated fallback should succeed");
2950
2951 assert_eq!(response.text, "gated sequential");
2952
2953 let received = prompts.lock().expect("provider lock poisoned");
2954 let last_prompt = received.last().expect("provider should have been called");
2955 assert!(
2956 last_prompt.contains("echoed:hello"),
2957 "first tool should execute, got: {last_prompt}"
2958 );
2959 assert!(
2960 !last_prompt.contains("WORLD"),
2961 "gated tool should NOT execute in parallel, got: {last_prompt}"
2962 );
2963 }
2964
2965 fn research_config(trigger: ResearchTrigger) -> ResearchPolicy {
2968 ResearchPolicy {
2969 enabled: true,
2970 trigger,
2971 keywords: vec!["search".to_string(), "find".to_string()],
2972 min_message_length: 10,
2973 max_iterations: 5,
2974 show_progress: true,
2975 }
2976 }
2977
2978 fn config_with_research(trigger: ResearchTrigger) -> AgentConfig {
2979 AgentConfig {
2980 research: research_config(trigger),
2981 ..Default::default()
2982 }
2983 }
2984
2985 #[tokio::test]
2986 async fn research_trigger_never() {
2987 let prompts = Arc::new(Mutex::new(Vec::new()));
2988 let provider = TestProvider {
2989 received_prompts: prompts.clone(),
2990 response_text: "final answer".to_string(),
2991 };
2992 let agent = Agent::new(
2993 config_with_research(ResearchTrigger::Never),
2994 Box::new(provider),
2995 Box::new(TestMemory::default()),
2996 vec![Box::new(EchoTool)],
2997 );
2998
2999 let response = agent
3000 .respond(
3001 UserMessage {
3002 text: "search for something".to_string(),
3003 },
3004 &test_ctx(),
3005 )
3006 .await
3007 .expect("should succeed");
3008
3009 assert_eq!(response.text, "final answer");
3010 let received = prompts.lock().expect("lock");
3011 assert_eq!(
3012 received.len(),
3013 1,
3014 "should have exactly 1 provider call (no research)"
3015 );
3016 }
3017
3018 #[tokio::test]
3019 async fn research_trigger_always() {
3020 let provider = ScriptedProvider::new(vec![
3021 "tool:echo gathering data",
3022 "Found relevant information",
3023 "Final answer with research context",
3024 ]);
3025 let prompts = provider.received_prompts.clone();
3026 let agent = Agent::new(
3027 config_with_research(ResearchTrigger::Always),
3028 Box::new(provider),
3029 Box::new(TestMemory::default()),
3030 vec![Box::new(EchoTool)],
3031 );
3032
3033 let response = agent
3034 .respond(
3035 UserMessage {
3036 text: "hello world".to_string(),
3037 },
3038 &test_ctx(),
3039 )
3040 .await
3041 .expect("should succeed");
3042
3043 assert_eq!(response.text, "Final answer with research context");
3044 let received = prompts.lock().expect("lock");
3045 let last = received.last().expect("should have provider calls");
3046 assert!(
3047 last.contains("Research findings:"),
3048 "final prompt should contain research findings, got: {last}"
3049 );
3050 }
3051
3052 #[tokio::test]
3053 async fn research_trigger_keywords_match() {
3054 let provider = ScriptedProvider::new(vec!["Summary of search", "Answer based on research"]);
3055 let prompts = provider.received_prompts.clone();
3056 let agent = Agent::new(
3057 config_with_research(ResearchTrigger::Keywords),
3058 Box::new(provider),
3059 Box::new(TestMemory::default()),
3060 vec![],
3061 );
3062
3063 let response = agent
3064 .respond(
3065 UserMessage {
3066 text: "please search for the config".to_string(),
3067 },
3068 &test_ctx(),
3069 )
3070 .await
3071 .expect("should succeed");
3072
3073 assert_eq!(response.text, "Answer based on research");
3074 let received = prompts.lock().expect("lock");
3075 assert!(received.len() >= 2, "should have at least 2 provider calls");
3076 assert!(
3077 received[0].contains("RESEARCH mode"),
3078 "first call should be research prompt"
3079 );
3080 }
3081
3082 #[tokio::test]
3083 async fn research_trigger_keywords_no_match() {
3084 let prompts = Arc::new(Mutex::new(Vec::new()));
3085 let provider = TestProvider {
3086 received_prompts: prompts.clone(),
3087 response_text: "direct answer".to_string(),
3088 };
3089 let agent = Agent::new(
3090 config_with_research(ResearchTrigger::Keywords),
3091 Box::new(provider),
3092 Box::new(TestMemory::default()),
3093 vec![],
3094 );
3095
3096 let response = agent
3097 .respond(
3098 UserMessage {
3099 text: "hello world".to_string(),
3100 },
3101 &test_ctx(),
3102 )
3103 .await
3104 .expect("should succeed");
3105
3106 assert_eq!(response.text, "direct answer");
3107 let received = prompts.lock().expect("lock");
3108 assert_eq!(received.len(), 1, "no research phase should fire");
3109 }
3110
3111 #[tokio::test]
3112 async fn research_trigger_length_short_skips() {
3113 let prompts = Arc::new(Mutex::new(Vec::new()));
3114 let provider = TestProvider {
3115 received_prompts: prompts.clone(),
3116 response_text: "short".to_string(),
3117 };
3118 let config = AgentConfig {
3119 research: ResearchPolicy {
3120 min_message_length: 20,
3121 ..research_config(ResearchTrigger::Length)
3122 },
3123 ..Default::default()
3124 };
3125 let agent = Agent::new(
3126 config,
3127 Box::new(provider),
3128 Box::new(TestMemory::default()),
3129 vec![],
3130 );
3131 agent
3132 .respond(
3133 UserMessage {
3134 text: "hi".to_string(),
3135 },
3136 &test_ctx(),
3137 )
3138 .await
3139 .expect("should succeed");
3140 let received = prompts.lock().expect("lock");
3141 assert_eq!(received.len(), 1, "short message should skip research");
3142 }
3143
3144 #[tokio::test]
3145 async fn research_trigger_length_long_triggers() {
3146 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
3147 let prompts = provider.received_prompts.clone();
3148 let config = AgentConfig {
3149 research: ResearchPolicy {
3150 min_message_length: 20,
3151 ..research_config(ResearchTrigger::Length)
3152 },
3153 ..Default::default()
3154 };
3155 let agent = Agent::new(
3156 config,
3157 Box::new(provider),
3158 Box::new(TestMemory::default()),
3159 vec![],
3160 );
3161 agent
3162 .respond(
3163 UserMessage {
3164 text: "this is a longer message that exceeds the threshold".to_string(),
3165 },
3166 &test_ctx(),
3167 )
3168 .await
3169 .expect("should succeed");
3170 let received = prompts.lock().expect("lock");
3171 assert!(received.len() >= 2, "long message should trigger research");
3172 }
3173
3174 #[tokio::test]
3175 async fn research_trigger_question_with_mark() {
3176 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
3177 let prompts = provider.received_prompts.clone();
3178 let agent = Agent::new(
3179 config_with_research(ResearchTrigger::Question),
3180 Box::new(provider),
3181 Box::new(TestMemory::default()),
3182 vec![],
3183 );
3184 agent
3185 .respond(
3186 UserMessage {
3187 text: "what is the meaning of life?".to_string(),
3188 },
3189 &test_ctx(),
3190 )
3191 .await
3192 .expect("should succeed");
3193 let received = prompts.lock().expect("lock");
3194 assert!(received.len() >= 2, "question should trigger research");
3195 }
3196
3197 #[tokio::test]
3198 async fn research_trigger_question_without_mark() {
3199 let prompts = Arc::new(Mutex::new(Vec::new()));
3200 let provider = TestProvider {
3201 received_prompts: prompts.clone(),
3202 response_text: "no research".to_string(),
3203 };
3204 let agent = Agent::new(
3205 config_with_research(ResearchTrigger::Question),
3206 Box::new(provider),
3207 Box::new(TestMemory::default()),
3208 vec![],
3209 );
3210 agent
3211 .respond(
3212 UserMessage {
3213 text: "do this thing".to_string(),
3214 },
3215 &test_ctx(),
3216 )
3217 .await
3218 .expect("should succeed");
3219 let received = prompts.lock().expect("lock");
3220 assert_eq!(received.len(), 1, "non-question should skip research");
3221 }
3222
3223 #[tokio::test]
3224 async fn research_respects_max_iterations() {
3225 let provider = ScriptedProvider::new(vec![
3226 "tool:echo step1",
3227 "tool:echo step2",
3228 "tool:echo step3",
3229 "tool:echo step4",
3230 "tool:echo step5",
3231 "tool:echo step6",
3232 "tool:echo step7",
3233 "answer after research",
3234 ]);
3235 let prompts = provider.received_prompts.clone();
3236 let config = AgentConfig {
3237 research: ResearchPolicy {
3238 max_iterations: 3,
3239 ..research_config(ResearchTrigger::Always)
3240 },
3241 ..Default::default()
3242 };
3243 let agent = Agent::new(
3244 config,
3245 Box::new(provider),
3246 Box::new(TestMemory::default()),
3247 vec![Box::new(EchoTool)],
3248 );
3249
3250 let response = agent
3251 .respond(
3252 UserMessage {
3253 text: "test".to_string(),
3254 },
3255 &test_ctx(),
3256 )
3257 .await
3258 .expect("should succeed");
3259
3260 assert!(!response.text.is_empty(), "should get a response");
3261 let received = prompts.lock().expect("lock");
3262 let research_calls = received
3263 .iter()
3264 .filter(|p| p.contains("RESEARCH mode") || p.contains("Research iteration"))
3265 .count();
3266 assert_eq!(
3268 research_calls, 4,
3269 "should have 4 research provider calls (1 initial + 3 iterations)"
3270 );
3271 }
3272
3273 #[tokio::test]
3274 async fn research_disabled_skips_phase() {
3275 let prompts = Arc::new(Mutex::new(Vec::new()));
3276 let provider = TestProvider {
3277 received_prompts: prompts.clone(),
3278 response_text: "direct answer".to_string(),
3279 };
3280 let config = AgentConfig {
3281 research: ResearchPolicy {
3282 enabled: false,
3283 trigger: ResearchTrigger::Always,
3284 ..Default::default()
3285 },
3286 ..Default::default()
3287 };
3288 let agent = Agent::new(
3289 config,
3290 Box::new(provider),
3291 Box::new(TestMemory::default()),
3292 vec![Box::new(EchoTool)],
3293 );
3294
3295 let response = agent
3296 .respond(
3297 UserMessage {
3298 text: "search for something".to_string(),
3299 },
3300 &test_ctx(),
3301 )
3302 .await
3303 .expect("should succeed");
3304
3305 assert_eq!(response.text, "direct answer");
3306 let received = prompts.lock().expect("lock");
3307 assert_eq!(
3308 received.len(),
3309 1,
3310 "disabled research should not fire even with Always trigger"
3311 );
3312 }
3313
3314 struct ReasoningCapturingProvider {
3317 captured_reasoning: Arc<Mutex<Vec<ReasoningConfig>>>,
3318 response_text: String,
3319 }
3320
3321 #[async_trait]
3322 impl Provider for ReasoningCapturingProvider {
3323 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3324 self.captured_reasoning
3325 .lock()
3326 .expect("lock")
3327 .push(ReasoningConfig::default());
3328 Ok(ChatResult {
3329 output_text: self.response_text.clone(),
3330 ..Default::default()
3331 })
3332 }
3333
3334 async fn complete_with_reasoning(
3335 &self,
3336 _prompt: &str,
3337 reasoning: &ReasoningConfig,
3338 ) -> anyhow::Result<ChatResult> {
3339 self.captured_reasoning
3340 .lock()
3341 .expect("lock")
3342 .push(reasoning.clone());
3343 Ok(ChatResult {
3344 output_text: self.response_text.clone(),
3345 ..Default::default()
3346 })
3347 }
3348 }
3349
3350 #[tokio::test]
3351 async fn reasoning_config_passed_to_provider() {
3352 let captured = Arc::new(Mutex::new(Vec::new()));
3353 let provider = ReasoningCapturingProvider {
3354 captured_reasoning: captured.clone(),
3355 response_text: "ok".to_string(),
3356 };
3357 let config = AgentConfig {
3358 reasoning: ReasoningConfig {
3359 enabled: Some(true),
3360 level: Some("high".to_string()),
3361 },
3362 ..Default::default()
3363 };
3364 let agent = Agent::new(
3365 config,
3366 Box::new(provider),
3367 Box::new(TestMemory::default()),
3368 vec![],
3369 );
3370
3371 agent
3372 .respond(
3373 UserMessage {
3374 text: "test".to_string(),
3375 },
3376 &test_ctx(),
3377 )
3378 .await
3379 .expect("should succeed");
3380
3381 let configs = captured.lock().expect("lock");
3382 assert_eq!(configs.len(), 1, "provider should be called once");
3383 assert_eq!(configs[0].enabled, Some(true));
3384 assert_eq!(configs[0].level.as_deref(), Some("high"));
3385 }
3386
3387 #[tokio::test]
3388 async fn reasoning_disabled_passes_config_through() {
3389 let captured = Arc::new(Mutex::new(Vec::new()));
3390 let provider = ReasoningCapturingProvider {
3391 captured_reasoning: captured.clone(),
3392 response_text: "ok".to_string(),
3393 };
3394 let config = AgentConfig {
3395 reasoning: ReasoningConfig {
3396 enabled: Some(false),
3397 level: None,
3398 },
3399 ..Default::default()
3400 };
3401 let agent = Agent::new(
3402 config,
3403 Box::new(provider),
3404 Box::new(TestMemory::default()),
3405 vec![],
3406 );
3407
3408 agent
3409 .respond(
3410 UserMessage {
3411 text: "test".to_string(),
3412 },
3413 &test_ctx(),
3414 )
3415 .await
3416 .expect("should succeed");
3417
3418 let configs = captured.lock().expect("lock");
3419 assert_eq!(configs.len(), 1);
3420 assert_eq!(
3421 configs[0].enabled,
3422 Some(false),
3423 "disabled reasoning should still be passed to provider"
3424 );
3425 assert_eq!(configs[0].level, None);
3426 }
3427
3428 struct StructuredProvider {
3431 responses: Vec<ChatResult>,
3432 call_count: AtomicUsize,
3433 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
3434 }
3435
3436 impl StructuredProvider {
3437 fn new(responses: Vec<ChatResult>) -> Self {
3438 Self {
3439 responses,
3440 call_count: AtomicUsize::new(0),
3441 received_messages: Arc::new(Mutex::new(Vec::new())),
3442 }
3443 }
3444 }
3445
3446 #[async_trait]
3447 impl Provider for StructuredProvider {
3448 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3449 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3450 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3451 }
3452
3453 async fn complete_with_tools(
3454 &self,
3455 messages: &[ConversationMessage],
3456 _tools: &[ToolDefinition],
3457 _reasoning: &ReasoningConfig,
3458 ) -> anyhow::Result<ChatResult> {
3459 self.received_messages
3460 .lock()
3461 .expect("lock")
3462 .push(messages.to_vec());
3463 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3464 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3465 }
3466 }
3467
3468 struct StructuredEchoTool;
3469
3470 #[async_trait]
3471 impl Tool for StructuredEchoTool {
3472 fn name(&self) -> &'static str {
3473 "echo"
3474 }
3475 fn description(&self) -> &'static str {
3476 "Echoes input back"
3477 }
3478 fn input_schema(&self) -> Option<serde_json::Value> {
3479 Some(json!({
3480 "type": "object",
3481 "properties": {
3482 "text": { "type": "string", "description": "Text to echo" }
3483 },
3484 "required": ["text"]
3485 }))
3486 }
3487 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3488 Ok(ToolResult {
3489 output: format!("echoed:{input}"),
3490 })
3491 }
3492 }
3493
3494 struct StructuredFailingTool;
3495
3496 #[async_trait]
3497 impl Tool for StructuredFailingTool {
3498 fn name(&self) -> &'static str {
3499 "boom"
3500 }
3501 fn description(&self) -> &'static str {
3502 "Always fails"
3503 }
3504 fn input_schema(&self) -> Option<serde_json::Value> {
3505 Some(json!({
3506 "type": "object",
3507 "properties": {},
3508 }))
3509 }
3510 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3511 Err(anyhow::anyhow!("tool exploded"))
3512 }
3513 }
3514
3515 struct StructuredUpperTool;
3516
3517 #[async_trait]
3518 impl Tool for StructuredUpperTool {
3519 fn name(&self) -> &'static str {
3520 "upper"
3521 }
3522 fn description(&self) -> &'static str {
3523 "Uppercases input"
3524 }
3525 fn input_schema(&self) -> Option<serde_json::Value> {
3526 Some(json!({
3527 "type": "object",
3528 "properties": {
3529 "text": { "type": "string", "description": "Text to uppercase" }
3530 },
3531 "required": ["text"]
3532 }))
3533 }
3534 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3535 Ok(ToolResult {
3536 output: input.to_uppercase(),
3537 })
3538 }
3539 }
3540
3541 use crate::types::ToolUseRequest;
3542
3543 #[tokio::test]
3546 async fn structured_basic_tool_call_then_end_turn() {
3547 let provider = StructuredProvider::new(vec![
3548 ChatResult {
3549 output_text: "Let me echo that.".to_string(),
3550 tool_calls: vec![ToolUseRequest {
3551 id: "call_1".to_string(),
3552 name: "echo".to_string(),
3553 input: json!({"text": "hello"}),
3554 }],
3555 stop_reason: Some(StopReason::ToolUse),
3556 ..Default::default()
3557 },
3558 ChatResult {
3559 output_text: "The echo returned: hello".to_string(),
3560 tool_calls: vec![],
3561 stop_reason: Some(StopReason::EndTurn),
3562 ..Default::default()
3563 },
3564 ]);
3565 let received = provider.received_messages.clone();
3566
3567 let agent = Agent::new(
3568 AgentConfig::default(),
3569 Box::new(provider),
3570 Box::new(TestMemory::default()),
3571 vec![Box::new(StructuredEchoTool)],
3572 );
3573
3574 let response = agent
3575 .respond(
3576 UserMessage {
3577 text: "echo hello".to_string(),
3578 },
3579 &test_ctx(),
3580 )
3581 .await
3582 .expect("should succeed");
3583
3584 assert_eq!(response.text, "The echo returned: hello");
3585
3586 let msgs = received.lock().expect("lock");
3588 assert_eq!(msgs.len(), 2, "provider should be called twice");
3589 assert!(
3591 msgs[1].len() >= 3,
3592 "second call should have user + assistant + tool result"
3593 );
3594 }
3595
3596 #[tokio::test]
3597 async fn structured_no_tool_calls_returns_immediately() {
3598 let provider = StructuredProvider::new(vec![ChatResult {
3599 output_text: "Hello! No tools needed.".to_string(),
3600 tool_calls: vec![],
3601 stop_reason: Some(StopReason::EndTurn),
3602 ..Default::default()
3603 }]);
3604
3605 let agent = Agent::new(
3606 AgentConfig::default(),
3607 Box::new(provider),
3608 Box::new(TestMemory::default()),
3609 vec![Box::new(StructuredEchoTool)],
3610 );
3611
3612 let response = agent
3613 .respond(
3614 UserMessage {
3615 text: "hi".to_string(),
3616 },
3617 &test_ctx(),
3618 )
3619 .await
3620 .expect("should succeed");
3621
3622 assert_eq!(response.text, "Hello! No tools needed.");
3623 }
3624
3625 #[tokio::test]
3626 async fn structured_tool_not_found_sends_error_result() {
3627 let provider = StructuredProvider::new(vec![
3628 ChatResult {
3629 output_text: String::new(),
3630 tool_calls: vec![ToolUseRequest {
3631 id: "call_1".to_string(),
3632 name: "nonexistent".to_string(),
3633 input: json!({}),
3634 }],
3635 stop_reason: Some(StopReason::ToolUse),
3636 ..Default::default()
3637 },
3638 ChatResult {
3639 output_text: "I see the tool was not found.".to_string(),
3640 tool_calls: vec![],
3641 stop_reason: Some(StopReason::EndTurn),
3642 ..Default::default()
3643 },
3644 ]);
3645 let received = provider.received_messages.clone();
3646
3647 let agent = Agent::new(
3648 AgentConfig::default(),
3649 Box::new(provider),
3650 Box::new(TestMemory::default()),
3651 vec![Box::new(StructuredEchoTool)],
3652 );
3653
3654 let response = agent
3655 .respond(
3656 UserMessage {
3657 text: "use nonexistent".to_string(),
3658 },
3659 &test_ctx(),
3660 )
3661 .await
3662 .expect("should succeed, not abort");
3663
3664 assert_eq!(response.text, "I see the tool was not found.");
3665
3666 let msgs = received.lock().expect("lock");
3668 let last_call = &msgs[1];
3669 let has_error_result = last_call.iter().any(|m| {
3670 matches!(m, ConversationMessage::ToolResult(r) if r.is_error && r.content.contains("not found"))
3671 });
3672 assert!(has_error_result, "should include error ToolResult");
3673 }
3674
3675 #[tokio::test]
3676 async fn structured_tool_error_does_not_abort() {
3677 let provider = StructuredProvider::new(vec![
3678 ChatResult {
3679 output_text: String::new(),
3680 tool_calls: vec![ToolUseRequest {
3681 id: "call_1".to_string(),
3682 name: "boom".to_string(),
3683 input: json!({}),
3684 }],
3685 stop_reason: Some(StopReason::ToolUse),
3686 ..Default::default()
3687 },
3688 ChatResult {
3689 output_text: "I handled the error gracefully.".to_string(),
3690 tool_calls: vec![],
3691 stop_reason: Some(StopReason::EndTurn),
3692 ..Default::default()
3693 },
3694 ]);
3695
3696 let agent = Agent::new(
3697 AgentConfig::default(),
3698 Box::new(provider),
3699 Box::new(TestMemory::default()),
3700 vec![Box::new(StructuredFailingTool)],
3701 );
3702
3703 let response = agent
3704 .respond(
3705 UserMessage {
3706 text: "boom".to_string(),
3707 },
3708 &test_ctx(),
3709 )
3710 .await
3711 .expect("should succeed, error becomes ToolResultMessage");
3712
3713 assert_eq!(response.text, "I handled the error gracefully.");
3714 }
3715
3716 #[tokio::test]
3717 async fn structured_multi_iteration_tool_calls() {
3718 let provider = StructuredProvider::new(vec![
3719 ChatResult {
3721 output_text: String::new(),
3722 tool_calls: vec![ToolUseRequest {
3723 id: "call_1".to_string(),
3724 name: "echo".to_string(),
3725 input: json!({"text": "first"}),
3726 }],
3727 stop_reason: Some(StopReason::ToolUse),
3728 ..Default::default()
3729 },
3730 ChatResult {
3732 output_text: String::new(),
3733 tool_calls: vec![ToolUseRequest {
3734 id: "call_2".to_string(),
3735 name: "echo".to_string(),
3736 input: json!({"text": "second"}),
3737 }],
3738 stop_reason: Some(StopReason::ToolUse),
3739 ..Default::default()
3740 },
3741 ChatResult {
3743 output_text: "Done with two tool calls.".to_string(),
3744 tool_calls: vec![],
3745 stop_reason: Some(StopReason::EndTurn),
3746 ..Default::default()
3747 },
3748 ]);
3749 let received = provider.received_messages.clone();
3750
3751 let agent = Agent::new(
3752 AgentConfig::default(),
3753 Box::new(provider),
3754 Box::new(TestMemory::default()),
3755 vec![Box::new(StructuredEchoTool)],
3756 );
3757
3758 let response = agent
3759 .respond(
3760 UserMessage {
3761 text: "echo twice".to_string(),
3762 },
3763 &test_ctx(),
3764 )
3765 .await
3766 .expect("should succeed");
3767
3768 assert_eq!(response.text, "Done with two tool calls.");
3769 let msgs = received.lock().expect("lock");
3770 assert_eq!(msgs.len(), 3, "three provider calls");
3771 }
3772
3773 #[tokio::test]
3774 async fn structured_max_iterations_forces_final_answer() {
3775 let responses = vec![
3779 ChatResult {
3780 output_text: String::new(),
3781 tool_calls: vec![ToolUseRequest {
3782 id: "call_0".to_string(),
3783 name: "echo".to_string(),
3784 input: json!({"text": "loop"}),
3785 }],
3786 stop_reason: Some(StopReason::ToolUse),
3787 ..Default::default()
3788 },
3789 ChatResult {
3790 output_text: String::new(),
3791 tool_calls: vec![ToolUseRequest {
3792 id: "call_1".to_string(),
3793 name: "echo".to_string(),
3794 input: json!({"text": "loop"}),
3795 }],
3796 stop_reason: Some(StopReason::ToolUse),
3797 ..Default::default()
3798 },
3799 ChatResult {
3800 output_text: String::new(),
3801 tool_calls: vec![ToolUseRequest {
3802 id: "call_2".to_string(),
3803 name: "echo".to_string(),
3804 input: json!({"text": "loop"}),
3805 }],
3806 stop_reason: Some(StopReason::ToolUse),
3807 ..Default::default()
3808 },
3809 ChatResult {
3811 output_text: "Forced final answer.".to_string(),
3812 tool_calls: vec![],
3813 stop_reason: Some(StopReason::EndTurn),
3814 ..Default::default()
3815 },
3816 ];
3817
3818 let agent = Agent::new(
3819 AgentConfig {
3820 max_tool_iterations: 3,
3821 loop_detection_no_progress_threshold: 0, ..AgentConfig::default()
3823 },
3824 Box::new(StructuredProvider::new(responses)),
3825 Box::new(TestMemory::default()),
3826 vec![Box::new(StructuredEchoTool)],
3827 );
3828
3829 let response = agent
3830 .respond(
3831 UserMessage {
3832 text: "loop forever".to_string(),
3833 },
3834 &test_ctx(),
3835 )
3836 .await
3837 .expect("should succeed with forced answer");
3838
3839 assert_eq!(response.text, "Forced final answer.");
3840 }
3841
3842 #[tokio::test]
3843 async fn structured_fallback_to_text_when_disabled() {
3844 let prompts = Arc::new(Mutex::new(Vec::new()));
3845 let provider = TestProvider {
3846 received_prompts: prompts.clone(),
3847 response_text: "text path used".to_string(),
3848 };
3849
3850 let agent = Agent::new(
3851 AgentConfig {
3852 model_supports_tool_use: false,
3853 ..AgentConfig::default()
3854 },
3855 Box::new(provider),
3856 Box::new(TestMemory::default()),
3857 vec![Box::new(StructuredEchoTool)],
3858 );
3859
3860 let response = agent
3861 .respond(
3862 UserMessage {
3863 text: "hello".to_string(),
3864 },
3865 &test_ctx(),
3866 )
3867 .await
3868 .expect("should succeed");
3869
3870 assert_eq!(response.text, "text path used");
3871 }
3872
3873 #[tokio::test]
3874 async fn structured_fallback_when_no_tool_schemas() {
3875 let prompts = Arc::new(Mutex::new(Vec::new()));
3876 let provider = TestProvider {
3877 received_prompts: prompts.clone(),
3878 response_text: "text path used".to_string(),
3879 };
3880
3881 let agent = Agent::new(
3883 AgentConfig::default(),
3884 Box::new(provider),
3885 Box::new(TestMemory::default()),
3886 vec![Box::new(EchoTool)],
3887 );
3888
3889 let response = agent
3890 .respond(
3891 UserMessage {
3892 text: "hello".to_string(),
3893 },
3894 &test_ctx(),
3895 )
3896 .await
3897 .expect("should succeed");
3898
3899 assert_eq!(response.text, "text path used");
3900 }
3901
3902 #[tokio::test]
3903 async fn structured_parallel_tool_calls() {
3904 let provider = StructuredProvider::new(vec![
3905 ChatResult {
3906 output_text: String::new(),
3907 tool_calls: vec![
3908 ToolUseRequest {
3909 id: "call_1".to_string(),
3910 name: "echo".to_string(),
3911 input: json!({"text": "a"}),
3912 },
3913 ToolUseRequest {
3914 id: "call_2".to_string(),
3915 name: "upper".to_string(),
3916 input: json!({"text": "b"}),
3917 },
3918 ],
3919 stop_reason: Some(StopReason::ToolUse),
3920 ..Default::default()
3921 },
3922 ChatResult {
3923 output_text: "Both tools ran.".to_string(),
3924 tool_calls: vec![],
3925 stop_reason: Some(StopReason::EndTurn),
3926 ..Default::default()
3927 },
3928 ]);
3929 let received = provider.received_messages.clone();
3930
3931 let agent = Agent::new(
3932 AgentConfig {
3933 parallel_tools: true,
3934 ..AgentConfig::default()
3935 },
3936 Box::new(provider),
3937 Box::new(TestMemory::default()),
3938 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
3939 );
3940
3941 let response = agent
3942 .respond(
3943 UserMessage {
3944 text: "do both".to_string(),
3945 },
3946 &test_ctx(),
3947 )
3948 .await
3949 .expect("should succeed");
3950
3951 assert_eq!(response.text, "Both tools ran.");
3952
3953 let msgs = received.lock().expect("lock");
3955 let tool_results: Vec<_> = msgs[1]
3956 .iter()
3957 .filter(|m| matches!(m, ConversationMessage::ToolResult(_)))
3958 .collect();
3959 assert_eq!(tool_results.len(), 2, "should have two tool results");
3960 }
3961
3962 #[tokio::test]
3963 async fn structured_memory_integration() {
3964 let memory = TestMemory::default();
3965 memory
3967 .append(MemoryEntry {
3968 role: "user".to_string(),
3969 content: "old question".to_string(),
3970 ..Default::default()
3971 })
3972 .await
3973 .unwrap();
3974 memory
3975 .append(MemoryEntry {
3976 role: "assistant".to_string(),
3977 content: "old answer".to_string(),
3978 ..Default::default()
3979 })
3980 .await
3981 .unwrap();
3982
3983 let provider = StructuredProvider::new(vec![ChatResult {
3984 output_text: "I see your history.".to_string(),
3985 tool_calls: vec![],
3986 stop_reason: Some(StopReason::EndTurn),
3987 ..Default::default()
3988 }]);
3989 let received = provider.received_messages.clone();
3990
3991 let agent = Agent::new(
3992 AgentConfig::default(),
3993 Box::new(provider),
3994 Box::new(memory),
3995 vec![Box::new(StructuredEchoTool)],
3996 );
3997
3998 agent
3999 .respond(
4000 UserMessage {
4001 text: "new question".to_string(),
4002 },
4003 &test_ctx(),
4004 )
4005 .await
4006 .expect("should succeed");
4007
4008 let msgs = received.lock().expect("lock");
4009 assert!(msgs[0].len() >= 3, "should include memory messages");
4012 assert!(matches!(
4014 &msgs[0][0],
4015 ConversationMessage::User { content, .. } if content == "old question"
4016 ));
4017 }
4018
4019 #[tokio::test]
4020 async fn prepare_tool_input_extracts_single_string_field() {
4021 let tool = StructuredEchoTool;
4022 let input = json!({"text": "hello world"});
4023 let result = prepare_tool_input(&tool, &input).expect("valid input");
4024 assert_eq!(result, "hello world");
4025 }
4026
4027 #[tokio::test]
4028 async fn prepare_tool_input_serializes_multi_field_json() {
4029 struct MultiFieldTool;
4031 #[async_trait]
4032 impl Tool for MultiFieldTool {
4033 fn name(&self) -> &'static str {
4034 "multi"
4035 }
4036 fn input_schema(&self) -> Option<serde_json::Value> {
4037 Some(json!({
4038 "type": "object",
4039 "properties": {
4040 "path": { "type": "string" },
4041 "content": { "type": "string" }
4042 },
4043 "required": ["path", "content"]
4044 }))
4045 }
4046 async fn execute(
4047 &self,
4048 _input: &str,
4049 _ctx: &ToolContext,
4050 ) -> anyhow::Result<ToolResult> {
4051 unreachable!()
4052 }
4053 }
4054
4055 let tool = MultiFieldTool;
4056 let input = json!({"path": "a.txt", "content": "hello"});
4057 let result = prepare_tool_input(&tool, &input).expect("valid input");
4058 let parsed: serde_json::Value = serde_json::from_str(&result).expect("valid JSON");
4060 assert_eq!(parsed["path"], "a.txt");
4061 assert_eq!(parsed["content"], "hello");
4062 }
4063
4064 #[tokio::test]
4065 async fn prepare_tool_input_unwraps_bare_string() {
4066 let tool = StructuredEchoTool;
4067 let input = json!("plain text");
4068 let result = prepare_tool_input(&tool, &input).expect("valid input");
4069 assert_eq!(result, "plain text");
4070 }
4071
4072 #[tokio::test]
4073 async fn prepare_tool_input_rejects_schema_violating_input() {
4074 struct StrictTool;
4075 #[async_trait]
4076 impl Tool for StrictTool {
4077 fn name(&self) -> &'static str {
4078 "strict"
4079 }
4080 fn input_schema(&self) -> Option<serde_json::Value> {
4081 Some(json!({
4082 "type": "object",
4083 "required": ["path"],
4084 "properties": {
4085 "path": { "type": "string" }
4086 }
4087 }))
4088 }
4089 async fn execute(
4090 &self,
4091 _input: &str,
4092 _ctx: &ToolContext,
4093 ) -> anyhow::Result<ToolResult> {
4094 unreachable!("should not be called with invalid input")
4095 }
4096 }
4097
4098 let result = prepare_tool_input(&StrictTool, &json!({}));
4100 assert!(result.is_err());
4101 let err = result.unwrap_err();
4102 assert!(err.contains("Invalid input for tool 'strict'"));
4103 assert!(err.contains("missing required field"));
4104 }
4105
4106 #[tokio::test]
4107 async fn prepare_tool_input_rejects_wrong_type() {
4108 struct TypedTool;
4109 #[async_trait]
4110 impl Tool for TypedTool {
4111 fn name(&self) -> &'static str {
4112 "typed"
4113 }
4114 fn input_schema(&self) -> Option<serde_json::Value> {
4115 Some(json!({
4116 "type": "object",
4117 "required": ["count"],
4118 "properties": {
4119 "count": { "type": "integer" }
4120 }
4121 }))
4122 }
4123 async fn execute(
4124 &self,
4125 _input: &str,
4126 _ctx: &ToolContext,
4127 ) -> anyhow::Result<ToolResult> {
4128 unreachable!("should not be called with invalid input")
4129 }
4130 }
4131
4132 let result = prepare_tool_input(&TypedTool, &json!({"count": "not a number"}));
4134 assert!(result.is_err());
4135 let err = result.unwrap_err();
4136 assert!(err.contains("expected type \"integer\""));
4137 }
4138
4139 #[test]
4140 fn truncate_messages_preserves_small_conversation() {
4141 let mut msgs = vec![
4142 ConversationMessage::user("hello".to_string()),
4143 ConversationMessage::Assistant {
4144 content: Some("world".to_string()),
4145 tool_calls: vec![],
4146 },
4147 ];
4148 truncate_messages(&mut msgs, 1000);
4149 assert_eq!(msgs.len(), 2, "should not truncate small conversation");
4150 }
4151
4152 #[test]
4153 fn truncate_messages_drops_middle() {
4154 let mut msgs = vec![
4155 ConversationMessage::user("A".repeat(100)),
4156 ConversationMessage::Assistant {
4157 content: Some("B".repeat(100)),
4158 tool_calls: vec![],
4159 },
4160 ConversationMessage::user("C".repeat(100)),
4161 ConversationMessage::Assistant {
4162 content: Some("D".repeat(100)),
4163 tool_calls: vec![],
4164 },
4165 ];
4166 truncate_messages(&mut msgs, 250);
4169 assert!(msgs.len() < 4, "should have dropped some messages");
4170 assert!(matches!(
4172 &msgs[0],
4173 ConversationMessage::User { content, .. } if content.starts_with("AAA")
4174 ));
4175 }
4176
4177 #[test]
4178 fn memory_to_messages_reverses_order() {
4179 let entries = vec![
4180 MemoryEntry {
4181 role: "assistant".to_string(),
4182 content: "newest".to_string(),
4183 ..Default::default()
4184 },
4185 MemoryEntry {
4186 role: "user".to_string(),
4187 content: "oldest".to_string(),
4188 ..Default::default()
4189 },
4190 ];
4191 let msgs = memory_to_messages(&entries);
4192 assert_eq!(msgs.len(), 2);
4193 assert!(matches!(
4194 &msgs[0],
4195 ConversationMessage::User { content, .. } if content == "oldest"
4196 ));
4197 assert!(matches!(
4198 &msgs[1],
4199 ConversationMessage::Assistant { content: Some(c), .. } if c == "newest"
4200 ));
4201 }
4202
4203 #[test]
4204 fn single_required_string_field_detects_single() {
4205 let schema = json!({
4206 "type": "object",
4207 "properties": {
4208 "path": { "type": "string" }
4209 },
4210 "required": ["path"]
4211 });
4212 assert_eq!(
4213 single_required_string_field(&schema),
4214 Some("path".to_string())
4215 );
4216 }
4217
4218 #[test]
4219 fn single_required_string_field_none_for_multiple() {
4220 let schema = json!({
4221 "type": "object",
4222 "properties": {
4223 "path": { "type": "string" },
4224 "mode": { "type": "string" }
4225 },
4226 "required": ["path", "mode"]
4227 });
4228 assert_eq!(single_required_string_field(&schema), None);
4229 }
4230
4231 #[test]
4232 fn single_required_string_field_none_for_non_string() {
4233 let schema = json!({
4234 "type": "object",
4235 "properties": {
4236 "count": { "type": "integer" }
4237 },
4238 "required": ["count"]
4239 });
4240 assert_eq!(single_required_string_field(&schema), None);
4241 }
4242
4243 use crate::types::StreamChunk;
4246
4247 struct StreamingProvider {
4249 responses: Vec<ChatResult>,
4250 call_count: AtomicUsize,
4251 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
4252 }
4253
4254 impl StreamingProvider {
4255 fn new(responses: Vec<ChatResult>) -> Self {
4256 Self {
4257 responses,
4258 call_count: AtomicUsize::new(0),
4259 received_messages: Arc::new(Mutex::new(Vec::new())),
4260 }
4261 }
4262 }
4263
4264 #[async_trait]
4265 impl Provider for StreamingProvider {
4266 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4267 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4268 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4269 }
4270
4271 async fn complete_streaming(
4272 &self,
4273 _prompt: &str,
4274 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4275 ) -> anyhow::Result<ChatResult> {
4276 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4277 let result = self.responses.get(idx).cloned().unwrap_or_default();
4278 for word in result.output_text.split_whitespace() {
4280 let _ = sender.send(StreamChunk {
4281 delta: format!("{word} "),
4282 done: false,
4283 tool_call_delta: None,
4284 });
4285 }
4286 let _ = sender.send(StreamChunk {
4287 delta: String::new(),
4288 done: true,
4289 tool_call_delta: None,
4290 });
4291 Ok(result)
4292 }
4293
4294 async fn complete_with_tools(
4295 &self,
4296 messages: &[ConversationMessage],
4297 _tools: &[ToolDefinition],
4298 _reasoning: &ReasoningConfig,
4299 ) -> anyhow::Result<ChatResult> {
4300 self.received_messages
4301 .lock()
4302 .expect("lock")
4303 .push(messages.to_vec());
4304 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4305 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4306 }
4307
4308 async fn complete_streaming_with_tools(
4309 &self,
4310 messages: &[ConversationMessage],
4311 _tools: &[ToolDefinition],
4312 _reasoning: &ReasoningConfig,
4313 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4314 ) -> anyhow::Result<ChatResult> {
4315 self.received_messages
4316 .lock()
4317 .expect("lock")
4318 .push(messages.to_vec());
4319 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4320 let result = self.responses.get(idx).cloned().unwrap_or_default();
4321 for ch in result.output_text.chars() {
4323 let _ = sender.send(StreamChunk {
4324 delta: ch.to_string(),
4325 done: false,
4326 tool_call_delta: None,
4327 });
4328 }
4329 let _ = sender.send(StreamChunk {
4330 delta: String::new(),
4331 done: true,
4332 tool_call_delta: None,
4333 });
4334 Ok(result)
4335 }
4336 }
4337
4338 #[tokio::test]
4339 async fn streaming_text_only_sends_chunks() {
4340 let provider = StreamingProvider::new(vec![ChatResult {
4341 output_text: "Hello world".to_string(),
4342 ..Default::default()
4343 }]);
4344 let agent = Agent::new(
4345 AgentConfig {
4346 model_supports_tool_use: false,
4347 ..Default::default()
4348 },
4349 Box::new(provider),
4350 Box::new(TestMemory::default()),
4351 vec![],
4352 );
4353
4354 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4355 let response = agent
4356 .respond_streaming(
4357 UserMessage {
4358 text: "hi".to_string(),
4359 },
4360 &test_ctx(),
4361 tx,
4362 )
4363 .await
4364 .expect("should succeed");
4365
4366 assert_eq!(response.text, "Hello world");
4367
4368 let mut chunks = Vec::new();
4370 while let Ok(chunk) = rx.try_recv() {
4371 chunks.push(chunk);
4372 }
4373 assert!(chunks.len() >= 2, "should have at least text + done chunks");
4374 assert!(chunks.last().unwrap().done, "last chunk should be done");
4375 }
4376
4377 #[tokio::test]
4378 async fn streaming_single_tool_call_round_trip() {
4379 let provider = StreamingProvider::new(vec![
4380 ChatResult {
4381 output_text: "I'll echo that.".to_string(),
4382 tool_calls: vec![ToolUseRequest {
4383 id: "call_1".to_string(),
4384 name: "echo".to_string(),
4385 input: json!({"text": "hello"}),
4386 }],
4387 stop_reason: Some(StopReason::ToolUse),
4388 ..Default::default()
4389 },
4390 ChatResult {
4391 output_text: "Done echoing.".to_string(),
4392 tool_calls: vec![],
4393 stop_reason: Some(StopReason::EndTurn),
4394 ..Default::default()
4395 },
4396 ]);
4397
4398 let agent = Agent::new(
4399 AgentConfig::default(),
4400 Box::new(provider),
4401 Box::new(TestMemory::default()),
4402 vec![Box::new(StructuredEchoTool)],
4403 );
4404
4405 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4406 let response = agent
4407 .respond_streaming(
4408 UserMessage {
4409 text: "echo hello".to_string(),
4410 },
4411 &test_ctx(),
4412 tx,
4413 )
4414 .await
4415 .expect("should succeed");
4416
4417 assert_eq!(response.text, "Done echoing.");
4418
4419 let mut chunks = Vec::new();
4420 while let Ok(chunk) = rx.try_recv() {
4421 chunks.push(chunk);
4422 }
4423 assert!(chunks.len() >= 2, "should have streaming chunks");
4425 }
4426
4427 #[tokio::test]
4428 async fn streaming_multi_iteration_tool_calls() {
4429 let provider = StreamingProvider::new(vec![
4430 ChatResult {
4431 output_text: String::new(),
4432 tool_calls: vec![ToolUseRequest {
4433 id: "call_1".to_string(),
4434 name: "echo".to_string(),
4435 input: json!({"text": "first"}),
4436 }],
4437 stop_reason: Some(StopReason::ToolUse),
4438 ..Default::default()
4439 },
4440 ChatResult {
4441 output_text: String::new(),
4442 tool_calls: vec![ToolUseRequest {
4443 id: "call_2".to_string(),
4444 name: "upper".to_string(),
4445 input: json!({"text": "second"}),
4446 }],
4447 stop_reason: Some(StopReason::ToolUse),
4448 ..Default::default()
4449 },
4450 ChatResult {
4451 output_text: "All done.".to_string(),
4452 tool_calls: vec![],
4453 stop_reason: Some(StopReason::EndTurn),
4454 ..Default::default()
4455 },
4456 ]);
4457
4458 let agent = Agent::new(
4459 AgentConfig::default(),
4460 Box::new(provider),
4461 Box::new(TestMemory::default()),
4462 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4463 );
4464
4465 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4466 let response = agent
4467 .respond_streaming(
4468 UserMessage {
4469 text: "do two things".to_string(),
4470 },
4471 &test_ctx(),
4472 tx,
4473 )
4474 .await
4475 .expect("should succeed");
4476
4477 assert_eq!(response.text, "All done.");
4478
4479 let mut chunks = Vec::new();
4480 while let Ok(chunk) = rx.try_recv() {
4481 chunks.push(chunk);
4482 }
4483 let done_count = chunks.iter().filter(|c| c.done).count();
4485 assert!(
4486 done_count >= 3,
4487 "should have done chunks from 3 provider calls, got {}",
4488 done_count
4489 );
4490 }
4491
4492 #[tokio::test]
4493 async fn streaming_timeout_returns_error() {
4494 struct StreamingSlowProvider;
4495
4496 #[async_trait]
4497 impl Provider for StreamingSlowProvider {
4498 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4499 sleep(Duration::from_millis(500)).await;
4500 Ok(ChatResult::default())
4501 }
4502
4503 async fn complete_streaming(
4504 &self,
4505 _prompt: &str,
4506 _sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4507 ) -> anyhow::Result<ChatResult> {
4508 sleep(Duration::from_millis(500)).await;
4509 Ok(ChatResult::default())
4510 }
4511 }
4512
4513 let agent = Agent::new(
4514 AgentConfig {
4515 request_timeout_ms: 50,
4516 model_supports_tool_use: false,
4517 ..Default::default()
4518 },
4519 Box::new(StreamingSlowProvider),
4520 Box::new(TestMemory::default()),
4521 vec![],
4522 );
4523
4524 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4525 let result = agent
4526 .respond_streaming(
4527 UserMessage {
4528 text: "hi".to_string(),
4529 },
4530 &test_ctx(),
4531 tx,
4532 )
4533 .await;
4534
4535 assert!(result.is_err());
4536 match result.unwrap_err() {
4537 AgentError::Timeout { timeout_ms } => assert_eq!(timeout_ms, 50),
4538 other => panic!("expected Timeout, got {other:?}"),
4539 }
4540 }
4541
4542 #[tokio::test]
4543 async fn streaming_no_schema_fallback_sends_chunks() {
4544 let provider = StreamingProvider::new(vec![ChatResult {
4546 output_text: "Fallback response".to_string(),
4547 ..Default::default()
4548 }]);
4549 let agent = Agent::new(
4550 AgentConfig::default(), Box::new(provider),
4552 Box::new(TestMemory::default()),
4553 vec![Box::new(EchoTool)], );
4555
4556 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4557 let response = agent
4558 .respond_streaming(
4559 UserMessage {
4560 text: "hi".to_string(),
4561 },
4562 &test_ctx(),
4563 tx,
4564 )
4565 .await
4566 .expect("should succeed");
4567
4568 assert_eq!(response.text, "Fallback response");
4569
4570 let mut chunks = Vec::new();
4571 while let Ok(chunk) = rx.try_recv() {
4572 chunks.push(chunk);
4573 }
4574 assert!(!chunks.is_empty(), "should have streaming chunks");
4575 assert!(chunks.last().unwrap().done, "last chunk should be done");
4576 }
4577
4578 #[tokio::test]
4579 async fn streaming_tool_error_does_not_abort() {
4580 let provider = StreamingProvider::new(vec![
4581 ChatResult {
4582 output_text: String::new(),
4583 tool_calls: vec![ToolUseRequest {
4584 id: "call_1".to_string(),
4585 name: "boom".to_string(),
4586 input: json!({}),
4587 }],
4588 stop_reason: Some(StopReason::ToolUse),
4589 ..Default::default()
4590 },
4591 ChatResult {
4592 output_text: "Recovered from error.".to_string(),
4593 tool_calls: vec![],
4594 stop_reason: Some(StopReason::EndTurn),
4595 ..Default::default()
4596 },
4597 ]);
4598
4599 let agent = Agent::new(
4600 AgentConfig::default(),
4601 Box::new(provider),
4602 Box::new(TestMemory::default()),
4603 vec![Box::new(StructuredFailingTool)],
4604 );
4605
4606 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4607 let response = agent
4608 .respond_streaming(
4609 UserMessage {
4610 text: "boom".to_string(),
4611 },
4612 &test_ctx(),
4613 tx,
4614 )
4615 .await
4616 .expect("should succeed despite tool error");
4617
4618 assert_eq!(response.text, "Recovered from error.");
4619 }
4620
4621 #[tokio::test]
4622 async fn streaming_parallel_tools() {
4623 let provider = StreamingProvider::new(vec![
4624 ChatResult {
4625 output_text: String::new(),
4626 tool_calls: vec![
4627 ToolUseRequest {
4628 id: "call_1".to_string(),
4629 name: "echo".to_string(),
4630 input: json!({"text": "a"}),
4631 },
4632 ToolUseRequest {
4633 id: "call_2".to_string(),
4634 name: "upper".to_string(),
4635 input: json!({"text": "b"}),
4636 },
4637 ],
4638 stop_reason: Some(StopReason::ToolUse),
4639 ..Default::default()
4640 },
4641 ChatResult {
4642 output_text: "Parallel done.".to_string(),
4643 tool_calls: vec![],
4644 stop_reason: Some(StopReason::EndTurn),
4645 ..Default::default()
4646 },
4647 ]);
4648
4649 let agent = Agent::new(
4650 AgentConfig {
4651 parallel_tools: true,
4652 ..Default::default()
4653 },
4654 Box::new(provider),
4655 Box::new(TestMemory::default()),
4656 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4657 );
4658
4659 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4660 let response = agent
4661 .respond_streaming(
4662 UserMessage {
4663 text: "parallel test".to_string(),
4664 },
4665 &test_ctx(),
4666 tx,
4667 )
4668 .await
4669 .expect("should succeed");
4670
4671 assert_eq!(response.text, "Parallel done.");
4672 }
4673
4674 #[tokio::test]
4675 async fn streaming_done_chunk_sentinel() {
4676 let provider = StreamingProvider::new(vec![ChatResult {
4677 output_text: "abc".to_string(),
4678 tool_calls: vec![],
4679 stop_reason: Some(StopReason::EndTurn),
4680 ..Default::default()
4681 }]);
4682
4683 let agent = Agent::new(
4684 AgentConfig::default(),
4685 Box::new(provider),
4686 Box::new(TestMemory::default()),
4687 vec![Box::new(StructuredEchoTool)],
4688 );
4689
4690 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4691 let _response = agent
4692 .respond_streaming(
4693 UserMessage {
4694 text: "test".to_string(),
4695 },
4696 &test_ctx(),
4697 tx,
4698 )
4699 .await
4700 .expect("should succeed");
4701
4702 let mut chunks = Vec::new();
4703 while let Ok(chunk) = rx.try_recv() {
4704 chunks.push(chunk);
4705 }
4706
4707 let done_chunks: Vec<_> = chunks.iter().filter(|c| c.done).collect();
4709 assert!(
4710 !done_chunks.is_empty(),
4711 "must have at least one done sentinel chunk"
4712 );
4713 assert!(chunks.last().unwrap().done, "final chunk must be done=true");
4715 let content_chunks: Vec<_> = chunks
4717 .iter()
4718 .filter(|c| !c.done && !c.delta.is_empty())
4719 .collect();
4720 assert!(!content_chunks.is_empty(), "should have content chunks");
4721 }
4722
4723 #[tokio::test]
4726 async fn system_prompt_prepended_in_structured_path() {
4727 let provider = StructuredProvider::new(vec![ChatResult {
4728 output_text: "I understand.".to_string(),
4729 stop_reason: Some(StopReason::EndTurn),
4730 ..Default::default()
4731 }]);
4732 let received = provider.received_messages.clone();
4733 let memory = TestMemory::default();
4734
4735 let agent = Agent::new(
4736 AgentConfig {
4737 model_supports_tool_use: true,
4738 system_prompt: Some("You are a math tutor.".to_string()),
4739 ..AgentConfig::default()
4740 },
4741 Box::new(provider),
4742 Box::new(memory),
4743 vec![Box::new(StructuredEchoTool)],
4744 );
4745
4746 let result = agent
4747 .respond(
4748 UserMessage {
4749 text: "What is 2+2?".to_string(),
4750 },
4751 &test_ctx(),
4752 )
4753 .await
4754 .expect("respond should succeed");
4755
4756 assert_eq!(result.text, "I understand.");
4757
4758 let msgs = received.lock().expect("lock");
4759 assert!(!msgs.is_empty(), "provider should have been called");
4760 let first_call = &msgs[0];
4761 match &first_call[0] {
4763 ConversationMessage::System { content } => {
4764 assert_eq!(content, "You are a math tutor.");
4765 }
4766 other => panic!("expected System message first, got {other:?}"),
4767 }
4768 match &first_call[1] {
4770 ConversationMessage::User { content, .. } => {
4771 assert_eq!(content, "What is 2+2?");
4772 }
4773 other => panic!("expected User message second, got {other:?}"),
4774 }
4775 }
4776
4777 #[tokio::test]
4778 async fn no_system_prompt_omits_system_message() {
4779 let provider = StructuredProvider::new(vec![ChatResult {
4780 output_text: "ok".to_string(),
4781 stop_reason: Some(StopReason::EndTurn),
4782 ..Default::default()
4783 }]);
4784 let received = provider.received_messages.clone();
4785 let memory = TestMemory::default();
4786
4787 let agent = Agent::new(
4788 AgentConfig {
4789 model_supports_tool_use: true,
4790 system_prompt: None,
4791 ..AgentConfig::default()
4792 },
4793 Box::new(provider),
4794 Box::new(memory),
4795 vec![Box::new(StructuredEchoTool)],
4796 );
4797
4798 agent
4799 .respond(
4800 UserMessage {
4801 text: "hello".to_string(),
4802 },
4803 &test_ctx(),
4804 )
4805 .await
4806 .expect("respond should succeed");
4807
4808 let msgs = received.lock().expect("lock");
4809 let first_call = &msgs[0];
4810 match &first_call[0] {
4812 ConversationMessage::User { content, .. } => {
4813 assert_eq!(content, "hello");
4814 }
4815 other => panic!("expected User message first (no system prompt), got {other:?}"),
4816 }
4817 }
4818
4819 #[tokio::test]
4820 async fn system_prompt_persists_across_tool_iterations() {
4821 let provider = StructuredProvider::new(vec![
4823 ChatResult {
4824 output_text: String::new(),
4825 tool_calls: vec![ToolUseRequest {
4826 id: "call_1".to_string(),
4827 name: "echo".to_string(),
4828 input: serde_json::json!({"text": "ping"}),
4829 }],
4830 stop_reason: Some(StopReason::ToolUse),
4831 ..Default::default()
4832 },
4833 ChatResult {
4834 output_text: "pong".to_string(),
4835 stop_reason: Some(StopReason::EndTurn),
4836 ..Default::default()
4837 },
4838 ]);
4839 let received = provider.received_messages.clone();
4840 let memory = TestMemory::default();
4841
4842 let agent = Agent::new(
4843 AgentConfig {
4844 model_supports_tool_use: true,
4845 system_prompt: Some("Always be concise.".to_string()),
4846 ..AgentConfig::default()
4847 },
4848 Box::new(provider),
4849 Box::new(memory),
4850 vec![Box::new(StructuredEchoTool)],
4851 );
4852
4853 let result = agent
4854 .respond(
4855 UserMessage {
4856 text: "test".to_string(),
4857 },
4858 &test_ctx(),
4859 )
4860 .await
4861 .expect("respond should succeed");
4862 assert_eq!(result.text, "pong");
4863
4864 let msgs = received.lock().expect("lock");
4865 for (i, call_msgs) in msgs.iter().enumerate() {
4867 match &call_msgs[0] {
4868 ConversationMessage::System { content } => {
4869 assert_eq!(content, "Always be concise.", "call {i} system prompt");
4870 }
4871 other => panic!("call {i}: expected System first, got {other:?}"),
4872 }
4873 }
4874 }
4875
4876 #[test]
4877 fn agent_config_system_prompt_defaults_to_none() {
4878 let config = AgentConfig::default();
4879 assert!(config.system_prompt.is_none());
4880 }
4881
4882 #[tokio::test]
4883 async fn memory_write_propagates_source_channel_from_context() {
4884 let memory = TestMemory::default();
4885 let entries = memory.entries.clone();
4886 let provider = StructuredProvider::new(vec![ChatResult {
4887 output_text: "noted".to_string(),
4888 tool_calls: vec![],
4889 stop_reason: None,
4890 ..Default::default()
4891 }]);
4892 let config = AgentConfig {
4893 privacy_boundary: "encrypted_only".to_string(),
4894 ..AgentConfig::default()
4895 };
4896 let agent = Agent::new(config, Box::new(provider), Box::new(memory), vec![]);
4897
4898 let mut ctx = ToolContext::new(".".to_string());
4899 ctx.source_channel = Some("telegram".to_string());
4900 ctx.privacy_boundary = "encrypted_only".to_string();
4901
4902 agent
4903 .respond(
4904 UserMessage {
4905 text: "hello".to_string(),
4906 },
4907 &ctx,
4908 )
4909 .await
4910 .expect("respond should succeed");
4911
4912 let stored = entries.lock().expect("memory lock poisoned");
4913 assert_eq!(stored[0].source_channel.as_deref(), Some("telegram"));
4915 assert_eq!(stored[0].privacy_boundary, "encrypted_only");
4916 assert_eq!(stored[1].source_channel.as_deref(), Some("telegram"));
4917 assert_eq!(stored[1].privacy_boundary, "encrypted_only");
4918 }
4919
4920 #[tokio::test]
4921 async fn memory_write_source_channel_none_when_ctx_empty() {
4922 let memory = TestMemory::default();
4923 let entries = memory.entries.clone();
4924 let provider = StructuredProvider::new(vec![ChatResult {
4925 output_text: "ok".to_string(),
4926 tool_calls: vec![],
4927 stop_reason: None,
4928 ..Default::default()
4929 }]);
4930 let agent = Agent::new(
4931 AgentConfig::default(),
4932 Box::new(provider),
4933 Box::new(memory),
4934 vec![],
4935 );
4936
4937 agent
4938 .respond(
4939 UserMessage {
4940 text: "hi".to_string(),
4941 },
4942 &test_ctx(),
4943 )
4944 .await
4945 .expect("respond should succeed");
4946
4947 let stored = entries.lock().expect("memory lock poisoned");
4948 assert!(stored[0].source_channel.is_none());
4949 assert!(stored[1].source_channel.is_none());
4950 }
4951}