thulp_mcp/
transport.rs

1//! MCP transport implementations using rs-utcp.
2
3use crate::Result;
4use async_trait::async_trait;
5use rs_utcp::providers::base::Provider;
6use rs_utcp::providers::mcp::McpProvider;
7use rs_utcp::transports::mcp::McpTransport as RsUtcpMcpTransport;
8use rs_utcp::transports::ClientTransport;
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12use thulp_core::{Error, ToolCall, ToolDefinition, ToolResult, Transport as CoreTransport};
13
14/// Wrapper around rs-utcp's MCP transport
15pub struct McpTransport {
16    /// The underlying rs-utcp transport
17    inner: RsUtcpMcpTransport,
18    /// The MCP provider
19    provider: Arc<dyn Provider>,
20    /// Connection status
21    connected: bool,
22}
23
24impl McpTransport {
25    /// Create a new MCP transport for HTTP connection
26    pub fn new_http(name: String, url: String) -> Self {
27        let provider = Arc::new(McpProvider::new(name, url, None));
28        let inner = RsUtcpMcpTransport::new();
29
30        Self {
31            inner,
32            provider,
33            connected: false,
34        }
35    }
36
37    /// Create a new MCP transport for STDIO connection
38    pub fn new_stdio(name: String, command: String, args: Option<Vec<String>>) -> Self {
39        let provider = Arc::new(McpProvider::new_stdio(name, command, args, None));
40        let inner = RsUtcpMcpTransport::new();
41
42        Self {
43            inner,
44            provider,
45            connected: false,
46        }
47    }
48
49    /// Create a new MCP transport with default configuration
50    pub fn new() -> Self {
51        Self::new_http("default".to_string(), "http://localhost:8080".to_string())
52    }
53}
54
55impl Default for McpTransport {
56    fn default() -> Self {
57        Self::new()
58    }
59}
60
61#[async_trait]
62impl CoreTransport for McpTransport {
63    async fn connect(&mut self) -> Result<()> {
64        // Register the provider with the transport
65        let _tools = self
66            .inner
67            .register_tool_provider(&*self.provider)
68            .await
69            .map_err(|e| Error::ExecutionFailed(format!("Failed to register provider: {}", e)))?;
70
71        self.connected = true;
72        Ok(())
73    }
74
75    async fn disconnect(&mut self) -> Result<()> {
76        self.inner
77            .deregister_tool_provider(&*self.provider)
78            .await
79            .map_err(|e| Error::ExecutionFailed(format!("Failed to deregister provider: {}", e)))?;
80        self.connected = false;
81        Ok(())
82    }
83
84    fn is_connected(&self) -> bool {
85        self.connected
86    }
87
88    async fn list_tools(&self) -> Result<Vec<ToolDefinition>> {
89        if !self.connected {
90            return Err(Error::ExecutionFailed("not connected".to_string()));
91        }
92
93        // Use rs-utcp's register_tool_provider which calls tools/list internally
94        let tools = self
95            .inner
96            .register_tool_provider(&*self.provider)
97            .await
98            .map_err(|e| Error::ExecutionFailed(format!("Failed to list tools: {}", e)))?;
99
100        // Convert rs-utcp Tool to ToolDefinition
101        let mut definitions = Vec::new();
102        for tool in tools {
103            // Convert the ToolInputOutputSchema to JSON for parsing
104            let inputs_json = serde_json::to_value(&tool.inputs).map_err(|e| {
105                Error::ExecutionFailed(format!("Failed to serialize inputs: {}", e))
106            })?;
107
108            // Extract parameters from tool.inputs
109            let parameters =
110                ToolDefinition::parse_mcp_input_schema(&inputs_json).unwrap_or_default();
111
112            definitions.push(ToolDefinition {
113                name: tool.name,
114                description: tool.description,
115                parameters,
116            });
117        }
118
119        Ok(definitions)
120    }
121
122    async fn call(&self, call: &ToolCall) -> Result<ToolResult> {
123        if !self.connected {
124            return Err(Error::ExecutionFailed("not connected".to_string()));
125        }
126
127        // Convert arguments to the format expected by rs-utcp
128        let args: HashMap<String, Value> = match &call.arguments {
129            Value::Object(map) => map.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
130            _ => HashMap::new(),
131        };
132
133        // Call the tool through the transport
134        let result = self
135            .inner
136            .call_tool(&call.tool, args, &*self.provider)
137            .await
138            .map_err(|e| Error::ExecutionFailed(format!("Tool call failed: {}", e)))?;
139
140        Ok(ToolResult::success(result))
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147    use serde_json::json;
148
149    #[test]
150    fn transport_new_http() {
151        let transport =
152            McpTransport::new_http("test".to_string(), "http://localhost:8080".to_string());
153        assert!(!transport.is_connected());
154    }
155
156    #[test]
157    fn transport_new_stdio() {
158        let transport = McpTransport::new_stdio(
159            "test".to_string(),
160            "test-cmd".to_string(),
161            Some(vec!["--arg1".to_string()]),
162        );
163        assert!(!transport.is_connected());
164    }
165
166    #[test]
167    fn transport_new_default() {
168        let transport = McpTransport::new();
169        assert!(!transport.is_connected());
170    }
171
172    #[test]
173    fn transport_default_impl() {
174        let transport1 = McpTransport::default();
175        let transport2 = McpTransport::new();
176        assert!(!transport1.is_connected());
177        assert!(!transport2.is_connected());
178    }
179
180    #[tokio::test]
181    async fn transport_connect_disconnect() {
182        let transport =
183            McpTransport::new_http("test".to_string(), "http://localhost:9999".to_string());
184
185        // Start with disconnected
186        assert!(!transport.is_connected());
187
188        // Note: We can't actually connect in tests without a real MCP server
189        // This test verifies the structure and state management
190        assert_eq!(transport.is_connected(), false);
191    }
192
193    #[test]
194    fn test_argument_conversion() {
195        // Test argument conversion for tool calls
196        let call = ToolCall::builder("test_tool")
197            .arg_str("string_param", "value")
198            .arg_int("int_param", 42)
199            .arg_bool("bool_param", true)
200            .build();
201
202        // Verify arguments are in expected format
203        assert!(call.arguments.is_object());
204        assert_eq!(call.arguments["string_param"], "value");
205        assert_eq!(call.arguments["int_param"], 42);
206        assert_eq!(call.arguments["bool_param"], true);
207    }
208
209    #[test]
210    fn test_argument_conversion_nested() {
211        // Test with nested object arguments
212        let call = ToolCall::builder("test_tool")
213            .arg("nested", json!({"key": "value", "number": 123}))
214            .build();
215
216        assert!(call.arguments.is_object());
217        assert_eq!(call.arguments["nested"]["key"], "value");
218        assert_eq!(call.arguments["nested"]["number"], 123);
219    }
220
221    #[test]
222    fn test_argument_conversion_array() {
223        // Test with array arguments
224        let call = ToolCall::builder("test_tool")
225            .arg("items", json!([1, 2, 3, 4, 5]))
226            .build();
227
228        assert!(call.arguments.is_object());
229        assert_eq!(call.arguments["items"], json!([1, 2, 3, 4, 5]));
230    }
231
232    // Edge case tests
233    #[tokio::test]
234    async fn test_list_tools_when_disconnected() {
235        let transport = McpTransport::new();
236
237        // Should fail when not connected
238        let result = transport.list_tools().await;
239        assert!(result.is_err());
240        assert!(result.unwrap_err().to_string().contains("not connected"));
241    }
242
243    #[tokio::test]
244    async fn test_call_when_disconnected() {
245        let transport = McpTransport::new();
246        let call = ToolCall::new("test_tool");
247
248        // Should fail when not connected
249        let result = transport.call(&call).await;
250        assert!(result.is_err());
251        assert!(result.unwrap_err().to_string().contains("not connected"));
252    }
253
254    #[test]
255    fn test_empty_arguments() {
256        let call = ToolCall::new("test_tool");
257        assert!(call.arguments.is_object());
258        assert_eq!(call.arguments.as_object().unwrap().len(), 0);
259    }
260
261    #[test]
262    fn test_special_characters_in_tool_name() {
263        let call = ToolCall::new("test-tool_v2.0");
264        assert_eq!(call.tool, "test-tool_v2.0");
265    }
266
267    #[test]
268    fn test_unicode_in_arguments() {
269        let call = ToolCall::builder("test_tool")
270            .arg_str("message", "Hello δΈ–η•Œ 🌍")
271            .build();
272
273        assert_eq!(call.arguments["message"], "Hello δΈ–η•Œ 🌍");
274    }
275
276    #[test]
277    fn test_large_argument_values() {
278        // Test with large string
279        let large_string = "x".repeat(10000);
280        let call = ToolCall::builder("test_tool")
281            .arg_str("data", &large_string)
282            .build();
283
284        assert_eq!(call.arguments["data"].as_str().unwrap().len(), 10000);
285    }
286
287    #[test]
288    fn test_null_arguments() {
289        let call = ToolCall::builder("test_tool")
290            .arg("null_param", json!(null))
291            .build();
292
293        assert!(call.arguments["null_param"].is_null());
294    }
295
296    #[test]
297    fn test_mixed_type_arguments() {
298        let call = ToolCall::builder("test_tool")
299            .arg_str("string", "value")
300            .arg_int("int", 42)
301            .arg_bool("bool", true)
302            .arg("null", json!(null))
303            .arg("array", json!([1, 2, 3]))
304            .arg("object", json!({"key": "value"}))
305            .build();
306
307        assert_eq!(call.arguments["string"], "value");
308        assert_eq!(call.arguments["int"], 42);
309        assert_eq!(call.arguments["bool"], true);
310        assert!(call.arguments["null"].is_null());
311        assert!(call.arguments["array"].is_array());
312        assert!(call.arguments["object"].is_object());
313    }
314
315    #[test]
316    fn test_stdio_transport_creation() {
317        let transport = McpTransport::new_stdio(
318            "echo-server".to_string(),
319            "npx".to_string(),
320            Some(vec![
321                "-y".to_string(),
322                "@modelcontextprotocol/server-echo".to_string(),
323            ]),
324        );
325
326        assert!(!transport.is_connected());
327    }
328
329    #[test]
330    fn test_http_transport_with_https() {
331        let transport = McpTransport::new_http(
332            "secure".to_string(),
333            "https://api.example.com/mcp".to_string(),
334        );
335
336        assert!(!transport.is_connected());
337    }
338
339    #[test]
340    fn test_transport_creation_with_empty_name() {
341        let transport = McpTransport::new_http("".to_string(), "http://localhost:8080".to_string());
342        assert!(!transport.is_connected());
343    }
344
345    #[test]
346    fn test_deeply_nested_arguments() {
347        let nested = json!({
348            "level1": {
349                "level2": {
350                    "level3": {
351                        "level4": {
352                            "level5": "deep"
353                        }
354                    }
355                }
356            }
357        });
358
359        let call = ToolCall::builder("test_tool")
360            .arg("nested", nested.clone())
361            .build();
362
363        assert_eq!(call.arguments["nested"], nested);
364        assert_eq!(
365            call.arguments["nested"]["level1"]["level2"]["level3"]["level4"]["level5"],
366            "deep"
367        );
368    }
369
370    #[test]
371    fn test_argument_with_numbers() {
372        let call = ToolCall::builder("test_tool")
373            .arg_int("positive", 42)
374            .arg_int("negative", -42)
375            .arg_int("zero", 0)
376            .arg("float", json!(3.14159))
377            .arg("scientific", json!(1.5e10))
378            .build();
379
380        assert_eq!(call.arguments["positive"], 42);
381        assert_eq!(call.arguments["negative"], -42);
382        assert_eq!(call.arguments["zero"], 0);
383        assert_eq!(call.arguments["float"], 3.14159);
384        assert_eq!(call.arguments["scientific"], 1.5e10);
385    }
386
387    #[test]
388    fn test_argument_with_special_json_values() {
389        let call = ToolCall::builder("test_tool")
390            .arg("empty_string", json!(""))
391            .arg("empty_array", json!([]))
392            .arg("empty_object", json!({}))
393            .arg("boolean_true", json!(true))
394            .arg("boolean_false", json!(false))
395            .build();
396
397        assert_eq!(call.arguments["empty_string"], "");
398        assert_eq!(call.arguments["empty_array"], json!([]));
399        assert_eq!(call.arguments["empty_object"], json!({}));
400        assert_eq!(call.arguments["boolean_true"], true);
401        assert_eq!(call.arguments["boolean_false"], false);
402    }
403}