1use crate::common::privacy_helpers::{is_network_tool, resolve_boundary};
2use crate::security::redaction::redact_text;
3use crate::types::{
4 AgentConfig, AgentError, AssistantMessage, AuditEvent, AuditSink, ConversationMessage,
5 HookEvent, HookFailureMode, HookRiskTier, HookSink, MemoryEntry, MemoryStore, MetricsSink,
6 Provider, ResearchTrigger, StopReason, StreamSink, Tool, ToolContext, ToolDefinition,
7 ToolResultMessage, UserMessage,
8};
9use crate::validation::validate_json;
10use serde_json::json;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::time::{Instant, SystemTime, UNIX_EPOCH};
13use tokio::time::{timeout, Duration};
14use tracing::{info, warn};
15
16static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(0);
17
18fn cap_prompt(input: &str, max_chars: usize) -> (String, bool) {
19 let len = input.chars().count();
20 if len <= max_chars {
21 return (input.to_string(), false);
22 }
23 let truncated = input
24 .chars()
25 .rev()
26 .take(max_chars)
27 .collect::<Vec<_>>()
28 .into_iter()
29 .rev()
30 .collect::<String>();
31 (truncated, true)
32}
33
34fn build_provider_prompt(
35 current_prompt: &str,
36 recent_memory: &[MemoryEntry],
37 max_prompt_chars: usize,
38) -> (String, bool) {
39 if recent_memory.is_empty() {
40 return cap_prompt(current_prompt, max_prompt_chars);
41 }
42
43 let mut prompt = String::from("Recent memory:\n");
44 for entry in recent_memory.iter().rev() {
45 prompt.push_str("- ");
46 prompt.push_str(&entry.role);
47 prompt.push_str(": ");
48 prompt.push_str(&entry.content);
49 prompt.push('\n');
50 }
51 prompt.push_str("\nCurrent input:\n");
52 prompt.push_str(current_prompt);
53
54 cap_prompt(&prompt, max_prompt_chars)
55}
56
57fn parse_tool_calls(prompt: &str) -> Vec<(&str, &str)> {
59 prompt
60 .lines()
61 .filter_map(|line| {
62 line.strip_prefix("tool:").map(|rest| {
63 let rest = rest.trim();
64 let mut parts = rest.splitn(2, ' ');
65 let name = parts.next().unwrap_or_default();
66 let input = parts.next().unwrap_or_default();
67 (name, input)
68 })
69 })
70 .collect()
71}
72
73fn single_required_string_field(schema: &serde_json::Value) -> Option<String> {
75 let required = schema.get("required")?.as_array()?;
76 if required.len() != 1 {
77 return None;
78 }
79 let field_name = required[0].as_str()?;
80 let properties = schema.get("properties")?.as_object()?;
81 let field_schema = properties.get(field_name)?;
82 if field_schema.get("type")?.as_str()? == "string" {
83 Some(field_name.to_string())
84 } else {
85 None
86 }
87}
88
89fn prepare_tool_input(tool: &dyn Tool, raw_input: &serde_json::Value) -> Result<String, String> {
98 if let Some(s) = raw_input.as_str() {
100 return Ok(s.to_string());
101 }
102
103 if let Some(schema) = tool.input_schema() {
105 if let Err(errors) = validate_json(raw_input, &schema) {
106 return Err(format!(
107 "Invalid input for tool '{}': {}",
108 tool.name(),
109 errors.join("; ")
110 ));
111 }
112
113 if let Some(field) = single_required_string_field(&schema) {
114 if let Some(val) = raw_input.get(&field).and_then(|v| v.as_str()) {
115 return Ok(val.to_string());
116 }
117 }
118 }
119
120 Ok(serde_json::to_string(raw_input).unwrap_or_default())
121}
122
123fn truncate_messages(messages: &mut Vec<ConversationMessage>, max_chars: usize) {
127 let total_chars: usize = messages.iter().map(|m| m.char_count()).sum();
128 if total_chars <= max_chars || messages.len() <= 2 {
129 return;
130 }
131
132 let first_cost = messages[0].char_count();
133 let mut budget = max_chars.saturating_sub(first_cost);
134
135 let mut keep_from_end = 0;
137 for msg in messages.iter().rev() {
138 let cost = msg.char_count();
139 if cost > budget {
140 break;
141 }
142 budget -= cost;
143 keep_from_end += 1;
144 }
145
146 if keep_from_end == 0 {
147 let last = messages
149 .pop()
150 .expect("messages must have at least 2 entries");
151 messages.truncate(1);
152 messages.push(last);
153 return;
154 }
155
156 let split_point = messages.len() - keep_from_end;
157 if split_point <= 1 {
158 return;
159 }
160
161 messages.drain(1..split_point);
162}
163
164fn memory_to_messages(entries: &[MemoryEntry]) -> Vec<ConversationMessage> {
167 entries
168 .iter()
169 .rev()
170 .map(|entry| {
171 if entry.role == "assistant" {
172 ConversationMessage::Assistant {
173 content: Some(entry.content.clone()),
174 tool_calls: vec![],
175 }
176 } else {
177 ConversationMessage::User {
178 content: entry.content.clone(),
179 }
180 }
181 })
182 .collect()
183}
184
185pub struct Agent {
186 config: AgentConfig,
187 provider: Box<dyn Provider>,
188 memory: Box<dyn MemoryStore>,
189 tools: Vec<Box<dyn Tool>>,
190 audit: Option<Box<dyn AuditSink>>,
191 hooks: Option<Box<dyn HookSink>>,
192 metrics: Option<Box<dyn MetricsSink>>,
193}
194
195impl Agent {
196 pub fn new(
197 config: AgentConfig,
198 provider: Box<dyn Provider>,
199 memory: Box<dyn MemoryStore>,
200 tools: Vec<Box<dyn Tool>>,
201 ) -> Self {
202 Self {
203 config,
204 provider,
205 memory,
206 tools,
207 audit: None,
208 hooks: None,
209 metrics: None,
210 }
211 }
212
213 pub fn with_audit(mut self, audit: Box<dyn AuditSink>) -> Self {
214 self.audit = Some(audit);
215 self
216 }
217
218 pub fn with_hooks(mut self, hooks: Box<dyn HookSink>) -> Self {
219 self.hooks = Some(hooks);
220 self
221 }
222
223 pub fn with_metrics(mut self, metrics: Box<dyn MetricsSink>) -> Self {
224 self.metrics = Some(metrics);
225 self
226 }
227
228 fn build_tool_definitions(&self) -> Vec<ToolDefinition> {
230 self.tools
231 .iter()
232 .filter_map(|tool| ToolDefinition::from_tool(&**tool))
233 .collect()
234 }
235
236 fn has_tool_definitions(&self) -> bool {
238 self.tools.iter().any(|t| t.input_schema().is_some())
239 }
240
241 async fn audit(&self, stage: &str, detail: serde_json::Value) {
242 if let Some(sink) = &self.audit {
243 let _ = sink
244 .record(AuditEvent {
245 stage: stage.to_string(),
246 detail,
247 })
248 .await;
249 }
250 }
251
252 fn next_request_id() -> String {
253 let ts_ms = SystemTime::now()
254 .duration_since(UNIX_EPOCH)
255 .unwrap_or_default()
256 .as_millis();
257 let seq = REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed);
258 format!("req-{ts_ms}-{seq}")
259 }
260
261 fn hook_risk_tier(stage: &str) -> HookRiskTier {
262 if matches!(
263 stage,
264 "before_tool_call" | "after_tool_call" | "before_plugin_call" | "after_plugin_call"
265 ) {
266 HookRiskTier::High
267 } else if matches!(
268 stage,
269 "before_provider_call"
270 | "after_provider_call"
271 | "before_memory_write"
272 | "after_memory_write"
273 | "before_run"
274 | "after_run"
275 ) {
276 HookRiskTier::Medium
277 } else {
278 HookRiskTier::Low
279 }
280 }
281
282 fn hook_failure_mode_for_stage(&self, stage: &str) -> HookFailureMode {
283 if self.config.hooks.fail_closed {
284 return HookFailureMode::Block;
285 }
286
287 match Self::hook_risk_tier(stage) {
288 HookRiskTier::Low => self.config.hooks.low_tier_mode,
289 HookRiskTier::Medium => self.config.hooks.medium_tier_mode,
290 HookRiskTier::High => self.config.hooks.high_tier_mode,
291 }
292 }
293
294 async fn hook(&self, stage: &str, detail: serde_json::Value) -> Result<(), AgentError> {
295 if !self.config.hooks.enabled {
296 return Ok(());
297 }
298 let Some(sink) = &self.hooks else {
299 return Ok(());
300 };
301
302 let event = HookEvent {
303 stage: stage.to_string(),
304 detail,
305 };
306 let hook_call = sink.record(event);
307 let mode = self.hook_failure_mode_for_stage(stage);
308 let tier = Self::hook_risk_tier(stage);
309 match timeout(
310 Duration::from_millis(self.config.hooks.timeout_ms),
311 hook_call,
312 )
313 .await
314 {
315 Ok(Ok(())) => Ok(()),
316 Ok(Err(err)) => {
317 let redacted = redact_text(&err.to_string());
318 match mode {
319 HookFailureMode::Block => Err(AgentError::Hook {
320 stage: stage.to_string(),
321 source: err,
322 }),
323 HookFailureMode::Warn => {
324 warn!(
325 stage = stage,
326 tier = ?tier,
327 mode = "warn",
328 "hook error (continuing): {redacted}"
329 );
330 self.audit(
331 "hook_error_warn",
332 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
333 )
334 .await;
335 Ok(())
336 }
337 HookFailureMode::Ignore => {
338 self.audit(
339 "hook_error_ignored",
340 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "error": redacted}),
341 )
342 .await;
343 Ok(())
344 }
345 }
346 }
347 Err(_) => match mode {
348 HookFailureMode::Block => Err(AgentError::Hook {
349 stage: stage.to_string(),
350 source: anyhow::anyhow!(
351 "hook execution timed out after {} ms",
352 self.config.hooks.timeout_ms
353 ),
354 }),
355 HookFailureMode::Warn => {
356 warn!(
357 stage = stage,
358 tier = ?tier,
359 mode = "warn",
360 timeout_ms = self.config.hooks.timeout_ms,
361 "hook timeout (continuing)"
362 );
363 self.audit(
364 "hook_timeout_warn",
365 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
366 )
367 .await;
368 Ok(())
369 }
370 HookFailureMode::Ignore => {
371 self.audit(
372 "hook_timeout_ignored",
373 json!({"stage": stage, "tier": format!("{tier:?}").to_ascii_lowercase(), "timeout_ms": self.config.hooks.timeout_ms}),
374 )
375 .await;
376 Ok(())
377 }
378 },
379 }
380 }
381
382 fn increment_counter(&self, name: &'static str) {
383 if let Some(metrics) = &self.metrics {
384 metrics.increment_counter(name, 1);
385 }
386 }
387
388 fn observe_histogram(&self, name: &'static str, value: f64) {
389 if let Some(metrics) = &self.metrics {
390 metrics.observe_histogram(name, value);
391 }
392 }
393
394 async fn execute_tool(
395 &self,
396 tool: &dyn Tool,
397 tool_name: &str,
398 tool_input: &str,
399 ctx: &ToolContext,
400 request_id: &str,
401 iteration: usize,
402 ) -> Result<crate::types::ToolResult, AgentError> {
403 let tool_specific = self
406 .config
407 .tool_boundaries
408 .get(tool_name)
409 .map(|s| s.as_str())
410 .unwrap_or("");
411 let resolved = resolve_boundary(tool_specific, &self.config.privacy_boundary);
412 if resolved == "local_only" && is_network_tool(tool_name) {
413 return Err(AgentError::Tool {
414 tool: tool_name.to_string(),
415 source: anyhow::anyhow!(
416 "tool '{}' requires network access but privacy boundary is 'local_only'",
417 tool_name
418 ),
419 });
420 }
421
422 let is_plugin_call = tool_name.starts_with("plugin:");
423 self.hook(
424 "before_tool_call",
425 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
426 )
427 .await?;
428 if is_plugin_call {
429 self.hook(
430 "before_plugin_call",
431 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name}),
432 )
433 .await?;
434 }
435 self.audit(
436 "tool_execute_start",
437 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
438 )
439 .await;
440 let tool_started = Instant::now();
441 let result = match tool.execute(tool_input, ctx).await {
442 Ok(result) => result,
443 Err(source) => {
444 self.observe_histogram(
445 "tool_latency_ms",
446 tool_started.elapsed().as_secs_f64() * 1000.0,
447 );
448 self.increment_counter("tool_errors_total");
449 return Err(AgentError::Tool {
450 tool: tool_name.to_string(),
451 source,
452 });
453 }
454 };
455 self.observe_histogram(
456 "tool_latency_ms",
457 tool_started.elapsed().as_secs_f64() * 1000.0,
458 );
459 self.audit(
460 "tool_execute_success",
461 json!({
462 "request_id": request_id,
463 "iteration": iteration,
464 "tool_name": tool_name,
465 "tool_output_len": result.output.len(),
466 "duration_ms": tool_started.elapsed().as_millis(),
467 }),
468 )
469 .await;
470 info!(
471 request_id = %request_id,
472 stage = "tool",
473 tool_name = %tool_name,
474 duration_ms = %tool_started.elapsed().as_millis(),
475 "tool execution finished"
476 );
477 self.hook(
478 "after_tool_call",
479 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name, "status": "ok"}),
480 )
481 .await?;
482 if is_plugin_call {
483 self.hook(
484 "after_plugin_call",
485 json!({"request_id": request_id, "iteration": iteration, "plugin_tool": tool_name, "status": "ok"}),
486 )
487 .await?;
488 }
489 Ok(result)
490 }
491
492 async fn call_provider_with_context(
493 &self,
494 prompt: &str,
495 request_id: &str,
496 stream_sink: Option<StreamSink>,
497 source_channel: Option<&str>,
498 ) -> Result<String, AgentError> {
499 let recent_memory = self
500 .memory
501 .recent_for_boundary(
502 self.config.memory_window_size,
503 &self.config.privacy_boundary,
504 source_channel,
505 )
506 .await
507 .map_err(|source| AgentError::Memory { source })?;
508 self.audit(
509 "memory_recent_loaded",
510 json!({"request_id": request_id, "items": recent_memory.len()}),
511 )
512 .await;
513 let (provider_prompt, prompt_truncated) =
514 build_provider_prompt(prompt, &recent_memory, self.config.max_prompt_chars);
515 if prompt_truncated {
516 self.audit(
517 "provider_prompt_truncated",
518 json!({
519 "request_id": request_id,
520 "max_prompt_chars": self.config.max_prompt_chars,
521 }),
522 )
523 .await;
524 }
525 self.hook(
526 "before_provider_call",
527 json!({
528 "request_id": request_id,
529 "prompt_len": provider_prompt.len(),
530 "memory_items": recent_memory.len(),
531 "prompt_truncated": prompt_truncated
532 }),
533 )
534 .await?;
535 self.audit(
536 "provider_call_start",
537 json!({
538 "request_id": request_id,
539 "prompt_len": provider_prompt.len(),
540 "memory_items": recent_memory.len(),
541 "prompt_truncated": prompt_truncated
542 }),
543 )
544 .await;
545 let provider_started = Instant::now();
546 let provider_result = if let Some(sink) = stream_sink {
547 self.provider
548 .complete_streaming(&provider_prompt, sink)
549 .await
550 } else {
551 self.provider
552 .complete_with_reasoning(&provider_prompt, &self.config.reasoning)
553 .await
554 };
555 let completion = match provider_result {
556 Ok(result) => result,
557 Err(source) => {
558 self.observe_histogram(
559 "provider_latency_ms",
560 provider_started.elapsed().as_secs_f64() * 1000.0,
561 );
562 self.increment_counter("provider_errors_total");
563 return Err(AgentError::Provider { source });
564 }
565 };
566 self.observe_histogram(
567 "provider_latency_ms",
568 provider_started.elapsed().as_secs_f64() * 1000.0,
569 );
570 self.audit(
571 "provider_call_success",
572 json!({
573 "request_id": request_id,
574 "response_len": completion.output_text.len(),
575 "duration_ms": provider_started.elapsed().as_millis(),
576 }),
577 )
578 .await;
579 info!(
580 request_id = %request_id,
581 stage = "provider",
582 duration_ms = %provider_started.elapsed().as_millis(),
583 "provider call finished"
584 );
585 self.hook(
586 "after_provider_call",
587 json!({"request_id": request_id, "response_len": completion.output_text.len(), "status": "ok"}),
588 )
589 .await?;
590 Ok(completion.output_text)
591 }
592
593 async fn write_to_memory(
594 &self,
595 role: &str,
596 content: &str,
597 request_id: &str,
598 source_channel: Option<&str>,
599 ) -> Result<(), AgentError> {
600 self.hook(
601 "before_memory_write",
602 json!({"request_id": request_id, "role": role}),
603 )
604 .await?;
605 self.memory
606 .append(MemoryEntry {
607 role: role.to_string(),
608 content: content.to_string(),
609 privacy_boundary: self.config.privacy_boundary.clone(),
610 source_channel: source_channel.map(String::from),
611 })
612 .await
613 .map_err(|source| AgentError::Memory { source })?;
614 self.hook(
615 "after_memory_write",
616 json!({"request_id": request_id, "role": role}),
617 )
618 .await?;
619 self.audit(
620 &format!("memory_append_{role}"),
621 json!({"request_id": request_id}),
622 )
623 .await;
624 Ok(())
625 }
626
627 fn should_research(&self, user_text: &str) -> bool {
628 if !self.config.research.enabled {
629 return false;
630 }
631 match self.config.research.trigger {
632 ResearchTrigger::Never => false,
633 ResearchTrigger::Always => true,
634 ResearchTrigger::Keywords => {
635 let lower = user_text.to_lowercase();
636 self.config
637 .research
638 .keywords
639 .iter()
640 .any(|kw| lower.contains(&kw.to_lowercase()))
641 }
642 ResearchTrigger::Length => user_text.len() >= self.config.research.min_message_length,
643 ResearchTrigger::Question => user_text.trim_end().ends_with('?'),
644 }
645 }
646
647 async fn run_research_phase(
648 &self,
649 user_text: &str,
650 ctx: &ToolContext,
651 request_id: &str,
652 ) -> Result<String, AgentError> {
653 self.audit(
654 "research_phase_start",
655 json!({
656 "request_id": request_id,
657 "max_iterations": self.config.research.max_iterations,
658 }),
659 )
660 .await;
661 self.increment_counter("research_phase_started");
662
663 let research_prompt = format!(
664 "You are in RESEARCH mode. The user asked: \"{user_text}\"\n\
665 Gather relevant information using available tools. \
666 Respond with tool: calls to collect data. \
667 When done gathering, summarize your findings without a tool: prefix."
668 );
669 let mut prompt = self
670 .call_provider_with_context(
671 &research_prompt,
672 request_id,
673 None,
674 ctx.source_channel.as_deref(),
675 )
676 .await?;
677 let mut findings: Vec<String> = Vec::new();
678
679 for iteration in 0..self.config.research.max_iterations {
680 if !prompt.starts_with("tool:") {
681 findings.push(prompt.clone());
682 break;
683 }
684
685 let calls = parse_tool_calls(&prompt);
686 if calls.is_empty() {
687 break;
688 }
689
690 let (name, input) = calls[0];
691 if let Some(tool) = self.tools.iter().find(|t| t.name() == name) {
692 let result = self
693 .execute_tool(&**tool, name, input, ctx, request_id, iteration)
694 .await?;
695 findings.push(format!("{name}: {}", result.output));
696
697 if self.config.research.show_progress {
698 info!(iteration, tool = name, "research phase: tool executed");
699 self.audit(
700 "research_phase_iteration",
701 json!({
702 "request_id": request_id,
703 "iteration": iteration,
704 "tool_name": name,
705 }),
706 )
707 .await;
708 }
709
710 let next_prompt = format!(
711 "Research iteration {iteration}: tool `{name}` returned: {}\n\
712 Continue researching or summarize findings.",
713 result.output
714 );
715 prompt = self
716 .call_provider_with_context(
717 &next_prompt,
718 request_id,
719 None,
720 ctx.source_channel.as_deref(),
721 )
722 .await?;
723 } else {
724 break;
725 }
726 }
727
728 self.audit(
729 "research_phase_complete",
730 json!({
731 "request_id": request_id,
732 "findings_count": findings.len(),
733 }),
734 )
735 .await;
736 self.increment_counter("research_phase_completed");
737
738 Ok(findings.join("\n"))
739 }
740
741 async fn respond_with_tools(
744 &self,
745 request_id: &str,
746 user_text: &str,
747 research_context: &str,
748 ctx: &ToolContext,
749 stream_sink: Option<StreamSink>,
750 ) -> Result<AssistantMessage, AgentError> {
751 self.audit(
752 "structured_tool_use_start",
753 json!({
754 "request_id": request_id,
755 "max_tool_iterations": self.config.max_tool_iterations,
756 }),
757 )
758 .await;
759
760 let tool_definitions = self.build_tool_definitions();
761
762 let recent_memory = self
764 .memory
765 .recent_for_boundary(
766 self.config.memory_window_size,
767 &self.config.privacy_boundary,
768 ctx.source_channel.as_deref(),
769 )
770 .await
771 .map_err(|source| AgentError::Memory { source })?;
772 self.audit(
773 "memory_recent_loaded",
774 json!({"request_id": request_id, "items": recent_memory.len()}),
775 )
776 .await;
777
778 let mut messages: Vec<ConversationMessage> = Vec::new();
779
780 if let Some(ref sp) = self.config.system_prompt {
782 messages.push(ConversationMessage::System {
783 content: sp.clone(),
784 });
785 }
786
787 messages.extend(memory_to_messages(&recent_memory));
788
789 if !research_context.is_empty() {
790 messages.push(ConversationMessage::User {
791 content: format!("Research findings:\n{research_context}"),
792 });
793 }
794
795 messages.push(ConversationMessage::User {
796 content: user_text.to_string(),
797 });
798
799 let mut tool_history: Vec<(String, String, String)> = Vec::new();
800 let mut failure_streak: usize = 0;
801
802 for iteration in 0..self.config.max_tool_iterations {
803 truncate_messages(&mut messages, self.config.max_prompt_chars);
804
805 self.hook(
806 "before_provider_call",
807 json!({
808 "request_id": request_id,
809 "iteration": iteration,
810 "message_count": messages.len(),
811 "tool_count": tool_definitions.len(),
812 }),
813 )
814 .await?;
815 self.audit(
816 "provider_call_start",
817 json!({
818 "request_id": request_id,
819 "iteration": iteration,
820 "message_count": messages.len(),
821 }),
822 )
823 .await;
824
825 let provider_started = Instant::now();
826 let provider_result = if let Some(ref sink) = stream_sink {
827 self.provider
828 .complete_streaming_with_tools(
829 &messages,
830 &tool_definitions,
831 &self.config.reasoning,
832 sink.clone(),
833 )
834 .await
835 } else {
836 self.provider
837 .complete_with_tools(&messages, &tool_definitions, &self.config.reasoning)
838 .await
839 };
840 let chat_result = match provider_result {
841 Ok(result) => result,
842 Err(source) => {
843 self.observe_histogram(
844 "provider_latency_ms",
845 provider_started.elapsed().as_secs_f64() * 1000.0,
846 );
847 self.increment_counter("provider_errors_total");
848 return Err(AgentError::Provider { source });
849 }
850 };
851 self.observe_histogram(
852 "provider_latency_ms",
853 provider_started.elapsed().as_secs_f64() * 1000.0,
854 );
855 info!(
856 request_id = %request_id,
857 iteration = iteration,
858 stop_reason = ?chat_result.stop_reason,
859 tool_calls = chat_result.tool_calls.len(),
860 "structured provider call finished"
861 );
862 self.hook(
863 "after_provider_call",
864 json!({
865 "request_id": request_id,
866 "iteration": iteration,
867 "response_len": chat_result.output_text.len(),
868 "tool_calls": chat_result.tool_calls.len(),
869 "status": "ok",
870 }),
871 )
872 .await?;
873
874 if chat_result.tool_calls.is_empty()
876 || chat_result.stop_reason == Some(StopReason::EndTurn)
877 {
878 let response_text = chat_result.output_text;
879 self.write_to_memory(
880 "assistant",
881 &response_text,
882 request_id,
883 ctx.source_channel.as_deref(),
884 )
885 .await?;
886 self.audit("respond_success", json!({"request_id": request_id}))
887 .await;
888 self.hook(
889 "before_response_emit",
890 json!({"request_id": request_id, "response_len": response_text.len()}),
891 )
892 .await?;
893 let response = AssistantMessage {
894 text: response_text,
895 };
896 self.hook(
897 "after_response_emit",
898 json!({"request_id": request_id, "response_len": response.text.len()}),
899 )
900 .await?;
901 return Ok(response);
902 }
903
904 messages.push(ConversationMessage::Assistant {
906 content: if chat_result.output_text.is_empty() {
907 None
908 } else {
909 Some(chat_result.output_text.clone())
910 },
911 tool_calls: chat_result.tool_calls.clone(),
912 });
913
914 let tool_calls = &chat_result.tool_calls;
916 let has_gated = tool_calls
917 .iter()
918 .any(|tc| self.config.gated_tools.contains(&tc.name));
919 let use_parallel = self.config.parallel_tools && tool_calls.len() > 1 && !has_gated;
920
921 let mut tool_results: Vec<ToolResultMessage> = Vec::new();
922
923 if use_parallel {
924 let futs: Vec<_> = tool_calls
925 .iter()
926 .map(|tc| {
927 let tool = self.tools.iter().find(|t| t.name() == tc.name);
928 async move {
929 match tool {
930 Some(tool) => {
931 let input_str = match prepare_tool_input(&**tool, &tc.input) {
932 Ok(s) => s,
933 Err(validation_err) => {
934 return (
935 tc.name.clone(),
936 String::new(),
937 ToolResultMessage {
938 tool_use_id: tc.id.clone(),
939 content: validation_err,
940 is_error: true,
941 },
942 );
943 }
944 };
945 match tool.execute(&input_str, ctx).await {
946 Ok(result) => (
947 tc.name.clone(),
948 input_str,
949 ToolResultMessage {
950 tool_use_id: tc.id.clone(),
951 content: result.output,
952 is_error: false,
953 },
954 ),
955 Err(e) => (
956 tc.name.clone(),
957 input_str,
958 ToolResultMessage {
959 tool_use_id: tc.id.clone(),
960 content: format!("Error: {e}"),
961 is_error: true,
962 },
963 ),
964 }
965 }
966 None => (
967 tc.name.clone(),
968 String::new(),
969 ToolResultMessage {
970 tool_use_id: tc.id.clone(),
971 content: format!("Tool '{}' not found", tc.name),
972 is_error: true,
973 },
974 ),
975 }
976 }
977 })
978 .collect();
979 let results = futures_util::future::join_all(futs).await;
980 for (name, input_str, result_msg) in results {
981 if result_msg.is_error {
982 failure_streak += 1;
983 } else {
984 failure_streak = 0;
985 tool_history.push((name, input_str, result_msg.content.clone()));
986 }
987 tool_results.push(result_msg);
988 }
989 } else {
990 for tc in tool_calls {
992 let result_msg = match self.tools.iter().find(|t| t.name() == tc.name) {
993 Some(tool) => {
994 let input_str = match prepare_tool_input(&**tool, &tc.input) {
995 Ok(s) => s,
996 Err(validation_err) => {
997 failure_streak += 1;
998 tool_results.push(ToolResultMessage {
999 tool_use_id: tc.id.clone(),
1000 content: validation_err,
1001 is_error: true,
1002 });
1003 continue;
1004 }
1005 };
1006 self.audit(
1007 "tool_requested",
1008 json!({
1009 "request_id": request_id,
1010 "iteration": iteration,
1011 "tool_name": tc.name,
1012 "tool_input_len": input_str.len(),
1013 }),
1014 )
1015 .await;
1016
1017 match self
1018 .execute_tool(
1019 &**tool, &tc.name, &input_str, ctx, request_id, iteration,
1020 )
1021 .await
1022 {
1023 Ok(result) => {
1024 failure_streak = 0;
1025 tool_history.push((
1026 tc.name.clone(),
1027 input_str,
1028 result.output.clone(),
1029 ));
1030 ToolResultMessage {
1031 tool_use_id: tc.id.clone(),
1032 content: result.output,
1033 is_error: false,
1034 }
1035 }
1036 Err(e) => {
1037 failure_streak += 1;
1038 ToolResultMessage {
1039 tool_use_id: tc.id.clone(),
1040 content: format!("Error: {e}"),
1041 is_error: true,
1042 }
1043 }
1044 }
1045 }
1046 None => {
1047 self.audit(
1048 "tool_not_found",
1049 json!({
1050 "request_id": request_id,
1051 "iteration": iteration,
1052 "tool_name": tc.name,
1053 }),
1054 )
1055 .await;
1056 failure_streak += 1;
1057 ToolResultMessage {
1058 tool_use_id: tc.id.clone(),
1059 content: format!(
1060 "Tool '{}' not found. Available tools: {}",
1061 tc.name,
1062 tool_definitions
1063 .iter()
1064 .map(|d| d.name.as_str())
1065 .collect::<Vec<_>>()
1066 .join(", ")
1067 ),
1068 is_error: true,
1069 }
1070 }
1071 };
1072 tool_results.push(result_msg);
1073 }
1074 }
1075
1076 for result in &tool_results {
1078 messages.push(ConversationMessage::ToolResult(result.clone()));
1079 }
1080
1081 if self.config.loop_detection_no_progress_threshold > 0 {
1083 let threshold = self.config.loop_detection_no_progress_threshold;
1084 if tool_history.len() >= threshold {
1085 let recent = &tool_history[tool_history.len() - threshold..];
1086 let all_same = recent.iter().all(|entry| {
1087 entry.0 == recent[0].0 && entry.1 == recent[0].1 && entry.2 == recent[0].2
1088 });
1089 if all_same {
1090 warn!(
1091 tool_name = recent[0].0,
1092 threshold, "structured loop detection: no-progress threshold reached"
1093 );
1094 self.audit(
1095 "loop_detection_no_progress",
1096 json!({
1097 "request_id": request_id,
1098 "tool_name": recent[0].0,
1099 "threshold": threshold,
1100 }),
1101 )
1102 .await;
1103 messages.push(ConversationMessage::User {
1104 content: format!(
1105 "SYSTEM NOTICE: Loop detected — the tool `{}` was called \
1106 {} times with identical arguments and output. You are stuck \
1107 in a loop. Stop calling this tool and try a different approach, \
1108 or provide your best answer with the information you already have.",
1109 recent[0].0, threshold
1110 ),
1111 });
1112 }
1113 }
1114 }
1115
1116 if self.config.loop_detection_ping_pong_cycles > 0 {
1118 let cycles = self.config.loop_detection_ping_pong_cycles;
1119 let needed = cycles * 2;
1120 if tool_history.len() >= needed {
1121 let recent = &tool_history[tool_history.len() - needed..];
1122 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1123 && recent[0].0 != recent[1].0;
1124 if is_ping_pong {
1125 warn!(
1126 tool_a = recent[0].0,
1127 tool_b = recent[1].0,
1128 cycles,
1129 "structured loop detection: ping-pong pattern detected"
1130 );
1131 self.audit(
1132 "loop_detection_ping_pong",
1133 json!({
1134 "request_id": request_id,
1135 "tool_a": recent[0].0,
1136 "tool_b": recent[1].0,
1137 "cycles": cycles,
1138 }),
1139 )
1140 .await;
1141 messages.push(ConversationMessage::User {
1142 content: format!(
1143 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1144 alternating for {} cycles in a ping-pong pattern. Stop this \
1145 alternation and try a different approach, or provide your best \
1146 answer with the information you already have.",
1147 recent[0].0, recent[1].0, cycles
1148 ),
1149 });
1150 }
1151 }
1152 }
1153
1154 if self.config.loop_detection_failure_streak > 0
1156 && failure_streak >= self.config.loop_detection_failure_streak
1157 {
1158 warn!(
1159 failure_streak,
1160 "structured loop detection: consecutive failure streak"
1161 );
1162 self.audit(
1163 "loop_detection_failure_streak",
1164 json!({
1165 "request_id": request_id,
1166 "streak": failure_streak,
1167 }),
1168 )
1169 .await;
1170 messages.push(ConversationMessage::User {
1171 content: format!(
1172 "SYSTEM NOTICE: {} consecutive tool calls have failed. \
1173 Stop calling tools and provide your best answer with \
1174 the information you already have.",
1175 failure_streak
1176 ),
1177 });
1178 }
1179 }
1180
1181 self.audit(
1183 "structured_tool_use_max_iterations",
1184 json!({
1185 "request_id": request_id,
1186 "max_tool_iterations": self.config.max_tool_iterations,
1187 }),
1188 )
1189 .await;
1190
1191 messages.push(ConversationMessage::User {
1192 content: "You have reached the maximum number of tool call iterations. \
1193 Please provide your best answer now without calling any more tools."
1194 .to_string(),
1195 });
1196 truncate_messages(&mut messages, self.config.max_prompt_chars);
1197
1198 let final_result = if let Some(ref sink) = stream_sink {
1199 self.provider
1200 .complete_streaming_with_tools(&messages, &[], &self.config.reasoning, sink.clone())
1201 .await
1202 .map_err(|source| AgentError::Provider { source })?
1203 } else {
1204 self.provider
1205 .complete_with_tools(&messages, &[], &self.config.reasoning)
1206 .await
1207 .map_err(|source| AgentError::Provider { source })?
1208 };
1209
1210 let response_text = final_result.output_text;
1211 self.write_to_memory(
1212 "assistant",
1213 &response_text,
1214 request_id,
1215 ctx.source_channel.as_deref(),
1216 )
1217 .await?;
1218 self.audit("respond_success", json!({"request_id": request_id}))
1219 .await;
1220 self.hook(
1221 "before_response_emit",
1222 json!({"request_id": request_id, "response_len": response_text.len()}),
1223 )
1224 .await?;
1225 let response = AssistantMessage {
1226 text: response_text,
1227 };
1228 self.hook(
1229 "after_response_emit",
1230 json!({"request_id": request_id, "response_len": response.text.len()}),
1231 )
1232 .await?;
1233 Ok(response)
1234 }
1235
1236 pub async fn respond(
1237 &self,
1238 user: UserMessage,
1239 ctx: &ToolContext,
1240 ) -> Result<AssistantMessage, AgentError> {
1241 let request_id = Self::next_request_id();
1242 self.increment_counter("requests_total");
1243 let run_started = Instant::now();
1244 self.hook("before_run", json!({"request_id": request_id}))
1245 .await?;
1246 let timed = timeout(
1247 Duration::from_millis(self.config.request_timeout_ms),
1248 self.respond_inner(&request_id, user, ctx, None),
1249 )
1250 .await;
1251 let result = match timed {
1252 Ok(result) => result,
1253 Err(_) => Err(AgentError::Timeout {
1254 timeout_ms: self.config.request_timeout_ms,
1255 }),
1256 };
1257
1258 let after_detail = match &result {
1259 Ok(response) => json!({
1260 "request_id": request_id,
1261 "status": "ok",
1262 "response_len": response.text.len(),
1263 "duration_ms": run_started.elapsed().as_millis(),
1264 }),
1265 Err(err) => json!({
1266 "request_id": request_id,
1267 "status": "error",
1268 "error": redact_text(&err.to_string()),
1269 "duration_ms": run_started.elapsed().as_millis(),
1270 }),
1271 };
1272 info!(
1273 request_id = %request_id,
1274 duration_ms = %run_started.elapsed().as_millis(),
1275 "agent run completed"
1276 );
1277 self.hook("after_run", after_detail).await?;
1278 result
1279 }
1280
1281 pub async fn respond_streaming(
1285 &self,
1286 user: UserMessage,
1287 ctx: &ToolContext,
1288 sink: StreamSink,
1289 ) -> Result<AssistantMessage, AgentError> {
1290 let request_id = Self::next_request_id();
1291 self.increment_counter("requests_total");
1292 let run_started = Instant::now();
1293 self.hook("before_run", json!({"request_id": request_id}))
1294 .await?;
1295 let timed = timeout(
1296 Duration::from_millis(self.config.request_timeout_ms),
1297 self.respond_inner(&request_id, user, ctx, Some(sink)),
1298 )
1299 .await;
1300 let result = match timed {
1301 Ok(result) => result,
1302 Err(_) => Err(AgentError::Timeout {
1303 timeout_ms: self.config.request_timeout_ms,
1304 }),
1305 };
1306
1307 let after_detail = match &result {
1308 Ok(response) => json!({
1309 "request_id": request_id,
1310 "status": "ok",
1311 "response_len": response.text.len(),
1312 "duration_ms": run_started.elapsed().as_millis(),
1313 }),
1314 Err(err) => json!({
1315 "request_id": request_id,
1316 "status": "error",
1317 "error": redact_text(&err.to_string()),
1318 "duration_ms": run_started.elapsed().as_millis(),
1319 }),
1320 };
1321 info!(
1322 request_id = %request_id,
1323 duration_ms = %run_started.elapsed().as_millis(),
1324 "streaming agent run completed"
1325 );
1326 self.hook("after_run", after_detail).await?;
1327 result
1328 }
1329
1330 async fn respond_inner(
1331 &self,
1332 request_id: &str,
1333 user: UserMessage,
1334 ctx: &ToolContext,
1335 stream_sink: Option<StreamSink>,
1336 ) -> Result<AssistantMessage, AgentError> {
1337 self.audit(
1338 "respond_start",
1339 json!({
1340 "request_id": request_id,
1341 "user_message_len": user.text.len(),
1342 "max_tool_iterations": self.config.max_tool_iterations,
1343 "request_timeout_ms": self.config.request_timeout_ms,
1344 }),
1345 )
1346 .await;
1347 self.write_to_memory(
1348 "user",
1349 &user.text,
1350 request_id,
1351 ctx.source_channel.as_deref(),
1352 )
1353 .await?;
1354
1355 let research_context = if self.should_research(&user.text) {
1356 self.run_research_phase(&user.text, ctx, request_id).await?
1357 } else {
1358 String::new()
1359 };
1360
1361 if self.config.model_supports_tool_use && self.has_tool_definitions() {
1363 return self
1364 .respond_with_tools(request_id, &user.text, &research_context, ctx, stream_sink)
1365 .await;
1366 }
1367
1368 let mut prompt = user.text;
1369 let mut tool_history: Vec<(String, String, String)> = Vec::new();
1370
1371 for iteration in 0..self.config.max_tool_iterations {
1372 if !prompt.starts_with("tool:") {
1373 break;
1374 }
1375
1376 let calls = parse_tool_calls(&prompt);
1377 if calls.is_empty() {
1378 break;
1379 }
1380
1381 let has_gated = calls
1385 .iter()
1386 .any(|(name, _)| self.config.gated_tools.contains(*name));
1387 if calls.len() > 1 && self.config.parallel_tools && !has_gated {
1388 let mut resolved: Vec<(&str, &str, &dyn Tool)> = Vec::new();
1389 for &(name, input) in &calls {
1390 let tool = self.tools.iter().find(|t| t.name() == name);
1391 match tool {
1392 Some(t) => resolved.push((name, input, &**t)),
1393 None => {
1394 self.audit(
1395 "tool_not_found",
1396 json!({"request_id": request_id, "iteration": iteration, "tool_name": name}),
1397 )
1398 .await;
1399 }
1400 }
1401 }
1402 if resolved.is_empty() {
1403 break;
1404 }
1405
1406 let futs: Vec<_> = resolved
1407 .iter()
1408 .map(|&(name, input, tool)| async move {
1409 let r = tool.execute(input, ctx).await;
1410 (name, input, r)
1411 })
1412 .collect();
1413 let results = futures_util::future::join_all(futs).await;
1414
1415 let mut output_parts: Vec<String> = Vec::new();
1416 for (name, input, result) in results {
1417 match result {
1418 Ok(r) => {
1419 tool_history.push((
1420 name.to_string(),
1421 input.to_string(),
1422 r.output.clone(),
1423 ));
1424 output_parts.push(format!("Tool output from {name}: {}", r.output));
1425 }
1426 Err(source) => {
1427 return Err(AgentError::Tool {
1428 tool: name.to_string(),
1429 source,
1430 });
1431 }
1432 }
1433 }
1434 prompt = output_parts.join("\n");
1435 continue;
1436 }
1437
1438 let (tool_name, tool_input) = calls[0];
1440 self.audit(
1441 "tool_requested",
1442 json!({
1443 "request_id": request_id,
1444 "iteration": iteration,
1445 "tool_name": tool_name,
1446 "tool_input_len": tool_input.len(),
1447 }),
1448 )
1449 .await;
1450
1451 if let Some(tool) = self.tools.iter().find(|t| t.name() == tool_name) {
1452 let result = self
1453 .execute_tool(&**tool, tool_name, tool_input, ctx, request_id, iteration)
1454 .await?;
1455
1456 tool_history.push((
1457 tool_name.to_string(),
1458 tool_input.to_string(),
1459 result.output.clone(),
1460 ));
1461
1462 if self.config.loop_detection_no_progress_threshold > 0 {
1464 let threshold = self.config.loop_detection_no_progress_threshold;
1465 if tool_history.len() >= threshold {
1466 let recent = &tool_history[tool_history.len() - threshold..];
1467 let all_same = recent.iter().all(|entry| {
1468 entry.0 == recent[0].0
1469 && entry.1 == recent[0].1
1470 && entry.2 == recent[0].2
1471 });
1472 if all_same {
1473 warn!(
1474 tool_name,
1475 threshold, "loop detection: no-progress threshold reached"
1476 );
1477 self.audit(
1478 "loop_detection_no_progress",
1479 json!({
1480 "request_id": request_id,
1481 "tool_name": tool_name,
1482 "threshold": threshold,
1483 }),
1484 )
1485 .await;
1486 prompt = format!(
1487 "SYSTEM NOTICE: Loop detected — the tool `{tool_name}` was called \
1488 {threshold} times with identical arguments and output. You are stuck \
1489 in a loop. Stop calling this tool and try a different approach, or \
1490 provide your best answer with the information you already have."
1491 );
1492 break;
1493 }
1494 }
1495 }
1496
1497 if self.config.loop_detection_ping_pong_cycles > 0 {
1499 let cycles = self.config.loop_detection_ping_pong_cycles;
1500 let needed = cycles * 2;
1501 if tool_history.len() >= needed {
1502 let recent = &tool_history[tool_history.len() - needed..];
1503 let is_ping_pong = (0..needed).all(|i| recent[i].0 == recent[i % 2].0)
1504 && recent[0].0 != recent[1].0;
1505 if is_ping_pong {
1506 warn!(
1507 tool_a = recent[0].0,
1508 tool_b = recent[1].0,
1509 cycles,
1510 "loop detection: ping-pong pattern detected"
1511 );
1512 self.audit(
1513 "loop_detection_ping_pong",
1514 json!({
1515 "request_id": request_id,
1516 "tool_a": recent[0].0,
1517 "tool_b": recent[1].0,
1518 "cycles": cycles,
1519 }),
1520 )
1521 .await;
1522 prompt = format!(
1523 "SYSTEM NOTICE: Loop detected — tools `{}` and `{}` have been \
1524 alternating for {} cycles in a ping-pong pattern. Stop this \
1525 alternation and try a different approach, or provide your best \
1526 answer with the information you already have.",
1527 recent[0].0, recent[1].0, cycles
1528 );
1529 break;
1530 }
1531 }
1532 }
1533
1534 if result.output.starts_with("tool:") {
1535 prompt = result.output.clone();
1536 } else {
1537 prompt = format!("Tool output from {tool_name}: {}", result.output);
1538 }
1539 continue;
1540 }
1541
1542 self.audit(
1543 "tool_not_found",
1544 json!({"request_id": request_id, "iteration": iteration, "tool_name": tool_name}),
1545 )
1546 .await;
1547 break;
1548 }
1549
1550 if !research_context.is_empty() {
1551 prompt = format!("Research findings:\n{research_context}\n\nUser request:\n{prompt}");
1552 }
1553
1554 let response_text = self
1555 .call_provider_with_context(
1556 &prompt,
1557 request_id,
1558 stream_sink,
1559 ctx.source_channel.as_deref(),
1560 )
1561 .await?;
1562 self.write_to_memory(
1563 "assistant",
1564 &response_text,
1565 request_id,
1566 ctx.source_channel.as_deref(),
1567 )
1568 .await?;
1569 self.audit("respond_success", json!({"request_id": request_id}))
1570 .await;
1571
1572 self.hook(
1573 "before_response_emit",
1574 json!({"request_id": request_id, "response_len": response_text.len()}),
1575 )
1576 .await?;
1577 let response = AssistantMessage {
1578 text: response_text,
1579 };
1580 self.hook(
1581 "after_response_emit",
1582 json!({"request_id": request_id, "response_len": response.text.len()}),
1583 )
1584 .await?;
1585
1586 Ok(response)
1587 }
1588}
1589
1590#[cfg(test)]
1591mod tests {
1592 use super::*;
1593 use crate::types::{
1594 ChatResult, HookEvent, HookFailureMode, HookPolicy, ReasoningConfig, ResearchPolicy,
1595 ResearchTrigger, ToolResult,
1596 };
1597 use async_trait::async_trait;
1598 use serde_json::Value;
1599 use std::collections::HashMap;
1600 use std::sync::atomic::{AtomicUsize, Ordering};
1601 use std::sync::{Arc, Mutex};
1602 use tokio::time::{sleep, Duration};
1603
1604 #[derive(Default)]
1605 struct TestMemory {
1606 entries: Arc<Mutex<Vec<MemoryEntry>>>,
1607 }
1608
1609 #[async_trait]
1610 impl MemoryStore for TestMemory {
1611 async fn append(&self, entry: MemoryEntry) -> anyhow::Result<()> {
1612 self.entries
1613 .lock()
1614 .expect("memory lock poisoned")
1615 .push(entry);
1616 Ok(())
1617 }
1618
1619 async fn recent(&self, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
1620 let entries = self.entries.lock().expect("memory lock poisoned");
1621 Ok(entries.iter().rev().take(limit).cloned().collect())
1622 }
1623 }
1624
1625 struct TestProvider {
1626 received_prompts: Arc<Mutex<Vec<String>>>,
1627 response_text: String,
1628 }
1629
1630 #[async_trait]
1631 impl Provider for TestProvider {
1632 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
1633 self.received_prompts
1634 .lock()
1635 .expect("provider lock poisoned")
1636 .push(prompt.to_string());
1637 Ok(ChatResult {
1638 output_text: self.response_text.clone(),
1639 ..Default::default()
1640 })
1641 }
1642 }
1643
1644 struct EchoTool;
1645
1646 #[async_trait]
1647 impl Tool for EchoTool {
1648 fn name(&self) -> &'static str {
1649 "echo"
1650 }
1651
1652 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1653 Ok(ToolResult {
1654 output: format!("echoed:{input}"),
1655 })
1656 }
1657 }
1658
1659 struct PluginEchoTool;
1660
1661 #[async_trait]
1662 impl Tool for PluginEchoTool {
1663 fn name(&self) -> &'static str {
1664 "plugin:echo"
1665 }
1666
1667 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1668 Ok(ToolResult {
1669 output: format!("plugin-echoed:{input}"),
1670 })
1671 }
1672 }
1673
1674 struct FailingTool;
1675
1676 #[async_trait]
1677 impl Tool for FailingTool {
1678 fn name(&self) -> &'static str {
1679 "boom"
1680 }
1681
1682 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1683 Err(anyhow::anyhow!("tool exploded"))
1684 }
1685 }
1686
1687 struct FailingProvider;
1688
1689 #[async_trait]
1690 impl Provider for FailingProvider {
1691 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
1692 Err(anyhow::anyhow!("provider boom"))
1693 }
1694 }
1695
1696 struct SlowProvider;
1697
1698 #[async_trait]
1699 impl Provider for SlowProvider {
1700 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
1701 sleep(Duration::from_millis(500)).await;
1702 Ok(ChatResult {
1703 output_text: "late".to_string(),
1704 ..Default::default()
1705 })
1706 }
1707 }
1708
1709 struct ScriptedProvider {
1710 responses: Vec<String>,
1711 call_count: AtomicUsize,
1712 received_prompts: Arc<Mutex<Vec<String>>>,
1713 }
1714
1715 impl ScriptedProvider {
1716 fn new(responses: Vec<&str>) -> Self {
1717 Self {
1718 responses: responses.into_iter().map(|s| s.to_string()).collect(),
1719 call_count: AtomicUsize::new(0),
1720 received_prompts: Arc::new(Mutex::new(Vec::new())),
1721 }
1722 }
1723 }
1724
1725 #[async_trait]
1726 impl Provider for ScriptedProvider {
1727 async fn complete(&self, prompt: &str) -> anyhow::Result<ChatResult> {
1728 self.received_prompts
1729 .lock()
1730 .expect("provider lock poisoned")
1731 .push(prompt.to_string());
1732 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
1733 let response = if idx < self.responses.len() {
1734 self.responses[idx].clone()
1735 } else {
1736 self.responses.last().cloned().unwrap_or_default()
1737 };
1738 Ok(ChatResult {
1739 output_text: response,
1740 ..Default::default()
1741 })
1742 }
1743 }
1744
1745 struct UpperTool;
1746
1747 #[async_trait]
1748 impl Tool for UpperTool {
1749 fn name(&self) -> &'static str {
1750 "upper"
1751 }
1752
1753 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1754 Ok(ToolResult {
1755 output: input.to_uppercase(),
1756 })
1757 }
1758 }
1759
1760 struct LoopTool;
1763
1764 #[async_trait]
1765 impl Tool for LoopTool {
1766 fn name(&self) -> &'static str {
1767 "loop_tool"
1768 }
1769
1770 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1771 Ok(ToolResult {
1772 output: "tool:loop_tool x".to_string(),
1773 })
1774 }
1775 }
1776
1777 struct PingTool;
1779
1780 #[async_trait]
1781 impl Tool for PingTool {
1782 fn name(&self) -> &'static str {
1783 "ping"
1784 }
1785
1786 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1787 Ok(ToolResult {
1788 output: "tool:pong x".to_string(),
1789 })
1790 }
1791 }
1792
1793 struct PongTool;
1795
1796 #[async_trait]
1797 impl Tool for PongTool {
1798 fn name(&self) -> &'static str {
1799 "pong"
1800 }
1801
1802 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
1803 Ok(ToolResult {
1804 output: "tool:ping x".to_string(),
1805 })
1806 }
1807 }
1808
1809 struct RecordingHookSink {
1810 events: Arc<Mutex<Vec<String>>>,
1811 }
1812
1813 #[async_trait]
1814 impl HookSink for RecordingHookSink {
1815 async fn record(&self, event: HookEvent) -> anyhow::Result<()> {
1816 self.events
1817 .lock()
1818 .expect("hook lock poisoned")
1819 .push(event.stage);
1820 Ok(())
1821 }
1822 }
1823
1824 struct FailingHookSink;
1825
1826 #[async_trait]
1827 impl HookSink for FailingHookSink {
1828 async fn record(&self, _event: HookEvent) -> anyhow::Result<()> {
1829 Err(anyhow::anyhow!("hook sink failure"))
1830 }
1831 }
1832
1833 struct RecordingAuditSink {
1834 events: Arc<Mutex<Vec<AuditEvent>>>,
1835 }
1836
1837 #[async_trait]
1838 impl AuditSink for RecordingAuditSink {
1839 async fn record(&self, event: AuditEvent) -> anyhow::Result<()> {
1840 self.events.lock().expect("audit lock poisoned").push(event);
1841 Ok(())
1842 }
1843 }
1844
1845 struct RecordingMetricsSink {
1846 counters: Arc<Mutex<HashMap<&'static str, u64>>>,
1847 histograms: Arc<Mutex<HashMap<&'static str, usize>>>,
1848 }
1849
1850 impl MetricsSink for RecordingMetricsSink {
1851 fn increment_counter(&self, name: &'static str, value: u64) {
1852 let mut counters = self.counters.lock().expect("metrics lock poisoned");
1853 *counters.entry(name).or_insert(0) += value;
1854 }
1855
1856 fn observe_histogram(&self, name: &'static str, _value: f64) {
1857 let mut histograms = self.histograms.lock().expect("metrics lock poisoned");
1858 *histograms.entry(name).or_insert(0) += 1;
1859 }
1860 }
1861
1862 fn test_ctx() -> ToolContext {
1863 ToolContext::new(".".to_string())
1864 }
1865
1866 fn counter(counters: &Arc<Mutex<HashMap<&'static str, u64>>>, name: &'static str) -> u64 {
1867 counters
1868 .lock()
1869 .expect("metrics lock poisoned")
1870 .get(name)
1871 .copied()
1872 .unwrap_or(0)
1873 }
1874
1875 fn histogram_count(
1876 histograms: &Arc<Mutex<HashMap<&'static str, usize>>>,
1877 name: &'static str,
1878 ) -> usize {
1879 histograms
1880 .lock()
1881 .expect("metrics lock poisoned")
1882 .get(name)
1883 .copied()
1884 .unwrap_or(0)
1885 }
1886
1887 #[tokio::test]
1888 async fn respond_appends_user_then_assistant_memory_entries() {
1889 let entries = Arc::new(Mutex::new(Vec::new()));
1890 let prompts = Arc::new(Mutex::new(Vec::new()));
1891 let memory = TestMemory {
1892 entries: entries.clone(),
1893 };
1894 let provider = TestProvider {
1895 received_prompts: prompts,
1896 response_text: "assistant-output".to_string(),
1897 };
1898 let agent = Agent::new(
1899 AgentConfig::default(),
1900 Box::new(provider),
1901 Box::new(memory),
1902 vec![],
1903 );
1904
1905 let response = agent
1906 .respond(
1907 UserMessage {
1908 text: "hello world".to_string(),
1909 },
1910 &test_ctx(),
1911 )
1912 .await
1913 .expect("agent respond should succeed");
1914
1915 assert_eq!(response.text, "assistant-output");
1916 let stored = entries.lock().expect("memory lock poisoned");
1917 assert_eq!(stored.len(), 2);
1918 assert_eq!(stored[0].role, "user");
1919 assert_eq!(stored[0].content, "hello world");
1920 assert_eq!(stored[1].role, "assistant");
1921 assert_eq!(stored[1].content, "assistant-output");
1922 }
1923
1924 #[tokio::test]
1925 async fn respond_invokes_tool_for_tool_prefixed_prompt() {
1926 let entries = Arc::new(Mutex::new(Vec::new()));
1927 let prompts = Arc::new(Mutex::new(Vec::new()));
1928 let memory = TestMemory {
1929 entries: entries.clone(),
1930 };
1931 let provider = TestProvider {
1932 received_prompts: prompts.clone(),
1933 response_text: "assistant-after-tool".to_string(),
1934 };
1935 let agent = Agent::new(
1936 AgentConfig::default(),
1937 Box::new(provider),
1938 Box::new(memory),
1939 vec![Box::new(EchoTool)],
1940 );
1941
1942 let response = agent
1943 .respond(
1944 UserMessage {
1945 text: "tool:echo ping".to_string(),
1946 },
1947 &test_ctx(),
1948 )
1949 .await
1950 .expect("agent respond should succeed");
1951
1952 assert_eq!(response.text, "assistant-after-tool");
1953 let prompts = prompts.lock().expect("provider lock poisoned");
1954 assert_eq!(prompts.len(), 1);
1955 assert!(prompts[0].contains("Recent memory:"));
1956 assert!(prompts[0].contains("Current input:\nTool output from echo: echoed:ping"));
1957 }
1958
1959 #[tokio::test]
1960 async fn respond_with_unknown_tool_falls_back_to_provider_prompt() {
1961 let entries = Arc::new(Mutex::new(Vec::new()));
1962 let prompts = Arc::new(Mutex::new(Vec::new()));
1963 let memory = TestMemory {
1964 entries: entries.clone(),
1965 };
1966 let provider = TestProvider {
1967 received_prompts: prompts.clone(),
1968 response_text: "assistant-without-tool".to_string(),
1969 };
1970 let agent = Agent::new(
1971 AgentConfig::default(),
1972 Box::new(provider),
1973 Box::new(memory),
1974 vec![Box::new(EchoTool)],
1975 );
1976
1977 let response = agent
1978 .respond(
1979 UserMessage {
1980 text: "tool:unknown payload".to_string(),
1981 },
1982 &test_ctx(),
1983 )
1984 .await
1985 .expect("agent respond should succeed");
1986
1987 assert_eq!(response.text, "assistant-without-tool");
1988 let prompts = prompts.lock().expect("provider lock poisoned");
1989 assert_eq!(prompts.len(), 1);
1990 assert!(prompts[0].contains("Recent memory:"));
1991 assert!(prompts[0].contains("Current input:\ntool:unknown payload"));
1992 }
1993
1994 #[tokio::test]
1995 async fn respond_includes_bounded_recent_memory_in_provider_prompt() {
1996 let entries = Arc::new(Mutex::new(vec![
1997 MemoryEntry {
1998 role: "assistant".to_string(),
1999 content: "very-old".to_string(),
2000 ..Default::default()
2001 },
2002 MemoryEntry {
2003 role: "user".to_string(),
2004 content: "recent-before-request".to_string(),
2005 ..Default::default()
2006 },
2007 ]));
2008 let prompts = Arc::new(Mutex::new(Vec::new()));
2009 let memory = TestMemory {
2010 entries: entries.clone(),
2011 };
2012 let provider = TestProvider {
2013 received_prompts: prompts.clone(),
2014 response_text: "ok".to_string(),
2015 };
2016 let agent = Agent::new(
2017 AgentConfig {
2018 memory_window_size: 2,
2019 ..AgentConfig::default()
2020 },
2021 Box::new(provider),
2022 Box::new(memory),
2023 vec![],
2024 );
2025
2026 agent
2027 .respond(
2028 UserMessage {
2029 text: "latest-user".to_string(),
2030 },
2031 &test_ctx(),
2032 )
2033 .await
2034 .expect("respond should succeed");
2035
2036 let prompts = prompts.lock().expect("provider lock poisoned");
2037 let provider_prompt = prompts.first().expect("provider prompt should exist");
2038 assert!(provider_prompt.contains("- user: recent-before-request"));
2039 assert!(provider_prompt.contains("- user: latest-user"));
2040 assert!(!provider_prompt.contains("very-old"));
2041 }
2042
2043 #[tokio::test]
2044 async fn respond_caps_provider_prompt_to_configured_max_chars() {
2045 let entries = Arc::new(Mutex::new(vec![MemoryEntry {
2046 role: "assistant".to_string(),
2047 content: "historic context ".repeat(16),
2048 ..Default::default()
2049 }]));
2050 let prompts = Arc::new(Mutex::new(Vec::new()));
2051 let memory = TestMemory {
2052 entries: entries.clone(),
2053 };
2054 let provider = TestProvider {
2055 received_prompts: prompts.clone(),
2056 response_text: "ok".to_string(),
2057 };
2058 let agent = Agent::new(
2059 AgentConfig {
2060 max_prompt_chars: 64,
2061 ..AgentConfig::default()
2062 },
2063 Box::new(provider),
2064 Box::new(memory),
2065 vec![],
2066 );
2067
2068 agent
2069 .respond(
2070 UserMessage {
2071 text: "final-tail-marker".to_string(),
2072 },
2073 &test_ctx(),
2074 )
2075 .await
2076 .expect("respond should succeed");
2077
2078 let prompts = prompts.lock().expect("provider lock poisoned");
2079 let provider_prompt = prompts.first().expect("provider prompt should exist");
2080 assert!(provider_prompt.chars().count() <= 64);
2081 assert!(provider_prompt.contains("final-tail-marker"));
2082 }
2083
2084 #[tokio::test]
2085 async fn respond_returns_provider_typed_error() {
2086 let memory = TestMemory::default();
2087 let counters = Arc::new(Mutex::new(HashMap::new()));
2088 let histograms = Arc::new(Mutex::new(HashMap::new()));
2089 let agent = Agent::new(
2090 AgentConfig::default(),
2091 Box::new(FailingProvider),
2092 Box::new(memory),
2093 vec![],
2094 )
2095 .with_metrics(Box::new(RecordingMetricsSink {
2096 counters: counters.clone(),
2097 histograms: histograms.clone(),
2098 }));
2099
2100 let result = agent
2101 .respond(
2102 UserMessage {
2103 text: "hello".to_string(),
2104 },
2105 &test_ctx(),
2106 )
2107 .await;
2108
2109 match result {
2110 Err(AgentError::Provider { source }) => {
2111 assert!(source.to_string().contains("provider boom"));
2112 }
2113 other => panic!("expected provider error, got {other:?}"),
2114 }
2115 assert_eq!(counter(&counters, "requests_total"), 1);
2116 assert_eq!(counter(&counters, "provider_errors_total"), 1);
2117 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2118 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2119 }
2120
2121 #[tokio::test]
2122 async fn respond_increments_tool_error_counter_on_tool_failure() {
2123 let memory = TestMemory::default();
2124 let prompts = Arc::new(Mutex::new(Vec::new()));
2125 let provider = TestProvider {
2126 received_prompts: prompts,
2127 response_text: "unused".to_string(),
2128 };
2129 let counters = Arc::new(Mutex::new(HashMap::new()));
2130 let histograms = Arc::new(Mutex::new(HashMap::new()));
2131 let agent = Agent::new(
2132 AgentConfig::default(),
2133 Box::new(provider),
2134 Box::new(memory),
2135 vec![Box::new(FailingTool)],
2136 )
2137 .with_metrics(Box::new(RecordingMetricsSink {
2138 counters: counters.clone(),
2139 histograms: histograms.clone(),
2140 }));
2141
2142 let result = agent
2143 .respond(
2144 UserMessage {
2145 text: "tool:boom ping".to_string(),
2146 },
2147 &test_ctx(),
2148 )
2149 .await;
2150
2151 match result {
2152 Err(AgentError::Tool { tool, source }) => {
2153 assert_eq!(tool, "boom");
2154 assert!(source.to_string().contains("tool exploded"));
2155 }
2156 other => panic!("expected tool error, got {other:?}"),
2157 }
2158 assert_eq!(counter(&counters, "requests_total"), 1);
2159 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2160 assert_eq!(counter(&counters, "tool_errors_total"), 1);
2161 assert_eq!(histogram_count(&histograms, "tool_latency_ms"), 1);
2162 }
2163
2164 #[tokio::test]
2165 async fn respond_increments_requests_counter_on_success() {
2166 let memory = TestMemory::default();
2167 let prompts = Arc::new(Mutex::new(Vec::new()));
2168 let provider = TestProvider {
2169 received_prompts: prompts,
2170 response_text: "ok".to_string(),
2171 };
2172 let counters = Arc::new(Mutex::new(HashMap::new()));
2173 let histograms = Arc::new(Mutex::new(HashMap::new()));
2174 let agent = Agent::new(
2175 AgentConfig::default(),
2176 Box::new(provider),
2177 Box::new(memory),
2178 vec![],
2179 )
2180 .with_metrics(Box::new(RecordingMetricsSink {
2181 counters: counters.clone(),
2182 histograms: histograms.clone(),
2183 }));
2184
2185 let response = agent
2186 .respond(
2187 UserMessage {
2188 text: "hello".to_string(),
2189 },
2190 &test_ctx(),
2191 )
2192 .await
2193 .expect("response should succeed");
2194
2195 assert_eq!(response.text, "ok");
2196 assert_eq!(counter(&counters, "requests_total"), 1);
2197 assert_eq!(counter(&counters, "provider_errors_total"), 0);
2198 assert_eq!(counter(&counters, "tool_errors_total"), 0);
2199 assert_eq!(histogram_count(&histograms, "provider_latency_ms"), 1);
2200 }
2201
2202 #[tokio::test]
2203 async fn respond_returns_timeout_typed_error() {
2204 let memory = TestMemory::default();
2205 let agent = Agent::new(
2206 AgentConfig {
2207 max_tool_iterations: 1,
2208 request_timeout_ms: 10,
2209 ..AgentConfig::default()
2210 },
2211 Box::new(SlowProvider),
2212 Box::new(memory),
2213 vec![],
2214 );
2215
2216 let result = agent
2217 .respond(
2218 UserMessage {
2219 text: "hello".to_string(),
2220 },
2221 &test_ctx(),
2222 )
2223 .await;
2224
2225 match result {
2226 Err(AgentError::Timeout { timeout_ms }) => assert_eq!(timeout_ms, 10),
2227 other => panic!("expected timeout error, got {other:?}"),
2228 }
2229 }
2230
2231 #[tokio::test]
2232 async fn respond_emits_before_after_hook_events_when_enabled() {
2233 let events = Arc::new(Mutex::new(Vec::new()));
2234 let memory = TestMemory::default();
2235 let provider = TestProvider {
2236 received_prompts: Arc::new(Mutex::new(Vec::new())),
2237 response_text: "ok".to_string(),
2238 };
2239 let agent = Agent::new(
2240 AgentConfig {
2241 hooks: HookPolicy {
2242 enabled: true,
2243 timeout_ms: 50,
2244 fail_closed: true,
2245 ..HookPolicy::default()
2246 },
2247 ..AgentConfig::default()
2248 },
2249 Box::new(provider),
2250 Box::new(memory),
2251 vec![Box::new(EchoTool), Box::new(PluginEchoTool)],
2252 )
2253 .with_hooks(Box::new(RecordingHookSink {
2254 events: events.clone(),
2255 }));
2256
2257 agent
2258 .respond(
2259 UserMessage {
2260 text: "tool:plugin:echo ping".to_string(),
2261 },
2262 &test_ctx(),
2263 )
2264 .await
2265 .expect("respond should succeed");
2266
2267 let stages = events.lock().expect("hook lock poisoned");
2268 assert!(stages.contains(&"before_run".to_string()));
2269 assert!(stages.contains(&"after_run".to_string()));
2270 assert!(stages.contains(&"before_tool_call".to_string()));
2271 assert!(stages.contains(&"after_tool_call".to_string()));
2272 assert!(stages.contains(&"before_plugin_call".to_string()));
2273 assert!(stages.contains(&"after_plugin_call".to_string()));
2274 assert!(stages.contains(&"before_provider_call".to_string()));
2275 assert!(stages.contains(&"after_provider_call".to_string()));
2276 assert!(stages.contains(&"before_memory_write".to_string()));
2277 assert!(stages.contains(&"after_memory_write".to_string()));
2278 assert!(stages.contains(&"before_response_emit".to_string()));
2279 assert!(stages.contains(&"after_response_emit".to_string()));
2280 }
2281
2282 #[tokio::test]
2283 async fn respond_audit_events_include_request_id_and_durations() {
2284 let audit_events = Arc::new(Mutex::new(Vec::new()));
2285 let memory = TestMemory::default();
2286 let provider = TestProvider {
2287 received_prompts: Arc::new(Mutex::new(Vec::new())),
2288 response_text: "ok".to_string(),
2289 };
2290 let agent = Agent::new(
2291 AgentConfig::default(),
2292 Box::new(provider),
2293 Box::new(memory),
2294 vec![Box::new(EchoTool)],
2295 )
2296 .with_audit(Box::new(RecordingAuditSink {
2297 events: audit_events.clone(),
2298 }));
2299
2300 agent
2301 .respond(
2302 UserMessage {
2303 text: "tool:echo ping".to_string(),
2304 },
2305 &test_ctx(),
2306 )
2307 .await
2308 .expect("respond should succeed");
2309
2310 let events = audit_events.lock().expect("audit lock poisoned");
2311 let request_id = events
2312 .iter()
2313 .find_map(|e| {
2314 e.detail
2315 .get("request_id")
2316 .and_then(Value::as_str)
2317 .map(ToString::to_string)
2318 })
2319 .expect("request_id should exist on audit events");
2320
2321 let provider_event = events
2322 .iter()
2323 .find(|e| e.stage == "provider_call_success")
2324 .expect("provider success event should exist");
2325 assert_eq!(
2326 provider_event
2327 .detail
2328 .get("request_id")
2329 .and_then(Value::as_str),
2330 Some(request_id.as_str())
2331 );
2332 assert!(provider_event
2333 .detail
2334 .get("duration_ms")
2335 .and_then(Value::as_u64)
2336 .is_some());
2337
2338 let tool_event = events
2339 .iter()
2340 .find(|e| e.stage == "tool_execute_success")
2341 .expect("tool success event should exist");
2342 assert_eq!(
2343 tool_event.detail.get("request_id").and_then(Value::as_str),
2344 Some(request_id.as_str())
2345 );
2346 assert!(tool_event
2347 .detail
2348 .get("duration_ms")
2349 .and_then(Value::as_u64)
2350 .is_some());
2351 }
2352
2353 #[tokio::test]
2354 async fn hook_errors_respect_block_mode_for_high_tier_negative_path() {
2355 let memory = TestMemory::default();
2356 let provider = TestProvider {
2357 received_prompts: Arc::new(Mutex::new(Vec::new())),
2358 response_text: "ok".to_string(),
2359 };
2360 let agent = Agent::new(
2361 AgentConfig {
2362 hooks: HookPolicy {
2363 enabled: true,
2364 timeout_ms: 50,
2365 fail_closed: false,
2366 default_mode: HookFailureMode::Warn,
2367 low_tier_mode: HookFailureMode::Ignore,
2368 medium_tier_mode: HookFailureMode::Warn,
2369 high_tier_mode: HookFailureMode::Block,
2370 },
2371 ..AgentConfig::default()
2372 },
2373 Box::new(provider),
2374 Box::new(memory),
2375 vec![Box::new(EchoTool)],
2376 )
2377 .with_hooks(Box::new(FailingHookSink));
2378
2379 let result = agent
2380 .respond(
2381 UserMessage {
2382 text: "tool:echo ping".to_string(),
2383 },
2384 &test_ctx(),
2385 )
2386 .await;
2387
2388 match result {
2389 Err(AgentError::Hook { stage, .. }) => assert_eq!(stage, "before_tool_call"),
2390 other => panic!("expected hook block error, got {other:?}"),
2391 }
2392 }
2393
2394 #[tokio::test]
2395 async fn hook_errors_respect_warn_mode_for_high_tier_success_path() {
2396 let audit_events = Arc::new(Mutex::new(Vec::new()));
2397 let memory = TestMemory::default();
2398 let provider = TestProvider {
2399 received_prompts: Arc::new(Mutex::new(Vec::new())),
2400 response_text: "ok".to_string(),
2401 };
2402 let agent = Agent::new(
2403 AgentConfig {
2404 hooks: HookPolicy {
2405 enabled: true,
2406 timeout_ms: 50,
2407 fail_closed: false,
2408 default_mode: HookFailureMode::Warn,
2409 low_tier_mode: HookFailureMode::Ignore,
2410 medium_tier_mode: HookFailureMode::Warn,
2411 high_tier_mode: HookFailureMode::Warn,
2412 },
2413 ..AgentConfig::default()
2414 },
2415 Box::new(provider),
2416 Box::new(memory),
2417 vec![Box::new(EchoTool)],
2418 )
2419 .with_hooks(Box::new(FailingHookSink))
2420 .with_audit(Box::new(RecordingAuditSink {
2421 events: audit_events.clone(),
2422 }));
2423
2424 let response = agent
2425 .respond(
2426 UserMessage {
2427 text: "tool:echo ping".to_string(),
2428 },
2429 &test_ctx(),
2430 )
2431 .await
2432 .expect("warn mode should continue");
2433 assert!(!response.text.is_empty());
2434
2435 let events = audit_events.lock().expect("audit lock poisoned");
2436 assert!(events.iter().any(|event| event.stage == "hook_error_warn"));
2437 }
2438
2439 #[tokio::test]
2442 async fn self_correction_injected_on_no_progress() {
2443 let prompts = Arc::new(Mutex::new(Vec::new()));
2444 let provider = TestProvider {
2445 received_prompts: prompts.clone(),
2446 response_text: "I'll try a different approach".to_string(),
2447 };
2448 let agent = Agent::new(
2451 AgentConfig {
2452 loop_detection_no_progress_threshold: 3,
2453 max_tool_iterations: 20,
2454 ..AgentConfig::default()
2455 },
2456 Box::new(provider),
2457 Box::new(TestMemory::default()),
2458 vec![Box::new(LoopTool)],
2459 );
2460
2461 let response = agent
2462 .respond(
2463 UserMessage {
2464 text: "tool:loop_tool x".to_string(),
2465 },
2466 &test_ctx(),
2467 )
2468 .await
2469 .expect("self-correction should succeed");
2470
2471 assert_eq!(response.text, "I'll try a different approach");
2472
2473 let received = prompts.lock().expect("provider lock poisoned");
2475 let last_prompt = received.last().expect("provider should have been called");
2476 assert!(
2477 last_prompt.contains("Loop detected"),
2478 "provider prompt should contain self-correction notice, got: {last_prompt}"
2479 );
2480 assert!(last_prompt.contains("loop_tool"));
2481 }
2482
2483 #[tokio::test]
2484 async fn self_correction_injected_on_ping_pong() {
2485 let prompts = Arc::new(Mutex::new(Vec::new()));
2486 let provider = TestProvider {
2487 received_prompts: prompts.clone(),
2488 response_text: "I stopped the loop".to_string(),
2489 };
2490 let agent = Agent::new(
2492 AgentConfig {
2493 loop_detection_ping_pong_cycles: 2,
2494 loop_detection_no_progress_threshold: 0, max_tool_iterations: 20,
2496 ..AgentConfig::default()
2497 },
2498 Box::new(provider),
2499 Box::new(TestMemory::default()),
2500 vec![Box::new(PingTool), Box::new(PongTool)],
2501 );
2502
2503 let response = agent
2504 .respond(
2505 UserMessage {
2506 text: "tool:ping x".to_string(),
2507 },
2508 &test_ctx(),
2509 )
2510 .await
2511 .expect("self-correction should succeed");
2512
2513 assert_eq!(response.text, "I stopped the loop");
2514
2515 let received = prompts.lock().expect("provider lock poisoned");
2516 let last_prompt = received.last().expect("provider should have been called");
2517 assert!(
2518 last_prompt.contains("ping-pong"),
2519 "provider prompt should contain ping-pong notice, got: {last_prompt}"
2520 );
2521 }
2522
2523 #[tokio::test]
2524 async fn no_progress_disabled_when_threshold_zero() {
2525 let prompts = Arc::new(Mutex::new(Vec::new()));
2526 let provider = TestProvider {
2527 received_prompts: prompts.clone(),
2528 response_text: "final answer".to_string(),
2529 };
2530 let agent = Agent::new(
2533 AgentConfig {
2534 loop_detection_no_progress_threshold: 0,
2535 loop_detection_ping_pong_cycles: 0,
2536 max_tool_iterations: 5,
2537 ..AgentConfig::default()
2538 },
2539 Box::new(provider),
2540 Box::new(TestMemory::default()),
2541 vec![Box::new(LoopTool)],
2542 );
2543
2544 let response = agent
2545 .respond(
2546 UserMessage {
2547 text: "tool:loop_tool x".to_string(),
2548 },
2549 &test_ctx(),
2550 )
2551 .await
2552 .expect("should complete without error");
2553
2554 assert_eq!(response.text, "final answer");
2555
2556 let received = prompts.lock().expect("provider lock poisoned");
2558 let last_prompt = received.last().expect("provider should have been called");
2559 assert!(
2560 !last_prompt.contains("Loop detected"),
2561 "should not contain self-correction when detection is disabled"
2562 );
2563 }
2564
2565 #[tokio::test]
2568 async fn parallel_tools_executes_multiple_calls() {
2569 let prompts = Arc::new(Mutex::new(Vec::new()));
2570 let provider = TestProvider {
2571 received_prompts: prompts.clone(),
2572 response_text: "parallel done".to_string(),
2573 };
2574 let agent = Agent::new(
2575 AgentConfig {
2576 parallel_tools: true,
2577 max_tool_iterations: 5,
2578 ..AgentConfig::default()
2579 },
2580 Box::new(provider),
2581 Box::new(TestMemory::default()),
2582 vec![Box::new(EchoTool), Box::new(UpperTool)],
2583 );
2584
2585 let response = agent
2586 .respond(
2587 UserMessage {
2588 text: "tool:echo hello\ntool:upper world".to_string(),
2589 },
2590 &test_ctx(),
2591 )
2592 .await
2593 .expect("parallel tools should succeed");
2594
2595 assert_eq!(response.text, "parallel done");
2596
2597 let received = prompts.lock().expect("provider lock poisoned");
2598 let last_prompt = received.last().expect("provider should have been called");
2599 assert!(
2600 last_prompt.contains("echoed:hello"),
2601 "should contain echo output, got: {last_prompt}"
2602 );
2603 assert!(
2604 last_prompt.contains("WORLD"),
2605 "should contain upper output, got: {last_prompt}"
2606 );
2607 }
2608
2609 #[tokio::test]
2610 async fn parallel_tools_maintains_stable_ordering() {
2611 let prompts = Arc::new(Mutex::new(Vec::new()));
2612 let provider = TestProvider {
2613 received_prompts: prompts.clone(),
2614 response_text: "ordered".to_string(),
2615 };
2616 let agent = Agent::new(
2617 AgentConfig {
2618 parallel_tools: true,
2619 max_tool_iterations: 5,
2620 ..AgentConfig::default()
2621 },
2622 Box::new(provider),
2623 Box::new(TestMemory::default()),
2624 vec![Box::new(EchoTool), Box::new(UpperTool)],
2625 );
2626
2627 let response = agent
2629 .respond(
2630 UserMessage {
2631 text: "tool:echo aaa\ntool:upper bbb".to_string(),
2632 },
2633 &test_ctx(),
2634 )
2635 .await
2636 .expect("parallel tools should succeed");
2637
2638 assert_eq!(response.text, "ordered");
2639
2640 let received = prompts.lock().expect("provider lock poisoned");
2641 let last_prompt = received.last().expect("provider should have been called");
2642 let echo_pos = last_prompt
2643 .find("echoed:aaa")
2644 .expect("echo output should exist");
2645 let upper_pos = last_prompt.find("BBB").expect("upper output should exist");
2646 assert!(
2647 echo_pos < upper_pos,
2648 "echo output should come before upper output in the prompt"
2649 );
2650 }
2651
2652 #[tokio::test]
2653 async fn parallel_disabled_runs_first_call_only() {
2654 let prompts = Arc::new(Mutex::new(Vec::new()));
2655 let provider = TestProvider {
2656 received_prompts: prompts.clone(),
2657 response_text: "sequential".to_string(),
2658 };
2659 let agent = Agent::new(
2660 AgentConfig {
2661 parallel_tools: false,
2662 max_tool_iterations: 5,
2663 ..AgentConfig::default()
2664 },
2665 Box::new(provider),
2666 Box::new(TestMemory::default()),
2667 vec![Box::new(EchoTool), Box::new(UpperTool)],
2668 );
2669
2670 let response = agent
2671 .respond(
2672 UserMessage {
2673 text: "tool:echo hello\ntool:upper world".to_string(),
2674 },
2675 &test_ctx(),
2676 )
2677 .await
2678 .expect("sequential tools should succeed");
2679
2680 assert_eq!(response.text, "sequential");
2681
2682 let received = prompts.lock().expect("provider lock poisoned");
2684 let last_prompt = received.last().expect("provider should have been called");
2685 assert!(
2686 last_prompt.contains("echoed:hello"),
2687 "should contain first tool output, got: {last_prompt}"
2688 );
2689 assert!(
2690 !last_prompt.contains("WORLD"),
2691 "should NOT contain second tool output when parallel is disabled"
2692 );
2693 }
2694
2695 #[tokio::test]
2696 async fn single_call_parallel_matches_sequential() {
2697 let prompts = Arc::new(Mutex::new(Vec::new()));
2698 let provider = TestProvider {
2699 received_prompts: prompts.clone(),
2700 response_text: "single".to_string(),
2701 };
2702 let agent = Agent::new(
2703 AgentConfig {
2704 parallel_tools: true,
2705 max_tool_iterations: 5,
2706 ..AgentConfig::default()
2707 },
2708 Box::new(provider),
2709 Box::new(TestMemory::default()),
2710 vec![Box::new(EchoTool)],
2711 );
2712
2713 let response = agent
2714 .respond(
2715 UserMessage {
2716 text: "tool:echo ping".to_string(),
2717 },
2718 &test_ctx(),
2719 )
2720 .await
2721 .expect("single tool with parallel enabled should succeed");
2722
2723 assert_eq!(response.text, "single");
2724
2725 let received = prompts.lock().expect("provider lock poisoned");
2726 let last_prompt = received.last().expect("provider should have been called");
2727 assert!(
2728 last_prompt.contains("echoed:ping"),
2729 "single tool call should work same as sequential, got: {last_prompt}"
2730 );
2731 }
2732
2733 #[tokio::test]
2734 async fn parallel_falls_back_to_sequential_for_gated_tools() {
2735 let prompts = Arc::new(Mutex::new(Vec::new()));
2736 let provider = TestProvider {
2737 received_prompts: prompts.clone(),
2738 response_text: "gated sequential".to_string(),
2739 };
2740 let mut gated = std::collections::HashSet::new();
2741 gated.insert("upper".to_string());
2742 let agent = Agent::new(
2743 AgentConfig {
2744 parallel_tools: true,
2745 gated_tools: gated,
2746 max_tool_iterations: 5,
2747 ..AgentConfig::default()
2748 },
2749 Box::new(provider),
2750 Box::new(TestMemory::default()),
2751 vec![Box::new(EchoTool), Box::new(UpperTool)],
2752 );
2753
2754 let response = agent
2757 .respond(
2758 UserMessage {
2759 text: "tool:echo hello\ntool:upper world".to_string(),
2760 },
2761 &test_ctx(),
2762 )
2763 .await
2764 .expect("gated fallback should succeed");
2765
2766 assert_eq!(response.text, "gated sequential");
2767
2768 let received = prompts.lock().expect("provider lock poisoned");
2769 let last_prompt = received.last().expect("provider should have been called");
2770 assert!(
2771 last_prompt.contains("echoed:hello"),
2772 "first tool should execute, got: {last_prompt}"
2773 );
2774 assert!(
2775 !last_prompt.contains("WORLD"),
2776 "gated tool should NOT execute in parallel, got: {last_prompt}"
2777 );
2778 }
2779
2780 fn research_config(trigger: ResearchTrigger) -> ResearchPolicy {
2783 ResearchPolicy {
2784 enabled: true,
2785 trigger,
2786 keywords: vec!["search".to_string(), "find".to_string()],
2787 min_message_length: 10,
2788 max_iterations: 5,
2789 show_progress: true,
2790 }
2791 }
2792
2793 fn config_with_research(trigger: ResearchTrigger) -> AgentConfig {
2794 AgentConfig {
2795 research: research_config(trigger),
2796 ..Default::default()
2797 }
2798 }
2799
2800 #[tokio::test]
2801 async fn research_trigger_never() {
2802 let prompts = Arc::new(Mutex::new(Vec::new()));
2803 let provider = TestProvider {
2804 received_prompts: prompts.clone(),
2805 response_text: "final answer".to_string(),
2806 };
2807 let agent = Agent::new(
2808 config_with_research(ResearchTrigger::Never),
2809 Box::new(provider),
2810 Box::new(TestMemory::default()),
2811 vec![Box::new(EchoTool)],
2812 );
2813
2814 let response = agent
2815 .respond(
2816 UserMessage {
2817 text: "search for something".to_string(),
2818 },
2819 &test_ctx(),
2820 )
2821 .await
2822 .expect("should succeed");
2823
2824 assert_eq!(response.text, "final answer");
2825 let received = prompts.lock().expect("lock");
2826 assert_eq!(
2827 received.len(),
2828 1,
2829 "should have exactly 1 provider call (no research)"
2830 );
2831 }
2832
2833 #[tokio::test]
2834 async fn research_trigger_always() {
2835 let provider = ScriptedProvider::new(vec![
2836 "tool:echo gathering data",
2837 "Found relevant information",
2838 "Final answer with research context",
2839 ]);
2840 let prompts = provider.received_prompts.clone();
2841 let agent = Agent::new(
2842 config_with_research(ResearchTrigger::Always),
2843 Box::new(provider),
2844 Box::new(TestMemory::default()),
2845 vec![Box::new(EchoTool)],
2846 );
2847
2848 let response = agent
2849 .respond(
2850 UserMessage {
2851 text: "hello world".to_string(),
2852 },
2853 &test_ctx(),
2854 )
2855 .await
2856 .expect("should succeed");
2857
2858 assert_eq!(response.text, "Final answer with research context");
2859 let received = prompts.lock().expect("lock");
2860 let last = received.last().expect("should have provider calls");
2861 assert!(
2862 last.contains("Research findings:"),
2863 "final prompt should contain research findings, got: {last}"
2864 );
2865 }
2866
2867 #[tokio::test]
2868 async fn research_trigger_keywords_match() {
2869 let provider = ScriptedProvider::new(vec!["Summary of search", "Answer based on research"]);
2870 let prompts = provider.received_prompts.clone();
2871 let agent = Agent::new(
2872 config_with_research(ResearchTrigger::Keywords),
2873 Box::new(provider),
2874 Box::new(TestMemory::default()),
2875 vec![],
2876 );
2877
2878 let response = agent
2879 .respond(
2880 UserMessage {
2881 text: "please search for the config".to_string(),
2882 },
2883 &test_ctx(),
2884 )
2885 .await
2886 .expect("should succeed");
2887
2888 assert_eq!(response.text, "Answer based on research");
2889 let received = prompts.lock().expect("lock");
2890 assert!(received.len() >= 2, "should have at least 2 provider calls");
2891 assert!(
2892 received[0].contains("RESEARCH mode"),
2893 "first call should be research prompt"
2894 );
2895 }
2896
2897 #[tokio::test]
2898 async fn research_trigger_keywords_no_match() {
2899 let prompts = Arc::new(Mutex::new(Vec::new()));
2900 let provider = TestProvider {
2901 received_prompts: prompts.clone(),
2902 response_text: "direct answer".to_string(),
2903 };
2904 let agent = Agent::new(
2905 config_with_research(ResearchTrigger::Keywords),
2906 Box::new(provider),
2907 Box::new(TestMemory::default()),
2908 vec![],
2909 );
2910
2911 let response = agent
2912 .respond(
2913 UserMessage {
2914 text: "hello world".to_string(),
2915 },
2916 &test_ctx(),
2917 )
2918 .await
2919 .expect("should succeed");
2920
2921 assert_eq!(response.text, "direct answer");
2922 let received = prompts.lock().expect("lock");
2923 assert_eq!(received.len(), 1, "no research phase should fire");
2924 }
2925
2926 #[tokio::test]
2927 async fn research_trigger_length_short_skips() {
2928 let prompts = Arc::new(Mutex::new(Vec::new()));
2929 let provider = TestProvider {
2930 received_prompts: prompts.clone(),
2931 response_text: "short".to_string(),
2932 };
2933 let config = AgentConfig {
2934 research: ResearchPolicy {
2935 min_message_length: 20,
2936 ..research_config(ResearchTrigger::Length)
2937 },
2938 ..Default::default()
2939 };
2940 let agent = Agent::new(
2941 config,
2942 Box::new(provider),
2943 Box::new(TestMemory::default()),
2944 vec![],
2945 );
2946 agent
2947 .respond(
2948 UserMessage {
2949 text: "hi".to_string(),
2950 },
2951 &test_ctx(),
2952 )
2953 .await
2954 .expect("should succeed");
2955 let received = prompts.lock().expect("lock");
2956 assert_eq!(received.len(), 1, "short message should skip research");
2957 }
2958
2959 #[tokio::test]
2960 async fn research_trigger_length_long_triggers() {
2961 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
2962 let prompts = provider.received_prompts.clone();
2963 let config = AgentConfig {
2964 research: ResearchPolicy {
2965 min_message_length: 20,
2966 ..research_config(ResearchTrigger::Length)
2967 },
2968 ..Default::default()
2969 };
2970 let agent = Agent::new(
2971 config,
2972 Box::new(provider),
2973 Box::new(TestMemory::default()),
2974 vec![],
2975 );
2976 agent
2977 .respond(
2978 UserMessage {
2979 text: "this is a longer message that exceeds the threshold".to_string(),
2980 },
2981 &test_ctx(),
2982 )
2983 .await
2984 .expect("should succeed");
2985 let received = prompts.lock().expect("lock");
2986 assert!(received.len() >= 2, "long message should trigger research");
2987 }
2988
2989 #[tokio::test]
2990 async fn research_trigger_question_with_mark() {
2991 let provider = ScriptedProvider::new(vec!["research summary", "answer with research"]);
2992 let prompts = provider.received_prompts.clone();
2993 let agent = Agent::new(
2994 config_with_research(ResearchTrigger::Question),
2995 Box::new(provider),
2996 Box::new(TestMemory::default()),
2997 vec![],
2998 );
2999 agent
3000 .respond(
3001 UserMessage {
3002 text: "what is the meaning of life?".to_string(),
3003 },
3004 &test_ctx(),
3005 )
3006 .await
3007 .expect("should succeed");
3008 let received = prompts.lock().expect("lock");
3009 assert!(received.len() >= 2, "question should trigger research");
3010 }
3011
3012 #[tokio::test]
3013 async fn research_trigger_question_without_mark() {
3014 let prompts = Arc::new(Mutex::new(Vec::new()));
3015 let provider = TestProvider {
3016 received_prompts: prompts.clone(),
3017 response_text: "no research".to_string(),
3018 };
3019 let agent = Agent::new(
3020 config_with_research(ResearchTrigger::Question),
3021 Box::new(provider),
3022 Box::new(TestMemory::default()),
3023 vec![],
3024 );
3025 agent
3026 .respond(
3027 UserMessage {
3028 text: "do this thing".to_string(),
3029 },
3030 &test_ctx(),
3031 )
3032 .await
3033 .expect("should succeed");
3034 let received = prompts.lock().expect("lock");
3035 assert_eq!(received.len(), 1, "non-question should skip research");
3036 }
3037
3038 #[tokio::test]
3039 async fn research_respects_max_iterations() {
3040 let provider = ScriptedProvider::new(vec![
3041 "tool:echo step1",
3042 "tool:echo step2",
3043 "tool:echo step3",
3044 "tool:echo step4",
3045 "tool:echo step5",
3046 "tool:echo step6",
3047 "tool:echo step7",
3048 "answer after research",
3049 ]);
3050 let prompts = provider.received_prompts.clone();
3051 let config = AgentConfig {
3052 research: ResearchPolicy {
3053 max_iterations: 3,
3054 ..research_config(ResearchTrigger::Always)
3055 },
3056 ..Default::default()
3057 };
3058 let agent = Agent::new(
3059 config,
3060 Box::new(provider),
3061 Box::new(TestMemory::default()),
3062 vec![Box::new(EchoTool)],
3063 );
3064
3065 let response = agent
3066 .respond(
3067 UserMessage {
3068 text: "test".to_string(),
3069 },
3070 &test_ctx(),
3071 )
3072 .await
3073 .expect("should succeed");
3074
3075 assert!(!response.text.is_empty(), "should get a response");
3076 let received = prompts.lock().expect("lock");
3077 let research_calls = received
3078 .iter()
3079 .filter(|p| p.contains("RESEARCH mode") || p.contains("Research iteration"))
3080 .count();
3081 assert_eq!(
3083 research_calls, 4,
3084 "should have 4 research provider calls (1 initial + 3 iterations)"
3085 );
3086 }
3087
3088 #[tokio::test]
3089 async fn research_disabled_skips_phase() {
3090 let prompts = Arc::new(Mutex::new(Vec::new()));
3091 let provider = TestProvider {
3092 received_prompts: prompts.clone(),
3093 response_text: "direct answer".to_string(),
3094 };
3095 let config = AgentConfig {
3096 research: ResearchPolicy {
3097 enabled: false,
3098 trigger: ResearchTrigger::Always,
3099 ..Default::default()
3100 },
3101 ..Default::default()
3102 };
3103 let agent = Agent::new(
3104 config,
3105 Box::new(provider),
3106 Box::new(TestMemory::default()),
3107 vec![Box::new(EchoTool)],
3108 );
3109
3110 let response = agent
3111 .respond(
3112 UserMessage {
3113 text: "search for something".to_string(),
3114 },
3115 &test_ctx(),
3116 )
3117 .await
3118 .expect("should succeed");
3119
3120 assert_eq!(response.text, "direct answer");
3121 let received = prompts.lock().expect("lock");
3122 assert_eq!(
3123 received.len(),
3124 1,
3125 "disabled research should not fire even with Always trigger"
3126 );
3127 }
3128
3129 struct ReasoningCapturingProvider {
3132 captured_reasoning: Arc<Mutex<Vec<ReasoningConfig>>>,
3133 response_text: String,
3134 }
3135
3136 #[async_trait]
3137 impl Provider for ReasoningCapturingProvider {
3138 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3139 self.captured_reasoning
3140 .lock()
3141 .expect("lock")
3142 .push(ReasoningConfig::default());
3143 Ok(ChatResult {
3144 output_text: self.response_text.clone(),
3145 ..Default::default()
3146 })
3147 }
3148
3149 async fn complete_with_reasoning(
3150 &self,
3151 _prompt: &str,
3152 reasoning: &ReasoningConfig,
3153 ) -> anyhow::Result<ChatResult> {
3154 self.captured_reasoning
3155 .lock()
3156 .expect("lock")
3157 .push(reasoning.clone());
3158 Ok(ChatResult {
3159 output_text: self.response_text.clone(),
3160 ..Default::default()
3161 })
3162 }
3163 }
3164
3165 #[tokio::test]
3166 async fn reasoning_config_passed_to_provider() {
3167 let captured = Arc::new(Mutex::new(Vec::new()));
3168 let provider = ReasoningCapturingProvider {
3169 captured_reasoning: captured.clone(),
3170 response_text: "ok".to_string(),
3171 };
3172 let config = AgentConfig {
3173 reasoning: ReasoningConfig {
3174 enabled: Some(true),
3175 level: Some("high".to_string()),
3176 },
3177 ..Default::default()
3178 };
3179 let agent = Agent::new(
3180 config,
3181 Box::new(provider),
3182 Box::new(TestMemory::default()),
3183 vec![],
3184 );
3185
3186 agent
3187 .respond(
3188 UserMessage {
3189 text: "test".to_string(),
3190 },
3191 &test_ctx(),
3192 )
3193 .await
3194 .expect("should succeed");
3195
3196 let configs = captured.lock().expect("lock");
3197 assert_eq!(configs.len(), 1, "provider should be called once");
3198 assert_eq!(configs[0].enabled, Some(true));
3199 assert_eq!(configs[0].level.as_deref(), Some("high"));
3200 }
3201
3202 #[tokio::test]
3203 async fn reasoning_disabled_passes_config_through() {
3204 let captured = Arc::new(Mutex::new(Vec::new()));
3205 let provider = ReasoningCapturingProvider {
3206 captured_reasoning: captured.clone(),
3207 response_text: "ok".to_string(),
3208 };
3209 let config = AgentConfig {
3210 reasoning: ReasoningConfig {
3211 enabled: Some(false),
3212 level: None,
3213 },
3214 ..Default::default()
3215 };
3216 let agent = Agent::new(
3217 config,
3218 Box::new(provider),
3219 Box::new(TestMemory::default()),
3220 vec![],
3221 );
3222
3223 agent
3224 .respond(
3225 UserMessage {
3226 text: "test".to_string(),
3227 },
3228 &test_ctx(),
3229 )
3230 .await
3231 .expect("should succeed");
3232
3233 let configs = captured.lock().expect("lock");
3234 assert_eq!(configs.len(), 1);
3235 assert_eq!(
3236 configs[0].enabled,
3237 Some(false),
3238 "disabled reasoning should still be passed to provider"
3239 );
3240 assert_eq!(configs[0].level, None);
3241 }
3242
3243 struct StructuredProvider {
3246 responses: Vec<ChatResult>,
3247 call_count: AtomicUsize,
3248 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
3249 }
3250
3251 impl StructuredProvider {
3252 fn new(responses: Vec<ChatResult>) -> Self {
3253 Self {
3254 responses,
3255 call_count: AtomicUsize::new(0),
3256 received_messages: Arc::new(Mutex::new(Vec::new())),
3257 }
3258 }
3259 }
3260
3261 #[async_trait]
3262 impl Provider for StructuredProvider {
3263 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
3264 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3265 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3266 }
3267
3268 async fn complete_with_tools(
3269 &self,
3270 messages: &[ConversationMessage],
3271 _tools: &[ToolDefinition],
3272 _reasoning: &ReasoningConfig,
3273 ) -> anyhow::Result<ChatResult> {
3274 self.received_messages
3275 .lock()
3276 .expect("lock")
3277 .push(messages.to_vec());
3278 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
3279 Ok(self.responses.get(idx).cloned().unwrap_or_default())
3280 }
3281 }
3282
3283 struct StructuredEchoTool;
3284
3285 #[async_trait]
3286 impl Tool for StructuredEchoTool {
3287 fn name(&self) -> &'static str {
3288 "echo"
3289 }
3290 fn description(&self) -> &'static str {
3291 "Echoes input back"
3292 }
3293 fn input_schema(&self) -> Option<serde_json::Value> {
3294 Some(json!({
3295 "type": "object",
3296 "properties": {
3297 "text": { "type": "string", "description": "Text to echo" }
3298 },
3299 "required": ["text"]
3300 }))
3301 }
3302 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3303 Ok(ToolResult {
3304 output: format!("echoed:{input}"),
3305 })
3306 }
3307 }
3308
3309 struct StructuredFailingTool;
3310
3311 #[async_trait]
3312 impl Tool for StructuredFailingTool {
3313 fn name(&self) -> &'static str {
3314 "boom"
3315 }
3316 fn description(&self) -> &'static str {
3317 "Always fails"
3318 }
3319 fn input_schema(&self) -> Option<serde_json::Value> {
3320 Some(json!({
3321 "type": "object",
3322 "properties": {},
3323 }))
3324 }
3325 async fn execute(&self, _input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3326 Err(anyhow::anyhow!("tool exploded"))
3327 }
3328 }
3329
3330 struct StructuredUpperTool;
3331
3332 #[async_trait]
3333 impl Tool for StructuredUpperTool {
3334 fn name(&self) -> &'static str {
3335 "upper"
3336 }
3337 fn description(&self) -> &'static str {
3338 "Uppercases input"
3339 }
3340 fn input_schema(&self) -> Option<serde_json::Value> {
3341 Some(json!({
3342 "type": "object",
3343 "properties": {
3344 "text": { "type": "string", "description": "Text to uppercase" }
3345 },
3346 "required": ["text"]
3347 }))
3348 }
3349 async fn execute(&self, input: &str, _ctx: &ToolContext) -> anyhow::Result<ToolResult> {
3350 Ok(ToolResult {
3351 output: input.to_uppercase(),
3352 })
3353 }
3354 }
3355
3356 use crate::types::ToolUseRequest;
3357
3358 #[tokio::test]
3361 async fn structured_basic_tool_call_then_end_turn() {
3362 let provider = StructuredProvider::new(vec![
3363 ChatResult {
3364 output_text: "Let me echo that.".to_string(),
3365 tool_calls: vec![ToolUseRequest {
3366 id: "call_1".to_string(),
3367 name: "echo".to_string(),
3368 input: json!({"text": "hello"}),
3369 }],
3370 stop_reason: Some(StopReason::ToolUse),
3371 },
3372 ChatResult {
3373 output_text: "The echo returned: hello".to_string(),
3374 tool_calls: vec![],
3375 stop_reason: Some(StopReason::EndTurn),
3376 },
3377 ]);
3378 let received = provider.received_messages.clone();
3379
3380 let agent = Agent::new(
3381 AgentConfig::default(),
3382 Box::new(provider),
3383 Box::new(TestMemory::default()),
3384 vec![Box::new(StructuredEchoTool)],
3385 );
3386
3387 let response = agent
3388 .respond(
3389 UserMessage {
3390 text: "echo hello".to_string(),
3391 },
3392 &test_ctx(),
3393 )
3394 .await
3395 .expect("should succeed");
3396
3397 assert_eq!(response.text, "The echo returned: hello");
3398
3399 let msgs = received.lock().expect("lock");
3401 assert_eq!(msgs.len(), 2, "provider should be called twice");
3402 assert!(
3404 msgs[1].len() >= 3,
3405 "second call should have user + assistant + tool result"
3406 );
3407 }
3408
3409 #[tokio::test]
3410 async fn structured_no_tool_calls_returns_immediately() {
3411 let provider = StructuredProvider::new(vec![ChatResult {
3412 output_text: "Hello! No tools needed.".to_string(),
3413 tool_calls: vec![],
3414 stop_reason: Some(StopReason::EndTurn),
3415 }]);
3416
3417 let agent = Agent::new(
3418 AgentConfig::default(),
3419 Box::new(provider),
3420 Box::new(TestMemory::default()),
3421 vec![Box::new(StructuredEchoTool)],
3422 );
3423
3424 let response = agent
3425 .respond(
3426 UserMessage {
3427 text: "hi".to_string(),
3428 },
3429 &test_ctx(),
3430 )
3431 .await
3432 .expect("should succeed");
3433
3434 assert_eq!(response.text, "Hello! No tools needed.");
3435 }
3436
3437 #[tokio::test]
3438 async fn structured_tool_not_found_sends_error_result() {
3439 let provider = StructuredProvider::new(vec![
3440 ChatResult {
3441 output_text: String::new(),
3442 tool_calls: vec![ToolUseRequest {
3443 id: "call_1".to_string(),
3444 name: "nonexistent".to_string(),
3445 input: json!({}),
3446 }],
3447 stop_reason: Some(StopReason::ToolUse),
3448 },
3449 ChatResult {
3450 output_text: "I see the tool was not found.".to_string(),
3451 tool_calls: vec![],
3452 stop_reason: Some(StopReason::EndTurn),
3453 },
3454 ]);
3455 let received = provider.received_messages.clone();
3456
3457 let agent = Agent::new(
3458 AgentConfig::default(),
3459 Box::new(provider),
3460 Box::new(TestMemory::default()),
3461 vec![Box::new(StructuredEchoTool)],
3462 );
3463
3464 let response = agent
3465 .respond(
3466 UserMessage {
3467 text: "use nonexistent".to_string(),
3468 },
3469 &test_ctx(),
3470 )
3471 .await
3472 .expect("should succeed, not abort");
3473
3474 assert_eq!(response.text, "I see the tool was not found.");
3475
3476 let msgs = received.lock().expect("lock");
3478 let last_call = &msgs[1];
3479 let has_error_result = last_call.iter().any(|m| {
3480 matches!(m, ConversationMessage::ToolResult(r) if r.is_error && r.content.contains("not found"))
3481 });
3482 assert!(has_error_result, "should include error ToolResult");
3483 }
3484
3485 #[tokio::test]
3486 async fn structured_tool_error_does_not_abort() {
3487 let provider = StructuredProvider::new(vec![
3488 ChatResult {
3489 output_text: String::new(),
3490 tool_calls: vec![ToolUseRequest {
3491 id: "call_1".to_string(),
3492 name: "boom".to_string(),
3493 input: json!({}),
3494 }],
3495 stop_reason: Some(StopReason::ToolUse),
3496 },
3497 ChatResult {
3498 output_text: "I handled the error gracefully.".to_string(),
3499 tool_calls: vec![],
3500 stop_reason: Some(StopReason::EndTurn),
3501 },
3502 ]);
3503
3504 let agent = Agent::new(
3505 AgentConfig::default(),
3506 Box::new(provider),
3507 Box::new(TestMemory::default()),
3508 vec![Box::new(StructuredFailingTool)],
3509 );
3510
3511 let response = agent
3512 .respond(
3513 UserMessage {
3514 text: "boom".to_string(),
3515 },
3516 &test_ctx(),
3517 )
3518 .await
3519 .expect("should succeed, error becomes ToolResultMessage");
3520
3521 assert_eq!(response.text, "I handled the error gracefully.");
3522 }
3523
3524 #[tokio::test]
3525 async fn structured_multi_iteration_tool_calls() {
3526 let provider = StructuredProvider::new(vec![
3527 ChatResult {
3529 output_text: String::new(),
3530 tool_calls: vec![ToolUseRequest {
3531 id: "call_1".to_string(),
3532 name: "echo".to_string(),
3533 input: json!({"text": "first"}),
3534 }],
3535 stop_reason: Some(StopReason::ToolUse),
3536 },
3537 ChatResult {
3539 output_text: String::new(),
3540 tool_calls: vec![ToolUseRequest {
3541 id: "call_2".to_string(),
3542 name: "echo".to_string(),
3543 input: json!({"text": "second"}),
3544 }],
3545 stop_reason: Some(StopReason::ToolUse),
3546 },
3547 ChatResult {
3549 output_text: "Done with two tool calls.".to_string(),
3550 tool_calls: vec![],
3551 stop_reason: Some(StopReason::EndTurn),
3552 },
3553 ]);
3554 let received = provider.received_messages.clone();
3555
3556 let agent = Agent::new(
3557 AgentConfig::default(),
3558 Box::new(provider),
3559 Box::new(TestMemory::default()),
3560 vec![Box::new(StructuredEchoTool)],
3561 );
3562
3563 let response = agent
3564 .respond(
3565 UserMessage {
3566 text: "echo twice".to_string(),
3567 },
3568 &test_ctx(),
3569 )
3570 .await
3571 .expect("should succeed");
3572
3573 assert_eq!(response.text, "Done with two tool calls.");
3574 let msgs = received.lock().expect("lock");
3575 assert_eq!(msgs.len(), 3, "three provider calls");
3576 }
3577
3578 #[tokio::test]
3579 async fn structured_max_iterations_forces_final_answer() {
3580 let responses = vec![
3584 ChatResult {
3585 output_text: String::new(),
3586 tool_calls: vec![ToolUseRequest {
3587 id: "call_0".to_string(),
3588 name: "echo".to_string(),
3589 input: json!({"text": "loop"}),
3590 }],
3591 stop_reason: Some(StopReason::ToolUse),
3592 },
3593 ChatResult {
3594 output_text: String::new(),
3595 tool_calls: vec![ToolUseRequest {
3596 id: "call_1".to_string(),
3597 name: "echo".to_string(),
3598 input: json!({"text": "loop"}),
3599 }],
3600 stop_reason: Some(StopReason::ToolUse),
3601 },
3602 ChatResult {
3603 output_text: String::new(),
3604 tool_calls: vec![ToolUseRequest {
3605 id: "call_2".to_string(),
3606 name: "echo".to_string(),
3607 input: json!({"text": "loop"}),
3608 }],
3609 stop_reason: Some(StopReason::ToolUse),
3610 },
3611 ChatResult {
3613 output_text: "Forced final answer.".to_string(),
3614 tool_calls: vec![],
3615 stop_reason: Some(StopReason::EndTurn),
3616 },
3617 ];
3618
3619 let agent = Agent::new(
3620 AgentConfig {
3621 max_tool_iterations: 3,
3622 loop_detection_no_progress_threshold: 0, ..AgentConfig::default()
3624 },
3625 Box::new(StructuredProvider::new(responses)),
3626 Box::new(TestMemory::default()),
3627 vec![Box::new(StructuredEchoTool)],
3628 );
3629
3630 let response = agent
3631 .respond(
3632 UserMessage {
3633 text: "loop forever".to_string(),
3634 },
3635 &test_ctx(),
3636 )
3637 .await
3638 .expect("should succeed with forced answer");
3639
3640 assert_eq!(response.text, "Forced final answer.");
3641 }
3642
3643 #[tokio::test]
3644 async fn structured_fallback_to_text_when_disabled() {
3645 let prompts = Arc::new(Mutex::new(Vec::new()));
3646 let provider = TestProvider {
3647 received_prompts: prompts.clone(),
3648 response_text: "text path used".to_string(),
3649 };
3650
3651 let agent = Agent::new(
3652 AgentConfig {
3653 model_supports_tool_use: false,
3654 ..AgentConfig::default()
3655 },
3656 Box::new(provider),
3657 Box::new(TestMemory::default()),
3658 vec![Box::new(StructuredEchoTool)],
3659 );
3660
3661 let response = agent
3662 .respond(
3663 UserMessage {
3664 text: "hello".to_string(),
3665 },
3666 &test_ctx(),
3667 )
3668 .await
3669 .expect("should succeed");
3670
3671 assert_eq!(response.text, "text path used");
3672 }
3673
3674 #[tokio::test]
3675 async fn structured_fallback_when_no_tool_schemas() {
3676 let prompts = Arc::new(Mutex::new(Vec::new()));
3677 let provider = TestProvider {
3678 received_prompts: prompts.clone(),
3679 response_text: "text path used".to_string(),
3680 };
3681
3682 let agent = Agent::new(
3684 AgentConfig::default(),
3685 Box::new(provider),
3686 Box::new(TestMemory::default()),
3687 vec![Box::new(EchoTool)],
3688 );
3689
3690 let response = agent
3691 .respond(
3692 UserMessage {
3693 text: "hello".to_string(),
3694 },
3695 &test_ctx(),
3696 )
3697 .await
3698 .expect("should succeed");
3699
3700 assert_eq!(response.text, "text path used");
3701 }
3702
3703 #[tokio::test]
3704 async fn structured_parallel_tool_calls() {
3705 let provider = StructuredProvider::new(vec![
3706 ChatResult {
3707 output_text: String::new(),
3708 tool_calls: vec![
3709 ToolUseRequest {
3710 id: "call_1".to_string(),
3711 name: "echo".to_string(),
3712 input: json!({"text": "a"}),
3713 },
3714 ToolUseRequest {
3715 id: "call_2".to_string(),
3716 name: "upper".to_string(),
3717 input: json!({"text": "b"}),
3718 },
3719 ],
3720 stop_reason: Some(StopReason::ToolUse),
3721 },
3722 ChatResult {
3723 output_text: "Both tools ran.".to_string(),
3724 tool_calls: vec![],
3725 stop_reason: Some(StopReason::EndTurn),
3726 },
3727 ]);
3728 let received = provider.received_messages.clone();
3729
3730 let agent = Agent::new(
3731 AgentConfig {
3732 parallel_tools: true,
3733 ..AgentConfig::default()
3734 },
3735 Box::new(provider),
3736 Box::new(TestMemory::default()),
3737 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
3738 );
3739
3740 let response = agent
3741 .respond(
3742 UserMessage {
3743 text: "do both".to_string(),
3744 },
3745 &test_ctx(),
3746 )
3747 .await
3748 .expect("should succeed");
3749
3750 assert_eq!(response.text, "Both tools ran.");
3751
3752 let msgs = received.lock().expect("lock");
3754 let tool_results: Vec<_> = msgs[1]
3755 .iter()
3756 .filter(|m| matches!(m, ConversationMessage::ToolResult(_)))
3757 .collect();
3758 assert_eq!(tool_results.len(), 2, "should have two tool results");
3759 }
3760
3761 #[tokio::test]
3762 async fn structured_memory_integration() {
3763 let memory = TestMemory::default();
3764 memory
3766 .append(MemoryEntry {
3767 role: "user".to_string(),
3768 content: "old question".to_string(),
3769 ..Default::default()
3770 })
3771 .await
3772 .unwrap();
3773 memory
3774 .append(MemoryEntry {
3775 role: "assistant".to_string(),
3776 content: "old answer".to_string(),
3777 ..Default::default()
3778 })
3779 .await
3780 .unwrap();
3781
3782 let provider = StructuredProvider::new(vec![ChatResult {
3783 output_text: "I see your history.".to_string(),
3784 tool_calls: vec![],
3785 stop_reason: Some(StopReason::EndTurn),
3786 }]);
3787 let received = provider.received_messages.clone();
3788
3789 let agent = Agent::new(
3790 AgentConfig::default(),
3791 Box::new(provider),
3792 Box::new(memory),
3793 vec![Box::new(StructuredEchoTool)],
3794 );
3795
3796 agent
3797 .respond(
3798 UserMessage {
3799 text: "new question".to_string(),
3800 },
3801 &test_ctx(),
3802 )
3803 .await
3804 .expect("should succeed");
3805
3806 let msgs = received.lock().expect("lock");
3807 assert!(msgs[0].len() >= 3, "should include memory messages");
3810 assert!(matches!(
3812 &msgs[0][0],
3813 ConversationMessage::User { content } if content == "old question"
3814 ));
3815 }
3816
3817 #[tokio::test]
3818 async fn prepare_tool_input_extracts_single_string_field() {
3819 let tool = StructuredEchoTool;
3820 let input = json!({"text": "hello world"});
3821 let result = prepare_tool_input(&tool, &input).expect("valid input");
3822 assert_eq!(result, "hello world");
3823 }
3824
3825 #[tokio::test]
3826 async fn prepare_tool_input_serializes_multi_field_json() {
3827 struct MultiFieldTool;
3829 #[async_trait]
3830 impl Tool for MultiFieldTool {
3831 fn name(&self) -> &'static str {
3832 "multi"
3833 }
3834 fn input_schema(&self) -> Option<serde_json::Value> {
3835 Some(json!({
3836 "type": "object",
3837 "properties": {
3838 "path": { "type": "string" },
3839 "content": { "type": "string" }
3840 },
3841 "required": ["path", "content"]
3842 }))
3843 }
3844 async fn execute(
3845 &self,
3846 _input: &str,
3847 _ctx: &ToolContext,
3848 ) -> anyhow::Result<ToolResult> {
3849 unreachable!()
3850 }
3851 }
3852
3853 let tool = MultiFieldTool;
3854 let input = json!({"path": "a.txt", "content": "hello"});
3855 let result = prepare_tool_input(&tool, &input).expect("valid input");
3856 let parsed: serde_json::Value = serde_json::from_str(&result).expect("valid JSON");
3858 assert_eq!(parsed["path"], "a.txt");
3859 assert_eq!(parsed["content"], "hello");
3860 }
3861
3862 #[tokio::test]
3863 async fn prepare_tool_input_unwraps_bare_string() {
3864 let tool = StructuredEchoTool;
3865 let input = json!("plain text");
3866 let result = prepare_tool_input(&tool, &input).expect("valid input");
3867 assert_eq!(result, "plain text");
3868 }
3869
3870 #[tokio::test]
3871 async fn prepare_tool_input_rejects_schema_violating_input() {
3872 struct StrictTool;
3873 #[async_trait]
3874 impl Tool for StrictTool {
3875 fn name(&self) -> &'static str {
3876 "strict"
3877 }
3878 fn input_schema(&self) -> Option<serde_json::Value> {
3879 Some(json!({
3880 "type": "object",
3881 "required": ["path"],
3882 "properties": {
3883 "path": { "type": "string" }
3884 }
3885 }))
3886 }
3887 async fn execute(
3888 &self,
3889 _input: &str,
3890 _ctx: &ToolContext,
3891 ) -> anyhow::Result<ToolResult> {
3892 unreachable!("should not be called with invalid input")
3893 }
3894 }
3895
3896 let result = prepare_tool_input(&StrictTool, &json!({}));
3898 assert!(result.is_err());
3899 let err = result.unwrap_err();
3900 assert!(err.contains("Invalid input for tool 'strict'"));
3901 assert!(err.contains("missing required field"));
3902 }
3903
3904 #[tokio::test]
3905 async fn prepare_tool_input_rejects_wrong_type() {
3906 struct TypedTool;
3907 #[async_trait]
3908 impl Tool for TypedTool {
3909 fn name(&self) -> &'static str {
3910 "typed"
3911 }
3912 fn input_schema(&self) -> Option<serde_json::Value> {
3913 Some(json!({
3914 "type": "object",
3915 "required": ["count"],
3916 "properties": {
3917 "count": { "type": "integer" }
3918 }
3919 }))
3920 }
3921 async fn execute(
3922 &self,
3923 _input: &str,
3924 _ctx: &ToolContext,
3925 ) -> anyhow::Result<ToolResult> {
3926 unreachable!("should not be called with invalid input")
3927 }
3928 }
3929
3930 let result = prepare_tool_input(&TypedTool, &json!({"count": "not a number"}));
3932 assert!(result.is_err());
3933 let err = result.unwrap_err();
3934 assert!(err.contains("expected type \"integer\""));
3935 }
3936
3937 #[test]
3938 fn truncate_messages_preserves_small_conversation() {
3939 let mut msgs = vec![
3940 ConversationMessage::User {
3941 content: "hello".to_string(),
3942 },
3943 ConversationMessage::Assistant {
3944 content: Some("world".to_string()),
3945 tool_calls: vec![],
3946 },
3947 ];
3948 truncate_messages(&mut msgs, 1000);
3949 assert_eq!(msgs.len(), 2, "should not truncate small conversation");
3950 }
3951
3952 #[test]
3953 fn truncate_messages_drops_middle() {
3954 let mut msgs = vec![
3955 ConversationMessage::User {
3956 content: "A".repeat(100),
3957 },
3958 ConversationMessage::Assistant {
3959 content: Some("B".repeat(100)),
3960 tool_calls: vec![],
3961 },
3962 ConversationMessage::User {
3963 content: "C".repeat(100),
3964 },
3965 ConversationMessage::Assistant {
3966 content: Some("D".repeat(100)),
3967 tool_calls: vec![],
3968 },
3969 ];
3970 truncate_messages(&mut msgs, 250);
3973 assert!(msgs.len() < 4, "should have dropped some messages");
3974 assert!(matches!(
3976 &msgs[0],
3977 ConversationMessage::User { content } if content.starts_with("AAA")
3978 ));
3979 }
3980
3981 #[test]
3982 fn memory_to_messages_reverses_order() {
3983 let entries = vec![
3984 MemoryEntry {
3985 role: "assistant".to_string(),
3986 content: "newest".to_string(),
3987 ..Default::default()
3988 },
3989 MemoryEntry {
3990 role: "user".to_string(),
3991 content: "oldest".to_string(),
3992 ..Default::default()
3993 },
3994 ];
3995 let msgs = memory_to_messages(&entries);
3996 assert_eq!(msgs.len(), 2);
3997 assert!(matches!(
3998 &msgs[0],
3999 ConversationMessage::User { content } if content == "oldest"
4000 ));
4001 assert!(matches!(
4002 &msgs[1],
4003 ConversationMessage::Assistant { content: Some(c), .. } if c == "newest"
4004 ));
4005 }
4006
4007 #[test]
4008 fn single_required_string_field_detects_single() {
4009 let schema = json!({
4010 "type": "object",
4011 "properties": {
4012 "path": { "type": "string" }
4013 },
4014 "required": ["path"]
4015 });
4016 assert_eq!(
4017 single_required_string_field(&schema),
4018 Some("path".to_string())
4019 );
4020 }
4021
4022 #[test]
4023 fn single_required_string_field_none_for_multiple() {
4024 let schema = json!({
4025 "type": "object",
4026 "properties": {
4027 "path": { "type": "string" },
4028 "mode": { "type": "string" }
4029 },
4030 "required": ["path", "mode"]
4031 });
4032 assert_eq!(single_required_string_field(&schema), None);
4033 }
4034
4035 #[test]
4036 fn single_required_string_field_none_for_non_string() {
4037 let schema = json!({
4038 "type": "object",
4039 "properties": {
4040 "count": { "type": "integer" }
4041 },
4042 "required": ["count"]
4043 });
4044 assert_eq!(single_required_string_field(&schema), None);
4045 }
4046
4047 use crate::types::StreamChunk;
4050
4051 struct StreamingProvider {
4053 responses: Vec<ChatResult>,
4054 call_count: AtomicUsize,
4055 received_messages: Arc<Mutex<Vec<Vec<ConversationMessage>>>>,
4056 }
4057
4058 impl StreamingProvider {
4059 fn new(responses: Vec<ChatResult>) -> Self {
4060 Self {
4061 responses,
4062 call_count: AtomicUsize::new(0),
4063 received_messages: Arc::new(Mutex::new(Vec::new())),
4064 }
4065 }
4066 }
4067
4068 #[async_trait]
4069 impl Provider for StreamingProvider {
4070 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4071 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4072 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4073 }
4074
4075 async fn complete_streaming(
4076 &self,
4077 _prompt: &str,
4078 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4079 ) -> anyhow::Result<ChatResult> {
4080 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4081 let result = self.responses.get(idx).cloned().unwrap_or_default();
4082 for word in result.output_text.split_whitespace() {
4084 let _ = sender.send(StreamChunk {
4085 delta: format!("{word} "),
4086 done: false,
4087 tool_call_delta: None,
4088 });
4089 }
4090 let _ = sender.send(StreamChunk {
4091 delta: String::new(),
4092 done: true,
4093 tool_call_delta: None,
4094 });
4095 Ok(result)
4096 }
4097
4098 async fn complete_with_tools(
4099 &self,
4100 messages: &[ConversationMessage],
4101 _tools: &[ToolDefinition],
4102 _reasoning: &ReasoningConfig,
4103 ) -> anyhow::Result<ChatResult> {
4104 self.received_messages
4105 .lock()
4106 .expect("lock")
4107 .push(messages.to_vec());
4108 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4109 Ok(self.responses.get(idx).cloned().unwrap_or_default())
4110 }
4111
4112 async fn complete_streaming_with_tools(
4113 &self,
4114 messages: &[ConversationMessage],
4115 _tools: &[ToolDefinition],
4116 _reasoning: &ReasoningConfig,
4117 sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4118 ) -> anyhow::Result<ChatResult> {
4119 self.received_messages
4120 .lock()
4121 .expect("lock")
4122 .push(messages.to_vec());
4123 let idx = self.call_count.fetch_add(1, Ordering::Relaxed);
4124 let result = self.responses.get(idx).cloned().unwrap_or_default();
4125 for ch in result.output_text.chars() {
4127 let _ = sender.send(StreamChunk {
4128 delta: ch.to_string(),
4129 done: false,
4130 tool_call_delta: None,
4131 });
4132 }
4133 let _ = sender.send(StreamChunk {
4134 delta: String::new(),
4135 done: true,
4136 tool_call_delta: None,
4137 });
4138 Ok(result)
4139 }
4140 }
4141
4142 #[tokio::test]
4143 async fn streaming_text_only_sends_chunks() {
4144 let provider = StreamingProvider::new(vec![ChatResult {
4145 output_text: "Hello world".to_string(),
4146 ..Default::default()
4147 }]);
4148 let agent = Agent::new(
4149 AgentConfig {
4150 model_supports_tool_use: false,
4151 ..Default::default()
4152 },
4153 Box::new(provider),
4154 Box::new(TestMemory::default()),
4155 vec![],
4156 );
4157
4158 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4159 let response = agent
4160 .respond_streaming(
4161 UserMessage {
4162 text: "hi".to_string(),
4163 },
4164 &test_ctx(),
4165 tx,
4166 )
4167 .await
4168 .expect("should succeed");
4169
4170 assert_eq!(response.text, "Hello world");
4171
4172 let mut chunks = Vec::new();
4174 while let Ok(chunk) = rx.try_recv() {
4175 chunks.push(chunk);
4176 }
4177 assert!(chunks.len() >= 2, "should have at least text + done chunks");
4178 assert!(chunks.last().unwrap().done, "last chunk should be done");
4179 }
4180
4181 #[tokio::test]
4182 async fn streaming_single_tool_call_round_trip() {
4183 let provider = StreamingProvider::new(vec![
4184 ChatResult {
4185 output_text: "I'll echo that.".to_string(),
4186 tool_calls: vec![ToolUseRequest {
4187 id: "call_1".to_string(),
4188 name: "echo".to_string(),
4189 input: json!({"text": "hello"}),
4190 }],
4191 stop_reason: Some(StopReason::ToolUse),
4192 },
4193 ChatResult {
4194 output_text: "Done echoing.".to_string(),
4195 tool_calls: vec![],
4196 stop_reason: Some(StopReason::EndTurn),
4197 },
4198 ]);
4199
4200 let agent = Agent::new(
4201 AgentConfig::default(),
4202 Box::new(provider),
4203 Box::new(TestMemory::default()),
4204 vec![Box::new(StructuredEchoTool)],
4205 );
4206
4207 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4208 let response = agent
4209 .respond_streaming(
4210 UserMessage {
4211 text: "echo hello".to_string(),
4212 },
4213 &test_ctx(),
4214 tx,
4215 )
4216 .await
4217 .expect("should succeed");
4218
4219 assert_eq!(response.text, "Done echoing.");
4220
4221 let mut chunks = Vec::new();
4222 while let Ok(chunk) = rx.try_recv() {
4223 chunks.push(chunk);
4224 }
4225 assert!(chunks.len() >= 2, "should have streaming chunks");
4227 }
4228
4229 #[tokio::test]
4230 async fn streaming_multi_iteration_tool_calls() {
4231 let provider = StreamingProvider::new(vec![
4232 ChatResult {
4233 output_text: String::new(),
4234 tool_calls: vec![ToolUseRequest {
4235 id: "call_1".to_string(),
4236 name: "echo".to_string(),
4237 input: json!({"text": "first"}),
4238 }],
4239 stop_reason: Some(StopReason::ToolUse),
4240 },
4241 ChatResult {
4242 output_text: String::new(),
4243 tool_calls: vec![ToolUseRequest {
4244 id: "call_2".to_string(),
4245 name: "upper".to_string(),
4246 input: json!({"text": "second"}),
4247 }],
4248 stop_reason: Some(StopReason::ToolUse),
4249 },
4250 ChatResult {
4251 output_text: "All done.".to_string(),
4252 tool_calls: vec![],
4253 stop_reason: Some(StopReason::EndTurn),
4254 },
4255 ]);
4256
4257 let agent = Agent::new(
4258 AgentConfig::default(),
4259 Box::new(provider),
4260 Box::new(TestMemory::default()),
4261 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4262 );
4263
4264 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4265 let response = agent
4266 .respond_streaming(
4267 UserMessage {
4268 text: "do two things".to_string(),
4269 },
4270 &test_ctx(),
4271 tx,
4272 )
4273 .await
4274 .expect("should succeed");
4275
4276 assert_eq!(response.text, "All done.");
4277
4278 let mut chunks = Vec::new();
4279 while let Ok(chunk) = rx.try_recv() {
4280 chunks.push(chunk);
4281 }
4282 let done_count = chunks.iter().filter(|c| c.done).count();
4284 assert!(
4285 done_count >= 3,
4286 "should have done chunks from 3 provider calls, got {}",
4287 done_count
4288 );
4289 }
4290
4291 #[tokio::test]
4292 async fn streaming_timeout_returns_error() {
4293 struct StreamingSlowProvider;
4294
4295 #[async_trait]
4296 impl Provider for StreamingSlowProvider {
4297 async fn complete(&self, _prompt: &str) -> anyhow::Result<ChatResult> {
4298 sleep(Duration::from_millis(500)).await;
4299 Ok(ChatResult::default())
4300 }
4301
4302 async fn complete_streaming(
4303 &self,
4304 _prompt: &str,
4305 _sender: tokio::sync::mpsc::UnboundedSender<StreamChunk>,
4306 ) -> anyhow::Result<ChatResult> {
4307 sleep(Duration::from_millis(500)).await;
4308 Ok(ChatResult::default())
4309 }
4310 }
4311
4312 let agent = Agent::new(
4313 AgentConfig {
4314 request_timeout_ms: 50,
4315 model_supports_tool_use: false,
4316 ..Default::default()
4317 },
4318 Box::new(StreamingSlowProvider),
4319 Box::new(TestMemory::default()),
4320 vec![],
4321 );
4322
4323 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4324 let result = agent
4325 .respond_streaming(
4326 UserMessage {
4327 text: "hi".to_string(),
4328 },
4329 &test_ctx(),
4330 tx,
4331 )
4332 .await;
4333
4334 assert!(result.is_err());
4335 match result.unwrap_err() {
4336 AgentError::Timeout { timeout_ms } => assert_eq!(timeout_ms, 50),
4337 other => panic!("expected Timeout, got {other:?}"),
4338 }
4339 }
4340
4341 #[tokio::test]
4342 async fn streaming_no_schema_fallback_sends_chunks() {
4343 let provider = StreamingProvider::new(vec![ChatResult {
4345 output_text: "Fallback response".to_string(),
4346 ..Default::default()
4347 }]);
4348 let agent = Agent::new(
4349 AgentConfig::default(), Box::new(provider),
4351 Box::new(TestMemory::default()),
4352 vec![Box::new(EchoTool)], );
4354
4355 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4356 let response = agent
4357 .respond_streaming(
4358 UserMessage {
4359 text: "hi".to_string(),
4360 },
4361 &test_ctx(),
4362 tx,
4363 )
4364 .await
4365 .expect("should succeed");
4366
4367 assert_eq!(response.text, "Fallback response");
4368
4369 let mut chunks = Vec::new();
4370 while let Ok(chunk) = rx.try_recv() {
4371 chunks.push(chunk);
4372 }
4373 assert!(!chunks.is_empty(), "should have streaming chunks");
4374 assert!(chunks.last().unwrap().done, "last chunk should be done");
4375 }
4376
4377 #[tokio::test]
4378 async fn streaming_tool_error_does_not_abort() {
4379 let provider = StreamingProvider::new(vec![
4380 ChatResult {
4381 output_text: String::new(),
4382 tool_calls: vec![ToolUseRequest {
4383 id: "call_1".to_string(),
4384 name: "boom".to_string(),
4385 input: json!({}),
4386 }],
4387 stop_reason: Some(StopReason::ToolUse),
4388 },
4389 ChatResult {
4390 output_text: "Recovered from error.".to_string(),
4391 tool_calls: vec![],
4392 stop_reason: Some(StopReason::EndTurn),
4393 },
4394 ]);
4395
4396 let agent = Agent::new(
4397 AgentConfig::default(),
4398 Box::new(provider),
4399 Box::new(TestMemory::default()),
4400 vec![Box::new(StructuredFailingTool)],
4401 );
4402
4403 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4404 let response = agent
4405 .respond_streaming(
4406 UserMessage {
4407 text: "boom".to_string(),
4408 },
4409 &test_ctx(),
4410 tx,
4411 )
4412 .await
4413 .expect("should succeed despite tool error");
4414
4415 assert_eq!(response.text, "Recovered from error.");
4416 }
4417
4418 #[tokio::test]
4419 async fn streaming_parallel_tools() {
4420 let provider = StreamingProvider::new(vec![
4421 ChatResult {
4422 output_text: String::new(),
4423 tool_calls: vec![
4424 ToolUseRequest {
4425 id: "call_1".to_string(),
4426 name: "echo".to_string(),
4427 input: json!({"text": "a"}),
4428 },
4429 ToolUseRequest {
4430 id: "call_2".to_string(),
4431 name: "upper".to_string(),
4432 input: json!({"text": "b"}),
4433 },
4434 ],
4435 stop_reason: Some(StopReason::ToolUse),
4436 },
4437 ChatResult {
4438 output_text: "Parallel done.".to_string(),
4439 tool_calls: vec![],
4440 stop_reason: Some(StopReason::EndTurn),
4441 },
4442 ]);
4443
4444 let agent = Agent::new(
4445 AgentConfig {
4446 parallel_tools: true,
4447 ..Default::default()
4448 },
4449 Box::new(provider),
4450 Box::new(TestMemory::default()),
4451 vec![Box::new(StructuredEchoTool), Box::new(StructuredUpperTool)],
4452 );
4453
4454 let (tx, _rx) = tokio::sync::mpsc::unbounded_channel();
4455 let response = agent
4456 .respond_streaming(
4457 UserMessage {
4458 text: "parallel test".to_string(),
4459 },
4460 &test_ctx(),
4461 tx,
4462 )
4463 .await
4464 .expect("should succeed");
4465
4466 assert_eq!(response.text, "Parallel done.");
4467 }
4468
4469 #[tokio::test]
4470 async fn streaming_done_chunk_sentinel() {
4471 let provider = StreamingProvider::new(vec![ChatResult {
4472 output_text: "abc".to_string(),
4473 tool_calls: vec![],
4474 stop_reason: Some(StopReason::EndTurn),
4475 }]);
4476
4477 let agent = Agent::new(
4478 AgentConfig::default(),
4479 Box::new(provider),
4480 Box::new(TestMemory::default()),
4481 vec![Box::new(StructuredEchoTool)],
4482 );
4483
4484 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
4485 let _response = agent
4486 .respond_streaming(
4487 UserMessage {
4488 text: "test".to_string(),
4489 },
4490 &test_ctx(),
4491 tx,
4492 )
4493 .await
4494 .expect("should succeed");
4495
4496 let mut chunks = Vec::new();
4497 while let Ok(chunk) = rx.try_recv() {
4498 chunks.push(chunk);
4499 }
4500
4501 let done_chunks: Vec<_> = chunks.iter().filter(|c| c.done).collect();
4503 assert!(
4504 !done_chunks.is_empty(),
4505 "must have at least one done sentinel chunk"
4506 );
4507 assert!(chunks.last().unwrap().done, "final chunk must be done=true");
4509 let content_chunks: Vec<_> = chunks
4511 .iter()
4512 .filter(|c| !c.done && !c.delta.is_empty())
4513 .collect();
4514 assert!(!content_chunks.is_empty(), "should have content chunks");
4515 }
4516
4517 #[tokio::test]
4520 async fn system_prompt_prepended_in_structured_path() {
4521 let provider = StructuredProvider::new(vec![ChatResult {
4522 output_text: "I understand.".to_string(),
4523 stop_reason: Some(StopReason::EndTurn),
4524 ..Default::default()
4525 }]);
4526 let received = provider.received_messages.clone();
4527 let memory = TestMemory::default();
4528
4529 let agent = Agent::new(
4530 AgentConfig {
4531 model_supports_tool_use: true,
4532 system_prompt: Some("You are a math tutor.".to_string()),
4533 ..AgentConfig::default()
4534 },
4535 Box::new(provider),
4536 Box::new(memory),
4537 vec![Box::new(StructuredEchoTool)],
4538 );
4539
4540 let result = agent
4541 .respond(
4542 UserMessage {
4543 text: "What is 2+2?".to_string(),
4544 },
4545 &test_ctx(),
4546 )
4547 .await
4548 .expect("respond should succeed");
4549
4550 assert_eq!(result.text, "I understand.");
4551
4552 let msgs = received.lock().expect("lock");
4553 assert!(!msgs.is_empty(), "provider should have been called");
4554 let first_call = &msgs[0];
4555 match &first_call[0] {
4557 ConversationMessage::System { content } => {
4558 assert_eq!(content, "You are a math tutor.");
4559 }
4560 other => panic!("expected System message first, got {other:?}"),
4561 }
4562 match &first_call[1] {
4564 ConversationMessage::User { content } => {
4565 assert_eq!(content, "What is 2+2?");
4566 }
4567 other => panic!("expected User message second, got {other:?}"),
4568 }
4569 }
4570
4571 #[tokio::test]
4572 async fn no_system_prompt_omits_system_message() {
4573 let provider = StructuredProvider::new(vec![ChatResult {
4574 output_text: "ok".to_string(),
4575 stop_reason: Some(StopReason::EndTurn),
4576 ..Default::default()
4577 }]);
4578 let received = provider.received_messages.clone();
4579 let memory = TestMemory::default();
4580
4581 let agent = Agent::new(
4582 AgentConfig {
4583 model_supports_tool_use: true,
4584 system_prompt: None,
4585 ..AgentConfig::default()
4586 },
4587 Box::new(provider),
4588 Box::new(memory),
4589 vec![Box::new(StructuredEchoTool)],
4590 );
4591
4592 agent
4593 .respond(
4594 UserMessage {
4595 text: "hello".to_string(),
4596 },
4597 &test_ctx(),
4598 )
4599 .await
4600 .expect("respond should succeed");
4601
4602 let msgs = received.lock().expect("lock");
4603 let first_call = &msgs[0];
4604 match &first_call[0] {
4606 ConversationMessage::User { content } => {
4607 assert_eq!(content, "hello");
4608 }
4609 other => panic!("expected User message first (no system prompt), got {other:?}"),
4610 }
4611 }
4612
4613 #[tokio::test]
4614 async fn system_prompt_persists_across_tool_iterations() {
4615 let provider = StructuredProvider::new(vec![
4617 ChatResult {
4618 output_text: String::new(),
4619 tool_calls: vec![ToolUseRequest {
4620 id: "call_1".to_string(),
4621 name: "echo".to_string(),
4622 input: serde_json::json!({"text": "ping"}),
4623 }],
4624 stop_reason: Some(StopReason::ToolUse),
4625 },
4626 ChatResult {
4627 output_text: "pong".to_string(),
4628 stop_reason: Some(StopReason::EndTurn),
4629 ..Default::default()
4630 },
4631 ]);
4632 let received = provider.received_messages.clone();
4633 let memory = TestMemory::default();
4634
4635 let agent = Agent::new(
4636 AgentConfig {
4637 model_supports_tool_use: true,
4638 system_prompt: Some("Always be concise.".to_string()),
4639 ..AgentConfig::default()
4640 },
4641 Box::new(provider),
4642 Box::new(memory),
4643 vec![Box::new(StructuredEchoTool)],
4644 );
4645
4646 let result = agent
4647 .respond(
4648 UserMessage {
4649 text: "test".to_string(),
4650 },
4651 &test_ctx(),
4652 )
4653 .await
4654 .expect("respond should succeed");
4655 assert_eq!(result.text, "pong");
4656
4657 let msgs = received.lock().expect("lock");
4658 for (i, call_msgs) in msgs.iter().enumerate() {
4660 match &call_msgs[0] {
4661 ConversationMessage::System { content } => {
4662 assert_eq!(content, "Always be concise.", "call {i} system prompt");
4663 }
4664 other => panic!("call {i}: expected System first, got {other:?}"),
4665 }
4666 }
4667 }
4668
4669 #[test]
4670 fn agent_config_system_prompt_defaults_to_none() {
4671 let config = AgentConfig::default();
4672 assert!(config.system_prompt.is_none());
4673 }
4674
4675 #[tokio::test]
4676 async fn memory_write_propagates_source_channel_from_context() {
4677 let memory = TestMemory::default();
4678 let entries = memory.entries.clone();
4679 let provider = StructuredProvider::new(vec![ChatResult {
4680 output_text: "noted".to_string(),
4681 tool_calls: vec![],
4682 stop_reason: None,
4683 }]);
4684 let config = AgentConfig {
4685 privacy_boundary: "encrypted_only".to_string(),
4686 ..AgentConfig::default()
4687 };
4688 let agent = Agent::new(config, Box::new(provider), Box::new(memory), vec![]);
4689
4690 let mut ctx = ToolContext::new(".".to_string());
4691 ctx.source_channel = Some("telegram".to_string());
4692 ctx.privacy_boundary = "encrypted_only".to_string();
4693
4694 agent
4695 .respond(
4696 UserMessage {
4697 text: "hello".to_string(),
4698 },
4699 &ctx,
4700 )
4701 .await
4702 .expect("respond should succeed");
4703
4704 let stored = entries.lock().expect("memory lock poisoned");
4705 assert_eq!(stored[0].source_channel.as_deref(), Some("telegram"));
4707 assert_eq!(stored[0].privacy_boundary, "encrypted_only");
4708 assert_eq!(stored[1].source_channel.as_deref(), Some("telegram"));
4709 assert_eq!(stored[1].privacy_boundary, "encrypted_only");
4710 }
4711
4712 #[tokio::test]
4713 async fn memory_write_source_channel_none_when_ctx_empty() {
4714 let memory = TestMemory::default();
4715 let entries = memory.entries.clone();
4716 let provider = StructuredProvider::new(vec![ChatResult {
4717 output_text: "ok".to_string(),
4718 tool_calls: vec![],
4719 stop_reason: None,
4720 }]);
4721 let agent = Agent::new(
4722 AgentConfig::default(),
4723 Box::new(provider),
4724 Box::new(memory),
4725 vec![],
4726 );
4727
4728 agent
4729 .respond(
4730 UserMessage {
4731 text: "hi".to_string(),
4732 },
4733 &test_ctx(),
4734 )
4735 .await
4736 .expect("respond should succeed");
4737
4738 let stored = entries.lock().expect("memory lock poisoned");
4739 assert!(stored[0].source_channel.is_none());
4740 assert!(stored[1].source_channel.is_none());
4741 }
4742}