1use std::time::Instant;
4
5use futures::stream::{self, StreamExt};
6use serde_json::Value;
7
8use crate::events::AgentEvent;
9use crate::permission::{Authorization, AuthorizationResponse};
10use crate::tool::{box_tool, ToolResult};
11use crate::types::{Message, ToolResultBlock, ToolResultStatus, ToolUseBlock};
12
13use super::types::{AgentError, ToolCallInfo, ToolInfo};
14use super::Agent;
15
16#[cfg(feature = "session")]
17use crate::session::ToolCall;
18
19impl Agent {
20 pub fn add_tool<T: crate::tool::Tool + 'static>(&mut self, tool: T)
22 where
23 T::Input: serde::Serialize,
24 {
25 let tool_name = tool.name().to_string();
26
27 if self.tools.iter().any(|t| t.name() == tool_name) {
29 eprintln!(
30 "Warning: Tool '{}' is already registered. This will cause errors when calling the model.",
31 tool_name
32 );
33 eprintln!(" Consider using .with_namespace() on MCP servers to avoid conflicts.");
34 }
35
36 self.tools.push(box_tool(tool));
37 }
38
39 pub fn list_tools(&self) -> Vec<ToolInfo> {
41 self.tools
42 .iter()
43 .map(|t| ToolInfo {
44 name: t.name().to_string(),
45 description: t.description().to_string(),
46 })
47 .collect()
48 }
49
50 pub fn format_tool_input(
55 &self,
56 tool_name: &str,
57 params: &Value,
58 context: crate::presentation::Display,
59 ) -> Option<String> {
60 let tool = self.tools.iter().find(|t| t.name() == tool_name)?;
61
62 Some(match context {
63 crate::presentation::Display::Cli => tool.format_input_ansi(params),
64 })
65 }
66
67 pub fn format_tool_output(
71 &self,
72 tool_name: &str,
73 result: &crate::tool::ToolResult,
74 context: crate::presentation::Display,
75 ) -> Option<String> {
76 let tool = self.tools.iter().find(|t| t.name() == tool_name)?;
77
78 Some(match context {
79 crate::presentation::Display::Cli => tool.format_output_ansi(result),
80 })
81 }
82
83 pub(super) async fn execute_tool(
85 &self,
86 tool_use: &ToolUseBlock,
87 ) -> Result<ToolResult, AgentError> {
88 let tool_start = Instant::now();
89 let tool_id = tool_use.id.clone();
90 let tool_name = tool_use.name.clone();
91 let input = tool_use.input.clone();
92
93 self.emit_event(AgentEvent::ToolRequested {
95 tool_use_id: tool_id.clone(),
96 name: tool_name.clone(),
97 input: input.clone(),
98 });
99
100 if !input.is_object() {
102 let type_name = match &input {
103 Value::Null => "null",
104 Value::Bool(_) => "boolean",
105 Value::Number(_) => "number",
106 Value::String(_) => "string",
107 Value::Array(_) => "array",
108 Value::Object(_) => "object", };
110 let error_msg = format!("Tool input must be a JSON object, got: {}", type_name);
111 self.emit_event(AgentEvent::ToolFailed {
112 tool_use_id: tool_id,
113 name: tool_name,
114 error: error_msg.clone(),
115 duration: tool_start.elapsed(),
116 });
117 return Err(AgentError::InvalidToolInput(error_msg));
118 }
119
120 let tool = self
121 .tools
122 .iter()
123 .find(|t| t.name() == tool_use.name)
124 .ok_or_else(|| {
125 self.emit_event(AgentEvent::ToolFailed {
126 tool_use_id: tool_id.clone(),
127 name: tool_name.clone(),
128 error: format!("Tool not found: {}", tool_name),
129 duration: tool_start.elapsed(),
130 });
131 AgentError::ToolNotFound(tool_name.clone())
132 })?;
133
134 self.check_tool_approval(&tool_id, &tool_name, &input, tool_start)
136 .await?;
137
138 self.emit_event(AgentEvent::ToolExecuting {
140 tool_use_id: tool_id.clone(),
141 name: tool_name.clone(),
142 });
143
144 match tool.execute_raw(input).await {
146 Ok(result) => {
147 self.emit_event(AgentEvent::ToolCompleted {
148 tool_use_id: tool_id,
149 name: tool_name,
150 output: result.clone(),
151 duration: tool_start.elapsed(),
152 });
153 Ok(result)
154 }
155 Err(e) => {
156 let error_msg = e.to_string();
157 self.emit_event(AgentEvent::ToolFailed {
158 tool_use_id: tool_id,
159 name: tool_name,
160 error: error_msg,
161 duration: tool_start.elapsed(),
162 });
163 Err(AgentError::Tool(e))
164 }
165 }
166 }
167
168 async fn check_tool_approval(
170 &self,
171 tool_id: &str,
172 tool_name: &str,
173 input: &Value,
174 tool_start: Instant,
175 ) -> Result<(), AgentError> {
176 let authorizer = self.authorizer.read().await;
177
178 match authorizer.check(tool_name, input).await {
179 Authorization::Granted { grant } => {
180 self.emit_event(AgentEvent::PermissionGranted {
181 tool_use_id: tool_id.to_string(),
182 tool_name: tool_name.to_string(),
183 scope: Some(grant.scope),
184 });
185 Ok(())
186 }
187 Authorization::Denied { reason } => {
188 self.emit_event(AgentEvent::PermissionDenied {
189 tool_use_id: tool_id.to_string(),
190 tool_name: tool_name.to_string(),
191 reason: reason.clone(),
192 });
193 self.emit_event(AgentEvent::ToolFailed {
194 tool_use_id: tool_id.to_string(),
195 name: tool_name.to_string(),
196 error: reason,
197 duration: tool_start.elapsed(),
198 });
199 Err(AgentError::ToolDenied(tool_name.to_string()))
200 }
201 Authorization::PendingApproval { params_hash } => {
202 drop(authorizer);
204
205 self.request_authorization(tool_id, tool_name, input, params_hash, tool_start)
206 .await
207 }
208 }
209 }
210
211 async fn request_authorization(
213 &self,
214 tool_id: &str,
215 tool_name: &str,
216 input: &Value,
217 params_hash: String,
218 tool_start: Instant,
219 ) -> Result<(), AgentError> {
220 let (tx, mut rx) = tokio::sync::mpsc::channel::<AuthorizationResponse>(1);
221
222 let proposal_id = tool_id.to_string();
224
225 {
227 let mut pending = self.pending_authorizations.write().await;
228 pending.insert(proposal_id.clone(), tx);
229 }
230
231 self.emit_event(AgentEvent::PermissionRequired {
233 proposal_id: proposal_id.clone(),
234 tool_name: tool_name.to_string(),
235 params: input.clone(),
236 params_hash: params_hash.clone(),
237 });
238
239 let response = match tokio::time::timeout(self.authorization_timeout, rx.recv()).await {
241 Ok(Some(response)) => response,
242 Ok(None) => AuthorizationResponse::Deny {
243 reason: Some("Channel closed".to_string()),
244 },
245 Err(_) => {
246 self.emit_event(AgentEvent::PermissionDenied {
247 tool_use_id: tool_id.to_string(),
248 tool_name: tool_name.to_string(),
249 reason: "Authorization request timed out".to_string(),
250 });
251 AuthorizationResponse::Deny {
252 reason: Some("Timeout".to_string()),
253 }
254 }
255 };
256
257 {
259 let mut pending = self.pending_authorizations.write().await;
260 pending.remove(&proposal_id);
261 }
262
263 match response {
264 AuthorizationResponse::Once => {
265 self.emit_event(AgentEvent::PermissionGranted {
266 tool_use_id: tool_id.to_string(),
267 tool_name: tool_name.to_string(),
268 scope: None,
269 });
270 Ok(())
271 }
272 AuthorizationResponse::Trust { grant } => {
273 let authorizer = self.authorizer.read().await;
275 let result = if grant.is_tool_wide() {
276 authorizer.grant_tool(&grant.tool).await
277 } else if let Some(ref hash) = grant.params_hash {
278 authorizer.grant_params_hash(&grant.tool, hash).await
279 } else {
280 authorizer.grant_tool(&grant.tool).await
281 };
282 if let Err(e) = result {
283 eprintln!("Warning: Failed to save grant: {}", e);
284 }
285 self.emit_event(AgentEvent::PermissionGranted {
286 tool_use_id: tool_id.to_string(),
287 tool_name: tool_name.to_string(),
288 scope: Some(grant.scope),
289 });
290 Ok(())
291 }
292 AuthorizationResponse::Deny { reason } => {
293 let reason_str =
294 reason.unwrap_or_else(|| "Authorization denied by user".to_string());
295 self.emit_event(AgentEvent::PermissionDenied {
296 tool_use_id: tool_id.to_string(),
297 tool_name: tool_name.to_string(),
298 reason: reason_str,
299 });
300 self.emit_event(AgentEvent::ToolFailed {
301 tool_use_id: tool_id.to_string(),
302 name: tool_name.to_string(),
303 error: "Tool execution denied by user".to_string(),
304 duration: tool_start.elapsed(),
305 });
306 Err(AgentError::ToolDenied(tool_name.to_string()))
307 }
308 }
309 }
310
311 pub(super) async fn process_tool_calls(
316 &self,
317 message: &Message,
318 tool_call_infos: &mut Vec<ToolCallInfo>,
319 #[cfg(feature = "session")] session_tool_calls: &mut Vec<ToolCall>,
320 #[cfg(feature = "session")] session_tool_results: &mut Vec<crate::session::ToolResult>,
321 ) -> Vec<ToolResultBlock> {
322 let tool_uses = message.tool_uses();
323 let tool_use_blocks: Vec<_> = tool_uses.into_iter().cloned().collect();
324
325 let futures: Vec<_> = tool_use_blocks
327 .iter()
328 .map(|tool_use| {
329 let tool_use = tool_use.clone();
330 async move {
331 let start = Instant::now();
332 let result = self.execute_tool(&tool_use).await;
333 let duration = start.elapsed();
334 (tool_use, result, duration)
335 }
336 })
337 .collect();
338
339 let results: Vec<_> = stream::iter(futures)
340 .buffer_unordered(self.max_concurrent_tools)
341 .collect()
342 .await;
343
344 results
345 .into_iter()
346 .map(|(tool_use, result, duration)| {
347 #[cfg(feature = "session")]
349 {
350 session_tool_calls.push(ToolCall {
351 id: tool_use.id.clone(),
352 name: tool_use.name.clone(),
353 input: tool_use.input.to_string(),
354 });
355 }
356
357 match result {
358 Ok(ref tool_result) => {
359 tool_call_infos.push(ToolCallInfo {
361 name: tool_use.name.clone(),
362 input: tool_use.input.clone(),
363 output: tool_result.as_text(),
364 success: true,
365 duration,
366 });
367
368 #[cfg(feature = "session")]
370 {
371 session_tool_results.push(crate::session::ToolResult {
372 tool_use_id: tool_use.id.clone(),
373 success: true,
374 content: tool_result.as_text(),
375 });
376 }
377
378 ToolResultBlock {
379 tool_use_id: tool_use.id,
380 content: tool_result.clone(),
381 status: ToolResultStatus::Success,
382 }
383 }
384 Err(ref e) => {
385 let error_msg = format!("Error: {}", e);
386
387 tool_call_infos.push(ToolCallInfo {
389 name: tool_use.name.clone(),
390 input: tool_use.input.clone(),
391 output: error_msg.clone(),
392 success: false,
393 duration,
394 });
395
396 #[cfg(feature = "session")]
398 {
399 session_tool_results.push(crate::session::ToolResult {
400 tool_use_id: tool_use.id.clone(),
401 success: false,
402 content: error_msg.clone(),
403 });
404 }
405
406 ToolResultBlock {
407 tool_use_id: tool_use.id,
408 content: ToolResult::Text(error_msg),
409 status: ToolResultStatus::Error,
410 }
411 }
412 }
413 })
414 .collect()
415 }
416}
417
418#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::provider::{ModelProvider, ProviderError};
426 use crate::tool::{Tool, ToolError, ToolResult as MxToolResult};
427 use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition, ToolUseBlock};
428 use crate::{Agent, ModelResponse};
429 use schemars::JsonSchema;
430 use serde::{Deserialize, Serialize};
431 use std::sync::Arc;
432
433 #[derive(Clone)]
435 struct MockProvider {
436 responses: Arc<parking_lot::Mutex<Vec<ModelResponse>>>,
437 }
438
439 impl MockProvider {
440 fn new() -> Self {
441 Self {
442 responses: Arc::new(parking_lot::Mutex::new(Vec::new())),
443 }
444 }
445
446 fn with_text(self, text: impl Into<String>) -> Self {
447 let message = Message {
448 role: Role::Assistant,
449 content: vec![ContentBlock::Text(text.into())],
450 };
451 let response = ModelResponse {
452 message,
453 stop_reason: StopReason::EndTurn,
454 usage: None,
455 };
456 self.responses.lock().push(response);
457 self
458 }
459 }
460
461 #[async_trait::async_trait]
462 impl ModelProvider for MockProvider {
463 fn name(&self) -> &str {
464 "MockProvider"
465 }
466
467 fn max_context_tokens(&self) -> usize {
468 200_000
469 }
470
471 fn max_output_tokens(&self) -> usize {
472 8_192
473 }
474
475 async fn generate(
476 &self,
477 _messages: Vec<Message>,
478 _tools: Vec<ToolDefinition>,
479 _system_prompt: Option<String>,
480 ) -> Result<ModelResponse, ProviderError> {
481 let mut responses = self.responses.lock();
482 if responses.is_empty() {
483 return Err(ProviderError::Other("No more responses".to_string()));
484 }
485 Ok(responses.remove(0))
486 }
487 }
488
489 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
491 struct EchoInput {
492 message: String,
493 }
494
495 struct EchoTool;
497
498 impl Tool for EchoTool {
499 type Input = EchoInput;
500
501 fn name(&self) -> &str {
502 "echo"
503 }
504
505 fn description(&self) -> &str {
506 "Echoes the input back"
507 }
508
509 async fn execute(&self, input: Self::Input) -> Result<MxToolResult, ToolError> {
510 Ok(MxToolResult::text(input.message))
511 }
512 }
513
514 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
516 struct AddInput {
517 a: f64,
518 b: f64,
519 }
520
521 struct AddTool;
523
524 impl Tool for AddTool {
525 type Input = AddInput;
526
527 fn name(&self) -> &str {
528 "add"
529 }
530
531 fn description(&self) -> &str {
532 "Adds two numbers"
533 }
534
535 async fn execute(&self, input: Self::Input) -> Result<MxToolResult, ToolError> {
536 Ok(MxToolResult::text(format!("{}", input.a + input.b)))
537 }
538 }
539
540 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
542 struct EmptyInput {}
543
544 struct FailingTool;
546
547 impl Tool for FailingTool {
548 type Input = EmptyInput;
549
550 fn name(&self) -> &str {
551 "failing_tool"
552 }
553
554 fn description(&self) -> &str {
555 "A tool that always fails"
556 }
557
558 async fn execute(&self, _input: Self::Input) -> Result<MxToolResult, ToolError> {
559 Err(ToolError::Custom("Tool execution failed".to_string()))
560 }
561 }
562
563 #[tokio::test]
566 async fn test_add_tool() {
567 let provider = MockProvider::new().with_text("ok");
568 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
569
570 assert_eq!(agent.list_tools().len(), 0);
572
573 agent.add_tool(EchoTool);
575
576 let tools = agent.list_tools();
578 assert_eq!(tools.len(), 1);
579 assert_eq!(tools[0].name, "echo");
580 assert_eq!(tools[0].description, "Echoes the input back");
581 }
582
583 #[tokio::test]
584 async fn test_add_multiple_tools() {
585 let provider = MockProvider::new().with_text("ok");
586 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
587
588 agent.add_tool(EchoTool);
589 agent.add_tool(AddTool);
590
591 let tools = agent.list_tools();
592 assert_eq!(tools.len(), 2);
593
594 let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
595 assert!(names.contains(&"echo"));
596 assert!(names.contains(&"add"));
597 }
598
599 #[tokio::test]
600 async fn test_add_tool_with_builder() {
601 let provider = MockProvider::new().with_text("ok");
602 let agent = Agent::builder()
603 .provider(provider)
604 .add_tool(EchoTool)
605 .add_tool(AddTool)
606 .build()
607 .await
608 .unwrap();
609
610 let tools = agent.list_tools();
611 assert_eq!(tools.len(), 2);
612 }
613
614 #[tokio::test]
617 async fn test_list_tools_empty() {
618 let provider = MockProvider::new().with_text("ok");
619 let agent = Agent::builder().provider(provider).build().await.unwrap();
620
621 let tools = agent.list_tools();
622 assert!(tools.is_empty());
623 }
624
625 #[tokio::test]
626 async fn test_list_tools_preserves_order() {
627 let provider = MockProvider::new().with_text("ok");
628 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
629
630 agent.add_tool(EchoTool);
631 agent.add_tool(AddTool);
632 agent.add_tool(FailingTool);
633
634 let tools = agent.list_tools();
635 assert_eq!(tools[0].name, "echo");
636 assert_eq!(tools[1].name, "add");
637 assert_eq!(tools[2].name, "failing_tool");
638 }
639
640 #[tokio::test]
643 async fn test_execute_tool_success() {
644 let provider = MockProvider::new().with_text("ok");
645 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
646
647 agent.add_tool(EchoTool);
648
649 agent
651 .authorizer()
652 .write()
653 .await
654 .grant_tool("echo")
655 .await
656 .unwrap();
657
658 let tool_use = ToolUseBlock {
659 id: "tool_123".to_string(),
660 name: "echo".to_string(),
661 input: serde_json::json!({"message": "Hello, world!"}),
662 };
663
664 let result = agent.execute_tool(&tool_use).await;
665 assert!(result.is_ok());
666 assert_eq!(result.unwrap().as_text(), "Hello, world!");
667 }
668
669 #[tokio::test]
670 async fn test_execute_tool_not_found() {
671 let provider = MockProvider::new().with_text("ok");
672 let agent = Agent::builder().provider(provider).build().await.unwrap();
673
674 let tool_use = ToolUseBlock {
675 id: "tool_123".to_string(),
676 name: "nonexistent_tool".to_string(),
677 input: serde_json::json!({}),
678 };
679
680 let result = agent.execute_tool(&tool_use).await;
681 assert!(result.is_err());
682 assert!(matches!(result.unwrap_err(), AgentError::ToolNotFound(_)));
683 }
684
685 #[tokio::test]
686 async fn test_execute_tool_invalid_input_not_object() {
687 let provider = MockProvider::new().with_text("ok");
688 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
689
690 agent.add_tool(EchoTool);
691
692 let tool_use = ToolUseBlock {
694 id: "tool_123".to_string(),
695 name: "echo".to_string(),
696 input: serde_json::json!("not an object"),
697 };
698
699 let result = agent.execute_tool(&tool_use).await;
700 assert!(result.is_err());
701 assert!(matches!(
702 result.unwrap_err(),
703 AgentError::InvalidToolInput(_)
704 ));
705 }
706
707 #[tokio::test]
708 async fn test_execute_tool_invalid_input_array() {
709 let provider = MockProvider::new().with_text("ok");
710 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
711
712 agent.add_tool(EchoTool);
713
714 let tool_use = ToolUseBlock {
715 id: "tool_123".to_string(),
716 name: "echo".to_string(),
717 input: serde_json::json!([1, 2, 3]),
718 };
719
720 let result = agent.execute_tool(&tool_use).await;
721 assert!(result.is_err());
722 let err = result.unwrap_err();
723 if let AgentError::InvalidToolInput(msg) = &err {
724 assert!(msg.contains("array"));
725 }
726 }
727
728 #[tokio::test]
729 async fn test_execute_tool_invalid_input_null() {
730 let provider = MockProvider::new().with_text("ok");
731 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
732
733 agent.add_tool(EchoTool);
734
735 let tool_use = ToolUseBlock {
736 id: "tool_123".to_string(),
737 name: "echo".to_string(),
738 input: serde_json::Value::Null,
739 };
740
741 let result = agent.execute_tool(&tool_use).await;
742 assert!(result.is_err());
743 let err = result.unwrap_err();
744 if let AgentError::InvalidToolInput(msg) = &err {
745 assert!(msg.contains("null"));
746 }
747 }
748
749 #[tokio::test]
750 async fn test_execute_tool_execution_failure() {
751 let provider = MockProvider::new().with_text("ok");
752 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
753
754 agent.add_tool(FailingTool);
755
756 agent
758 .authorizer()
759 .write()
760 .await
761 .grant_tool("failing_tool")
762 .await
763 .unwrap();
764
765 let tool_use = ToolUseBlock {
766 id: "tool_123".to_string(),
767 name: "failing_tool".to_string(),
768 input: serde_json::json!({}),
769 };
770
771 let result = agent.execute_tool(&tool_use).await;
772 assert!(result.is_err());
773 assert!(matches!(result.unwrap_err(), AgentError::Tool(_)));
774 }
775
776 #[tokio::test]
779 async fn test_format_tool_input_existing_tool() {
780 let provider = MockProvider::new().with_text("ok");
781 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
782
783 agent.add_tool(EchoTool);
784
785 let params = serde_json::json!({"message": "test"});
786 let formatted = agent.format_tool_input("echo", ¶ms, crate::presentation::Display::Cli);
787
788 assert!(formatted.is_some());
790 }
791
792 #[tokio::test]
793 async fn test_format_tool_input_nonexistent_tool() {
794 let provider = MockProvider::new().with_text("ok");
795 let agent = Agent::builder().provider(provider).build().await.unwrap();
796
797 let params = serde_json::json!({"message": "test"});
798 let formatted =
799 agent.format_tool_input("nonexistent", ¶ms, crate::presentation::Display::Cli);
800
801 assert!(formatted.is_none());
802 }
803
804 #[tokio::test]
805 async fn test_format_tool_output_existing_tool() {
806 let provider = MockProvider::new().with_text("ok");
807 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
808
809 agent.add_tool(EchoTool);
810
811 let result = crate::tool::ToolResult::text("output");
812 let formatted =
813 agent.format_tool_output("echo", &result, crate::presentation::Display::Cli);
814
815 assert!(formatted.is_some());
816 }
817
818 #[tokio::test]
819 async fn test_format_tool_output_nonexistent_tool() {
820 let provider = MockProvider::new().with_text("ok");
821 let agent = Agent::builder().provider(provider).build().await.unwrap();
822
823 let result = crate::tool::ToolResult::text("output");
824 let formatted =
825 agent.format_tool_output("nonexistent", &result, crate::presentation::Display::Cli);
826
827 assert!(formatted.is_none());
828 }
829}