1use crate::common::privacy_helpers::{is_network_tool, resolve_boundary};
2use crate::security::redaction::redact_text;
3use crate::types::{
4 AgentConfig, AgentError, AssistantMessage, AuditEvent, AuditSink, ConversationMessage,
5 HookEvent, HookFailureMode, HookRiskTier, HookSink, MemoryEntry, MemoryStore, MetricsSink,
6 Provider, ResearchTrigger, StopReason, StreamSink, Tool, ToolContext, ToolDefinition,
7 ToolResultMessage, UserMessage,
8};
9use crate::validation::validate_json;
10use serde_json::json;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Instant, SystemTime, UNIX_EPOCH};
13use tokio::time::{timeout, Duration};
14use tracing::{info, warn};
15
16static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(0);
17
18fn cap_prompt(input: &str, max_chars: usize) -> (String, bool) {
19 let len = input.chars().count();
20 if len <= max_chars {
21 return (input.to_string(), false);
22 }
23 let truncated = input
24 .chars()
25 .rev()
26 .take(max_chars)
27 .collect::<Vec<_>>()
28 .into_iter()
29 .rev()
30 .collect::<String>();
31 (truncated, true)
32}
33
34fn build_provider_prompt(
35 current_prompt: &str,
36 recent_memory: &[MemoryEntry],
37 max_prompt_chars: usize,
38) -> (String, bool) {
39 if recent_memory.is_empty() {
40 return cap_prompt(current_prompt, max_prompt_chars);
41 }
42
43 let mut prompt = String::from("Recent memory:\n");
44 for entry in recent_memory.iter().rev() {
45 prompt.push_str("- ");
46 prompt.push_str(&entry.role);
47 prompt.push_str(": ");
48 prompt.push_str(&entry.content);
49 prompt.push('\n');
50 }
51 prompt.push_str("\nCurrent input:\n");
52 prompt.push_str(current_prompt);
53
54 cap_prompt(&prompt, max_prompt_chars)
55}
56
57fn parse_tool_calls(prompt: &str) -> Vec<(&str, &str)> {
59 prompt
60 .lines()
61 .filter_map(|line| {
62 line.strip_prefix("tool:").map(|rest| {
63 let rest = rest.trim();
64 let mut parts = rest.splitn(2, ' ');
65 let name = parts.next().unwrap_or_default();
66 let input = parts.next().unwrap_or_default();
67 (name, input)
68 })
69 })
70 .collect()
71}
72
73fn single_required_string_field(schema: &serde_json::Value) -> Option<String> {
75 let required = schema.get("required")?.as_array()?;
76 if required.len() != 1 {
77 return None;
78 }
79 let field_name = required[0].as_str()?;
80 let properties = schema.get("properties")?.as_object()?;
81 let field_schema = properties.get(field_name)?;
82 if field_schema.get("type")?.as_str()? == "string" {
83 Some(field_name.to_string())
84 } else {
85 None
86 }
87}
88
89fn prepare_tool_input(tool: &dyn Tool, raw_input: &serde_json::Value) -> Result<String, String> {
98 if let Some(s) = raw_input.as_str() {
100 return Ok(s.to_string());
101 }
102
103 if let Some(schema) = tool.input_schema() {
105 if let Err(errors) = validate_json(raw_input, &schema) {
106 return Err(format!(
107 "Invalid input for tool '{}': {}",
108 tool.name(),
109 errors.join("; ")
110 ));
111 }
112
113 if let Some(field) = single_required_string_field(&schema) {
114 if let Some(val) = raw_input.get(&field).and_then(|v| v.as_str()) {
115 return Ok(val.to_string());
116 }
117 }
118 }
119
120 Ok(serde_json::to_string(raw_input).unwrap_or_default())
121}
122
123fn truncate_messages(messages: &mut Vec<ConversationMessage>, max_chars: usize) {
127 let total_chars: usize = messages.iter().map(|m| m.char_count()).sum();
128 if total_chars <= max_chars || messages.len() <= 2 {
129 return;
130 }
131
132 let first_cost = messages[0].char_count();
133 let mut budget = max_chars.saturating_sub(first_cost);
134
135 let mut keep_from_end = 0;
137 for msg in messages.iter().rev() {
138 let cost = msg.char_count();
139 if cost > budget {
140 break;
141 }
142 budget -= cost;
143 keep_from_end += 1;
144 }
145
146 if keep_from_end == 0 {
147 let last = messages
149 .pop()
150 .expect("messages must have at least 2 entries");
151 messages.truncate(1);
152 messages.push(last);
153 return;
154 }
155
156 let split_point = messages.len() - keep_from_end;
157 if split_point <= 1 {
158 return;
159 }
160
161 messages.drain(1..split_point);
162}
163
164fn memory_to_messages(entries: &[MemoryEntry]) -> Vec<ConversationMessage> {
167 entries
168 .iter()
169 .rev()
170 .map(|entry| {
171 if entry.role == "assistant" {
172 ConversationMessage::Assistant {
173 content: Some(entry.content.clone()),
174 tool_calls: vec![],
175 }
176 } else {
177 ConversationMessage::user(entry.content.clone())
178 }
179 })
180 .collect()
181}
182
183pub struct Agent {
184 config: AgentConfig,
185 provider: Box<dyn Provider>,
186 memory: Box<dyn MemoryStore>,
187 tools: Vec<Box<dyn Tool>>,
188 audit: Option<Box<dyn AuditSink>>,
189 hooks: Option<Box<dyn HookSink>>,
190 metrics: Option<Box<dyn MetricsSink>>,
191}
192
193impl Agent {
194 pub fn new(
195 config: AgentConfig,
196 provider: Box<dyn Provider>,
197 memory: Box<dyn MemoryStore>,
198 tools: Vec<Box<dyn Tool>>,
199 ) -> Self {
200 Self {
201 config,
202 provider,
203 memory,
204 tools,
205 audit: None,
206 hooks: None,
207 metrics: None,
208 }
209 }
210
211 pub fn with_audit(mut self, audit: Box<dyn AuditSink>) -> Self {
212 self.audit = Some(audit);
213 self
214 }
215
216 pub fn with_hooks(mut self, hooks: Box<dyn HookSink>) -> Self {
217 self.hooks = Some(hooks);
218 self
219 }
220
221 pub fn with_metrics(mut self, metrics: Box<dyn MetricsSink>) -> Self {
222 self.metrics = Some(metrics);
223 self
224 }
225
226 fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
228 self.tools
229 .iter()
230 .filter_map(|tool| ToolDefinition::from_tool(&**tool))
231 .collect()
232 }
233
234 fn has_tool_definitions(&self) -> bool {
236 self.tools.iter().any(|t| t.input_schema().is_some())
237 }
238
239 async fn audit(&self, stage: &str, detail: serde_json::Value) {
240 if let Some(sink) = &self.audit {
241 let _ = sink
242 .record(AuditEvent {
243 stage: stage.to_string(),
244 detail,
245 })
246 .await;
247 }
248 }
249
250 fn next_request_id() -> String {
251 let ts_ms = SystemTime::now()
252 .duration_since(UNIX_EPOCH)
253 .unwrap_or_default()
254 .as_millis();
255 let seq = REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed);
256 format!("req-{ts_ms}-{seq}")
257 }
258
259 fn hook_risk_tier(stage: &str) -> HookRiskTier {
260 if matches!(
261 stage,
262 "before_tool_call" | "after_tool_call" | "before_plugin_call" | "after_plugin_call"
263 ) {
264 HookRiskTier::High
265 } else if matches!(
266 stage,
267 "before_provider_call"
268 | "after_provider_call"
269 | "before_memory_write"
270 | "after_memory_write"
271 | "before_run"
272 | "after_run"
273 ) {
274 HookRiskTier::Medium
275 } else {
276 HookRiskTier::Low
277 }
278 }
279
280 fn hook_failure_mode_for_stage(&self, stage: &str) -> HookFailureMode {
281 if self.config.hooks.fail_closed {
282 return HookFailureMode::Block;
283 }
284
285 match Self::hook_risk_tier(stage) {
286 HookRiskTier::Low => self.config.hooks.low_tier_mode,
287 HookRiskTier::Medium => self.config.hooks.medium_tier_mode,
288 HookRiskTier::High => self.config.hooks.high_tier_mode,
289 }
290 }
291
292 async fn hook(&self, stage: &str, detail: serde_json::Value) -> Result<(), AgentError> {
293 if !self.config.hooks.enabled {
294 return Ok(());
295 }
296 let Some(sink) = &self.hooks else {
297 return Ok(());
298 };
299
300 let event = HookEvent {
301 stage: stage.to_string(),
302 detail,
303 };
304 let hook_call = sink.record(event);
305 let mode = self.hook_failure_mode_for_stage(stage);
306 let tier = Self::hook_risk_tier(stage);
307 match timeout(
308 Duration::from_millis(self.config.hooks.timeout_ms),
309 hook_call,
310 )
311 .await
312 {
313 Ok(Ok(())) => Ok(()),
314 Ok(Err(err)) => {
315 let redacted = redact_text(&err.to_string());
316 match mode {
317 HookFailureMode::Block => Err(AgentError::Hook {
318 stage: stage.to_string(),
319 source: err,
320 }),
321 HookFailureMode::Warn => {
322 warn!(
323 stage = stage,
324 tier = ?tier,
325 mode = "warn",
326 "hook error (continuing): {redacted}"
327 );
328 self.audit(
329 "hook_error_warn",
330 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
331 )
332 .await;
333 Ok(())
334 }
335 HookFailureMode::Ignore => {
336 self.audit(
337 "hook_error_ignored",
338 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
339 )
340 .await;
341 Ok(())
342 }
343 }
344 }
345 Err(_) => match mode {
346 HookFailureMode::Block => Err(AgentError::Hook {
347 stage: stage.to_string(),
348 source: anyhow::anyhow!(
349 "hook execution timed out after {} ms",
350 self.config.hooks.timeout_ms
351 ),
352 }),
353 HookFailureMode::Warn => {
354 warn!(
355 stage = stage,
356 tier = ?tier,
357 mode = "warn",
358 timeout_ms = self.config.hooks.timeout_ms,
359 "hook timeout (continuing)"
360 );
361 self.audit(
362 "hook_timeout_warn",
363 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
364 )
365 .await;
366 Ok(())
367 }
368 HookFailureMode::Ignore => {
369 self.audit(
370 "hook_timeout_ignored",
371 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
372 )
373 .await;
374 Ok(())
375 }
376 },
377 }
378 }
379
380 fn increment_counter(&self, name: &'static str) {
381 if let Some(metrics) = &self.metrics {
382 metrics.increment_counter(name, 1);
383 }
384 }
385
386 fn observe_histogram(&self, name: &'static str, value: f64) {
387 if let Some(metrics) = &self.metrics {
388 metrics.observe_histogram(name, value);
389 }
390 }
391
392 async fn execute_tool(
393 &self,
394 tool: &dyn Tool,
395 tool_name: &str,
396 tool_input: &str,
397 ctx: &ToolContext,
398 request_id: &str,
399 iteration: usize,
400 ) -> Result<crate::types::ToolResult, AgentError> {
401 let tool_specific = self
404 .config
405 .tool_boundaries
406 .get(tool_name)
407 .map(|s| s.as_str())
408 .unwrap_or("");
409 let resolved = resolve_boundary(tool_specific, &self.config.privacy_boundary);
410 if resolved == "local_only" && is_network_tool(tool_name) {
411 return Err(AgentError::Tool {
412 tool: tool_name.to_string(),
413 source: anyhow::anyhow!(
414 "tool '{}' requires network access but privacy boundary is 'local_only'",
415 tool_name
416 ),
417 });
418 }
419
420 let is_plugin_call = tool_name.starts_with("plugin:");
421 self.hook(
422 "before_tool_call",
423 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
424 )
425 .await?;
426 if is_plugin_call {
427 self.hook(
428 "before_plugin_call",
429 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name}),
430 )
431 .await?;
432 }
433 self.audit(
434 "tool_execute_start",
435 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
436 )
437 .await;
438 let tool_started = Instant::now();
439 let result = match tool.execute(tool_input, ctx).await {
440 Ok(result) => result,
441 Err(source) => {
442 self.observe_histogram(
443 "tool_latency_ms",
444 tool_started.elapsed().as_secs_f64() * 1000.0,
445 );
446 self.increment_counter("tool_errors_total");
447 return Err(AgentError::Tool {
448 tool: tool_name.to_string(),
449 source,
450 });
451 }
452 };
453 self.observe_histogram(
454 "tool_latency_ms",
455 tool_started.elapsed().as_secs_f64() * 1000.0,
456 );
457 self.audit(
458 "tool_execute_success",
459 json!({
460 "request_id": request_id,
461 "iteration": iteration,
462 "tool_name": tool_name,
463 "tool_output_len": result.output.len(),
464 "duration_ms": tool_started.elapsed().as_millis(),
465 }),
466 )
467 .await;
468 info!(
469 request_id = %request_id,
470 stage = "tool",
471 tool_name = %tool_name,
472 duration_ms = %tool_started.elapsed().as_millis(),
473 "tool execution finished"
474 );
475 self.hook(
476 "after_tool_call",
477 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name, "status": "ok"}),
478 )
479 .await?;
480 if is_plugin_call {
481 self.hook(
482 "after_plugin_call",
483 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name, "status": "ok"}),
484 )
485 .await?;
486 }
487 Ok(result)
488 }
489
490 async fn call_provider_with_context(
491 &self,
492 prompt: &str,
493 request_id: &str,
494 stream_sink: Option<StreamSink>,
495 source_channel: Option<&str>,
496 ) -> Result<String, AgentError> {
497 let recent_memory = self
498 .memory
499 .recent_for_boundary(
500 self.config.memory_window_size,
501 &self.config.privacy_boundary,
502 source_channel,
503 )
504 .await
505 .map_err(|source| AgentError::Memory { source })?;
506 self.audit(
507 "memory_recent_loaded",
508 json!({"request_id": request_id, "items": recent_memory.len()}),
509 )
510 .await;
511 let (provider_prompt, prompt_truncated) =
512 build_provider_prompt(prompt, &recent_memory, self.config.max_prompt_chars);
513 if prompt_truncated {
514 self.audit(
515 "provider_prompt_truncated",
516 json!({
517 "request_id": request_id,
518 "max_prompt_chars": self.config.max_prompt_chars,
519 }),
520 )
521 .await;
522 }
523 self.hook(
524 "before_provider_call",
525 json!({
526 "request_id": request_id,
527 "prompt_len": provider_prompt.len(),
528 "memory_items": recent_memory.len(),
529 "prompt_truncated": prompt_truncated
530 }),
531 )
532 .await?;
533 self.audit(
534 "provider_call_start",
535 json!({
536 "request_id": request_id,
537 "prompt_len": provider_prompt.len(),
538 "memory_items": recent_memory.len(),
539 "prompt_truncated": prompt_truncated
540 }),
541 )
542 .await;
543 let provider_started = Instant::now();
544 let provider_result = if let Some(sink) = stream_sink {
545 self.provider
546 .complete_streaming(&provider_prompt, sink)
547 .await
548 } else {
549 self.provider
550 .complete_with_reasoning(&provider_prompt, &self.config.reasoning)
551 .await
552 };
553 let completion = match provider_result {
554 Ok(result) => result,
555 Err(source) => {
556 self.observe_histogram(
557 "provider_latency_ms",
558 provider_started.elapsed().as_secs_f64() * 1000.0,
559 );
560 self.increment_counter("provider_errors_total");
561 return Err(AgentError::Provider { source });
562 }
563 };
564 self.observe_histogram(
565 "provider_latency_ms",
566 provider_started.elapsed().as_secs_f64() * 1000.0,
567 );
568 self.audit(
569 "provider_call_success",
570 json!({
571 "request_id": request_id,
572 "response_len": completion.output_text.len(),
573 "duration_ms": provider_started.elapsed().as_millis(),
574 }),
575 )
576 .await;
577 info!(
578 request_id = %request_id,
579 stage = "provider",
580 duration_ms = %provider_started.elapsed().as_millis(),
581 "provider call finished"
582 );
583 self.hook(
584 "after_provider_call",
585 json!({"request_id": request_id, "response_len": completion.output_text.len(), "status": "ok"}),
586 )
587 .await?;
588 Ok(completion.output_text)
589 }
590
591 async fn write_to_memory(
592 &self,
593 role: &str,
594 content: &str,
595 request_id: &str,
596 source_channel: Option<&str>,
597 conversation_id: &str,
598 ) -> Result<(), AgentError> {
599 self.hook(
600 "before_memory_write",
601 json!({"request_id": request_id, "role": role}),
602 )
603 .await?;
604 self.memory
605 .append(MemoryEntry {
606 role: role.to_string(),
607 content: content.to_string(),
608 privacy_boundary: self.config.privacy_boundary.clone(),
609 source_channel: source_channel.map(String::from),
610 conversation_id: conversation_id.to_string(),
611 })
612 .await
613 .map_err(|source| AgentError::Memory { source })?;
614 self.hook(
615 "after_memory_write",
616 json!({"request_id": request_id, "role": role}),
617 )
618 .await?;
619 self.audit(
620 &format!("memory_append_{role}"),
621 json!({"request_id": request_id}),
622 )
623 .await;
624 Ok(())
625 }
626
627 fn should_research(&self, user_text: &str) -> bool {
628 if !self.config.research.enabled {
629 return false;
630 }
631 match self.config.research.trigger {
632 ResearchTrigger::Never => false,
633 ResearchTrigger::Always => true,
634 ResearchTrigger::Keywords => {
635 let lower = user_text.to_lowercase();
636 self.config
637 .research
638 .keywords
639 .iter()
640 .any(|kw| lower.contains(&kw.to_lowercase()))
641 }
642 ResearchTrigger::Length => user_text.len() >= self.config.research.min_message_length,
643 ResearchTrigger::Question => user_text.trim_end().ends_with('?'),
644 }
645 }
646
647 async fn run_research_phase(
648 &self,
649 user_text: &str,
650 ctx: &ToolContext,
651 request_id: &str,
652 ) -> Result<String, AgentError> {
653 self.audit(
654 "research_phase_start",
655 json!({
656 "request_id": request_id,
657 "max_iterations": self.config.research.max_iterations,
658 }),
659 )
660 .await;
661 self.increment_counter("research_phase_started");
662
663 let research_prompt = format!(
664 "You are in RESEARCH mode. The user asked: \"{user_text}\"\n\
665 Gather relevant information using available tools. \
666 Respond with tool: calls to collect data. \
667 When done gathering, summarize your findings without a tool: prefix."
668 );
669 let mut prompt = self
670 .call_provider_with_context(
671 &research_prompt,
672 request_id,
673 None,
674 ctx.source_channel.as_deref(),
675 )
676 .await?;
677 let mut findings: Vec<String> = Vec::new();
678
679 for iteration in 0..self.config.research.max_iterations {
680 if !prompt.starts_with("tool:") {
681 findings.push(prompt.clone());
682 break;
683 }
684
685 let calls = parse_tool_calls(&prompt);
686 if calls.is_empty() {
687 break;
688 }
689
690 let (name, input) = calls[0];
691 if let Some(tool) = self.tools.iter().find(|t| t.name() == name) {
692 let result = self
693 .execute_tool(&**tool, name, input, ctx, request_id, iteration)
694 .await?;
695 findings.push(format!("{name}: {}", result.output));
696
697 if self.config.research.show_progress {
698 info!(iteration, tool = name, "research phase: tool executed");
699 self.audit(
700 "research_phase_iteration",
701 json!({
702 "request_id": request_id,
703 "iteration": iteration,
704 "tool_name": name,
705 }),
706 )
707 .await;
708 }
709
710 let next_prompt = format!(
711 "Research iteration {iteration}: tool `{name}` returned: {}\n\
712 Continue researching or summarize findings.",
713 result.output
714 );
715 prompt = self
716 .call_provider_with_context(
717 &next_prompt,
718 request_id,
719 None,
720 ctx.source_channel.as_deref(),
721 )
722 .await?;
723 } else {
724 break;
725 }
726 }
727
728 self.audit(
729 "research_phase_complete",
730 json!({
731 "request_id": request_id,
732 "findings_count": findings.len(),
733 }),
734 )
735 .await;
736 self.increment_counter("research_phase_completed");
737
738 Ok(findings.join("\n"))
739 }
740
741 async fn respond_with_tools(
744 &self,
745 request_id: &str,
746 user_text: &str,
747 research_context: &str,
748 ctx: &ToolContext,
749 stream_sink: Option<StreamSink>,
750 ) -> Result<AssistantMessage, AgentError> {
751 self.audit(
752 "structured_tool_use_start",
753 json!({
754 "request_id": request_id,
755 "max_tool_iterations": self.config.max_tool_iterations,
756 }),
757 )
758 .await;
759
760 let tool_definitions = self.build_tool_definitions();
761
762 let recent_memory = if let Some(ref cid) = ctx.conversation_id {
764 self.memory
765 .recent_for_conversation(cid, self.config.memory_window_size)
766 .await
767 .map_err(|source| AgentError::Memory { source })?
768 } else {
769 self.memory
770 .recent_for_boundary(
771 self.config.memory_window_size,
772 &self.config.privacy_boundary,
773 ctx.source_channel.as_deref(),
774 )
775 .await
776 .map_err(|source| AgentError::Memory { source })?
777 };
778 self.audit(
779 "memory_recent_loaded",
780 json!({"request_id": request_id, "items": recent_memory.len()}),
781 )
782 .await;
783
784 let mut messages: Vec<ConversationMessage> = Vec::new();
785
786 if let Some(ref sp) = self.config.system_prompt {
788 messages.push(ConversationMessage::System {
789 content: sp.clone(),
790 });
791 }
792
793 messages.extend(memory_to_messages(&recent_memory));
794
795 if !research_context.is_empty() {
796 messages.push(ConversationMessage::user(format!(
797 "Research findings:\n{research_context}",
798 )));
799 }
800
801 messages.push(ConversationMessage::user(user_text.to_string()));
802
803 let mut tool_history: Vec<(String, String, String)> = Vec::new();
804 let mut failure_streak: usize = 0;
805
806 for iteration in 0..self.config.max_tool_iterations {
807 truncate_messages(&mut messages, self.config.max_prompt_chars);
808
809 self.hook(
810 "before_provider_call",
811 json!({
812 "request_id": request_id,
813 "iteration": iteration,
814 "message_count": messages.len(),
815 "tool_count": tool_definitions.len(),
816 }),
817 )
818 .await?;
819 self.audit(
820 "provider_call_start",
821 json!({
822 "request_id": request_id,
823 "iteration": iteration,
824 "message_count": messages.len(),
825 }),
826 )
827 .await;
828
829 let provider_started = Instant::now();
830 let provider_result = if let Some(ref sink) = stream_sink {
831 self.provider
832 .complete_streaming_with_tools(
833 &messages,
834 &tool_definitions,
835 &self.config.reasoning,
836 sink.clone(),
837 )
838 .await
839 } else {
840 self.provider
841 .complete_with_tools(&messages, &tool_definitions, &self.config.reasoning)
842 .await
843 };
844 let chat_result = match provider_result {
845 Ok(result) => result,
846 Err(source) => {
847 self.observe_histogram(
848 "provider_latency_ms",
849 provider_started.elapsed().as_secs_f64() * 1000.0,
850 );
851 self.increment_counter("provider_errors_total");
852 return Err(AgentError::Provider { source });
853 }
854 };
855 self.observe_histogram(
856 "provider_latency_ms",
857 provider_started.elapsed().as_secs_f64() * 1000.0,
858 );
859 info!(
860 request_id = %request_id,
861 iteration = iteration,
862 stop_reason = ?chat_result.stop_reason,
863 tool_calls = chat_result.tool_calls.len(),
864 "structured provider call finished"
865 );
866 self.hook(
867 "after_provider_call",
868 json!({
869 "request_id": request_id,
870 "iteration": iteration,
871 "response_len": chat_result.output_text.len(),
872 "tool_calls": chat_result.tool_calls.len(),
873 "status": "ok",
874 }),
875 )
876 .await?;
877
878 if chat_result.tool_calls.is_empty()
880 || chat_result.stop_reason == Some(StopReason::EndTurn)
881 {
882 let response_text = chat_result.output_text;
883 self.write_to_memory(
884 "assistant",
885 &response_text,
886 request_id,
887 ctx.source_channel.as_deref(),
888 ctx.conversation_id.as_deref().unwrap_or(""),
889 )
890 .await?;
891 self.audit("respond_success", json!({"request_id": request_id}))
892 .await;
893 self.hook(
894 "before_response_emit",
895 json!({"request_id": request_id, "response_len": response_text.len()}),
896 )
897 .await?;
898 let response = AssistantMessage {
899 text: response_text,
900 };
901 self.hook(
902 "after_response_emit",
903 json!({"request_id": request_id, "response_len": response.text.len()}),
904 )
905 .await?;
906 return Ok(response);
907 }
908
909 messages.push(ConversationMessage::Assistant {
911 content: if chat_result.output_text.is_empty() {
912 None
913 } else {
914 Some(chat_result.output_text.clone())
915 },
916 tool_calls: chat_result.tool_calls.clone(),
917 });
918
919 let tool_calls = &chat_result.tool_calls;
921 let has_gated = tool_calls
922 .iter()
923 .any(|tc| self.config.gated_tools.contains(&tc.name));
924 let use_parallel = self.config.parallel_tools && tool_calls.len() > 1 && !has_gated;
925
926 let mut tool_results: Vec<ToolResultMessage> = Vec::new();
927
928 if use_parallel {
929 let futs: Vec<_> = tool_calls
930 .iter()
931 .map(|tc| {
932 let tool = self.tools.iter().find(|t| t.name() == tc.name);
933 async move {
934 match tool {
935 Some(tool) => {
936 let input_str = match prepare_tool_input(&**tool, &tc.input) {
937 Ok(s) => s,
938 Err(validation_err) => {
939 return (
940 tc.name.clone(),
941 String::new(),
942 ToolResultMessage {
943 tool_use_id: tc.id.clone(),
944 content: validation_err,
945 is_error: true,
946 },
947 );
948 }
949 };
950 match tool.execute(&input_str, ctx).await {
951 Ok(result) => (
952 tc.name.clone(),
953 input_str,
954 ToolResultMessage {
955 tool_use_id: tc.id.clone(),
956 content: result.output,
957 is_error: false,
958 },
959 ),
960 Err(e) => (
961 tc.name.clone(),
962 input_str,
963 ToolResultMessage {
964 tool_use_id: tc.id.clone(),
965 content: format!("Error: {e}"),
966 is_error: true,
967 },
968 ),
969 }
970 }
971 None => (
972 tc.name.clone(),
973 String::new(),
974 ToolResultMessage {
975 tool_use_id: tc.id.clone(),
976 content: format!("Tool '{}' not found", tc.name),
977 is_error: true,
978 },
979 ),
980 }
981 }
982 })
983 .collect();
984 let results = futures_util::future::join_all(futs).await;
985 for (name, input_str, result_msg) in results {
986 if result_msg.is_error {
987 failure_streak += 1;
988 } else {
989 failure_streak = 0;
990 tool_history.push((name, input_str, result_msg.content.clone()));
991 }
992 tool_results.push(result_msg);
993 }
994 } else {
995 for tc in tool_calls {
997 let result_msg = match self.tools.iter().find(|t| t.name() == tc.name) {
998 Some(tool) => {
999 let input_str = match prepare_tool_input(&**tool, &tc.input) {
1000 Ok(s) => s,
1001 Err(validation_err) => {
1002 failure_streak += 1;
1003 tool_results.push(ToolResultMessage {
1004 tool_use_id: tc.id.clone(),
1005 content: validation_err,
1006 is_error: true,
1007 });
1008 continue;
1009 }
1010 };
1011 self.audit(
1012 "tool_requested",
1013 json!({
1014 "request_id": request_id,
1015 "iteration": iteration,
1016 "tool_name": tc.name,
1017 "tool_input_len": input_str.len(),
1018 }),
1019 )
1020 .await;
1021
1022 match self
1023 .execute_tool(
1024 &**tool, &tc.name, &input_str, ctx, request_id, iteration,
1025 )
1026 .await
1027 {
1028 Ok(result) => {
1029 failure_streak = 0;
1030 tool_history.push((
1031 tc.name.clone(),
1032 input_str,
1033 result.output.clone(),
1034 ));
1035 ToolResultMessage {
1036 tool_use_id: tc.id.clone(),
1037 content: result.output,
1038 is_error: false,
1039 }
1040 }
1041 Err(e) => {
1042 failure_streak += 1;
1043 ToolResultMessage {
1044 tool_use_id: tc.id.clone(),
1045 content: format!("Error: {e}"),
1046 is_error: true,
1047 }
1048 }
1049 }
1050 }
1051 None => {
1052 self.audit(
1053 "tool_not_found",
1054 json!({
1055 "request_id": request_id,
1056 "iteration": iteration,
1057 "tool_name": tc.name,
1058 }),
1059 )
1060 .await;
1061 failure_streak += 1;
1062 ToolResultMessage {
1063 tool_use_id: tc.id.clone(),
1064 content: format!(
1065 "Tool '{}' not found. Available tools: {}",
1066 tc.name,
1067 tool_definitions
1068 .iter()
1069 .map(|d| d.name.as_str())
1070 .collect::<Vec<_>>()
1071 .join(", ")
1072 ),
1073 is_error: true,
1074 }
1075 }
1076 };
1077 tool_results.push(result_msg);
1078 }
1079 }
1080
1081 for result in &tool_results {
1083 messages.push(ConversationMessage::ToolResult(result.clone()));
1084 }
1085
1086 if self.config.loop_detection_no_progress_threshold > 0 {
1088 let threshold = self.config.loop_detection_no_progress_threshold;
1089 if tool_history.len() >= threshold {
1090 let recent = &tool_history[tool_history.len() - threshold..];
1091 let all_same = recent.iter().all(|entry| {
1092 entry.0 == recent[0].0 && entry.1 == recent[0].1 && entry.2 == recent[0].2
1093 });
1094 if all_same {
1095 warn!(
1096 tool_name = recent[0].0,
1097 threshold, "structured loop detection: no-progress threshold reached"
1098 );
1099 self.audit(
1100 "loop_detection_no_progress",
1101 json!({
1102 "request_id": request_id,
1103 "tool_name": recent[0].0,
1104 "threshold": threshold,
1105 }),
1106 )
1107 .await;
1108 messages.push(ConversationMessage::user(format!(
1109 "SYSTEM NOTICE: Loop detected — the tool `{}` was called \
1110 {} times with identical arguments and output. You are stuck \
1111 in a loop. Stop calling this tool and try a different approach, \
1112 or provide your best answer with the information you already have.",
1113 recent[0].0, threshold
1114 )));
1115 }
1116 }
1117 }
1118
1119 if self.config.loop_detection_ping_pong_cycles > 0 {
1121 let cycles = self.config.loop_detection_ping_pong_cycles;
1122 let needed = cycles * 2;
1123 if tool_history.len() >= needed {
1124 let recent = &tool_history[tool_history.len() - needed..];
1125 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1126 && recent[0].0 != recent[1].0;
1127 if is_ping_pong {
1128 warn!(
1129 tool_a = recent[0].0,
1130 tool_b = recent[1].0,
1131 cycles,
1132 "structured loop detection: ping-pong pattern detected"
1133 );
1134 self.audit(
1135 "loop_detection_ping_pong",
1136 json!({
1137 "request_id": request_id,
1138 "tool_a": recent[0].0,
1139 "tool_b": recent[1].0,
1140 "cycles": cycles,
1141 }),
1142 )
1143 .await;
1144 messages.push(ConversationMessage::user(format!(
1145 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1146 alternating for {} cycles in a ping-pong pattern. Stop this \
1147 alternation and try a different approach, or provide your best \
1148 answer with the information you already have.",
1149 recent[0].0, recent[1].0, cycles
1150 )));
1151 }
1152 }
1153 }
1154
1155 if self.config.loop_detection_failure_streak > 0
1157 && failure_streak >= self.config.loop_detection_failure_streak
1158 {
1159 warn!(
1160 failure_streak,
1161 "structured loop detection: consecutive failure streak"
1162 );
1163 self.audit(
1164 "loop_detection_failure_streak",
1165 json!({
1166 "request_id": request_id,
1167 "streak": failure_streak,
1168 }),
1169 )
1170 .await;
1171 messages.push(ConversationMessage::user(format!(
1172 "SYSTEM NOTICE: {} consecutive tool calls have failed. \
1173 Stop calling tools and provide your best answer with \
1174 the information you already have.",
1175 failure_streak
1176 )));
1177 }
1178 }
1179
1180 self.audit(
1182 "structured_tool_use_max_iterations",
1183 json!({
1184 "request_id": request_id,
1185 "max_tool_iterations": self.config.max_tool_iterations,
1186 }),
1187 )
1188 .await;
1189
1190 messages.push(ConversationMessage::user(
1191 "You have reached the maximum number of tool call iterations. \
1192 Please provide your best answer now without calling any more tools."
1193 .to_string(),
1194 ));
1195 truncate_messages(&mut messages, self.config.max_prompt_chars);
1196
1197 let final_result = if let Some(ref sink) = stream_sink {
1198 self.provider
1199 .complete_streaming_with_tools(&messages, &[], &self.config.reasoning, sink.clone())
1200 .await
1201 .map_err(|source| AgentError::Provider { source })?
1202 } else {
1203 self.provider
1204 .complete_with_tools(&messages, &[], &self.config.reasoning)
1205 .await
1206 .map_err(|source| AgentError::Provider { source })?
1207 };
1208
1209 let response_text = final_result.output_text;
1210 self.write_to_memory(
1211 "assistant",
1212 &response_text,
1213 request_id,
1214 ctx.source_channel.as_deref(),
1215 ctx.conversation_id.as_deref().unwrap_or(""),
1216 )
1217 .await?;
1218 self.audit("respond_success", json!({"request_id": request_id}))
1219 .await;
1220 self.hook(
1221 "before_response_emit",
1222 json!({"request_id": request_id, "response_len": response_text.len()}),
1223 )
1224 .await?;
1225 let response = AssistantMessage {
1226 text: response_text,
1227 };
1228 self.hook(
1229 "after_response_emit",
1230 json!({"request_id": request_id, "response_len": response.text.len()}),
1231 )
1232 .await?;
1233 Ok(response)
1234 }
1235
1236 pub async fn respond(
1237 &self,
1238 user: UserMessage,
1239 ctx: &ToolContext,
1240 ) -> Result<AssistantMessage, AgentError> {
1241 let request_id = Self::next_request_id();
1242 self.increment_counter("requests_total");
1243 let run_started = Instant::now();
1244 self.hook("before_run", json!({"request_id": request_id}))
1245 .await?;
1246 let timed = timeout(
1247 Duration::from_millis(self.config.request_timeout_ms),
1248 self.respond_inner(&request_id, user, ctx, None),
1249 )
1250 .await;
1251 let result = match timed {
1252 Ok(result) => result,
1253 Err(_) => Err(AgentError::Timeout {
1254 timeout_ms: self.config.request_timeout_ms,
1255 }),
1256 };
1257
1258 let after_detail = match &result {
1259 Ok(response) => json!({
1260 "request_id": request_id,
1261 "status": "ok",
1262 "response_len": response.text.len(),
1263 "duration_ms": run_started.elapsed().as_millis(),
1264 }),
1265 Err(err) => json!({
1266 "request_id": request_id,
1267 "status": "error",
1268 "error": redact_text(&err.to_string()),
1269 "duration_ms": run_started.elapsed().as_millis(),
1270 }),
1271 };
1272 info!(
1273 request_id = %request_id,
1274 duration_ms = %run_started.elapsed().as_millis(),
1275 "agent run completed"
1276 );
1277 self.hook("after_run", after_detail).await?;
1278 result
1279 }
1280
1281 pub async fn respond_streaming(
1285 &self,
1286 user: UserMessage,
1287 ctx: &ToolContext,
1288 sink: StreamSink,
1289 ) -> Result<AssistantMessage, AgentError> {
1290 let request_id = Self::next_request_id();
1291 self.increment_counter("requests_total");
1292 let run_started = Instant::now();
1293 self.hook("before_run", json!({"request_id": request_id}))
1294 .await?;
1295 let timed = timeout(
1296 Duration::from_millis(self.config.request_timeout_ms),
1297 self.respond_inner(&request_id, user, ctx, Some(sink)),
1298 )
1299 .await;
1300 let result = match timed {
1301 Ok(result) => result,
1302 Err(_) => Err(AgentError::Timeout {
1303 timeout_ms: self.config.request_timeout_ms,
1304 }),
1305 };
1306
1307 let after_detail = match &result {
1308 Ok(response) => json!({
1309 "request_id": request_id,
1310 "status": "ok",
1311 "response_len": response.text.len(),
1312 "duration_ms": run_started.elapsed().as_millis(),
1313 }),
1314 Err(err) => json!({
1315 "request_id": request_id,
1316 "status": "error",
1317 "error": redact_text(&err.to_string()),
1318 "duration_ms": run_started.elapsed().as_millis(),
1319 }),
1320 };
1321 info!(
1322 request_id = %request_id,
1323 duration_ms = %run_started.elapsed().as_millis(),
1324 "streaming agent run completed"
1325 );
1326 self.hook("after_run", after_detail).await?;
1327 result
1328 }
1329
1330 async fn respond_inner(
1331 &self,
1332 request_id: &str,
1333 user: UserMessage,
1334 ctx: &ToolContext,
1335 stream_sink: Option<StreamSink>,
1336 ) -> Result<AssistantMessage, AgentError> {
1337 self.audit(
1338 "respond_start",
1339 json!({
1340 "request_id": request_id,
1341 "user_message_len": user.text.len(),
1342 "max_tool_iterations": self.config.max_tool_iterations,
1343 "request_timeout_ms": self.config.request_timeout_ms,
1344 }),
1345 )
1346 .await;
1347 self.write_to_memory(
1348 "user",
1349 &user.text,
1350 request_id,
1351 ctx.source_channel.as_deref(),
1352 ctx.conversation_id.as_deref().unwrap_or(""),
1353 )
1354 .await?;
1355
1356 let research_context = if self.should_research(&user.text) {
1357 self.run_research_phase(&user.text, ctx, request_id).await?
1358 } else {
1359 String::new()
1360 };
1361
1362 if self.config.model_supports_tool_use && self.has_tool_definitions() {
1364 return self
1365 .respond_with_tools(request_id, &user.text, &research_context, ctx, stream_sink)
1366 .await;
1367 }
1368
1369 let mut prompt = user.text;
1370 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1371
1372 for iteration in 0..self.config.max_tool_iterations {
1373 if !prompt.starts_with("tool:") {
1374 break;
1375 }
1376
1377 let calls = parse_tool_calls(&prompt);
1378 if calls.is_empty() {
1379 break;
1380 }
1381
1382 let has_gated = calls
1386 .iter()
1387 .any(|(name, _)| self.config.gated_tools.contains(*name));
1388 if calls.len() > 1 && self.config.parallel_tools && !has_gated {
1389 let mut resolved: Vec<(&str, &str, &dyn Tool)> = Vec::new();
1390 for &(name, input) in &calls {
1391 let tool = self.tools.iter().find(|t| t.name() == name);
1392 match tool {
1393 Some(t) => resolved.push((name, input, &**t)),
1394 None => {
1395 self.audit(
1396 "tool_not_found",
1397 json!({"request_id": request_id, "iteration": iteration, "tool_name": name}),
1398 )
1399 .await;
1400 }
1401 }
1402 }
1403 if resolved.is_empty() {
1404 break;
1405 }
1406
1407 let futs: Vec<_> = resolved
1408 .iter()
1409 .map(|&(name, input, tool)| async move {
1410 let r = tool.execute(input, ctx).await;
1411 (name, input, r)
1412 })
1413 .collect();
1414 let results = futures_util::future::join_all(futs).await;
1415
1416 let mut output_parts: Vec<String> = Vec::new();
1417 for (name, input, result) in results {
1418 match result {
1419 Ok(r) => {
1420 tool_history.push((
1421 name.to_string(),
1422 input.to_string(),
1423 r.output.clone(),
1424 ));
1425 output_parts.push(format!("Tool output from {name}: {}", r.output));
1426 }
1427 Err(source) => {
1428 return Err(AgentError::Tool {
1429 tool: name.to_string(),
1430 source,
1431 });
1432 }
1433 }
1434 }
1435 prompt = output_parts.join("\n");
1436 continue;
1437 }
1438
1439 let (tool_name, tool_input) = calls[0];
1441 self.audit(
1442 "tool_requested",
1443 json!({
1444 "request_id": request_id,
1445 "iteration": iteration,
1446 "tool_name": tool_name,
1447 "tool_input_len": tool_input.len(),
1448 }),
1449 )
1450 .await;
1451
1452 if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) {
1453 let result = self
1454 .execute_tool(&**tool, tool_name, tool_input, ctx, request_id, iteration)
1455 .await?;
1456
1457 tool_history.push((
1458 tool_name.to_string(),
1459 tool_input.to_string(),
1460 result.output.clone(),
1461 ));
1462
1463 if self.config.loop_detection_no_progress_threshold > 0 {
1465 let threshold = self.config.loop_detection_no_progress_threshold;
1466 if tool_history.len() >= threshold {
1467 let recent = &tool_history[tool_history.len() - threshold..];
1468 let all_same = recent.iter().all(|entry| {
1469 entry.0 == recent[0].0
1470 && entry.1 == recent[0].1
1471 && entry.2 == recent[0].2
1472 });
1473 if all_same {
1474 warn!(
1475 tool_name,
1476 threshold, "loop detection: no-progress threshold reached"
1477 );
1478 self.audit(
1479 "loop_detection_no_progress",
1480 json!({
1481 "request_id": request_id,
1482 "tool_name": tool_name,
1483 "threshold": threshold,
1484 }),
1485 )
1486 .await;
1487 prompt = format!(
1488 "SYSTEM NOTICE: Loop detected — the tool `{tool_name}` was called \
1489 {threshold} times with identical arguments and output. You are stuck \
1490 in a loop. Stop calling this tool and try a different approach, or \
1491 provide your best answer with the information you already have."
1492 );
1493 break;
1494 }
1495 }
1496 }
1497
1498 if self.config.loop_detection_ping_pong_cycles > 0 {
1500 let cycles = self.config.loop_detection_ping_pong_cycles;
1501 let needed = cycles * 2;
1502 if tool_history.len() >= needed {
1503 let recent = &tool_history[tool_history.len() - needed..];
1504 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1505 && recent[0].0 != recent[1].0;
1506 if is_ping_pong {
1507 warn!(
1508 tool_a = recent[0].0,
1509 tool_b = recent[1].0,
1510 cycles,
1511 "loop detection: ping-pong pattern detected"
1512 );
1513 self.audit(
1514 "loop_detection_ping_pong",
1515 json!({
1516 "request_id": request_id,
1517 "tool_a": recent[0].0,
1518 "tool_b": recent[1].0,
1519 "cycles": cycles,
1520 }),
1521 )
1522 .await;
1523 prompt = format!(
1524 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1525 alternating for {} cycles in a ping-pong pattern. Stop this \
1526 alternation and try a different approach, or provide your best \
1527 answer with the information you already have.",
1528 recent[0].0, recent[1].0, cycles
1529 );
1530 break;
1531 }
1532 }
1533 }
1534
1535 if result.output.starts_with("tool:") {
1536 prompt = result.output.clone();
1537 } else {
1538 prompt = format!("Tool output from {tool_name}: {}", result.output);
1539 }
1540 continue;
1541 }
1542
1543 self.audit(
1544 "tool_not_found",
1545 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
1546 )
1547 .await;
1548 break;
1549 }
1550
1551 if !research_context.is_empty() {
1552 prompt = format!("Research findings:\n{research_context}\n\nUser request:\n{prompt}");
1553 }
1554
1555 let response_text = self
1556 .call_provider_with_context(
1557 &prompt,
1558 request_id,
1559 stream_sink,
1560 ctx.source_channel.as_deref(),
1561 )
1562 .await?;
1563 self.write_to_memory(
1564 "assistant",
1565 &response_text,
1566 request_id,
1567 ctx.source_channel.as_deref(),
1568 ctx.conversation_id.as_deref().unwrap_or(""),
1569 )
1570 .await?;
1571 self.audit("respond_success", json!({"request_id": request_id}))
1572 .await;
1573
1574 self.hook(
1575 "before_response_emit",
1576 json!({"request_id": request_id, "response_len": response_text.len()}),
1577 )
1578 .await?;
1579 let response = AssistantMessage {
1580 text: response_text,
1581 };
1582 self.hook(
1583 "after_response_emit",
1584 json!({"request_id": request_id, "response_len": response.text.len()}),
1585 )
1586 .await?;
1587
1588 Ok(response)
1589 }
1590}
1591
1592#[cfg(test)]
1593mod tests {
1594 use super::*;
1595 use crate::types::{
1596 ChatResult, HookEvent, HookFailureMode, HookPolicy, ReasoningConfig, ResearchPolicy,
1597 ResearchTrigger, ToolResult,
1598 };
1599 use async_trait::async_trait;
1600 use serde_json::Value;
1601 use std::collections::HashMap;
1602 use std::sync::atomic::{AtomicUsize, Ordering};
1603 use std::sync::{Arc, Mutex};
1604 use tokio::time::{sleep, Duration};
1605
1606 #[derive(Default)]
1607 struct TestMemory {
1608 entries: Arc<Mutex<Vec<MemoryEntry>>>,
1609 }
1610
1611 #[async_trait]
1612 impl MemoryStore for TestMemory {
1613 async fn append(&self, entry: MemoryEntry) -> anyhow::Result<()> {
1614 self.entries
1615 .lock()
1616 .expect("memory lock poisoned")
1617 .push(entry);
1618 Ok(())
1619 }
1620
1621 async fn recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
1622 let entries = self.entries.lock().expect("memory lock poisoned");
1623 Ok(entries.iter().rev().take(limit).cloned().collect())
1624 }
1625 }
1626
1627 struct TestProvider {
1628 received_prompts: Arc<Mutex<Vec<String>>>,
1629 response_text: String,
1630 }
1631
1632 #[async_trait]
1633 impl Provider for TestProvider {
1634 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
1635 self.received_prompts
1636 .lock()
1637 .expect("provider lock poisoned")
1638 .push(prompt.to_string());
1639 Ok(ChatResult {
1640 output_text: self.response_text.clone(),
1641 ..Default::default()
1642 })
1643 }
1644 }
1645
1646 struct EchoTool;
1647
1648 #[async_trait]
1649 impl Tool for EchoTool {
1650 fn name(&self) -> &'static str {
1651 "echo"
1652 }
1653
1654 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1655 Ok(ToolResult {
1656 output: format!("echoed:{input}"),
1657 })
1658 }
1659 }
1660
1661 struct PluginEchoTool;
1662
1663 #[async_trait]
1664 impl Tool for PluginEchoTool {
1665 fn name(&self) -> &'static str {
1666 "plugin:echo"
1667 }
1668
1669 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1670 Ok(ToolResult {
1671 output: format!("plugin-echoed:{input}"),
1672 })
1673 }
1674 }
1675
1676 struct FailingTool;
1677
1678 #[async_trait]
1679 impl Tool for FailingTool {
1680 fn name(&self) -> &'static str {
1681 "boom"
1682 }
1683
1684 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1685 Err(anyhow::anyhow!("tool exploded"))
1686 }
1687 }
1688
1689 struct FailingProvider;
1690
1691 #[async_trait]
1692 impl Provider for FailingProvider {
1693 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
1694 Err(anyhow::anyhow!("provider boom"))
1695 }
1696 }
1697
1698 struct SlowProvider;
1699
1700 #[async_trait]
1701 impl Provider for SlowProvider {
1702 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
1703 sleep(Duration::from_millis(500)).await;
1704 Ok(ChatResult {
1705 output_text: "late".to_string(),
1706 ..Default::default()
1707 })
1708 }
1709 }
1710
1711 struct ScriptedProvider {
1712 responses: Vec<String>,
1713 call_count: AtomicUsize,
1714 received_prompts: Arc<Mutex<Vec<String>>>,
1715 }
1716
1717 impl ScriptedProvider {
1718 fn new(responses: Vec<&str>) -> Self {
1719 Self {
1720 responses: responses.into_iter().map(|s| s.to_string()).collect(),
1721 call_count: AtomicUsize::new(0),
1722 received_prompts: Arc::new(Mutex::new(Vec::new())),
1723 }
1724 }
1725 }
1726
1727 #[async_trait]
1728 impl Provider for ScriptedProvider {
1729 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
1730 self.received_prompts
1731 .lock()
1732 .expect("provider lock poisoned")
1733 .push(prompt.to_string());
1734 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
1735 let response = if idx < self.responses.len() {
1736 self.responses[idx].clone()
1737 } else {
1738 self.responses.last().cloned().unwrap_or_default()
1739 };
1740 Ok(ChatResult {
1741 output_text: response,
1742 ..Default::default()
1743 })
1744 }
1745 }
1746
1747 struct UpperTool;
1748
1749 #[async_trait]
1750 impl Tool for UpperTool {
1751 fn name(&self) -> &'static str {
1752 "upper"
1753 }
1754
1755 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1756 Ok(ToolResult {
1757 output: input.to_uppercase(),
1758 })
1759 }
1760 }
1761
1762 struct LoopTool;
1765
1766 #[async_trait]
1767 impl Tool for LoopTool {
1768 fn name(&self) -> &'static str {
1769 "loop_tool"
1770 }
1771
1772 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1773 Ok(ToolResult {
1774 output: "tool:loop_tool x".to_string(),
1775 })
1776 }
1777 }
1778
1779 struct PingTool;
1781
1782 #[async_trait]
1783 impl Tool for PingTool {
1784 fn name(&self) -> &'static str {
1785 "ping"
1786 }
1787
1788 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1789 Ok(ToolResult {
1790 output: "tool:pong x".to_string(),
1791 })
1792 }
1793 }
1794
1795 struct PongTool;
1797
1798 #[async_trait]
1799 impl Tool for PongTool {
1800 fn name(&self) -> &'static str {
1801 "pong"
1802 }
1803
1804 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1805 Ok(ToolResult {
1806 output: "tool:ping x".to_string(),
1807 })
1808 }
1809 }
1810
1811 struct RecordingHookSink {
1812 events: Arc<Mutex<Vec<String>>>,
1813 }
1814
1815 #[async_trait]
1816 impl HookSink for RecordingHookSink {
1817 async fn record(&self, event: HookEvent) -> anyhow::Result<()> {
1818 self.events
1819 .lock()
1820 .expect("hook lock poisoned")
1821 .push(event.stage);
1822 Ok(())
1823 }
1824 }
1825
1826 struct FailingHookSink;
1827
1828 #[async_trait]
1829 impl HookSink for FailingHookSink {
1830 async fn record(&self, _event: HookEvent) -> anyhow::Result<()> {
1831 Err(anyhow::anyhow!("hook sink failure"))
1832 }
1833 }
1834
1835 struct RecordingAuditSink {
1836 events: Arc<Mutex<Vec<AuditEvent>>>,
1837 }
1838
1839 #[async_trait]
1840 impl AuditSink for RecordingAuditSink {
1841 async fn record(&self, event: AuditEvent) -> anyhow::Result<()> {
1842 self.events.lock().expect("audit lock poisoned").push(event);
1843 Ok(())
1844 }
1845 }
1846
1847 struct RecordingMetricsSink {
1848 counters: Arc<Mutex<HashMap<&'static str, u64>>>,
1849 histograms: Arc<Mutex<HashMap<&'static str, usize>>>,
1850 }
1851
1852 impl MetricsSink for RecordingMetricsSink {
1853 fn increment_counter(&self, name: &'static str, value: u64) {
1854 let mut counters = self.counters.lock().expect("metrics lock poisoned");
1855 *counters.entry(name).or_insert(0) += value;
1856 }
1857
1858 fn observe_histogram(&self, name: &'static str, _value: f64) {
1859 let mut histograms = self.histograms.lock().expect("metrics lock poisoned");
1860 *histograms.entry(name).or_insert(0) += 1;
1861 }
1862 }
1863
1864 fn test_ctx() -> ToolContext {
1865 ToolContext::new(".".to_string())
1866 }
1867
1868 fn counter(counters: &Arc<Mutex<HashMap<&'static str, u64>>>, name: &'static str) -> u64 {
1869 counters
1870 .lock()
1871 .expect("metrics lock poisoned")
1872 .get(name)
1873 .copied()
1874 .unwrap_or(0)
1875 }
1876
1877 fn histogram_count(
1878 histograms: &Arc<Mutex<HashMap<&'static str, usize>>>,
1879 name: &'static str,
1880 ) -> usize {
1881 histograms
1882 .lock()
1883 .expect("metrics lock poisoned")
1884 .get(name)
1885 .copied()
1886 .unwrap_or(0)
1887 }
1888
1889 #[tokio::test]
1890 async fn respond_appends_user_then_assistant_memory_entries() {
1891 let entries = Arc::new(Mutex::new(Vec::new()));
1892 let prompts = Arc::new(Mutex::new(Vec::new()));
1893 let memory = TestMemory {
1894 entries: entries.clone(),
1895 };
1896 let provider = TestProvider {
1897 received_prompts: prompts,
1898 response_text: "assistant-output".to_string(),
1899 };
1900 let agent = Agent::new(
1901 AgentConfig::default(),
1902 Box::new(provider),
1903 Box::new(memory),
1904 vec![],
1905 );
1906
1907 let response = agent
1908 .respond(
1909 UserMessage {
1910 text: "hello world".to_string(),
1911 },
1912 &test_ctx(),
1913 )
1914 .await
1915 .expect("agent respond should succeed");
1916
1917 assert_eq!(response.text, "assistant-output");
1918 let stored = entries.lock().expect("memory lock poisoned");
1919 assert_eq!(stored.len(), 2);
1920 assert_eq!(stored[0].role, "user");
1921 assert_eq!(stored[0].content, "hello world");
1922 assert_eq!(stored[1].role, "assistant");
1923 assert_eq!(stored[1].content, "assistant-output");
1924 }
1925
1926 #[tokio::test]
1927 async fn respond_invokes_tool_for_tool_prefixed_prompt() {
1928 let entries = Arc::new(Mutex::new(Vec::new()));
1929 let prompts = Arc::new(Mutex::new(Vec::new()));
1930 let memory = TestMemory {
1931 entries: entries.clone(),
1932 };
1933 let provider = TestProvider {
1934 received_prompts: prompts.clone(),
1935 response_text: "assistant-after-tool".to_string(),
1936 };
1937 let agent = Agent::new(
1938 AgentConfig::default(),
1939 Box::new(provider),
1940 Box::new(memory),
1941 vec![Box::new(EchoTool)],
1942 );
1943
1944 let response = agent
1945 .respond(
1946 UserMessage {
1947 text: "tool:echo ping".to_string(),
1948 },
1949 &test_ctx(),
1950 )
1951 .await
1952 .expect("agent respond should succeed");
1953
1954 assert_eq!(response.text, "assistant-after-tool");
1955 let prompts = prompts.lock().expect("provider lock poisoned");
1956 assert_eq!(prompts.len(), 1);
1957 assert!(prompts[0].contains("Recent memory:"));
1958 assert!(prompts[0].contains("Current input:\nTool output from echo: echoed:ping"));
1959 }
1960
1961 #[tokio::test]
1962 async fn respond_with_unknown_tool_falls_back_to_provider_prompt() {
1963 let entries = Arc::new(Mutex::new(Vec::new()));
1964 let prompts = Arc::new(Mutex::new(Vec::new()));
1965 let memory = TestMemory {
1966 entries: entries.clone(),
1967 };
1968 let provider = TestProvider {
1969 received_prompts: prompts.clone(),
1970 response_text: "assistant-without-tool".to_string(),
1971 };
1972 let agent = Agent::new(
1973 AgentConfig::default(),
1974 Box::new(provider),
1975 Box::new(memory),
1976 vec![Box::new(EchoTool)],
1977 );
1978
1979 let response = agent
1980 .respond(
1981 UserMessage {
1982 text: "tool:unknown payload".to_string(),
1983 },
1984 &test_ctx(),
1985 )
1986 .await
1987 .expect("agent respond should succeed");
1988
1989 assert_eq!(response.text, "assistant-without-tool");
1990 let prompts = prompts.lock().expect("provider lock poisoned");
1991 assert_eq!(prompts.len(), 1);
1992 assert!(prompts[0].contains("Recent memory:"));
1993 assert!(prompts[0].contains("Current input:\ntool:unknown payload"));
1994 }
1995
1996 #[tokio::test]
1997 async fn respond_includes_bounded_recent_memory_in_provider_prompt() {
1998 let entries = Arc::new(Mutex::new(vec![
1999 MemoryEntry {
2000 role: "assistant".to_string(),
2001 content: "very-old".to_string(),
2002 ..Default::default()
2003 },
2004 MemoryEntry {
2005 role: "user".to_string(),
2006 content: "recent-before-request".to_string(),
2007 ..Default::default()
2008 },
2009 ]));
2010 let prompts = Arc::new(Mutex::new(Vec::new()));
2011 let memory = TestMemory {
2012 entries: entries.clone(),
2013 };
2014 let provider = TestProvider {
2015 received_prompts: prompts.clone(),
2016 response_text: "ok".to_string(),
2017 };
2018 let agent = Agent::new(
2019 AgentConfig {
2020 memory_window_size: 2,
2021 ..AgentConfig::default()
2022 },
2023 Box::new(provider),
2024 Box::new(memory),
2025 vec![],
2026 );
2027
2028 agent
2029 .respond(
2030 UserMessage {
2031 text: "latest-user".to_string(),
2032 },
2033 &test_ctx(),
2034 )
2035 .await
2036 .expect("respond should succeed");
2037
2038 let prompts = prompts.lock().expect("provider lock poisoned");
2039 let provider_prompt = prompts.first().expect("provider prompt should exist");
2040 assert!(provider_prompt.contains("- user: recent-before-request"));
2041 assert!(provider_prompt.contains("- user: latest-user"));
2042 assert!(!provider_prompt.contains("very-old"));
2043 }
2044
2045 #[tokio::test]
2046 async fn respond_caps_provider_prompt_to_configured_max_chars() {
2047 let entries = Arc::new(Mutex::new(vec![MemoryEntry {
2048 role: "assistant".to_string(),
2049 content: "historic context ".repeat(16),
2050 ..Default::default()
2051 }]));
2052 let prompts = Arc::new(Mutex::new(Vec::new()));
2053 let memory = TestMemory {
2054 entries: entries.clone(),
2055 };
2056 let provider = TestProvider {
2057 received_prompts: prompts.clone(),
2058 response_text: "ok".to_string(),
2059 };
2060 let agent = Agent::new(
2061 AgentConfig {
2062 max_prompt_chars: 64,
2063 ..AgentConfig::default()
2064 },
2065 Box::new(provider),
2066 Box::new(memory),
2067 vec![],
2068 );
2069
2070 agent
2071 .respond(
2072 UserMessage {
2073 text: "final-tail-marker".to_string(),
2074 },
2075 &test_ctx(),
2076 )
2077 .await
2078 .expect("respond should succeed");
2079
2080 let prompts = prompts.lock().expect("provider lock poisoned");
2081 let provider_prompt = prompts.first().expect("provider prompt should exist");
2082 assert!(provider_prompt.chars().count() <= 64);
2083 assert!(provider_prompt.contains("final-tail-marker"));
2084 }
2085
2086 #[tokio::test]
2087 async fn respond_returns_provider_typed_error() {
2088 let memory = TestMemory::default();
2089 let counters = Arc::new(Mutex::new(HashMap::new()));
2090 let histograms = Arc::new(Mutex::new(HashMap::new()));
2091 let agent = Agent::new(
2092 AgentConfig::default(),
2093 Box::new(FailingProvider),
2094 Box::new(memory),
2095 vec![],
2096 )
2097 .with_metrics(Box::new(RecordingMetricsSink {
2098 counters: counters.clone(),
2099 histograms: histograms.clone(),
2100 }));
2101
2102 let result = agent
2103 .respond(
2104 UserMessage {
2105 text: "hello".to_string(),
2106 },
2107 &test_ctx(),
2108 )
2109 .await;
2110
2111 match result {
2112 Err(AgentError::Provider { source }) => {
2113 assert!(source.to_string().contains("provider boom"));
2114 }
2115 other => panic!("expected provider error, got {other:?}"),
2116 }
2117 assert_eq!(counter(&counters, "requests_total"), 1);
2118 assert_eq!(counter(&counters, "provider_errors_total"), 1);
2119 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2120 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2121 }
2122
2123 #[tokio::test]
2124 async fn respond_increments_tool_error_counter_on_tool_failure() {
2125 let memory = TestMemory::default();
2126 let prompts = Arc::new(Mutex::new(Vec::new()));
2127 let provider = TestProvider {
2128 received_prompts: prompts,
2129 response_text: "unused".to_string(),
2130 };
2131 let counters = Arc::new(Mutex::new(HashMap::new()));
2132 let histograms = Arc::new(Mutex::new(HashMap::new()));
2133 let agent = Agent::new(
2134 AgentConfig::default(),
2135 Box::new(provider),
2136 Box::new(memory),
2137 vec![Box::new(FailingTool)],
2138 )
2139 .with_metrics(Box::new(RecordingMetricsSink {
2140 counters: counters.clone(),
2141 histograms: histograms.clone(),
2142 }));
2143
2144 let result = agent
2145 .respond(
2146 UserMessage {
2147 text: "tool:boom ping".to_string(),
2148 },
2149 &test_ctx(),
2150 )
2151 .await;
2152
2153 match result {
2154 Err(AgentError::Tool { tool, source }) => {
2155 assert_eq!(tool, "boom");
2156 assert!(source.to_string().contains("tool exploded"));
2157 }
2158 other => panic!("expected tool error, got {other:?}"),
2159 }
2160 assert_eq!(counter(&counters, "requests_total"), 1);
2161 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2162 assert_eq!(counter(&counters, "tool_errors_total"), 1);
2163 assert_eq!(histogram_count(&histograms, "tool_latency_ms"), 1);
2164 }
2165
2166 #[tokio::test]
2167 async fn respond_increments_requests_counter_on_success() {
2168 let memory = TestMemory::default();
2169 let prompts = Arc::new(Mutex::new(Vec::new()));
2170 let provider = TestProvider {
2171 received_prompts: prompts,
2172 response_text: "ok".to_string(),
2173 };
2174 let counters = Arc::new(Mutex::new(HashMap::new()));
2175 let histograms = Arc::new(Mutex::new(HashMap::new()));
2176 let agent = Agent::new(
2177 AgentConfig::default(),
2178 Box::new(provider),
2179 Box::new(memory),
2180 vec![],
2181 )
2182 .with_metrics(Box::new(RecordingMetricsSink {
2183 counters: counters.clone(),
2184 histograms: histograms.clone(),
2185 }));
2186
2187 let response = agent
2188 .respond(
2189 UserMessage {
2190 text: "hello".to_string(),
2191 },
2192 &test_ctx(),
2193 )
2194 .await
2195 .expect("response should succeed");
2196
2197 assert_eq!(response.text, "ok");
2198 assert_eq!(counter(&counters, "requests_total"), 1);
2199 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2200 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2201 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2202 }
2203
2204 #[tokio::test]
2205 async fn respond_returns_timeout_typed_error() {
2206 let memory = TestMemory::default();
2207 let agent = Agent::new(
2208 AgentConfig {
2209 max_tool_iterations: 1,
2210 request_timeout_ms: 10,
2211 ..AgentConfig::default()
2212 },
2213 Box::new(SlowProvider),
2214 Box::new(memory),
2215 vec![],
2216 );
2217
2218 let result = agent
2219 .respond(
2220 UserMessage {
2221 text: "hello".to_string(),
2222 },
2223 &test_ctx(),
2224 )
2225 .await;
2226
2227 match result {
2228 Err(AgentError::Timeout { timeout_ms }) => assert_eq!(timeout_ms, 10),
2229 other => panic!("expected timeout error, got {other:?}"),
2230 }
2231 }
2232
2233 #[tokio::test]
2234 async fn respond_emits_before_after_hook_events_when_enabled() {
2235 let events = Arc::new(Mutex::new(Vec::new()));
2236 let memory = TestMemory::default();
2237 let provider = TestProvider {
2238 received_prompts: Arc::new(Mutex::new(Vec::new())),
2239 response_text: "ok".to_string(),
2240 };
2241 let agent = Agent::new(
2242 AgentConfig {
2243 hooks: HookPolicy {
2244 enabled: true,
2245 timeout_ms: 50,
2246 fail_closed: true,
2247 ..HookPolicy::default()
2248 },
2249 ..AgentConfig::default()
2250 },
2251 Box::new(provider),
2252 Box::new(memory),
2253 vec![Box::new(EchoTool), Box::new(PluginEchoTool)],
2254 )
2255 .with_hooks(Box::new(RecordingHookSink {
2256 events: events.clone(),
2257 }));
2258
2259 agent
2260 .respond(
2261 UserMessage {
2262 text: "tool:plugin:echo ping".to_string(),
2263 },
2264 &test_ctx(),
2265 )
2266 .await
2267 .expect("respond should succeed");
2268
2269 let stages = events.lock().expect("hook lock poisoned");
2270 assert!(stages.contains(&"before_run".to_string()));
2271 assert!(stages.contains(&"after_run".to_string()));
2272 assert!(stages.contains(&"before_tool_call".to_string()));
2273 assert!(stages.contains(&"after_tool_call".to_string()));
2274 assert!(stages.contains(&"before_plugin_call".to_string()));
2275 assert!(stages.contains(&"after_plugin_call".to_string()));
2276 assert!(stages.contains(&"before_provider_call".to_string()));
2277 assert!(stages.contains(&"after_provider_call".to_string()));
2278 assert!(stages.contains(&"before_memory_write".to_string()));
2279 assert!(stages.contains(&"after_memory_write".to_string()));
2280 assert!(stages.contains(&"before_response_emit".to_string()));
2281 assert!(stages.contains(&"after_response_emit".to_string()));
2282 }
2283
2284 #[tokio::test]
2285 async fn respond_audit_events_include_request_id_and_durations() {
2286 let audit_events = Arc::new(Mutex::new(Vec::new()));
2287 let memory = TestMemory::default();
2288 let provider = TestProvider {
2289 received_prompts: Arc::new(Mutex::new(Vec::new())),
2290 response_text: "ok".to_string(),
2291 };
2292 let agent = Agent::new(
2293 AgentConfig::default(),
2294 Box::new(provider),
2295 Box::new(memory),
2296 vec![Box::new(EchoTool)],
2297 )
2298 .with_audit(Box::new(RecordingAuditSink {
2299 events: audit_events.clone(),
2300 }));
2301
2302 agent
2303 .respond(
2304 UserMessage {
2305 text: "tool:echo ping".to_string(),
2306 },
2307 &test_ctx(),
2308 )
2309 .await
2310 .expect("respond should succeed");
2311
2312 let events = audit_events.lock().expect("audit lock poisoned");
2313 let request_id = events
2314 .iter()
2315 .find_map(|e| {
2316 e.detail
2317 .get("request_id")
2318 .and_then(Value::as_str)
2319 .map(ToString::to_string)
2320 })
2321 .expect("request_id should exist on audit events");
2322
2323 let provider_event = events
2324 .iter()
2325 .find(|e| e.stage == "provider_call_success")
2326 .expect("provider success event should exist");
2327 assert_eq!(
2328 provider_event
2329 .detail
2330 .get("request_id")
2331 .and_then(Value::as_str),
2332 Some(request_id.as_str())
2333 );
2334 assert!(provider_event
2335 .detail
2336 .get("duration_ms")
2337 .and_then(Value::as_u64)
2338 .is_some());
2339
2340 let tool_event = events
2341 .iter()
2342 .find(|e| e.stage == "tool_execute_success")
2343 .expect("tool success event should exist");
2344 assert_eq!(
2345 tool_event.detail.get("request_id").and_then(Value::as_str),
2346 Some(request_id.as_str())
2347 );
2348 assert!(tool_event
2349 .detail
2350 .get("duration_ms")
2351 .and_then(Value::as_u64)
2352 .is_some());
2353 }
2354
2355 #[tokio::test]
2356 async fn hook_errors_respect_block_mode_for_high_tier_negative_path() {
2357 let memory = TestMemory::default();
2358 let provider = TestProvider {
2359 received_prompts: Arc::new(Mutex::new(Vec::new())),
2360 response_text: "ok".to_string(),
2361 };
2362 let agent = Agent::new(
2363 AgentConfig {
2364 hooks: HookPolicy {
2365 enabled: true,
2366 timeout_ms: 50,
2367 fail_closed: false,
2368 default_mode: HookFailureMode::Warn,
2369 low_tier_mode: HookFailureMode::Ignore,
2370 medium_tier_mode: HookFailureMode::Warn,
2371 high_tier_mode: HookFailureMode::Block,
2372 },
2373 ..AgentConfig::default()
2374 },
2375 Box::new(provider),
2376 Box::new(memory),
2377 vec![Box::new(EchoTool)],
2378 )
2379 .with_hooks(Box::new(FailingHookSink));
2380
2381 let result = agent
2382 .respond(
2383 UserMessage {
2384 text: "tool:echo ping".to_string(),
2385 },
2386 &test_ctx(),
2387 )
2388 .await;
2389
2390 match result {
2391 Err(AgentError::Hook { stage, .. }) => assert_eq!(stage, "before_tool_call"),
2392 other => panic!("expected hook block error, got {other:?}"),
2393 }
2394 }
2395
2396 #[tokio::test]
2397 async fn hook_errors_respect_warn_mode_for_high_tier_success_path() {
2398 let audit_events = Arc::new(Mutex::new(Vec::new()));
2399 let memory = TestMemory::default();
2400 let provider = TestProvider {
2401 received_prompts: Arc::new(Mutex::new(Vec::new())),
2402 response_text: "ok".to_string(),
2403 };
2404 let agent = Agent::new(
2405 AgentConfig {
2406 hooks: HookPolicy {
2407 enabled: true,
2408 timeout_ms: 50,
2409 fail_closed: false,
2410 default_mode: HookFailureMode::Warn,
2411 low_tier_mode: HookFailureMode::Ignore,
2412 medium_tier_mode: HookFailureMode::Warn,
2413 high_tier_mode: HookFailureMode::Warn,
2414 },
2415 ..AgentConfig::default()
2416 },
2417 Box::new(provider),
2418 Box::new(memory),
2419 vec![Box::new(EchoTool)],
2420 )
2421 .with_hooks(Box::new(FailingHookSink))
2422 .with_audit(Box::new(RecordingAuditSink {
2423 events: audit_events.clone(),
2424 }));
2425
2426 let response = agent
2427 .respond(
2428 UserMessage {
2429 text: "tool:echo ping".to_string(),
2430 },
2431 &test_ctx(),
2432 )
2433 .await
2434 .expect("warn mode should continue");
2435 assert!(!response.text.is_empty());
2436
2437 let events = audit_events.lock().expect("audit lock poisoned");
2438 assert!(events.iter().any(|event| event.stage == "hook_error_warn"));
2439 }
2440
2441 #[tokio::test]
2444 async fn self_correction_injected_on_no_progress() {
2445 let prompts = Arc::new(Mutex::new(Vec::new()));
2446 let provider = TestProvider {
2447 received_prompts: prompts.clone(),
2448 response_text: "I'll try a different approach".to_string(),
2449 };
2450 let agent = Agent::new(
2453 AgentConfig {
2454 loop_detection_no_progress_threshold: 3,
2455 max_tool_iterations: 20,
2456 ..AgentConfig::default()
2457 },
2458 Box::new(provider),
2459 Box::new(TestMemory::default()),
2460 vec![Box::new(LoopTool)],
2461 );
2462
2463 let response = agent
2464 .respond(
2465 UserMessage {
2466 text: "tool:loop_tool x".to_string(),
2467 },
2468 &test_ctx(),
2469 )
2470 .await
2471 .expect("self-correction should succeed");
2472
2473 assert_eq!(response.text, "I'll try a different approach");
2474
2475 let received = prompts.lock().expect("provider lock poisoned");
2477 let last_prompt = received.last().expect("provider should have been called");
2478 assert!(
2479 last_prompt.contains("Loop detected"),
2480 "provider prompt should contain self-correction notice, got: {last_prompt}"
2481 );
2482 assert!(last_prompt.contains("loop_tool"));
2483 }
2484
2485 #[tokio::test]
2486 async fn self_correction_injected_on_ping_pong() {
2487 let prompts = Arc::new(Mutex::new(Vec::new()));
2488 let provider = TestProvider {
2489 received_prompts: prompts.clone(),
2490 response_text: "I stopped the loop".to_string(),
2491 };
2492 let agent = Agent::new(
2494 AgentConfig {
2495 loop_detection_ping_pong_cycles: 2,
2496 loop_detection_no_progress_threshold: 0, max_tool_iterations: 20,
2498 ..AgentConfig::default()
2499 },
2500 Box::new(provider),
2501 Box::new(TestMemory::default()),
2502 vec![Box::new(PingTool), Box::new(PongTool)],
2503 );
2504
2505 let response = agent
2506 .respond(
2507 UserMessage {
2508 text: "tool:ping x".to_string(),
2509 },
2510 &test_ctx(),
2511 )
2512 .await
2513 .expect("self-correction should succeed");
2514
2515 assert_eq!(response.text, "I stopped the loop");
2516
2517 let received = prompts.lock().expect("provider lock poisoned");
2518 let last_prompt = received.last().expect("provider should have been called");
2519 assert!(
2520 last_prompt.contains("ping-pong"),
2521 "provider prompt should contain ping-pong notice, got: {last_prompt}"
2522 );
2523 }
2524
2525 #[tokio::test]
2526 async fn no_progress_disabled_when_threshold_zero() {
2527 let prompts = Arc::new(Mutex::new(Vec::new()));
2528 let provider = TestProvider {
2529 received_prompts: prompts.clone(),
2530 response_text: "final answer".to_string(),
2531 };
2532 let agent = Agent::new(
2535 AgentConfig {
2536 loop_detection_no_progress_threshold: 0,
2537 loop_detection_ping_pong_cycles: 0,
2538 max_tool_iterations: 5,
2539 ..AgentConfig::default()
2540 },
2541 Box::new(provider),
2542 Box::new(TestMemory::default()),
2543 vec![Box::new(LoopTool)],
2544 );
2545
2546 let response = agent
2547 .respond(
2548 UserMessage {
2549 text: "tool:loop_tool x".to_string(),
2550 },
2551 &test_ctx(),
2552 )
2553 .await
2554 .expect("should complete without error");
2555
2556 assert_eq!(response.text, "final answer");
2557
2558 let received = prompts.lock().expect("provider lock poisoned");
2560 let last_prompt = received.last().expect("provider should have been called");
2561 assert!(
2562 !last_prompt.contains("Loop detected"),
2563 "should not contain self-correction when detection is disabled"
2564 );
2565 }
2566
2567 #[tokio::test]
2570 async fn parallel_tools_executes_multiple_calls() {
2571 let prompts = Arc::new(Mutex::new(Vec::new()));
2572 let provider = TestProvider {
2573 received_prompts: prompts.clone(),
2574 response_text: "parallel done".to_string(),
2575 };
2576 let agent = Agent::new(
2577 AgentConfig {
2578 parallel_tools: true,
2579 max_tool_iterations: 5,
2580 ..AgentConfig::default()
2581 },
2582 Box::new(provider),
2583 Box::new(TestMemory::default()),
2584 vec![Box::new(EchoTool), Box::new(UpperTool)],
2585 );
2586
2587 let response = agent
2588 .respond(
2589 UserMessage {
2590 text: "tool:echo hello\ntool:upper world".to_string(),
2591 },
2592 &test_ctx(),
2593 )
2594 .await
2595 .expect("parallel tools should succeed");
2596
2597 assert_eq!(response.text, "parallel done");
2598
2599 let received = prompts.lock().expect("provider lock poisoned");
2600 let last_prompt = received.last().expect("provider should have been called");
2601 assert!(
2602 last_prompt.contains("echoed:hello"),
2603 "should contain echo output, got: {last_prompt}"
2604 );
2605 assert!(
2606 last_prompt.contains("WORLD"),
2607 "should contain upper output, got: {last_prompt}"
2608 );
2609 }
2610
2611 #[tokio::test]
2612 async fn parallel_tools_maintains_stable_ordering() {
2613 let prompts = Arc::new(Mutex::new(Vec::new()));
2614 let provider = TestProvider {
2615 received_prompts: prompts.clone(),
2616 response_text: "ordered".to_string(),
2617 };
2618 let agent = Agent::new(
2619 AgentConfig {
2620 parallel_tools: true,
2621 max_tool_iterations: 5,
2622 ..AgentConfig::default()
2623 },
2624 Box::new(provider),
2625 Box::new(TestMemory::default()),
2626 vec![Box::new(EchoTool), Box::new(UpperTool)],
2627 );
2628
2629 let response = agent
2631 .respond(
2632 UserMessage {
2633 text: "tool:echo aaa\ntool:upper bbb".to_string(),
2634 },
2635 &test_ctx(),
2636 )
2637 .await
2638 .expect("parallel tools should succeed");
2639
2640 assert_eq!(response.text, "ordered");
2641
2642 let received = prompts.lock().expect("provider lock poisoned");
2643 let last_prompt = received.last().expect("provider should have been called");
2644 let echo_pos = last_prompt
2645 .find("echoed:aaa")
2646 .expect("echo output should exist");
2647 let upper_pos = last_prompt.find("BBB").expect("upper output should exist");
2648 assert!(
2649 echo_pos < upper_pos,
2650 "echo output should come before upper output in the prompt"
2651 );
2652 }
2653
2654 #[tokio::test]
2655 async fn parallel_disabled_runs_first_call_only() {
2656 let prompts = Arc::new(Mutex::new(Vec::new()));
2657 let provider = TestProvider {
2658 received_prompts: prompts.clone(),
2659 response_text: "sequential".to_string(),
2660 };
2661 let agent = Agent::new(
2662 AgentConfig {
2663 parallel_tools: false,
2664 max_tool_iterations: 5,
2665 ..AgentConfig::default()
2666 },
2667 Box::new(provider),
2668 Box::new(TestMemory::default()),
2669 vec![Box::new(EchoTool), Box::new(UpperTool)],
2670 );
2671
2672 let response = agent
2673 .respond(
2674 UserMessage {
2675 text: "tool:echo hello\ntool:upper world".to_string(),
2676 },
2677 &test_ctx(),
2678 )
2679 .await
2680 .expect("sequential tools should succeed");
2681
2682 assert_eq!(response.text, "sequential");
2683
2684 let received = prompts.lock().expect("provider lock poisoned");
2686 let last_prompt = received.last().expect("provider should have been called");
2687 assert!(
2688 last_prompt.contains("echoed:hello"),
2689 "should contain first tool output, got: {last_prompt}"
2690 );
2691 assert!(
2692 !last_prompt.contains("WORLD"),
2693 "should NOT contain second tool output when parallel is disabled"
2694 );
2695 }
2696
2697 #[tokio::test]
2698 async fn single_call_parallel_matches_sequential() {
2699 let prompts = Arc::new(Mutex::new(Vec::new()));
2700 let provider = TestProvider {
2701 received_prompts: prompts.clone(),
2702 response_text: "single".to_string(),
2703 };
2704 let agent = Agent::new(
2705 AgentConfig {
2706 parallel_tools: true,
2707 max_tool_iterations: 5,
2708 ..AgentConfig::default()
2709 },
2710 Box::new(provider),
2711 Box::new(TestMemory::default()),
2712 vec![Box::new(EchoTool)],
2713 );
2714
2715 let response = agent
2716 .respond(
2717 UserMessage {
2718 text: "tool:echo ping".to_string(),
2719 },
2720 &test_ctx(),
2721 )
2722 .await
2723 .expect("single tool with parallel enabled should succeed");
2724
2725 assert_eq!(response.text, "single");
2726
2727 let received = prompts.lock().expect("provider lock poisoned");
2728 let last_prompt = received.last().expect("provider should have been called");
2729 assert!(
2730 last_prompt.contains("echoed:ping"),
2731 "single tool call should work same as sequential, got: {last_prompt}"
2732 );
2733 }
2734
2735 #[tokio::test]
2736 async fn parallel_falls_back_to_sequential_for_gated_tools() {
2737 let prompts = Arc::new(Mutex::new(Vec::new()));
2738 let provider = TestProvider {
2739 received_prompts: prompts.clone(),
2740 response_text: "gated sequential".to_string(),
2741 };
2742 let mut gated = std::collections::HashSet::new();
2743 gated.insert("upper".to_string());
2744 let agent = Agent::new(
2745 AgentConfig {
2746 parallel_tools: true,
2747 gated_tools: gated,
2748 max_tool_iterations: 5,
2749 ..AgentConfig::default()
2750 },
2751 Box::new(provider),
2752 Box::new(TestMemory::default()),
2753 vec![Box::new(EchoTool), Box::new(UpperTool)],
2754 );
2755
2756 let response = agent
2759 .respond(
2760 UserMessage {
2761 text: "tool:echo hello\ntool:upper world".to_string(),
2762 },
2763 &test_ctx(),
2764 )
2765 .await
2766 .expect("gated fallback should succeed");
2767
2768 assert_eq!(response.text, "gated sequential");
2769
2770 let received = prompts.lock().expect("provider lock poisoned");
2771 let last_prompt = received.last().expect("provider should have been called");
2772 assert!(
2773 last_prompt.contains("echoed:hello"),
2774 "first tool should execute, got: {last_prompt}"
2775 );
2776 assert!(
2777 !last_prompt.contains("WORLD"),
2778 "gated tool should NOT execute in parallel, got: {last_prompt}"
2779 );
2780 }
2781
2782 fn research_config(trigger: ResearchTrigger) -> ResearchPolicy {
2785 ResearchPolicy {
2786 enabled: true,
2787 trigger,
2788 keywords: vec!["search".to_string(), "find".to_string()],
2789 min_message_length: 10,
2790 max_iterations: 5,
2791 show_progress: true,
2792 }
2793 }
2794
2795 fn config_with_research(trigger: ResearchTrigger) -> AgentConfig {
2796 AgentConfig {
2797 research: research_config(trigger),
2798 ..Default::default()
2799 }
2800 }
2801
2802 #[tokio::test]
2803 async fn research_trigger_never() {
2804 let prompts = Arc::new(Mutex::new(Vec::new()));
2805 let provider = TestProvider {
2806 received_prompts: prompts.clone(),
2807 response_text: "final answer".to_string(),
2808 };
2809 let agent = Agent::new(
2810 config_with_research(ResearchTrigger::Never),
2811 Box::new(provider),
2812 Box::new(TestMemory::default()),
2813 vec![Box::new(EchoTool)],
2814 );
2815
2816 let response = agent
2817 .respond(
2818 UserMessage {
2819 text: "search for something".to_string(),
2820 },
2821 &test_ctx(),
2822 )
2823 .await
2824 .expect("should succeed");
2825
2826 assert_eq!(response.text, "final answer");
2827 let received = prompts.lock().expect("lock");
2828 assert_eq!(
2829 received.len(),
2830 1,
2831 "should have exactly 1 provider call (no research)"
2832 );
2833 }
2834
2835 #[tokio::test]
2836 async fn research_trigger_always() {
2837 let provider = ScriptedProvider::new(vec![
2838 "tool:echo gathering data",
2839 "Found relevant information",
2840 "Final answer with research context",
2841 ]);
2842 let prompts = provider.received_prompts.clone();
2843 let agent = Agent::new(
2844 config_with_research(ResearchTrigger::Always),
2845 Box::new(provider),
2846 Box::new(TestMemory::default()),
2847 vec![Box::new(EchoTool)],
2848 );
2849
2850 let response = agent
2851 .respond(
2852 UserMessage {
2853 text: "hello world".to_string(),
2854 },
2855 &test_ctx(),
2856 )
2857 .await
2858 .expect("should succeed");
2859
2860 assert_eq!(response.text, "Final answer with research context");
2861 let received = prompts.lock().expect("lock");
2862 let last = received.last().expect("should have provider calls");
2863 assert!(
2864 last.contains("Research findings:"),
2865 "final prompt should contain research findings, got: {last}"
2866 );
2867 }
2868
2869 #[tokio::test]
2870 async fn research_trigger_keywords_match() {
2871 let provider = ScriptedProvider::new(vec!["Summary of search", "Answer based on research"]);
2872 let prompts = provider.received_prompts.clone();
2873 let agent = Agent::new(
2874 config_with_research(ResearchTrigger::Keywords),
2875 Box::new(provider),
2876 Box::new(TestMemory::default()),
2877 vec![],
2878 );
2879
2880 let response = agent
2881 .respond(
2882 UserMessage {
2883 text: "please search for the config".to_string(),
2884 },
2885 &test_ctx(),
2886 )
2887 .await
2888 .expect("should succeed");
2889
2890 assert_eq!(response.text, "Answer based on research");
2891 let received = prompts.lock().expect("lock");
2892 assert!(received.len() >= 2, "should have at least 2 provider calls");
2893 assert!(
2894 received[0].contains("RESEARCH mode"),
2895 "first call should be research prompt"
2896 );
2897 }
2898
2899 #[tokio::test]
2900 async fn research_trigger_keywords_no_match() {
2901 let prompts = Arc::new(Mutex::new(Vec::new()));
2902 let provider = TestProvider {
2903 received_prompts: prompts.clone(),
2904 response_text: "direct answer".to_string(),
2905 };
2906 let agent = Agent::new(
2907 config_with_research(ResearchTrigger::Keywords),
2908 Box::new(provider),
2909 Box::new(TestMemory::default()),
2910 vec![],
2911 );
2912
2913 let response = agent
2914 .respond(
2915 UserMessage {
2916 text: "hello world".to_string(),
2917 },
2918 &test_ctx(),
2919 )
2920 .await
2921 .expect("should succeed");
2922
2923 assert_eq!(response.text, "direct answer");
2924 let received = prompts.lock().expect("lock");
2925 assert_eq!(received.len(), 1, "no research phase should fire");
2926 }
2927
2928 #[tokio::test]
2929 async fn research_trigger_length_short_skips() {
2930 let prompts = Arc::new(Mutex::new(Vec::new()));
2931 let provider = TestProvider {
2932 received_prompts: prompts.clone(),
2933 response_text: "short".to_string(),
2934 };
2935 let config = AgentConfig {
2936 research: ResearchPolicy {
2937 min_message_length: 20,
2938 ..research_config(ResearchTrigger::Length)
2939 },
2940 ..Default::default()
2941 };
2942 let agent = Agent::new(
2943 config,
2944 Box::new(provider),
2945 Box::new(TestMemory::default()),
2946 vec![],
2947 );
2948 agent
2949 .respond(
2950 UserMessage {
2951 text: "hi".to_string(),
2952 },
2953 &test_ctx(),
2954 )
2955 .await
2956 .expect("should succeed");
2957 let received = prompts.lock().expect("lock");
2958 assert_eq!(received.len(), 1, "short message should skip research");
2959 }
2960
2961 #[tokio::test]
2962 async fn research_trigger_length_long_triggers() {
2963 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
2964 let prompts = provider.received_prompts.clone();
2965 let config = AgentConfig {
2966 research: ResearchPolicy {
2967 min_message_length: 20,
2968 ..research_config(ResearchTrigger::Length)
2969 },
2970 ..Default::default()
2971 };
2972 let agent = Agent::new(
2973 config,
2974 Box::new(provider),
2975 Box::new(TestMemory::default()),
2976 vec![],
2977 );
2978 agent
2979 .respond(
2980 UserMessage {
2981 text: "this is a longer message that exceeds the threshold".to_string(),
2982 },
2983 &test_ctx(),
2984 )
2985 .await
2986 .expect("should succeed");
2987 let received = prompts.lock().expect("lock");
2988 assert!(received.len() >= 2, "long message should trigger research");
2989 }
2990
2991 #[tokio::test]
2992 async fn research_trigger_question_with_mark() {
2993 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
2994 let prompts = provider.received_prompts.clone();
2995 let agent = Agent::new(
2996 config_with_research(ResearchTrigger::Question),
2997 Box::new(provider),
2998 Box::new(TestMemory::default()),
2999 vec![],
3000 );
3001 agent
3002 .respond(
3003 UserMessage {
3004 text: "what is the meaning of life?".to_string(),
3005 },
3006 &test_ctx(),
3007 )
3008 .await
3009 .expect("should succeed");
3010 let received = prompts.lock().expect("lock");
3011 assert!(received.len() >= 2, "question should trigger research");
3012 }
3013
3014 #[tokio::test]
3015 async fn research_trigger_question_without_mark() {
3016 let prompts = Arc::new(Mutex::new(Vec::new()));
3017 let provider = TestProvider {
3018 received_prompts: prompts.clone(),
3019 response_text: "no research".to_string(),
3020 };
3021 let agent = Agent::new(
3022 config_with_research(ResearchTrigger::Question),
3023 Box::new(provider),
3024 Box::new(TestMemory::default()),
3025 vec![],
3026 );
3027 agent
3028 .respond(
3029 UserMessage {
3030 text: "do this thing".to_string(),
3031 },
3032 &test_ctx(),
3033 )
3034 .await
3035 .expect("should succeed");
3036 let received = prompts.lock().expect("lock");
3037 assert_eq!(received.len(), 1, "non-question should skip research");
3038 }
3039
3040 #[tokio::test]
3041 async fn research_respects_max_iterations() {
3042 let provider = ScriptedProvider::new(vec![
3043 "tool:echo step1",
3044 "tool:echo step2",
3045 "tool:echo step3",
3046 "tool:echo step4",
3047 "tool:echo step5",
3048 "tool:echo step6",
3049 "tool:echo step7",
3050 "answer after research",
3051 ]);
3052 let prompts = provider.received_prompts.clone();
3053 let config = AgentConfig {
3054 research: ResearchPolicy {
3055 max_iterations: 3,
3056 ..research_config(ResearchTrigger::Always)
3057 },
3058 ..Default::default()
3059 };
3060 let agent = Agent::new(
3061 config,
3062 Box::new(provider),
3063 Box::new(TestMemory::default()),
3064 vec![Box::new(EchoTool)],
3065 );
3066
3067 let response = agent
3068 .respond(
3069 UserMessage {
3070 text: "test".to_string(),
3071 },
3072 &test_ctx(),
3073 )
3074 .await
3075 .expect("should succeed");
3076
3077 assert!(!response.text.is_empty(), "should get a response");
3078 let received = prompts.lock().expect("lock");
3079 let research_calls = received
3080 .iter()
3081 .filter(|p| p.contains("RESEARCH mode") || p.contains("Research iteration"))
3082 .count();
3083 assert_eq!(
3085 research_calls, 4,
3086 "should have 4 research provider calls (1 initial + 3 iterations)"
3087 );
3088 }
3089
3090 #[tokio::test]
3091 async fn research_disabled_skips_phase() {
3092 let prompts = Arc::new(Mutex::new(Vec::new()));
3093 let provider = TestProvider {
3094 received_prompts: prompts.clone(),
3095 response_text: "direct answer".to_string(),
3096 };
3097 let config = AgentConfig {
3098 research: ResearchPolicy {
3099 enabled: false,
3100 trigger: ResearchTrigger::Always,
3101 ..Default::default()
3102 },
3103 ..Default::default()
3104 };
3105 let agent = Agent::new(
3106 config,
3107 Box::new(provider),
3108 Box::new(TestMemory::default()),
3109 vec![Box::new(EchoTool)],
3110 );
3111
3112 let response = agent
3113 .respond(
3114 UserMessage {
3115 text: "search for something".to_string(),
3116 },
3117 &test_ctx(),
3118 )
3119 .await
3120 .expect("should succeed");
3121
3122 assert_eq!(response.text, "direct answer");
3123 let received = prompts.lock().expect("lock");
3124 assert_eq!(
3125 received.len(),
3126 1,
3127 "disabled research should not fire even with Always trigger"
3128 );
3129 }
3130
3131 struct ReasoningCapturingProvider {
3134 captured_reasoning: Arc<Mutex<Vec<ReasoningConfig>>>,
3135 response_text: String,
3136 }
3137
3138 #[async_trait]
3139 impl Provider for ReasoningCapturingProvider {
3140 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3141 self.captured_reasoning
3142 .lock()
3143 .expect("lock")
3144 .push(ReasoningConfig::default());
3145 Ok(ChatResult {
3146 output_text: self.response_text.clone(),
3147 ..Default::default()
3148 })
3149 }
3150
3151 async fn complete_with_reasoning(
3152 &self,
3153 _prompt: &str,
3154 reasoning: &ReasoningConfig,
3155 ) -> anyhow::Result<ChatResult> {
3156 self.captured_reasoning
3157 .lock()
3158 .expect("lock")
3159 .push(reasoning.clone());
3160 Ok(ChatResult {
3161 output_text: self.response_text.clone(),
3162 ..Default::default()
3163 })
3164 }
3165 }
3166
3167 #[tokio::test]
3168 async fn reasoning_config_passed_to_provider() {
3169 let captured = Arc::new(Mutex::new(Vec::new()));
3170 let provider = ReasoningCapturingProvider {
3171 captured_reasoning: captured.clone(),
3172 response_text: "ok".to_string(),
3173 };
3174 let config = AgentConfig {
3175 reasoning: ReasoningConfig {
3176 enabled: Some(true),
3177 level: Some("high".to_string()),
3178 },
3179 ..Default::default()
3180 };
3181 let agent = Agent::new(
3182 config,
3183 Box::new(provider),
3184 Box::new(TestMemory::default()),
3185 vec![],
3186 );
3187
3188 agent
3189 .respond(
3190 UserMessage {
3191 text: "test".to_string(),
3192 },
3193 &test_ctx(),
3194 )
3195 .await
3196 .expect("should succeed");
3197
3198 let configs = captured.lock().expect("lock");
3199 assert_eq!(configs.len(), 1, "provider should be called once");
3200 assert_eq!(configs[0].enabled, Some(true));
3201 assert_eq!(configs[0].level.as_deref(), Some("high"));
3202 }
3203
3204 #[tokio::test]
3205 async fn reasoning_disabled_passes_config_through() {
3206 let captured = Arc::new(Mutex::new(Vec::new()));
3207 let provider = ReasoningCapturingProvider {
3208 captured_reasoning: captured.clone(),
3209 response_text: "ok".to_string(),
3210 };
3211 let config = AgentConfig {
3212 reasoning: ReasoningConfig {
3213 enabled: Some(false),
3214 level: None,
3215 },
3216 ..Default::default()
3217 };
3218 let agent = Agent::new(
3219 config,
3220 Box::new(provider),
3221 Box::new(TestMemory::default()),
3222 vec![],
3223 );
3224
3225 agent
3226 .respond(
3227 UserMessage {
3228 text: "test".to_string(),
3229 },
3230 &test_ctx(),
3231 )
3232 .await
3233 .expect("should succeed");
3234
3235 let configs = captured.lock().expect("lock");
3236 assert_eq!(configs.len(), 1);
3237 assert_eq!(
3238 configs[0].enabled,
3239 Some(false),
3240 "disabled reasoning should still be passed to provider"
3241 );
3242 assert_eq!(configs[0].level, None);
3243 }
3244
3245 struct StructuredProvider {
3248 responses: Vec<ChatResult>,
3249 call_count: AtomicUsize,
3250 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
3251 }
3252
3253 impl StructuredProvider {
3254 fn new(responses: Vec<ChatResult>) -> Self {
3255 Self {
3256 responses,
3257 call_count: AtomicUsize::new(0),
3258 received_messages: Arc::new(Mutex::new(Vec::new())),
3259 }
3260 }
3261 }
3262
3263 #[async_trait]
3264 impl Provider for StructuredProvider {
3265 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3266 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3267 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3268 }
3269
3270 async fn complete_with_tools(
3271 &self,
3272 messages: &[ConversationMessage],
3273 _tools: &[ToolDefinition],
3274 _reasoning: &ReasoningConfig,
3275 ) -> anyhow::Result<ChatResult> {
3276 self.received_messages
3277 .lock()
3278 .expect("lock")
3279 .push(messages.to_vec());
3280 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3281 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3282 }
3283 }
3284
3285 struct StructuredEchoTool;
3286
3287 #[async_trait]
3288 impl Tool for StructuredEchoTool {
3289 fn name(&self) -> &'static str {
3290 "echo"
3291 }
3292 fn description(&self) -> &'static str {
3293 "Echoes input back"
3294 }
3295 fn input_schema(&self) -> Option<serde_json::Value> {
3296 Some(json!({
3297 "type": "object",
3298 "properties": {
3299 "text": { "type": "string", "description": "Text to echo" }
3300 },
3301 "required": ["text"]
3302 }))
3303 }
3304 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3305 Ok(ToolResult {
3306 output: format!("echoed:{input}"),
3307 })
3308 }
3309 }
3310
3311 struct StructuredFailingTool;
3312
3313 #[async_trait]
3314 impl Tool for StructuredFailingTool {
3315 fn name(&self) -> &'static str {
3316 "boom"
3317 }
3318 fn description(&self) -> &'static str {
3319 "Always fails"
3320 }
3321 fn input_schema(&self) -> Option<serde_json::Value> {
3322 Some(json!({
3323 "type": "object",
3324 "properties": {},
3325 }))
3326 }
3327 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3328 Err(anyhow::anyhow!("tool exploded"))
3329 }
3330 }
3331
3332 struct StructuredUpperTool;
3333
3334 #[async_trait]
3335 impl Tool for StructuredUpperTool {
3336 fn name(&self) -> &'static str {
3337 "upper"
3338 }
3339 fn description(&self) -> &'static str {
3340 "Uppercases input"
3341 }
3342 fn input_schema(&self) -> Option<serde_json::Value> {
3343 Some(json!({
3344 "type": "object",
3345 "properties": {
3346 "text": { "type": "string", "description": "Text to uppercase" }
3347 },
3348 "required": ["text"]
3349 }))
3350 }
3351 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3352 Ok(ToolResult {
3353 output: input.to_uppercase(),
3354 })
3355 }
3356 }
3357
3358 use crate::types::ToolUseRequest;
3359
3360 #[tokio::test]
3363 async fn structured_basic_tool_call_then_end_turn() {
3364 let provider = StructuredProvider::new(vec![
3365 ChatResult {
3366 output_text: "Let me echo that.".to_string(),
3367 tool_calls: vec![ToolUseRequest {
3368 id: "call_1".to_string(),
3369 name: "echo".to_string(),
3370 input: json!({"text": "hello"}),
3371 }],
3372 stop_reason: Some(StopReason::ToolUse),
3373 },
3374 ChatResult {
3375 output_text: "The echo returned: hello".to_string(),
3376 tool_calls: vec![],
3377 stop_reason: Some(StopReason::EndTurn),
3378 },
3379 ]);
3380 let received = provider.received_messages.clone();
3381
3382 let agent = Agent::new(
3383 AgentConfig::default(),
3384 Box::new(provider),
3385 Box::new(TestMemory::default()),
3386 vec![Box::new(StructuredEchoTool)],
3387 );
3388
3389 let response = agent
3390 .respond(
3391 UserMessage {
3392 text: "echo hello".to_string(),
3393 },
3394 &test_ctx(),
3395 )
3396 .await
3397 .expect("should succeed");
3398
3399 assert_eq!(response.text, "The echo returned: hello");
3400
3401 let msgs = received.lock().expect("lock");
3403 assert_eq!(msgs.len(), 2, "provider should be called twice");
3404 assert!(
3406 msgs[1].len() >= 3,
3407 "second call should have user + assistant + tool result"
3408 );
3409 }
3410
3411 #[tokio::test]
3412 async fn structured_no_tool_calls_returns_immediately() {
3413 let provider = StructuredProvider::new(vec![ChatResult {
3414 output_text: "Hello! No tools needed.".to_string(),
3415 tool_calls: vec![],
3416 stop_reason: Some(StopReason::EndTurn),
3417 }]);
3418
3419 let agent = Agent::new(
3420 AgentConfig::default(),
3421 Box::new(provider),
3422 Box::new(TestMemory::default()),
3423 vec![Box::new(StructuredEchoTool)],
3424 );
3425
3426 let response = agent
3427 .respond(
3428 UserMessage {
3429 text: "hi".to_string(),
3430 },
3431 &test_ctx(),
3432 )
3433 .await
3434 .expect("should succeed");
3435
3436 assert_eq!(response.text, "Hello! No tools needed.");
3437 }
3438
3439 #[tokio::test]
3440 async fn structured_tool_not_found_sends_error_result() {
3441 let provider = StructuredProvider::new(vec![
3442 ChatResult {
3443 output_text: String::new(),
3444 tool_calls: vec![ToolUseRequest {
3445 id: "call_1".to_string(),
3446 name: "nonexistent".to_string(),
3447 input: json!({}),
3448 }],
3449 stop_reason: Some(StopReason::ToolUse),
3450 },
3451 ChatResult {
3452 output_text: "I see the tool was not found.".to_string(),
3453 tool_calls: vec![],
3454 stop_reason: Some(StopReason::EndTurn),
3455 },
3456 ]);
3457 let received = provider.received_messages.clone();
3458
3459 let agent = Agent::new(
3460 AgentConfig::default(),
3461 Box::new(provider),
3462 Box::new(TestMemory::default()),
3463 vec![Box::new(StructuredEchoTool)],
3464 );
3465
3466 let response = agent
3467 .respond(
3468 UserMessage {
3469 text: "use nonexistent".to_string(),
3470 },
3471 &test_ctx(),
3472 )
3473 .await
3474 .expect("should succeed, not abort");
3475
3476 assert_eq!(response.text, "I see the tool was not found.");
3477
3478 let msgs = received.lock().expect("lock");
3480 let last_call = &msgs[1];
3481 let has_error_result = last_call.iter().any(|m| {
3482 matches!(m, ConversationMessage::ToolResult(r) if r.is_error && r.content.contains("not found"))
3483 });
3484 assert!(has_error_result, "should include error ToolResult");
3485 }
3486
3487 #[tokio::test]
3488 async fn structured_tool_error_does_not_abort() {
3489 let provider = StructuredProvider::new(vec![
3490 ChatResult {
3491 output_text: String::new(),
3492 tool_calls: vec![ToolUseRequest {
3493 id: "call_1".to_string(),
3494 name: "boom".to_string(),
3495 input: json!({}),
3496 }],
3497 stop_reason: Some(StopReason::ToolUse),
3498 },
3499 ChatResult {
3500 output_text: "I handled the error gracefully.".to_string(),
3501 tool_calls: vec![],
3502 stop_reason: Some(StopReason::EndTurn),
3503 },
3504 ]);
3505
3506 let agent = Agent::new(
3507 AgentConfig::default(),
3508 Box::new(provider),
3509 Box::new(TestMemory::default()),
3510 vec![Box::new(StructuredFailingTool)],
3511 );
3512
3513 let response = agent
3514 .respond(
3515 UserMessage {
3516 text: "boom".to_string(),
3517 },
3518 &test_ctx(),
3519 )
3520 .await
3521 .expect("should succeed, error becomes ToolResultMessage");
3522
3523 assert_eq!(response.text, "I handled the error gracefully.");
3524 }
3525
3526 #[tokio::test]
3527 async fn structured_multi_iteration_tool_calls() {
3528 let provider = StructuredProvider::new(vec![
3529 ChatResult {
3531 output_text: String::new(),
3532 tool_calls: vec![ToolUseRequest {
3533 id: "call_1".to_string(),
3534 name: "echo".to_string(),
3535 input: json!({"text": "first"}),
3536 }],
3537 stop_reason: Some(StopReason::ToolUse),
3538 },
3539 ChatResult {
3541 output_text: String::new(),
3542 tool_calls: vec![ToolUseRequest {
3543 id: "call_2".to_string(),
3544 name: "echo".to_string(),
3545 input: json!({"text": "second"}),
3546 }],
3547 stop_reason: Some(StopReason::ToolUse),
3548 },
3549 ChatResult {
3551 output_text: "Done with two tool calls.".to_string(),
3552 tool_calls: vec![],
3553 stop_reason: Some(StopReason::EndTurn),
3554 },
3555 ]);
3556 let received = provider.received_messages.clone();
3557
3558 let agent = Agent::new(
3559 AgentConfig::default(),
3560 Box::new(provider),
3561 Box::new(TestMemory::default()),
3562 vec![Box::new(StructuredEchoTool)],
3563 );
3564
3565 let response = agent
3566 .respond(
3567 UserMessage {
3568 text: "echo twice".to_string(),
3569 },
3570 &test_ctx(),
3571 )
3572 .await
3573 .expect("should succeed");
3574
3575 assert_eq!(response.text, "Done with two tool calls.");
3576 let msgs = received.lock().expect("lock");
3577 assert_eq!(msgs.len(), 3, "three provider calls");
3578 }
3579
3580 #[tokio::test]
3581 async fn structured_max_iterations_forces_final_answer() {
3582 let responses = vec![
3586 ChatResult {
3587 output_text: String::new(),
3588 tool_calls: vec![ToolUseRequest {
3589 id: "call_0".to_string(),
3590 name: "echo".to_string(),
3591 input: json!({"text": "loop"}),
3592 }],
3593 stop_reason: Some(StopReason::ToolUse),
3594 },
3595 ChatResult {
3596 output_text: String::new(),
3597 tool_calls: vec![ToolUseRequest {
3598 id: "call_1".to_string(),
3599 name: "echo".to_string(),
3600 input: json!({"text": "loop"}),
3601 }],
3602 stop_reason: Some(StopReason::ToolUse),
3603 },
3604 ChatResult {
3605 output_text: String::new(),
3606 tool_calls: vec![ToolUseRequest {
3607 id: "call_2".to_string(),
3608 name: "echo".to_string(),
3609 input: json!({"text": "loop"}),
3610 }],
3611 stop_reason: Some(StopReason::ToolUse),
3612 },
3613 ChatResult {
3615 output_text: "Forced final answer.".to_string(),
3616 tool_calls: vec![],
3617 stop_reason: Some(StopReason::EndTurn),
3618 },
3619 ];
3620
3621 let agent = Agent::new(
3622 AgentConfig {
3623 max_tool_iterations: 3,
3624 loop_detection_no_progress_threshold: 0, ..AgentConfig::default()
3626 },
3627 Box::new(StructuredProvider::new(responses)),
3628 Box::new(TestMemory::default()),
3629 vec![Box::new(StructuredEchoTool)],
3630 );
3631
3632 let response = agent
3633 .respond(
3634 UserMessage {
3635 text: "loop forever".to_string(),
3636 },
3637 &test_ctx(),
3638 )
3639 .await
3640 .expect("should succeed with forced answer");
3641
3642 assert_eq!(response.text, "Forced final answer.");
3643 }
3644
3645 #[tokio::test]
3646 async fn structured_fallback_to_text_when_disabled() {
3647 let prompts = Arc::new(Mutex::new(Vec::new()));
3648 let provider = TestProvider {
3649 received_prompts: prompts.clone(),
3650 response_text: "text path used".to_string(),
3651 };
3652
3653 let agent = Agent::new(
3654 AgentConfig {
3655 model_supports_tool_use: false,
3656 ..AgentConfig::default()
3657 },
3658 Box::new(provider),
3659 Box::new(TestMemory::default()),
3660 vec![Box::new(StructuredEchoTool)],
3661 );
3662
3663 let response = agent
3664 .respond(
3665 UserMessage {
3666 text: "hello".to_string(),
3667 },
3668 &test_ctx(),
3669 )
3670 .await
3671 .expect("should succeed");
3672
3673 assert_eq!(response.text, "text path used");
3674 }
3675
3676 #[tokio::test]
3677 async fn structured_fallback_when_no_tool_schemas() {
3678 let prompts = Arc::new(Mutex::new(Vec::new()));
3679 let provider = TestProvider {
3680 received_prompts: prompts.clone(),
3681 response_text: "text path used".to_string(),
3682 };
3683
3684 let agent = Agent::new(
3686 AgentConfig::default(),
3687 Box::new(provider),
3688 Box::new(TestMemory::default()),
3689 vec![Box::new(EchoTool)],
3690 );
3691
3692 let response = agent
3693 .respond(
3694 UserMessage {
3695 text: "hello".to_string(),
3696 },
3697 &test_ctx(),
3698 )
3699 .await
3700 .expect("should succeed");
3701
3702 assert_eq!(response.text, "text path used");
3703 }
3704
3705 #[tokio::test]
3706 async fn structured_parallel_tool_calls() {
3707 let provider = StructuredProvider::new(vec![
3708 ChatResult {
3709 output_text: String::new(),
3710 tool_calls: vec![
3711 ToolUseRequest {
3712 id: "call_1".to_string(),
3713 name: "echo".to_string(),
3714 input: json!({"text": "a"}),
3715 },
3716 ToolUseRequest {
3717 id: "call_2".to_string(),
3718 name: "upper".to_string(),
3719 input: json!({"text": "b"}),
3720 },
3721 ],
3722 stop_reason: Some(StopReason::ToolUse),
3723 },
3724 ChatResult {
3725 output_text: "Both tools ran.".to_string(),
3726 tool_calls: vec![],
3727 stop_reason: Some(StopReason::EndTurn),
3728 },
3729 ]);
3730 let received = provider.received_messages.clone();
3731
3732 let agent = Agent::new(
3733 AgentConfig {
3734 parallel_tools: true,
3735 ..AgentConfig::default()
3736 },
3737 Box::new(provider),
3738 Box::new(TestMemory::default()),
3739 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
3740 );
3741
3742 let response = agent
3743 .respond(
3744 UserMessage {
3745 text: "do both".to_string(),
3746 },
3747 &test_ctx(),
3748 )
3749 .await
3750 .expect("should succeed");
3751
3752 assert_eq!(response.text, "Both tools ran.");
3753
3754 let msgs = received.lock().expect("lock");
3756 let tool_results: Vec<_> = msgs[1]
3757 .iter()
3758 .filter(|m| matches!(m, ConversationMessage::ToolResult(_)))
3759 .collect();
3760 assert_eq!(tool_results.len(), 2, "should have two tool results");
3761 }
3762
3763 #[tokio::test]
3764 async fn structured_memory_integration() {
3765 let memory = TestMemory::default();
3766 memory
3768 .append(MemoryEntry {
3769 role: "user".to_string(),
3770 content: "old question".to_string(),
3771 ..Default::default()
3772 })
3773 .await
3774 .unwrap();
3775 memory
3776 .append(MemoryEntry {
3777 role: "assistant".to_string(),
3778 content: "old answer".to_string(),
3779 ..Default::default()
3780 })
3781 .await
3782 .unwrap();
3783
3784 let provider = StructuredProvider::new(vec![ChatResult {
3785 output_text: "I see your history.".to_string(),
3786 tool_calls: vec![],
3787 stop_reason: Some(StopReason::EndTurn),
3788 }]);
3789 let received = provider.received_messages.clone();
3790
3791 let agent = Agent::new(
3792 AgentConfig::default(),
3793 Box::new(provider),
3794 Box::new(memory),
3795 vec![Box::new(StructuredEchoTool)],
3796 );
3797
3798 agent
3799 .respond(
3800 UserMessage {
3801 text: "new question".to_string(),
3802 },
3803 &test_ctx(),
3804 )
3805 .await
3806 .expect("should succeed");
3807
3808 let msgs = received.lock().expect("lock");
3809 assert!(msgs[0].len() >= 3, "should include memory messages");
3812 assert!(matches!(
3814 &msgs[0][0],
3815 ConversationMessage::User { content, .. } if content == "old question"
3816 ));
3817 }
3818
3819 #[tokio::test]
3820 async fn prepare_tool_input_extracts_single_string_field() {
3821 let tool = StructuredEchoTool;
3822 let input = json!({"text": "hello world"});
3823 let result = prepare_tool_input(&tool, &input).expect("valid input");
3824 assert_eq!(result, "hello world");
3825 }
3826
3827 #[tokio::test]
3828 async fn prepare_tool_input_serializes_multi_field_json() {
3829 struct MultiFieldTool;
3831 #[async_trait]
3832 impl Tool for MultiFieldTool {
3833 fn name(&self) -> &'static str {
3834 "multi"
3835 }
3836 fn input_schema(&self) -> Option<serde_json::Value> {
3837 Some(json!({
3838 "type": "object",
3839 "properties": {
3840 "path": { "type": "string" },
3841 "content": { "type": "string" }
3842 },
3843 "required": ["path", "content"]
3844 }))
3845 }
3846 async fn execute(
3847 &self,
3848 _input: &str,
3849 _ctx: &ToolContext,
3850 ) -> anyhow::Result<ToolResult> {
3851 unreachable!()
3852 }
3853 }
3854
3855 let tool = MultiFieldTool;
3856 let input = json!({"path": "a.txt", "content": "hello"});
3857 let result = prepare_tool_input(&tool, &input).expect("valid input");
3858 let parsed: serde_json::Value = serde_json::from_str(&result).expect("valid JSON");
3860 assert_eq!(parsed["path"], "a.txt");
3861 assert_eq!(parsed["content"], "hello");
3862 }
3863
3864 #[tokio::test]
3865 async fn prepare_tool_input_unwraps_bare_string() {
3866 let tool = StructuredEchoTool;
3867 let input = json!("plain text");
3868 let result = prepare_tool_input(&tool, &input).expect("valid input");
3869 assert_eq!(result, "plain text");
3870 }
3871
3872 #[tokio::test]
3873 async fn prepare_tool_input_rejects_schema_violating_input() {
3874 struct StrictTool;
3875 #[async_trait]
3876 impl Tool for StrictTool {
3877 fn name(&self) -> &'static str {
3878 "strict"
3879 }
3880 fn input_schema(&self) -> Option<serde_json::Value> {
3881 Some(json!({
3882 "type": "object",
3883 "required": ["path"],
3884 "properties": {
3885 "path": { "type": "string" }
3886 }
3887 }))
3888 }
3889 async fn execute(
3890 &self,
3891 _input: &str,
3892 _ctx: &ToolContext,
3893 ) -> anyhow::Result<ToolResult> {
3894 unreachable!("should not be called with invalid input")
3895 }
3896 }
3897
3898 let result = prepare_tool_input(&StrictTool, &json!({}));
3900 assert!(result.is_err());
3901 let err = result.unwrap_err();
3902 assert!(err.contains("Invalid input for tool 'strict'"));
3903 assert!(err.contains("missing required field"));
3904 }
3905
3906 #[tokio::test]
3907 async fn prepare_tool_input_rejects_wrong_type() {
3908 struct TypedTool;
3909 #[async_trait]
3910 impl Tool for TypedTool {
3911 fn name(&self) -> &'static str {
3912 "typed"
3913 }
3914 fn input_schema(&self) -> Option<serde_json::Value> {
3915 Some(json!({
3916 "type": "object",
3917 "required": ["count"],
3918 "properties": {
3919 "count": { "type": "integer" }
3920 }
3921 }))
3922 }
3923 async fn execute(
3924 &self,
3925 _input: &str,
3926 _ctx: &ToolContext,
3927 ) -> anyhow::Result<ToolResult> {
3928 unreachable!("should not be called with invalid input")
3929 }
3930 }
3931
3932 let result = prepare_tool_input(&TypedTool, &json!({"count": "not a number"}));
3934 assert!(result.is_err());
3935 let err = result.unwrap_err();
3936 assert!(err.contains("expected type \"integer\""));
3937 }
3938
3939 #[test]
3940 fn truncate_messages_preserves_small_conversation() {
3941 let mut msgs = vec![
3942 ConversationMessage::user("hello".to_string()),
3943 ConversationMessage::Assistant {
3944 content: Some("world".to_string()),
3945 tool_calls: vec![],
3946 },
3947 ];
3948 truncate_messages(&mut msgs, 1000);
3949 assert_eq!(msgs.len(), 2, "should not truncate small conversation");
3950 }
3951
3952 #[test]
3953 fn truncate_messages_drops_middle() {
3954 let mut msgs = vec![
3955 ConversationMessage::user("A".repeat(100)),
3956 ConversationMessage::Assistant {
3957 content: Some("B".repeat(100)),
3958 tool_calls: vec![],
3959 },
3960 ConversationMessage::user("C".repeat(100)),
3961 ConversationMessage::Assistant {
3962 content: Some("D".repeat(100)),
3963 tool_calls: vec![],
3964 },
3965 ];
3966 truncate_messages(&mut msgs, 250);
3969 assert!(msgs.len() < 4, "should have dropped some messages");
3970 assert!(matches!(
3972 &msgs[0],
3973 ConversationMessage::User { content, .. } if content.starts_with("AAA")
3974 ));
3975 }
3976
3977 #[test]
3978 fn memory_to_messages_reverses_order() {
3979 let entries = vec![
3980 MemoryEntry {
3981 role: "assistant".to_string(),
3982 content: "newest".to_string(),
3983 ..Default::default()
3984 },
3985 MemoryEntry {
3986 role: "user".to_string(),
3987 content: "oldest".to_string(),
3988 ..Default::default()
3989 },
3990 ];
3991 let msgs = memory_to_messages(&entries);
3992 assert_eq!(msgs.len(), 2);
3993 assert!(matches!(
3994 &msgs[0],
3995 ConversationMessage::User { content, .. } if content == "oldest"
3996 ));
3997 assert!(matches!(
3998 &msgs[1],
3999 ConversationMessage::Assistant { content: Some(c), .. } if c == "newest"
4000 ));
4001 }
4002
4003 #[test]
4004 fn single_required_string_field_detects_single() {
4005 let schema = json!({
4006 "type": "object",
4007 "properties": {
4008 "path": { "type": "string" }
4009 },
4010 "required": ["path"]
4011 });
4012 assert_eq!(
4013 single_required_string_field(&schema),
4014 Some("path".to_string())
4015 );
4016 }
4017
4018 #[test]
4019 fn single_required_string_field_none_for_multiple() {
4020 let schema = json!({
4021 "type": "object",
4022 "properties": {
4023 "path": { "type": "string" },
4024 "mode": { "type": "string" }
4025 },
4026 "required": ["path", "mode"]
4027 });
4028 assert_eq!(single_required_string_field(&schema), None);
4029 }
4030
4031 #[test]
4032 fn single_required_string_field_none_for_non_string() {
4033 let schema = json!({
4034 "type": "object",
4035 "properties": {
4036 "count": { "type": "integer" }
4037 },
4038 "required": ["count"]
4039 });
4040 assert_eq!(single_required_string_field(&schema), None);
4041 }
4042
4043 use crate::types::StreamChunk;
4046
4047 struct StreamingProvider {
4049 responses: Vec<ChatResult>,
4050 call_count: AtomicUsize,
4051 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
4052 }
4053
4054 impl StreamingProvider {
4055 fn new(responses: Vec<ChatResult>) -> Self {
4056 Self {
4057 responses,
4058 call_count: AtomicUsize::new(0),
4059 received_messages: Arc::new(Mutex::new(Vec::new())),
4060 }
4061 }
4062 }
4063
4064 #[async_trait]
4065 impl Provider for StreamingProvider {
4066 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4067 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4068 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4069 }
4070
4071 async fn complete_streaming(
4072 &self,
4073 _prompt: &str,
4074 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4075 ) -> anyhow::Result<ChatResult> {
4076 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4077 let result = self.responses.get(idx).cloned().unwrap_or_default();
4078 for word in result.output_text.split_whitespace() {
4080 let _ = sender.send(StreamChunk {
4081 delta: format!("{word} "),
4082 done: false,
4083 tool_call_delta: None,
4084 });
4085 }
4086 let _ = sender.send(StreamChunk {
4087 delta: String::new(),
4088 done: true,
4089 tool_call_delta: None,
4090 });
4091 Ok(result)
4092 }
4093
4094 async fn complete_with_tools(
4095 &self,
4096 messages: &[ConversationMessage],
4097 _tools: &[ToolDefinition],
4098 _reasoning: &ReasoningConfig,
4099 ) -> anyhow::Result<ChatResult> {
4100 self.received_messages
4101 .lock()
4102 .expect("lock")
4103 .push(messages.to_vec());
4104 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4105 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4106 }
4107
4108 async fn complete_streaming_with_tools(
4109 &self,
4110 messages: &[ConversationMessage],
4111 _tools: &[ToolDefinition],
4112 _reasoning: &ReasoningConfig,
4113 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4114 ) -> anyhow::Result<ChatResult> {
4115 self.received_messages
4116 .lock()
4117 .expect("lock")
4118 .push(messages.to_vec());
4119 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4120 let result = self.responses.get(idx).cloned().unwrap_or_default();
4121 for ch in result.output_text.chars() {
4123 let _ = sender.send(StreamChunk {
4124 delta: ch.to_string(),
4125 done: false,
4126 tool_call_delta: None,
4127 });
4128 }
4129 let _ = sender.send(StreamChunk {
4130 delta: String::new(),
4131 done: true,
4132 tool_call_delta: None,
4133 });
4134 Ok(result)
4135 }
4136 }
4137
4138 #[tokio::test]
4139 async fn streaming_text_only_sends_chunks() {
4140 let provider = StreamingProvider::new(vec![ChatResult {
4141 output_text: "Hello world".to_string(),
4142 ..Default::default()
4143 }]);
4144 let agent = Agent::new(
4145 AgentConfig {
4146 model_supports_tool_use: false,
4147 ..Default::default()
4148 },
4149 Box::new(provider),
4150 Box::new(TestMemory::default()),
4151 vec![],
4152 );
4153
4154 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4155 let response = agent
4156 .respond_streaming(
4157 UserMessage {
4158 text: "hi".to_string(),
4159 },
4160 &test_ctx(),
4161 tx,
4162 )
4163 .await
4164 .expect("should succeed");
4165
4166 assert_eq!(response.text, "Hello world");
4167
4168 let mut chunks = Vec::new();
4170 while let Ok(chunk) = rx.try_recv() {
4171 chunks.push(chunk);
4172 }
4173 assert!(chunks.len() >= 2, "should have at least text + done chunks");
4174 assert!(chunks.last().unwrap().done, "last chunk should be done");
4175 }
4176
4177 #[tokio::test]
4178 async fn streaming_single_tool_call_round_trip() {
4179 let provider = StreamingProvider::new(vec![
4180 ChatResult {
4181 output_text: "I'll echo that.".to_string(),
4182 tool_calls: vec![ToolUseRequest {
4183 id: "call_1".to_string(),
4184 name: "echo".to_string(),
4185 input: json!({"text": "hello"}),
4186 }],
4187 stop_reason: Some(StopReason::ToolUse),
4188 },
4189 ChatResult {
4190 output_text: "Done echoing.".to_string(),
4191 tool_calls: vec![],
4192 stop_reason: Some(StopReason::EndTurn),
4193 },
4194 ]);
4195
4196 let agent = Agent::new(
4197 AgentConfig::default(),
4198 Box::new(provider),
4199 Box::new(TestMemory::default()),
4200 vec![Box::new(StructuredEchoTool)],
4201 );
4202
4203 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4204 let response = agent
4205 .respond_streaming(
4206 UserMessage {
4207 text: "echo hello".to_string(),
4208 },
4209 &test_ctx(),
4210 tx,
4211 )
4212 .await
4213 .expect("should succeed");
4214
4215 assert_eq!(response.text, "Done echoing.");
4216
4217 let mut chunks = Vec::new();
4218 while let Ok(chunk) = rx.try_recv() {
4219 chunks.push(chunk);
4220 }
4221 assert!(chunks.len() >= 2, "should have streaming chunks");
4223 }
4224
4225 #[tokio::test]
4226 async fn streaming_multi_iteration_tool_calls() {
4227 let provider = StreamingProvider::new(vec![
4228 ChatResult {
4229 output_text: String::new(),
4230 tool_calls: vec![ToolUseRequest {
4231 id: "call_1".to_string(),
4232 name: "echo".to_string(),
4233 input: json!({"text": "first"}),
4234 }],
4235 stop_reason: Some(StopReason::ToolUse),
4236 },
4237 ChatResult {
4238 output_text: String::new(),
4239 tool_calls: vec![ToolUseRequest {
4240 id: "call_2".to_string(),
4241 name: "upper".to_string(),
4242 input: json!({"text": "second"}),
4243 }],
4244 stop_reason: Some(StopReason::ToolUse),
4245 },
4246 ChatResult {
4247 output_text: "All done.".to_string(),
4248 tool_calls: vec![],
4249 stop_reason: Some(StopReason::EndTurn),
4250 },
4251 ]);
4252
4253 let agent = Agent::new(
4254 AgentConfig::default(),
4255 Box::new(provider),
4256 Box::new(TestMemory::default()),
4257 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4258 );
4259
4260 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4261 let response = agent
4262 .respond_streaming(
4263 UserMessage {
4264 text: "do two things".to_string(),
4265 },
4266 &test_ctx(),
4267 tx,
4268 )
4269 .await
4270 .expect("should succeed");
4271
4272 assert_eq!(response.text, "All done.");
4273
4274 let mut chunks = Vec::new();
4275 while let Ok(chunk) = rx.try_recv() {
4276 chunks.push(chunk);
4277 }
4278 let done_count = chunks.iter().filter(|c| c.done).count();
4280 assert!(
4281 done_count >= 3,
4282 "should have done chunks from 3 provider calls, got {}",
4283 done_count
4284 );
4285 }
4286
4287 #[tokio::test]
4288 async fn streaming_timeout_returns_error() {
4289 struct StreamingSlowProvider;
4290
4291 #[async_trait]
4292 impl Provider for StreamingSlowProvider {
4293 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4294 sleep(Duration::from_millis(500)).await;
4295 Ok(ChatResult::default())
4296 }
4297
4298 async fn complete_streaming(
4299 &self,
4300 _prompt: &str,
4301 _sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4302 ) -> anyhow::Result<ChatResult> {
4303 sleep(Duration::from_millis(500)).await;
4304 Ok(ChatResult::default())
4305 }
4306 }
4307
4308 let agent = Agent::new(
4309 AgentConfig {
4310 request_timeout_ms: 50,
4311 model_supports_tool_use: false,
4312 ..Default::default()
4313 },
4314 Box::new(StreamingSlowProvider),
4315 Box::new(TestMemory::default()),
4316 vec![],
4317 );
4318
4319 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4320 let result = agent
4321 .respond_streaming(
4322 UserMessage {
4323 text: "hi".to_string(),
4324 },
4325 &test_ctx(),
4326 tx,
4327 )
4328 .await;
4329
4330 assert!(result.is_err());
4331 match result.unwrap_err() {
4332 AgentError::Timeout { timeout_ms } => assert_eq!(timeout_ms, 50),
4333 other => panic!("expected Timeout, got {other:?}"),
4334 }
4335 }
4336
4337 #[tokio::test]
4338 async fn streaming_no_schema_fallback_sends_chunks() {
4339 let provider = StreamingProvider::new(vec![ChatResult {
4341 output_text: "Fallback response".to_string(),
4342 ..Default::default()
4343 }]);
4344 let agent = Agent::new(
4345 AgentConfig::default(), Box::new(provider),
4347 Box::new(TestMemory::default()),
4348 vec![Box::new(EchoTool)], );
4350
4351 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4352 let response = agent
4353 .respond_streaming(
4354 UserMessage {
4355 text: "hi".to_string(),
4356 },
4357 &test_ctx(),
4358 tx,
4359 )
4360 .await
4361 .expect("should succeed");
4362
4363 assert_eq!(response.text, "Fallback response");
4364
4365 let mut chunks = Vec::new();
4366 while let Ok(chunk) = rx.try_recv() {
4367 chunks.push(chunk);
4368 }
4369 assert!(!chunks.is_empty(), "should have streaming chunks");
4370 assert!(chunks.last().unwrap().done, "last chunk should be done");
4371 }
4372
4373 #[tokio::test]
4374 async fn streaming_tool_error_does_not_abort() {
4375 let provider = StreamingProvider::new(vec![
4376 ChatResult {
4377 output_text: String::new(),
4378 tool_calls: vec![ToolUseRequest {
4379 id: "call_1".to_string(),
4380 name: "boom".to_string(),
4381 input: json!({}),
4382 }],
4383 stop_reason: Some(StopReason::ToolUse),
4384 },
4385 ChatResult {
4386 output_text: "Recovered from error.".to_string(),
4387 tool_calls: vec![],
4388 stop_reason: Some(StopReason::EndTurn),
4389 },
4390 ]);
4391
4392 let agent = Agent::new(
4393 AgentConfig::default(),
4394 Box::new(provider),
4395 Box::new(TestMemory::default()),
4396 vec![Box::new(StructuredFailingTool)],
4397 );
4398
4399 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4400 let response = agent
4401 .respond_streaming(
4402 UserMessage {
4403 text: "boom".to_string(),
4404 },
4405 &test_ctx(),
4406 tx,
4407 )
4408 .await
4409 .expect("should succeed despite tool error");
4410
4411 assert_eq!(response.text, "Recovered from error.");
4412 }
4413
4414 #[tokio::test]
4415 async fn streaming_parallel_tools() {
4416 let provider = StreamingProvider::new(vec![
4417 ChatResult {
4418 output_text: String::new(),
4419 tool_calls: vec![
4420 ToolUseRequest {
4421 id: "call_1".to_string(),
4422 name: "echo".to_string(),
4423 input: json!({"text": "a"}),
4424 },
4425 ToolUseRequest {
4426 id: "call_2".to_string(),
4427 name: "upper".to_string(),
4428 input: json!({"text": "b"}),
4429 },
4430 ],
4431 stop_reason: Some(StopReason::ToolUse),
4432 },
4433 ChatResult {
4434 output_text: "Parallel done.".to_string(),
4435 tool_calls: vec![],
4436 stop_reason: Some(StopReason::EndTurn),
4437 },
4438 ]);
4439
4440 let agent = Agent::new(
4441 AgentConfig {
4442 parallel_tools: true,
4443 ..Default::default()
4444 },
4445 Box::new(provider),
4446 Box::new(TestMemory::default()),
4447 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4448 );
4449
4450 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4451 let response = agent
4452 .respond_streaming(
4453 UserMessage {
4454 text: "parallel test".to_string(),
4455 },
4456 &test_ctx(),
4457 tx,
4458 )
4459 .await
4460 .expect("should succeed");
4461
4462 assert_eq!(response.text, "Parallel done.");
4463 }
4464
4465 #[tokio::test]
4466 async fn streaming_done_chunk_sentinel() {
4467 let provider = StreamingProvider::new(vec![ChatResult {
4468 output_text: "abc".to_string(),
4469 tool_calls: vec![],
4470 stop_reason: Some(StopReason::EndTurn),
4471 }]);
4472
4473 let agent = Agent::new(
4474 AgentConfig::default(),
4475 Box::new(provider),
4476 Box::new(TestMemory::default()),
4477 vec![Box::new(StructuredEchoTool)],
4478 );
4479
4480 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4481 let _response = agent
4482 .respond_streaming(
4483 UserMessage {
4484 text: "test".to_string(),
4485 },
4486 &test_ctx(),
4487 tx,
4488 )
4489 .await
4490 .expect("should succeed");
4491
4492 let mut chunks = Vec::new();
4493 while let Ok(chunk) = rx.try_recv() {
4494 chunks.push(chunk);
4495 }
4496
4497 let done_chunks: Vec<_> = chunks.iter().filter(|c| c.done).collect();
4499 assert!(
4500 !done_chunks.is_empty(),
4501 "must have at least one done sentinel chunk"
4502 );
4503 assert!(chunks.last().unwrap().done, "final chunk must be done=true");
4505 let content_chunks: Vec<_> = chunks
4507 .iter()
4508 .filter(|c| !c.done && !c.delta.is_empty())
4509 .collect();
4510 assert!(!content_chunks.is_empty(), "should have content chunks");
4511 }
4512
4513 #[tokio::test]
4516 async fn system_prompt_prepended_in_structured_path() {
4517 let provider = StructuredProvider::new(vec![ChatResult {
4518 output_text: "I understand.".to_string(),
4519 stop_reason: Some(StopReason::EndTurn),
4520 ..Default::default()
4521 }]);
4522 let received = provider.received_messages.clone();
4523 let memory = TestMemory::default();
4524
4525 let agent = Agent::new(
4526 AgentConfig {
4527 model_supports_tool_use: true,
4528 system_prompt: Some("You are a math tutor.".to_string()),
4529 ..AgentConfig::default()
4530 },
4531 Box::new(provider),
4532 Box::new(memory),
4533 vec![Box::new(StructuredEchoTool)],
4534 );
4535
4536 let result = agent
4537 .respond(
4538 UserMessage {
4539 text: "What is 2+2?".to_string(),
4540 },
4541 &test_ctx(),
4542 )
4543 .await
4544 .expect("respond should succeed");
4545
4546 assert_eq!(result.text, "I understand.");
4547
4548 let msgs = received.lock().expect("lock");
4549 assert!(!msgs.is_empty(), "provider should have been called");
4550 let first_call = &msgs[0];
4551 match &first_call[0] {
4553 ConversationMessage::System { content } => {
4554 assert_eq!(content, "You are a math tutor.");
4555 }
4556 other => panic!("expected System message first, got {other:?}"),
4557 }
4558 match &first_call[1] {
4560 ConversationMessage::User { content, .. } => {
4561 assert_eq!(content, "What is 2+2?");
4562 }
4563 other => panic!("expected User message second, got {other:?}"),
4564 }
4565 }
4566
4567 #[tokio::test]
4568 async fn no_system_prompt_omits_system_message() {
4569 let provider = StructuredProvider::new(vec![ChatResult {
4570 output_text: "ok".to_string(),
4571 stop_reason: Some(StopReason::EndTurn),
4572 ..Default::default()
4573 }]);
4574 let received = provider.received_messages.clone();
4575 let memory = TestMemory::default();
4576
4577 let agent = Agent::new(
4578 AgentConfig {
4579 model_supports_tool_use: true,
4580 system_prompt: None,
4581 ..AgentConfig::default()
4582 },
4583 Box::new(provider),
4584 Box::new(memory),
4585 vec![Box::new(StructuredEchoTool)],
4586 );
4587
4588 agent
4589 .respond(
4590 UserMessage {
4591 text: "hello".to_string(),
4592 },
4593 &test_ctx(),
4594 )
4595 .await
4596 .expect("respond should succeed");
4597
4598 let msgs = received.lock().expect("lock");
4599 let first_call = &msgs[0];
4600 match &first_call[0] {
4602 ConversationMessage::User { content, .. } => {
4603 assert_eq!(content, "hello");
4604 }
4605 other => panic!("expected User message first (no system prompt), got {other:?}"),
4606 }
4607 }
4608
4609 #[tokio::test]
4610 async fn system_prompt_persists_across_tool_iterations() {
4611 let provider = StructuredProvider::new(vec![
4613 ChatResult {
4614 output_text: String::new(),
4615 tool_calls: vec![ToolUseRequest {
4616 id: "call_1".to_string(),
4617 name: "echo".to_string(),
4618 input: serde_json::json!({"text": "ping"}),
4619 }],
4620 stop_reason: Some(StopReason::ToolUse),
4621 },
4622 ChatResult {
4623 output_text: "pong".to_string(),
4624 stop_reason: Some(StopReason::EndTurn),
4625 ..Default::default()
4626 },
4627 ]);
4628 let received = provider.received_messages.clone();
4629 let memory = TestMemory::default();
4630
4631 let agent = Agent::new(
4632 AgentConfig {
4633 model_supports_tool_use: true,
4634 system_prompt: Some("Always be concise.".to_string()),
4635 ..AgentConfig::default()
4636 },
4637 Box::new(provider),
4638 Box::new(memory),
4639 vec![Box::new(StructuredEchoTool)],
4640 );
4641
4642 let result = agent
4643 .respond(
4644 UserMessage {
4645 text: "test".to_string(),
4646 },
4647 &test_ctx(),
4648 )
4649 .await
4650 .expect("respond should succeed");
4651 assert_eq!(result.text, "pong");
4652
4653 let msgs = received.lock().expect("lock");
4654 for (i, call_msgs) in msgs.iter().enumerate() {
4656 match &call_msgs[0] {
4657 ConversationMessage::System { content } => {
4658 assert_eq!(content, "Always be concise.", "call {i} system prompt");
4659 }
4660 other => panic!("call {i}: expected System first, got {other:?}"),
4661 }
4662 }
4663 }
4664
4665 #[test]
4666 fn agent_config_system_prompt_defaults_to_none() {
4667 let config = AgentConfig::default();
4668 assert!(config.system_prompt.is_none());
4669 }
4670
4671 #[tokio::test]
4672 async fn memory_write_propagates_source_channel_from_context() {
4673 let memory = TestMemory::default();
4674 let entries = memory.entries.clone();
4675 let provider = StructuredProvider::new(vec![ChatResult {
4676 output_text: "noted".to_string(),
4677 tool_calls: vec![],
4678 stop_reason: None,
4679 }]);
4680 let config = AgentConfig {
4681 privacy_boundary: "encrypted_only".to_string(),
4682 ..AgentConfig::default()
4683 };
4684 let agent = Agent::new(config, Box::new(provider), Box::new(memory), vec![]);
4685
4686 let mut ctx = ToolContext::new(".".to_string());
4687 ctx.source_channel = Some("telegram".to_string());
4688 ctx.privacy_boundary = "encrypted_only".to_string();
4689
4690 agent
4691 .respond(
4692 UserMessage {
4693 text: "hello".to_string(),
4694 },
4695 &ctx,
4696 )
4697 .await
4698 .expect("respond should succeed");
4699
4700 let stored = entries.lock().expect("memory lock poisoned");
4701 assert_eq!(stored[0].source_channel.as_deref(), Some("telegram"));
4703 assert_eq!(stored[0].privacy_boundary, "encrypted_only");
4704 assert_eq!(stored[1].source_channel.as_deref(), Some("telegram"));
4705 assert_eq!(stored[1].privacy_boundary, "encrypted_only");
4706 }
4707
4708 #[tokio::test]
4709 async fn memory_write_source_channel_none_when_ctx_empty() {
4710 let memory = TestMemory::default();
4711 let entries = memory.entries.clone();
4712 let provider = StructuredProvider::new(vec![ChatResult {
4713 output_text: "ok".to_string(),
4714 tool_calls: vec![],
4715 stop_reason: None,
4716 }]);
4717 let agent = Agent::new(
4718 AgentConfig::default(),
4719 Box::new(provider),
4720 Box::new(memory),
4721 vec![],
4722 );
4723
4724 agent
4725 .respond(
4726 UserMessage {
4727 text: "hi".to_string(),
4728 },
4729 &test_ctx(),
4730 )
4731 .await
4732 .expect("respond should succeed");
4733
4734 let stored = entries.lock().expect("memory lock poisoned");
4735 assert!(stored[0].source_channel.is_none());
4736 assert!(stored[1].source_channel.is_none());
4737 }
4738}