1use std::time::Instant;
4
5use futures::stream::{self, StreamExt};
6use serde_json::Value;
7
8use crate::events::{AgentEvent, ToolApprovalStatus};
9use crate::permission::{Authorization, AuthorizationResponse};
10use crate::tool::{box_tool, ToolResult};
11use crate::types::{Message, ToolResultBlock, ToolResultStatus, ToolUseBlock};
12
13use super::types::{AgentError, ToolCallInfo, ToolInfo};
14use super::Agent;
15
16#[cfg(feature = "session")]
17use crate::session::ToolCall;
18
19impl Agent {
20 pub fn add_tool<T: crate::tool::Tool + 'static>(&mut self, tool: T)
22 where
23 T::Input: serde::Serialize,
24 {
25 let tool_name = tool.name().to_string();
26
27 if self.tools.iter().any(|t| t.name() == tool_name) {
29 eprintln!(
30 "Warning: Tool '{}' is already registered. This will cause errors when calling the model.",
31 tool_name
32 );
33 eprintln!(" Consider using .with_namespace() on MCP servers to avoid conflicts.");
34 }
35
36 self.tools.push(box_tool(tool));
37 }
38
39 pub fn list_tools(&self) -> Vec<ToolInfo> {
41 self.tools
42 .iter()
43 .map(|t| ToolInfo {
44 name: t.name().to_string(),
45 description: t.description().to_string(),
46 })
47 .collect()
48 }
49
50 pub fn format_tool_input(
55 &self,
56 tool_name: &str,
57 params: &Value,
58 context: crate::presentation::Display,
59 ) -> Option<String> {
60 let tool = self.tools.iter().find(|t| t.name() == tool_name)?;
61
62 Some(match context {
63 crate::presentation::Display::Cli => tool.format_input_ansi(params),
64 })
65 }
66
67 pub fn format_tool_output(
71 &self,
72 tool_name: &str,
73 result: &crate::tool::ToolResult,
74 context: crate::presentation::Display,
75 ) -> Option<String> {
76 let tool = self.tools.iter().find(|t| t.name() == tool_name)?;
77
78 Some(match context {
79 crate::presentation::Display::Cli => tool.format_output_ansi(result),
80 })
81 }
82
83 pub(super) async fn execute_tool(
85 &self,
86 tool_use: &ToolUseBlock,
87 ) -> Result<ToolResult, AgentError> {
88 let tool_start = Instant::now();
89 let tool_id = tool_use.id.clone();
90 let tool_name = tool_use.name.clone();
91 let input = tool_use.input.clone();
92
93 if !input.is_object() {
95 let type_name = match &input {
96 Value::Null => "null",
97 Value::Bool(_) => "boolean",
98 Value::Number(_) => "number",
99 Value::String(_) => "string",
100 Value::Array(_) => "array",
101 Value::Object(_) => "object", };
103 let error_msg = format!("Tool input must be a JSON object, got: {}", type_name);
104 self.emit_event(AgentEvent::ToolFailed {
105 id: tool_id,
106 name: tool_name,
107 error: error_msg.clone(),
108 duration: tool_start.elapsed(),
109 });
110 return Err(AgentError::InvalidToolInput(error_msg));
111 }
112
113 let tool = self
114 .tools
115 .iter()
116 .find(|t| t.name() == tool_use.name)
117 .ok_or_else(|| {
118 self.emit_event(AgentEvent::ToolFailed {
120 id: tool_id.clone(),
121 name: tool_name.clone(),
122 error: format!("Tool not found: {}", tool_name),
123 duration: tool_start.elapsed(),
124 });
125 AgentError::ToolNotFound(tool_name.clone())
126 })?;
127
128 let approval_status = self
130 .check_tool_approval(&tool_id, &tool_name, &input, tool.as_ref(), tool_start)
131 .await?;
132
133 self.emit_event(AgentEvent::ToolStarted {
135 id: tool_id.clone(),
136 name: tool_name.clone(),
137 input: input.clone(),
138 approval_status,
139 timestamp: tool_start,
140 });
141
142 match tool.execute_raw(input).await {
144 Ok(result) => {
145 self.emit_event(AgentEvent::ToolCompleted {
147 id: tool_id,
148 name: tool_name,
149 output: result.clone(),
150 approval_status,
151 duration: tool_start.elapsed(),
152 });
153 Ok(result)
154 }
155 Err(e) => {
156 let error_msg = e.to_string();
158 self.emit_event(AgentEvent::ToolFailed {
159 id: tool_id,
160 name: tool_name,
161 error: error_msg,
162 duration: tool_start.elapsed(),
163 });
164 Err(AgentError::Tool(e))
165 }
166 }
167 }
168
169 async fn check_tool_approval(
171 &self,
172 tool_id: &str,
173 tool_name: &str,
174 input: &Value,
175 _tool: &dyn crate::tool::DynTool,
176 tool_start: Instant,
177 ) -> Result<ToolApprovalStatus, AgentError> {
178 let authorizer = self.authorizer.read().await;
179
180 match authorizer.check(tool_name, input).await {
181 Authorization::Granted { grant } => {
182 self.emit_event(AgentEvent::PermissionGranted {
184 proposal_id: format!("{}_{}", tool_name, tool_id),
185 scope: Some(grant.scope),
186 });
187 Ok(ToolApprovalStatus::AutoApproved)
188 }
189 Authorization::Denied { reason } => {
190 let proposal_id = format!("{}_{}", tool_name, tool_id);
192 self.emit_event(AgentEvent::PermissionDenied {
193 proposal_id,
194 reason: reason.clone(),
195 });
196 self.emit_event(AgentEvent::ToolFailed {
197 id: tool_id.to_string(),
198 name: tool_name.to_string(),
199 error: reason,
200 duration: tool_start.elapsed(),
201 });
202 Err(AgentError::ToolDenied(tool_name.to_string()))
203 }
204 Authorization::PendingApproval { params_hash } => {
205 drop(authorizer);
207
208 let proposal_id = format!("{}_{}", tool_name, tool_id);
210 self.request_authorization(
211 proposal_id,
212 tool_id,
213 tool_name,
214 input,
215 params_hash,
216 tool_start,
217 )
218 .await
219 }
220 }
221 }
222
223 async fn request_authorization(
225 &self,
226 proposal_id: String,
227 tool_id: &str,
228 tool_name: &str,
229 input: &Value,
230 params_hash: String,
231 tool_start: Instant,
232 ) -> Result<ToolApprovalStatus, AgentError> {
233 let (tx, mut rx) = tokio::sync::mpsc::channel::<AuthorizationResponse>(1);
235
236 {
238 let mut pending = self.pending_authorizations.write().await;
239 pending.insert(proposal_id.clone(), tx);
240 }
241
242 self.emit_event(AgentEvent::PermissionRequired {
244 proposal_id: proposal_id.clone(),
245 tool_name: tool_name.to_string(),
246 params: input.clone(),
247 params_hash: params_hash.clone(),
248 });
249
250 let response = match tokio::time::timeout(self.authorization_timeout, rx.recv()).await {
252 Ok(Some(response)) => response,
253 Ok(None) => AuthorizationResponse::Deny {
254 reason: Some("Channel closed".to_string()),
255 },
256 Err(_) => {
257 self.emit_event(AgentEvent::PermissionDenied {
258 proposal_id: proposal_id.clone(),
259 reason: "Authorization request timed out".to_string(),
260 });
261 AuthorizationResponse::Deny {
262 reason: Some("Timeout".to_string()),
263 }
264 }
265 };
266
267 {
269 let mut pending = self.pending_authorizations.write().await;
270 pending.remove(&proposal_id);
271 }
272
273 match response {
274 AuthorizationResponse::Once => {
275 self.emit_event(AgentEvent::PermissionGranted {
276 proposal_id,
277 scope: None,
278 });
279 Ok(ToolApprovalStatus::UserApproved)
280 }
281 AuthorizationResponse::Trust { grant } => {
282 let authorizer = self.authorizer.read().await;
284 let result = if grant.is_tool_wide() {
285 authorizer.grant_tool(&grant.tool).await
286 } else if let Some(ref hash) = grant.params_hash {
287 authorizer.grant_params_hash(&grant.tool, hash).await
288 } else {
289 authorizer.grant_tool(&grant.tool).await
290 };
291 if let Err(e) = result {
292 eprintln!("Warning: Failed to save grant: {}", e);
293 }
294 self.emit_event(AgentEvent::PermissionGranted {
295 proposal_id,
296 scope: Some(grant.scope),
297 });
298 Ok(ToolApprovalStatus::UserApproved)
299 }
300 AuthorizationResponse::Deny { reason } => {
301 let reason_str =
302 reason.unwrap_or_else(|| "Authorization denied by user".to_string());
303 self.emit_event(AgentEvent::PermissionDenied {
304 proposal_id,
305 reason: reason_str,
306 });
307 self.emit_event(AgentEvent::ToolFailed {
308 id: tool_id.to_string(),
309 name: tool_name.to_string(),
310 error: "Tool execution denied by user".to_string(),
311 duration: tool_start.elapsed(),
312 });
313 Err(AgentError::ToolDenied(tool_name.to_string()))
314 }
315 }
316 }
317
318 pub(super) async fn process_tool_calls(
323 &self,
324 message: &Message,
325 tool_call_infos: &mut Vec<ToolCallInfo>,
326 #[cfg(feature = "session")] session_tool_calls: &mut Vec<ToolCall>,
327 #[cfg(feature = "session")] session_tool_results: &mut Vec<crate::session::ToolResult>,
328 ) -> Vec<ToolResultBlock> {
329 let tool_uses = message.tool_uses();
330 let tool_use_blocks: Vec<_> = tool_uses.into_iter().cloned().collect();
331
332 let futures: Vec<_> = tool_use_blocks
334 .iter()
335 .map(|tool_use| {
336 let tool_use = tool_use.clone();
337 async move {
338 let start = Instant::now();
339 let result = self.execute_tool(&tool_use).await;
340 let duration = start.elapsed();
341 (tool_use, result, duration)
342 }
343 })
344 .collect();
345
346 let results: Vec<_> = stream::iter(futures)
347 .buffer_unordered(self.max_concurrent_tools)
348 .collect()
349 .await;
350
351 results
352 .into_iter()
353 .map(|(tool_use, result, duration)| {
354 #[cfg(feature = "session")]
356 {
357 session_tool_calls.push(ToolCall {
358 id: tool_use.id.clone(),
359 name: tool_use.name.clone(),
360 input: tool_use.input.to_string(),
361 });
362 }
363
364 match result {
365 Ok(ref tool_result) => {
366 tool_call_infos.push(ToolCallInfo {
368 name: tool_use.name.clone(),
369 input: tool_use.input.clone(),
370 output: tool_result.as_text(),
371 success: true,
372 duration,
373 });
374
375 #[cfg(feature = "session")]
377 {
378 session_tool_results.push(crate::session::ToolResult {
379 tool_use_id: tool_use.id.clone(),
380 success: true,
381 content: tool_result.as_text(),
382 });
383 }
384
385 ToolResultBlock {
386 tool_use_id: tool_use.id,
387 content: tool_result.clone(),
388 status: ToolResultStatus::Success,
389 }
390 }
391 Err(ref e) => {
392 let error_msg = format!("Error: {}", e);
393
394 tool_call_infos.push(ToolCallInfo {
396 name: tool_use.name.clone(),
397 input: tool_use.input.clone(),
398 output: error_msg.clone(),
399 success: false,
400 duration,
401 });
402
403 #[cfg(feature = "session")]
405 {
406 session_tool_results.push(crate::session::ToolResult {
407 tool_use_id: tool_use.id.clone(),
408 success: false,
409 content: error_msg.clone(),
410 });
411 }
412
413 ToolResultBlock {
414 tool_use_id: tool_use.id,
415 content: ToolResult::Text(error_msg),
416 status: ToolResultStatus::Error,
417 }
418 }
419 }
420 })
421 .collect()
422 }
423}
424
425#[cfg(test)]
430mod tests {
431 use super::*;
432 use crate::provider::{ModelProvider, ProviderError};
433 use crate::tool::{Tool, ToolError, ToolResult as MxToolResult};
434 use crate::types::{ContentBlock, Message, Role, StopReason, ToolDefinition, ToolUseBlock};
435 use crate::{Agent, ModelResponse};
436 use schemars::JsonSchema;
437 use serde::{Deserialize, Serialize};
438 use std::sync::Arc;
439
440 #[derive(Clone)]
442 struct MockProvider {
443 responses: Arc<parking_lot::Mutex<Vec<ModelResponse>>>,
444 }
445
446 impl MockProvider {
447 fn new() -> Self {
448 Self {
449 responses: Arc::new(parking_lot::Mutex::new(Vec::new())),
450 }
451 }
452
453 fn with_text(self, text: impl Into<String>) -> Self {
454 let message = Message {
455 role: Role::Assistant,
456 content: vec![ContentBlock::Text(text.into())],
457 };
458 let response = ModelResponse {
459 message,
460 stop_reason: StopReason::EndTurn,
461 usage: None,
462 };
463 self.responses.lock().push(response);
464 self
465 }
466 }
467
468 #[async_trait::async_trait]
469 impl ModelProvider for MockProvider {
470 fn name(&self) -> &str {
471 "MockProvider"
472 }
473
474 fn max_context_tokens(&self) -> usize {
475 200_000
476 }
477
478 fn max_output_tokens(&self) -> usize {
479 8_192
480 }
481
482 async fn generate(
483 &self,
484 _messages: Vec<Message>,
485 _tools: Vec<ToolDefinition>,
486 _system_prompt: Option<String>,
487 ) -> Result<ModelResponse, ProviderError> {
488 let mut responses = self.responses.lock();
489 if responses.is_empty() {
490 return Err(ProviderError::Other("No more responses".to_string()));
491 }
492 Ok(responses.remove(0))
493 }
494 }
495
496 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
498 struct EchoInput {
499 message: String,
500 }
501
502 struct EchoTool;
504
505 impl Tool for EchoTool {
506 type Input = EchoInput;
507
508 fn name(&self) -> &str {
509 "echo"
510 }
511
512 fn description(&self) -> &str {
513 "Echoes the input back"
514 }
515
516 async fn execute(&self, input: Self::Input) -> Result<MxToolResult, ToolError> {
517 Ok(MxToolResult::text(input.message))
518 }
519 }
520
521 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
523 struct AddInput {
524 a: f64,
525 b: f64,
526 }
527
528 struct AddTool;
530
531 impl Tool for AddTool {
532 type Input = AddInput;
533
534 fn name(&self) -> &str {
535 "add"
536 }
537
538 fn description(&self) -> &str {
539 "Adds two numbers"
540 }
541
542 async fn execute(&self, input: Self::Input) -> Result<MxToolResult, ToolError> {
543 Ok(MxToolResult::text(format!("{}", input.a + input.b)))
544 }
545 }
546
547 #[derive(Debug, Deserialize, Serialize, JsonSchema)]
549 struct EmptyInput {}
550
551 struct FailingTool;
553
554 impl Tool for FailingTool {
555 type Input = EmptyInput;
556
557 fn name(&self) -> &str {
558 "failing_tool"
559 }
560
561 fn description(&self) -> &str {
562 "A tool that always fails"
563 }
564
565 async fn execute(&self, _input: Self::Input) -> Result<MxToolResult, ToolError> {
566 Err(ToolError::Custom("Tool execution failed".to_string()))
567 }
568 }
569
570 #[tokio::test]
573 async fn test_add_tool() {
574 let provider = MockProvider::new().with_text("ok");
575 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
576
577 assert_eq!(agent.list_tools().len(), 0);
579
580 agent.add_tool(EchoTool);
582
583 let tools = agent.list_tools();
585 assert_eq!(tools.len(), 1);
586 assert_eq!(tools[0].name, "echo");
587 assert_eq!(tools[0].description, "Echoes the input back");
588 }
589
590 #[tokio::test]
591 async fn test_add_multiple_tools() {
592 let provider = MockProvider::new().with_text("ok");
593 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
594
595 agent.add_tool(EchoTool);
596 agent.add_tool(AddTool);
597
598 let tools = agent.list_tools();
599 assert_eq!(tools.len(), 2);
600
601 let names: Vec<&str> = tools.iter().map(|t| t.name.as_str()).collect();
602 assert!(names.contains(&"echo"));
603 assert!(names.contains(&"add"));
604 }
605
606 #[tokio::test]
607 async fn test_add_tool_with_builder() {
608 let provider = MockProvider::new().with_text("ok");
609 let agent = Agent::builder()
610 .provider(provider)
611 .add_tool(EchoTool)
612 .add_tool(AddTool)
613 .build()
614 .await
615 .unwrap();
616
617 let tools = agent.list_tools();
618 assert_eq!(tools.len(), 2);
619 }
620
621 #[tokio::test]
624 async fn test_list_tools_empty() {
625 let provider = MockProvider::new().with_text("ok");
626 let agent = Agent::builder().provider(provider).build().await.unwrap();
627
628 let tools = agent.list_tools();
629 assert!(tools.is_empty());
630 }
631
632 #[tokio::test]
633 async fn test_list_tools_preserves_order() {
634 let provider = MockProvider::new().with_text("ok");
635 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
636
637 agent.add_tool(EchoTool);
638 agent.add_tool(AddTool);
639 agent.add_tool(FailingTool);
640
641 let tools = agent.list_tools();
642 assert_eq!(tools[0].name, "echo");
643 assert_eq!(tools[1].name, "add");
644 assert_eq!(tools[2].name, "failing_tool");
645 }
646
647 #[tokio::test]
650 async fn test_execute_tool_success() {
651 let provider = MockProvider::new().with_text("ok");
652 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
653
654 agent.add_tool(EchoTool);
655
656 agent
658 .authorizer()
659 .write()
660 .await
661 .grant_tool("echo")
662 .await
663 .unwrap();
664
665 let tool_use = ToolUseBlock {
666 id: "tool_123".to_string(),
667 name: "echo".to_string(),
668 input: serde_json::json!({"message": "Hello, world!"}),
669 };
670
671 let result = agent.execute_tool(&tool_use).await;
672 assert!(result.is_ok());
673 assert_eq!(result.unwrap().as_text(), "Hello, world!");
674 }
675
676 #[tokio::test]
677 async fn test_execute_tool_not_found() {
678 let provider = MockProvider::new().with_text("ok");
679 let agent = Agent::builder().provider(provider).build().await.unwrap();
680
681 let tool_use = ToolUseBlock {
682 id: "tool_123".to_string(),
683 name: "nonexistent_tool".to_string(),
684 input: serde_json::json!({}),
685 };
686
687 let result = agent.execute_tool(&tool_use).await;
688 assert!(result.is_err());
689 assert!(matches!(result.unwrap_err(), AgentError::ToolNotFound(_)));
690 }
691
692 #[tokio::test]
693 async fn test_execute_tool_invalid_input_not_object() {
694 let provider = MockProvider::new().with_text("ok");
695 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
696
697 agent.add_tool(EchoTool);
698
699 let tool_use = ToolUseBlock {
701 id: "tool_123".to_string(),
702 name: "echo".to_string(),
703 input: serde_json::json!("not an object"),
704 };
705
706 let result = agent.execute_tool(&tool_use).await;
707 assert!(result.is_err());
708 assert!(matches!(
709 result.unwrap_err(),
710 AgentError::InvalidToolInput(_)
711 ));
712 }
713
714 #[tokio::test]
715 async fn test_execute_tool_invalid_input_array() {
716 let provider = MockProvider::new().with_text("ok");
717 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
718
719 agent.add_tool(EchoTool);
720
721 let tool_use = ToolUseBlock {
722 id: "tool_123".to_string(),
723 name: "echo".to_string(),
724 input: serde_json::json!([1, 2, 3]),
725 };
726
727 let result = agent.execute_tool(&tool_use).await;
728 assert!(result.is_err());
729 let err = result.unwrap_err();
730 if let AgentError::InvalidToolInput(msg) = &err {
731 assert!(msg.contains("array"));
732 }
733 }
734
735 #[tokio::test]
736 async fn test_execute_tool_invalid_input_null() {
737 let provider = MockProvider::new().with_text("ok");
738 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
739
740 agent.add_tool(EchoTool);
741
742 let tool_use = ToolUseBlock {
743 id: "tool_123".to_string(),
744 name: "echo".to_string(),
745 input: serde_json::Value::Null,
746 };
747
748 let result = agent.execute_tool(&tool_use).await;
749 assert!(result.is_err());
750 let err = result.unwrap_err();
751 if let AgentError::InvalidToolInput(msg) = &err {
752 assert!(msg.contains("null"));
753 }
754 }
755
756 #[tokio::test]
757 async fn test_execute_tool_execution_failure() {
758 let provider = MockProvider::new().with_text("ok");
759 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
760
761 agent.add_tool(FailingTool);
762
763 agent
765 .authorizer()
766 .write()
767 .await
768 .grant_tool("failing_tool")
769 .await
770 .unwrap();
771
772 let tool_use = ToolUseBlock {
773 id: "tool_123".to_string(),
774 name: "failing_tool".to_string(),
775 input: serde_json::json!({}),
776 };
777
778 let result = agent.execute_tool(&tool_use).await;
779 assert!(result.is_err());
780 assert!(matches!(result.unwrap_err(), AgentError::Tool(_)));
781 }
782
783 #[tokio::test]
786 async fn test_format_tool_input_existing_tool() {
787 let provider = MockProvider::new().with_text("ok");
788 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
789
790 agent.add_tool(EchoTool);
791
792 let params = serde_json::json!({"message": "test"});
793 let formatted = agent.format_tool_input("echo", ¶ms, crate::presentation::Display::Cli);
794
795 assert!(formatted.is_some());
797 }
798
799 #[tokio::test]
800 async fn test_format_tool_input_nonexistent_tool() {
801 let provider = MockProvider::new().with_text("ok");
802 let agent = Agent::builder().provider(provider).build().await.unwrap();
803
804 let params = serde_json::json!({"message": "test"});
805 let formatted =
806 agent.format_tool_input("nonexistent", ¶ms, crate::presentation::Display::Cli);
807
808 assert!(formatted.is_none());
809 }
810
811 #[tokio::test]
812 async fn test_format_tool_output_existing_tool() {
813 let provider = MockProvider::new().with_text("ok");
814 let mut agent = Agent::builder().provider(provider).build().await.unwrap();
815
816 agent.add_tool(EchoTool);
817
818 let result = crate::tool::ToolResult::text("output");
819 let formatted =
820 agent.format_tool_output("echo", &result, crate::presentation::Display::Cli);
821
822 assert!(formatted.is_some());
823 }
824
825 #[tokio::test]
826 async fn test_format_tool_output_nonexistent_tool() {
827 let provider = MockProvider::new().with_text("ok");
828 let agent = Agent::builder().provider(provider).build().await.unwrap();
829
830 let result = crate::tool::ToolResult::text("output");
831 let formatted =
832 agent.format_tool_output("nonexistent", &result, crate::presentation::Display::Cli);
833
834 assert!(formatted.is_none());
835 }
836}