Skip to main content

aster/mcp/
tool_manager.rs

1//! MCP Tool Manager
2//!
3//! This module implements the tool manager for MCP servers.
4//! It handles tool discovery, caching, argument validation, and tool invocation.
5//!
6//! # Features
7//!
8//! - Tool discovery and caching from connected servers
9//! - JSON Schema argument validation
10//! - Tool invocation with timeout support
11//! - Call tracking and cancellation
12//! - Batch tool calls with parallel execution
13//! - Result format conversion
14//!
15//! # Requirements Coverage
16//!
17//! - 4.1: Tool caching from connected servers
18//! - 4.2: Argument validation against input schema
19//! - 4.3: Descriptive error on validation failure
20//! - 4.4: Batch tool calls for parallel execution
21//! - 4.5: Tool call cancellation support
22//! - 4.6: Pending call tracking with unique IDs
23//! - 4.7: Tool call timeout handling
24//! - 4.8: MCP result to standardized format conversion
25
26use async_trait::async_trait;
27use chrono::{DateTime, Utc};
28use serde::{Deserialize, Serialize};
29use std::collections::HashMap;
30use std::sync::atomic::{AtomicU64, Ordering};
31use std::sync::Arc;
32use std::time::Duration;
33use tokio::sync::RwLock;
34use uuid::Uuid;
35
36use crate::mcp::connection_manager::ConnectionManager;
37use crate::mcp::error::{McpError, McpResult};
38use crate::mcp::transport::McpRequest;
39use crate::mcp::types::JsonObject;
40
41/// MCP tool definition
42///
43/// Represents a tool exposed by an MCP server, including its name,
44/// description, and input schema for argument validation.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct McpTool {
47    /// Tool name (unique within a server)
48    pub name: String,
49    /// Human-readable description
50    pub description: Option<String>,
51    /// JSON Schema for input validation
52    pub input_schema: serde_json::Value,
53    /// Server name that provides this tool
54    pub server_name: String,
55}
56
57impl McpTool {
58    /// Create a new MCP tool
59    pub fn new(
60        name: impl Into<String>,
61        server_name: impl Into<String>,
62        input_schema: serde_json::Value,
63    ) -> Self {
64        Self {
65            name: name.into(),
66            description: None,
67            input_schema,
68            server_name: server_name.into(),
69        }
70    }
71
72    /// Create a new MCP tool with description
73    pub fn with_description(
74        name: impl Into<String>,
75        server_name: impl Into<String>,
76        description: impl Into<String>,
77        input_schema: serde_json::Value,
78    ) -> Self {
79        Self {
80            name: name.into(),
81            description: Some(description.into()),
82            input_schema,
83            server_name: server_name.into(),
84        }
85    }
86}
87
88/// Tool result content types
89///
90/// MCP tools can return different types of content in their results.
91#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(tag = "type", rename_all = "lowercase")]
93pub enum ToolResultContent {
94    /// Text content
95    Text {
96        /// The text content
97        text: String,
98    },
99    /// Image content (base64 encoded)
100    Image {
101        /// Base64 encoded image data
102        data: String,
103        /// MIME type (e.g., "image/png")
104        #[serde(rename = "mimeType")]
105        mime_type: String,
106    },
107    /// Resource reference
108    Resource {
109        /// Resource URI
110        uri: String,
111        /// Optional text content
112        text: Option<String>,
113        /// Optional binary data (base64)
114        #[serde(rename = "blob")]
115        data: Option<String>,
116        /// MIME type
117        #[serde(rename = "mimeType")]
118        mime_type: Option<String>,
119    },
120}
121
122impl ToolResultContent {
123    /// Create text content
124    pub fn text(text: impl Into<String>) -> Self {
125        Self::Text { text: text.into() }
126    }
127
128    /// Create image content
129    pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
130        Self::Image {
131            data: data.into(),
132            mime_type: mime_type.into(),
133        }
134    }
135
136    /// Create resource content
137    pub fn resource(uri: impl Into<String>) -> Self {
138        Self::Resource {
139            uri: uri.into(),
140            text: None,
141            data: None,
142            mime_type: None,
143        }
144    }
145}
146
147/// Tool call result
148///
149/// Represents the result of a tool invocation, containing the content
150/// and an error flag.
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ToolCallResult {
153    /// Result content (can be multiple items)
154    pub content: Vec<ToolResultContent>,
155    /// Whether the result represents an error
156    #[serde(rename = "isError", default)]
157    pub is_error: bool,
158}
159
160impl ToolCallResult {
161    /// Create a successful result with text content
162    pub fn success_text(text: impl Into<String>) -> Self {
163        Self {
164            content: vec![ToolResultContent::text(text)],
165            is_error: false,
166        }
167    }
168
169    /// Create a successful result with multiple content items
170    pub fn success(content: Vec<ToolResultContent>) -> Self {
171        Self {
172            content,
173            is_error: false,
174        }
175    }
176
177    /// Create an error result
178    pub fn error(message: impl Into<String>) -> Self {
179        Self {
180            content: vec![ToolResultContent::text(message)],
181            is_error: true,
182        }
183    }
184
185    /// Check if the result is empty
186    pub fn is_empty(&self) -> bool {
187        self.content.is_empty()
188    }
189
190    /// Get the first text content if available
191    pub fn first_text(&self) -> Option<&str> {
192        self.content.iter().find_map(|c| match c {
193            ToolResultContent::Text { text } => Some(text.as_str()),
194            _ => None,
195        })
196    }
197}
198
199/// Argument validation result
200///
201/// Contains the validation status and any errors found.
202#[derive(Debug, Clone, Default)]
203pub struct ArgValidationResult {
204    /// Whether the arguments are valid
205    pub valid: bool,
206    /// Validation error messages
207    pub errors: Vec<String>,
208}
209
210impl ArgValidationResult {
211    /// Create a valid result
212    pub fn valid() -> Self {
213        Self {
214            valid: true,
215            errors: Vec::new(),
216        }
217    }
218
219    /// Create an invalid result with errors
220    pub fn invalid(errors: Vec<String>) -> Self {
221        Self {
222            valid: false,
223            errors,
224        }
225    }
226
227    /// Add an error
228    pub fn add_error(&mut self, error: impl Into<String>) {
229        self.valid = false;
230        self.errors.push(error.into());
231    }
232}
233
234/// Call information for tracking pending calls
235///
236/// Tracks the state of an in-progress tool call for monitoring
237/// and cancellation purposes.
238#[derive(Debug, Clone)]
239pub struct CallInfo {
240    /// Unique call ID
241    pub call_id: String,
242    /// Server name
243    pub server_name: String,
244    /// Tool name
245    pub tool_name: String,
246    /// Call arguments
247    pub args: JsonObject,
248    /// Call start time
249    pub start_time: DateTime<Utc>,
250    /// Whether the call has completed
251    pub completed: bool,
252    /// Whether the call was cancelled
253    pub cancelled: bool,
254}
255
256impl CallInfo {
257    /// Create a new call info
258    pub fn new(
259        call_id: impl Into<String>,
260        server_name: impl Into<String>,
261        tool_name: impl Into<String>,
262        args: JsonObject,
263    ) -> Self {
264        Self {
265            call_id: call_id.into(),
266            server_name: server_name.into(),
267            tool_name: tool_name.into(),
268            args,
269            start_time: Utc::now(),
270            completed: false,
271            cancelled: false,
272        }
273    }
274
275    /// Mark the call as completed
276    pub fn mark_completed(&mut self) {
277        self.completed = true;
278    }
279
280    /// Mark the call as cancelled
281    pub fn mark_cancelled(&mut self) {
282        self.cancelled = true;
283    }
284
285    /// Get the elapsed time since the call started
286    pub fn elapsed(&self) -> chrono::Duration {
287        Utc::now() - self.start_time
288    }
289}
290
291/// Tool call definition for batch operations
292///
293/// Defines a single tool call in a batch operation.
294#[derive(Debug, Clone)]
295pub struct ToolCall {
296    /// Server name
297    pub server_name: String,
298    /// Tool name
299    pub tool_name: String,
300    /// Call arguments
301    pub args: JsonObject,
302}
303
304impl ToolCall {
305    /// Create a new tool call
306    pub fn new(
307        server_name: impl Into<String>,
308        tool_name: impl Into<String>,
309        args: JsonObject,
310    ) -> Self {
311        Self {
312            server_name: server_name.into(),
313            tool_name: tool_name.into(),
314            args,
315        }
316    }
317}
318
319/// Tool manager trait
320///
321/// Defines the interface for managing MCP tools, including discovery,
322/// caching, validation, and invocation.
323#[async_trait]
324pub trait ToolManager: Send + Sync {
325    /// List all available tools from connected servers
326    ///
327    /// If `server_name` is provided, only lists tools from that server.
328    /// Results are cached for subsequent calls.
329    async fn list_tools(&self, server_name: Option<&str>) -> McpResult<Vec<McpTool>>;
330
331    /// Get a specific tool by server and name
332    ///
333    /// Returns the cached tool definition if available.
334    async fn get_tool(&self, server_name: &str, tool_name: &str) -> McpResult<Option<McpTool>>;
335
336    /// Clear the tool cache
337    ///
338    /// If `server_name` is provided, only clears cache for that server.
339    fn clear_cache(&self, server_name: Option<&str>);
340
341    /// Call a tool on a server
342    ///
343    /// Validates arguments before calling and tracks the call.
344    async fn call_tool(
345        &self,
346        server_name: &str,
347        tool_name: &str,
348        args: JsonObject,
349    ) -> McpResult<ToolCallResult>;
350
351    /// Call a tool with a timeout
352    ///
353    /// Returns an error if the call doesn't complete within the timeout.
354    async fn call_tool_with_timeout(
355        &self,
356        server_name: &str,
357        tool_name: &str,
358        args: JsonObject,
359        timeout: Duration,
360    ) -> McpResult<ToolCallResult>;
361
362    /// Validate tool arguments against the schema
363    ///
364    /// Returns validation result without making the actual call.
365    fn validate_args(&self, tool: &McpTool, args: &JsonObject) -> ArgValidationResult;
366
367    /// Cancel a pending tool call
368    ///
369    /// Sends a cancellation notification to the server.
370    fn cancel_call(&self, call_id: &str);
371
372    /// Get all pending (in-progress) calls
373    fn get_pending_calls(&self) -> Vec<CallInfo>;
374
375    /// Execute multiple tool calls in parallel
376    ///
377    /// Returns results in the same order as the input calls.
378    async fn call_tools_batch(&self, calls: Vec<ToolCall>) -> Vec<McpResult<ToolCallResult>>;
379}
380
381/// Tool cache entry
382struct ToolCacheEntry {
383    /// Cached tools
384    tools: Vec<McpTool>,
385    /// Cache timestamp
386    cached_at: DateTime<Utc>,
387}
388
389/// Default implementation of the tool manager
390pub struct McpToolManager<C: ConnectionManager> {
391    /// Connection manager for sending requests
392    connection_manager: Arc<C>,
393    /// Tool cache by server name
394    tool_cache: Arc<RwLock<HashMap<String, ToolCacheEntry>>>,
395    /// Pending calls by call ID
396    pending_calls: Arc<RwLock<HashMap<String, CallInfo>>>,
397    /// Call ID counter for unique ID generation
398    call_counter: AtomicU64,
399    /// Default timeout for tool calls
400    default_timeout: Duration,
401    /// Cache TTL (time-to-live)
402    cache_ttl: Duration,
403}
404
405impl<C: ConnectionManager> McpToolManager<C> {
406    /// Create a new tool manager
407    pub fn new(connection_manager: Arc<C>) -> Self {
408        Self {
409            connection_manager,
410            tool_cache: Arc::new(RwLock::new(HashMap::new())),
411            pending_calls: Arc::new(RwLock::new(HashMap::new())),
412            call_counter: AtomicU64::new(1),
413            default_timeout: Duration::from_secs(30),
414            cache_ttl: Duration::from_secs(300), // 5 minutes
415        }
416    }
417
418    /// Create a new tool manager with custom settings
419    pub fn with_settings(
420        connection_manager: Arc<C>,
421        default_timeout: Duration,
422        cache_ttl: Duration,
423    ) -> Self {
424        Self {
425            connection_manager,
426            tool_cache: Arc::new(RwLock::new(HashMap::new())),
427            pending_calls: Arc::new(RwLock::new(HashMap::new())),
428            call_counter: AtomicU64::new(1),
429            default_timeout,
430            cache_ttl,
431        }
432    }
433
434    /// Generate a unique call ID
435    pub fn generate_call_id(&self) -> String {
436        let counter = self.call_counter.fetch_add(1, Ordering::SeqCst);
437        format!("call-{}-{}", Uuid::new_v4(), counter)
438    }
439
440    /// Check if cache is valid for a server
441    fn is_cache_valid(&self, entry: &ToolCacheEntry) -> bool {
442        let age = Utc::now() - entry.cached_at;
443        age.num_seconds() < self.cache_ttl.as_secs() as i64
444    }
445
446    /// Fetch tools from a server (bypassing cache)
447    async fn fetch_tools_from_server(&self, server_name: &str) -> McpResult<Vec<McpTool>> {
448        // Get connection for the server
449        let connection = self
450            .connection_manager
451            .get_connection_by_server(server_name)
452            .ok_or_else(|| {
453                McpError::connection(format!("No connection found for server: {}", server_name))
454            })?;
455
456        // Send tools/list request
457        let request = McpRequest::new(
458            serde_json::json!(format!("tools-list-{}", Uuid::new_v4())),
459            "tools/list",
460        );
461
462        let response = self
463            .connection_manager
464            .send(&connection.id, request)
465            .await?;
466
467        // Parse response
468        let result = response.into_result()?;
469
470        // Extract tools from response
471        let tools_value = result
472            .get("tools")
473            .ok_or_else(|| McpError::protocol("Response missing 'tools' field"))?;
474
475        let raw_tools: Vec<serde_json::Value> = serde_json::from_value(tools_value.clone())
476            .map_err(|e| McpError::protocol(format!("Failed to parse tools: {}", e)))?;
477
478        // Convert to McpTool
479        let tools: Vec<McpTool> = raw_tools
480            .into_iter()
481            .filter_map(|t| {
482                let name = t.get("name")?.as_str()?.to_string();
483                let description = t
484                    .get("description")
485                    .and_then(|d| d.as_str())
486                    .map(String::from);
487                let input_schema = t
488                    .get("inputSchema")
489                    .cloned()
490                    .unwrap_or(serde_json::json!({}));
491
492                Some(McpTool {
493                    name,
494                    description,
495                    input_schema,
496                    server_name: server_name.to_string(),
497                })
498            })
499            .collect();
500
501        Ok(tools)
502    }
503
504    /// Register a pending call
505    async fn register_call(&self, call_info: CallInfo) {
506        let mut calls = self.pending_calls.write().await;
507        calls.insert(call_info.call_id.clone(), call_info);
508    }
509
510    /// Complete a pending call
511    async fn complete_call(&self, call_id: &str) {
512        let mut calls = self.pending_calls.write().await;
513        if let Some(info) = calls.get_mut(call_id) {
514            info.mark_completed();
515        }
516        calls.remove(call_id);
517    }
518
519    /// Convert MCP tool result to standardized format
520    ///
521    /// This handles the conversion from raw MCP response to ToolCallResult.
522    fn convert_result(&self, result: serde_json::Value) -> McpResult<ToolCallResult> {
523        // Check if result has content array
524        if let Some(content) = result.get("content") {
525            let content_items: Vec<ToolResultContent> = serde_json::from_value(content.clone())
526                .map_err(|e| {
527                    McpError::protocol(format!("Failed to parse tool result content: {}", e))
528                })?;
529
530            let is_error = result
531                .get("isError")
532                .and_then(|v| v.as_bool())
533                .unwrap_or(false);
534
535            return Ok(ToolCallResult {
536                content: content_items,
537                is_error,
538            });
539        }
540
541        // Handle legacy format or simple text response
542        if let Some(text) = result.as_str() {
543            return Ok(ToolCallResult::success_text(text));
544        }
545
546        // Return the raw result as JSON text
547        Ok(ToolCallResult::success_text(result.to_string()))
548    }
549}
550
551#[async_trait]
552impl<C: ConnectionManager + 'static> ToolManager for McpToolManager<C> {
553    async fn list_tools(&self, server_name: Option<&str>) -> McpResult<Vec<McpTool>> {
554        match server_name {
555            Some(name) => {
556                // Check cache first
557                {
558                    let cache = self.tool_cache.read().await;
559                    if let Some(entry) = cache.get(name) {
560                        if self.is_cache_valid(entry) {
561                            return Ok(entry.tools.clone());
562                        }
563                    }
564                }
565
566                // Fetch from server
567                let tools = self.fetch_tools_from_server(name).await?;
568
569                // Update cache
570                {
571                    let mut cache = self.tool_cache.write().await;
572                    cache.insert(
573                        name.to_string(),
574                        ToolCacheEntry {
575                            tools: tools.clone(),
576                            cached_at: Utc::now(),
577                        },
578                    );
579                }
580
581                Ok(tools)
582            }
583            None => {
584                // List tools from all connected servers
585                let connections = self.connection_manager.get_all_connections();
586                let mut all_tools = Vec::new();
587
588                for conn in connections {
589                    match self.list_tools(Some(&conn.server_name)).await {
590                        Ok(tools) => all_tools.extend(tools),
591                        Err(e) => {
592                            tracing::warn!(
593                                "Failed to list tools from server {}: {}",
594                                conn.server_name,
595                                e
596                            );
597                        }
598                    }
599                }
600
601                Ok(all_tools)
602            }
603        }
604    }
605
606    async fn get_tool(&self, server_name: &str, tool_name: &str) -> McpResult<Option<McpTool>> {
607        let tools = self.list_tools(Some(server_name)).await?;
608        Ok(tools.into_iter().find(|t| t.name == tool_name))
609    }
610
611    fn clear_cache(&self, server_name: Option<&str>) {
612        // Convert to owned string for async move
613        let server_name_owned = server_name.map(|s| s.to_string());
614        let cache = self.tool_cache.clone();
615        tokio::spawn(async move {
616            let mut cache = cache.write().await;
617            match server_name_owned {
618                Some(name) => {
619                    cache.remove(&name);
620                }
621                None => {
622                    cache.clear();
623                }
624            }
625        });
626    }
627
628    async fn call_tool(
629        &self,
630        server_name: &str,
631        tool_name: &str,
632        args: JsonObject,
633    ) -> McpResult<ToolCallResult> {
634        self.call_tool_with_timeout(server_name, tool_name, args, self.default_timeout)
635            .await
636    }
637
638    async fn call_tool_with_timeout(
639        &self,
640        server_name: &str,
641        tool_name: &str,
642        args: JsonObject,
643        timeout: Duration,
644    ) -> McpResult<ToolCallResult> {
645        // Get the tool definition for validation
646        let tool = self
647            .get_tool(server_name, tool_name)
648            .await?
649            .ok_or_else(|| {
650                McpError::tool(
651                    format!("Tool not found: {}/{}", server_name, tool_name),
652                    Some(tool_name.to_string()),
653                )
654            })?;
655
656        // Validate arguments
657        let validation = self.validate_args(&tool, &args);
658        if !validation.valid {
659            return Err(McpError::validation(
660                format!(
661                    "Invalid arguments for tool {}: {}",
662                    tool_name,
663                    validation.errors.join(", ")
664                ),
665                validation.errors,
666            ));
667        }
668
669        // Get connection
670        let connection = self
671            .connection_manager
672            .get_connection_by_server(server_name)
673            .ok_or_else(|| {
674                McpError::connection(format!("No connection found for server: {}", server_name))
675            })?;
676
677        // Generate call ID and register
678        let call_id = self.generate_call_id();
679        let call_info = CallInfo::new(&call_id, server_name, tool_name, args.clone());
680        self.register_call(call_info).await;
681
682        // Build request
683        let request = McpRequest::with_params(
684            serde_json::json!(call_id.clone()),
685            "tools/call",
686            serde_json::json!({
687                "name": tool_name,
688                "arguments": args
689            }),
690        );
691
692        // Send request with timeout
693        let result = self
694            .connection_manager
695            .send_with_timeout(&connection.id, request, timeout)
696            .await;
697
698        // Complete the call
699        self.complete_call(&call_id).await;
700
701        // Handle result
702        match result {
703            Ok(response) => {
704                let result_value = response.into_result()?;
705                self.convert_result(result_value)
706            }
707            Err(e) => Err(e),
708        }
709    }
710
711    fn validate_args(&self, tool: &McpTool, args: &JsonObject) -> ArgValidationResult {
712        let schema = &tool.input_schema;
713
714        // If no schema or empty schema, accept any args
715        if schema.is_null()
716            || (schema.is_object() && schema.as_object().is_none_or(|o| o.is_empty()))
717        {
718            return ArgValidationResult::valid();
719        }
720
721        let mut result = ArgValidationResult::valid();
722
723        // Check required properties
724        if let Some(required) = schema.get("required").and_then(|r| r.as_array()) {
725            for req in required {
726                if let Some(field_name) = req.as_str() {
727                    if !args.contains_key(field_name) {
728                        result.add_error(format!("Missing required field: {}", field_name));
729                    }
730                }
731            }
732        }
733
734        // Check property types if properties are defined
735        if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
736            for (key, value) in args.iter() {
737                if let Some(prop_schema) = properties.get(key) {
738                    // Validate type
739                    if let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str()) {
740                        let actual_type = get_json_type(value);
741                        if !types_compatible(expected_type, &actual_type) {
742                            result.add_error(format!(
743                                "Field '{}' has wrong type: expected {}, got {}",
744                                key, expected_type, actual_type
745                            ));
746                        }
747                    }
748                }
749            }
750        }
751
752        // Check for additional properties if not allowed
753        if let Some(additional) = schema.get("additionalProperties") {
754            if additional == &serde_json::Value::Bool(false) {
755                if let Some(properties) = schema.get("properties").and_then(|p| p.as_object()) {
756                    for key in args.keys() {
757                        if !properties.contains_key(key) {
758                            result.add_error(format!("Unknown field: {}", key));
759                        }
760                    }
761                }
762            }
763        }
764
765        result
766    }
767
768    fn cancel_call(&self, call_id: &str) {
769        let pending_calls = self.pending_calls.clone();
770        let connection_manager = self.connection_manager.clone();
771        let call_id = call_id.to_string();
772
773        tokio::spawn(async move {
774            let mut calls = pending_calls.write().await;
775            if let Some(info) = calls.get_mut(&call_id) {
776                info.mark_cancelled();
777
778                // Send cancellation to server
779                if let Some(conn) = connection_manager.get_connection_by_server(&info.server_name) {
780                    let _ = connection_manager.cancel_request(&conn.id, &call_id).await;
781                }
782            }
783        });
784    }
785
786    fn get_pending_calls(&self) -> Vec<CallInfo> {
787        // Use try_read to avoid blocking
788        self.pending_calls
789            .try_read()
790            .map(|calls| calls.values().cloned().collect())
791            .unwrap_or_default()
792    }
793
794    async fn call_tools_batch(&self, calls: Vec<ToolCall>) -> Vec<McpResult<ToolCallResult>> {
795        use futures::future::join_all;
796
797        let futures: Vec<_> = calls
798            .into_iter()
799            .map(|call| {
800                let server_name = call.server_name.clone();
801                let tool_name = call.tool_name.clone();
802                let args = call.args;
803                async move { self.call_tool(&server_name, &tool_name, args).await }
804            })
805            .collect();
806
807        join_all(futures).await
808    }
809}
810
811/// Get the JSON type name for a value
812fn get_json_type(value: &serde_json::Value) -> String {
813    match value {
814        serde_json::Value::Null => "null".to_string(),
815        serde_json::Value::Bool(_) => "boolean".to_string(),
816        serde_json::Value::Number(n) => {
817            if n.is_i64() || n.is_u64() {
818                "integer".to_string()
819            } else {
820                "number".to_string()
821            }
822        }
823        serde_json::Value::String(_) => "string".to_string(),
824        serde_json::Value::Array(_) => "array".to_string(),
825        serde_json::Value::Object(_) => "object".to_string(),
826    }
827}
828
829/// Check if types are compatible
830fn types_compatible(expected: &str, actual: &str) -> bool {
831    if expected == actual {
832        return true;
833    }
834    // number accepts integer
835    if expected == "number" && actual == "integer" {
836        return true;
837    }
838    false
839}
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844
845    #[test]
846    fn test_mcp_tool_new() {
847        let tool = McpTool::new("test_tool", "test_server", serde_json::json!({}));
848        assert_eq!(tool.name, "test_tool");
849        assert_eq!(tool.server_name, "test_server");
850        assert!(tool.description.is_none());
851    }
852
853    #[test]
854    fn test_mcp_tool_with_description() {
855        let tool = McpTool::with_description(
856            "test_tool",
857            "test_server",
858            "A test tool",
859            serde_json::json!({}),
860        );
861        assert_eq!(tool.description, Some("A test tool".to_string()));
862    }
863
864    #[test]
865    fn test_tool_result_content_text() {
866        let content = ToolResultContent::text("Hello, world!");
867        match content {
868            ToolResultContent::Text { text } => assert_eq!(text, "Hello, world!"),
869            _ => panic!("Expected Text content"),
870        }
871    }
872
873    #[test]
874    fn test_tool_result_content_image() {
875        let content = ToolResultContent::image("base64data", "image/png");
876        match content {
877            ToolResultContent::Image { data, mime_type } => {
878                assert_eq!(data, "base64data");
879                assert_eq!(mime_type, "image/png");
880            }
881            _ => panic!("Expected Image content"),
882        }
883    }
884
885    #[test]
886    fn test_tool_call_result_success() {
887        let result = ToolCallResult::success_text("Success!");
888        assert!(!result.is_error);
889        assert_eq!(result.first_text(), Some("Success!"));
890    }
891
892    #[test]
893    fn test_tool_call_result_error() {
894        let result = ToolCallResult::error("Something went wrong");
895        assert!(result.is_error);
896        assert_eq!(result.first_text(), Some("Something went wrong"));
897    }
898
899    #[test]
900    fn test_arg_validation_result_valid() {
901        let result = ArgValidationResult::valid();
902        assert!(result.valid);
903        assert!(result.errors.is_empty());
904    }
905
906    #[test]
907    fn test_arg_validation_result_invalid() {
908        let result = ArgValidationResult::invalid(vec!["Missing field".to_string()]);
909        assert!(!result.valid);
910        assert_eq!(result.errors.len(), 1);
911    }
912
913    #[test]
914    fn test_call_info_new() {
915        let args = serde_json::Map::new();
916        let info = CallInfo::new("call-1", "server", "tool", args);
917        assert_eq!(info.call_id, "call-1");
918        assert_eq!(info.server_name, "server");
919        assert_eq!(info.tool_name, "tool");
920        assert!(!info.completed);
921        assert!(!info.cancelled);
922    }
923
924    #[test]
925    fn test_call_info_mark_completed() {
926        let args = serde_json::Map::new();
927        let mut info = CallInfo::new("call-1", "server", "tool", args);
928        info.mark_completed();
929        assert!(info.completed);
930    }
931
932    #[test]
933    fn test_call_info_mark_cancelled() {
934        let args = serde_json::Map::new();
935        let mut info = CallInfo::new("call-1", "server", "tool", args);
936        info.mark_cancelled();
937        assert!(info.cancelled);
938    }
939
940    #[test]
941    fn test_tool_call_new() {
942        let args = serde_json::Map::new();
943        let call = ToolCall::new("server", "tool", args);
944        assert_eq!(call.server_name, "server");
945        assert_eq!(call.tool_name, "tool");
946    }
947
948    #[test]
949    fn test_get_json_type() {
950        assert_eq!(get_json_type(&serde_json::Value::Null), "null");
951        assert_eq!(get_json_type(&serde_json::json!(true)), "boolean");
952        assert_eq!(get_json_type(&serde_json::json!(42)), "integer");
953        assert_eq!(get_json_type(&serde_json::json!(3.15)), "number");
954        assert_eq!(get_json_type(&serde_json::json!("hello")), "string");
955        assert_eq!(get_json_type(&serde_json::json!([1, 2, 3])), "array");
956        assert_eq!(
957            get_json_type(&serde_json::json!({"key": "value"})),
958            "object"
959        );
960    }
961
962    #[test]
963    fn test_types_compatible() {
964        assert!(types_compatible("string", "string"));
965        assert!(types_compatible("number", "integer"));
966        assert!(!types_compatible("string", "number"));
967        assert!(!types_compatible("integer", "number"));
968    }
969}