Skip to main content

thulp_mcp/
client.rs

1//! MCP client implementation.
2
3use crate::{McpTransport, Result};
4use std::collections::HashMap;
5use thulp_core::{
6    GetPromptResult, PromptListResult, ResourceContents, ResourceListResult,
7    ResourceTemplateListResult, ToolCall, ToolDefinition, ToolResult, Transport,
8};
9
10/// MCP client wrapper.
11pub struct McpClient {
12    transport: McpTransport,
13    tool_cache: HashMap<String, ToolDefinition>,
14    session_id: String,
15}
16
17impl McpClient {
18    /// Create a new MCP client.
19    pub fn new(transport: McpTransport) -> Self {
20        Self {
21            transport,
22            tool_cache: HashMap::new(),
23            session_id: uuid::Uuid::new_v4().to_string(),
24        }
25    }
26
27    /// Create a client builder.
28    pub fn builder() -> McpClientBuilder {
29        McpClientBuilder::new()
30    }
31
32    /// Connect to the MCP server.
33    pub async fn connect(&mut self) -> Result<()> {
34        self.transport.connect().await?;
35        Ok(())
36    }
37
38    /// Disconnect from the MCP server.
39    pub async fn disconnect(&mut self) -> Result<()> {
40        self.transport.disconnect().await?;
41        self.tool_cache.clear();
42        Ok(())
43    }
44
45    /// Check if connected.
46    pub fn is_connected(&self) -> bool {
47        self.transport.is_connected()
48    }
49
50    /// List available tools.
51    pub async fn list_tools(&mut self) -> Result<Vec<ToolDefinition>> {
52        if self.tool_cache.is_empty() {
53            let tools = self.transport.list_tools().await?;
54            for tool in &tools {
55                self.tool_cache.insert(tool.name.clone(), tool.clone());
56            }
57        }
58
59        Ok(self.tool_cache.values().cloned().collect())
60    }
61
62    /// Get a specific tool definition.
63    pub async fn get_tool(&mut self, name: &str) -> Result<Option<ToolDefinition>> {
64        if !self.tool_cache.contains_key(name) {
65            // Refresh cache if tool not found
66            self.list_tools().await?;
67        }
68        Ok(self.tool_cache.get(name).cloned())
69    }
70
71    /// Execute a tool call.
72    pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<ToolResult> {
73        let call = ToolCall {
74            tool: name.to_string(),
75            arguments,
76        };
77        self.transport.call(&call).await
78    }
79
80    /// Get the session ID.
81    pub fn session_id(&self) -> &str {
82        &self.session_id
83    }
84
85    /// Clear the tool cache.
86    pub fn clear_cache(&mut self) {
87        self.tool_cache.clear();
88    }
89}
90
91/// Builder for [`McpClient`].
92pub struct McpClientBuilder {
93    transport: Option<McpTransport>,
94}
95
96impl McpClientBuilder {
97    /// Create a new builder.
98    pub fn new() -> Self {
99        Self { transport: None }
100    }
101
102    /// Set the transport.
103    pub fn transport(mut self, transport: McpTransport) -> Self {
104        self.transport = Some(transport);
105        self
106    }
107
108    /// Build the client.
109    pub fn build(self) -> Result<McpClient> {
110        use thulp_core::Error;
111        let transport = self
112            .transport
113            .ok_or_else(|| Error::InvalidConfig("transport not set".to_string()))?;
114
115        Ok(McpClient::new(transport))
116    }
117}
118
119impl Default for McpClientBuilder {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125/// Convenience functions for common connection patterns.
126impl McpClient {
127    /// List resources from the connected MCP server.
128    pub async fn list_resources(&self) -> Result<ResourceListResult> {
129        self.transport.list_resources().await
130    }
131
132    /// Read a resource by URI from the connected MCP server.
133    pub async fn read_resource(&self, uri: &str) -> Result<ResourceContents> {
134        self.transport.read_resource(uri).await
135    }
136
137    /// List resource templates from the connected MCP server.
138    pub async fn list_resource_templates(&self) -> Result<ResourceTemplateListResult> {
139        self.transport.list_resource_templates().await
140    }
141
142    /// Subscribe to resource change notifications.
143    pub async fn subscribe_resource(&self, uri: &str) -> Result<()> {
144        self.transport.subscribe_resource(uri).await
145    }
146
147    /// Unsubscribe from resource change notifications.
148    pub async fn unsubscribe_resource(&self, uri: &str) -> Result<()> {
149        self.transport.unsubscribe_resource(uri).await
150    }
151
152    /// List prompts from the connected MCP server.
153    pub async fn list_prompts(&self) -> Result<PromptListResult> {
154        self.transport.list_prompts().await
155    }
156
157    /// Get a rendered prompt by name with arguments.
158    pub async fn get_prompt(
159        &self,
160        name: &str,
161        arguments: HashMap<String, String>,
162    ) -> Result<GetPromptResult> {
163        self.transport.get_prompt(name, arguments).await
164    }
165
166    /// Connect to an MCP server via HTTP.
167    pub async fn connect_http(name: String, url: String) -> Result<McpClient> {
168        let transport = McpTransport::new_http(name, url);
169        let mut client = McpClient::new(transport);
170
171        client.connect().await?;
172        Ok(client)
173    }
174
175    /// Connect to an MCP server via STDIO.
176    pub async fn connect_stdio(
177        name: String,
178        command: String,
179        args: Option<Vec<String>>,
180    ) -> Result<McpClient> {
181        let transport = McpTransport::new_stdio(name, command, args);
182        let mut client = McpClient::new(transport);
183
184        client.connect().await?;
185        Ok(client)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[tokio::test]
194    async fn client_creation() {
195        let transport =
196            McpTransport::new_http("test".to_string(), "http://localhost:8080".to_string());
197        let client = McpClient::new(transport);
198        assert!(!client.is_connected());
199    }
200
201    #[tokio::test]
202    async fn client_builder() {
203        let client = McpClient::builder()
204            .transport(McpTransport::new_http(
205                "test".to_string(),
206                "http://localhost:8080".to_string(),
207            ))
208            .build()
209            .unwrap();
210        assert!(!client.is_connected());
211    }
212
213    #[tokio::test]
214    async fn client_convenience() {
215        // This is a placeholder test since we can't actually connect to MCP servers in tests
216        // In real usage, this would connect to a real MCP server
217        assert!(true);
218    }
219}