mcp_sdk_rs/client/
mod.rs

1use crate::{
2    error::{Error, ErrorCode},
3    protocol::{Notification, Request, RequestId},
4    transport::Message,
5    types::{ClientCapabilities, Implementation, LoggingLevel, LoggingMessage, ServerCapabilities},
6};
7use async_trait::async_trait;
8use serde_json::{json, Value};
9use std::{collections::HashMap, sync::Arc};
10use tokio::sync::{
11    mpsc::{UnboundedReceiver, UnboundedSender},
12    Mutex, RwLock,
13};
14
15/// Trait for implementing MCP client handlers
16#[async_trait]
17pub trait ClientHandler: Send + Sync {
18    /// Handle shutdown request
19    async fn shutdown(&self) -> Result<(), Error>;
20
21    /// Handle requests
22    async fn handle_request(
23        &self,
24        method: String,
25        params: Option<serde_json::Value>,
26    ) -> Result<serde_json::Value, Error>;
27
28    /// Handle notifications
29    async fn handle_notification(
30        &self,
31        method: String,
32        params: Option<serde_json::Value>,
33    ) -> Result<(), Error>;
34}
35
36#[derive(Clone, Default)]
37pub struct DefaultClientHandler;
38#[async_trait]
39impl ClientHandler for DefaultClientHandler {
40    /// Handle an incoming request
41    async fn handle_request(&self, method: String, _params: Option<Value>) -> Result<Value, Error> {
42        match method.as_str() {
43            "sampling/createMessage" => {
44                log::debug!("Got sampling/createMessage");
45                Ok(json!({}))
46            }
47            _ => Err(Error::Other("unknown method".to_string())),
48        }
49    }
50
51    /// Handle an incoming notification
52    async fn handle_notification(
53        &self,
54        method: String,
55        params: Option<Value>,
56    ) -> Result<(), Error> {
57        match method.as_str() {
58            "notifications/message" => {
59                // handle logging messages
60                if let Some(p) = params {
61                    let message: LoggingMessage = serde_json::from_value(p)?;
62                    log::log!(message.level.into(), "{}", message.data);
63                }
64                Ok(())
65            }
66            "notifications/resources/updated" => {
67                if let Some(p) = params {
68                    let update_params: HashMap<String, Value> = serde_json::from_value(p)?;
69                    if let Some(uri_val) = update_params.get("uri") {
70                        let uri = uri_val.as_str().ok_or("some file").unwrap();
71                        log::debug!("resource updated: {uri}");
72                    }
73                }
74                Ok(())
75            }
76            _ => Err(Error::Other("unknown notification".to_string())),
77        }
78    }
79
80    async fn shutdown(&self) -> Result<(), Error> {
81        log::debug!("Client shutting down");
82        Ok(())
83    }
84}
85
86/// MCP client state
87#[derive(Clone, Debug)]
88pub struct Client {
89    server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
90    request_counter: Arc<RwLock<i64>>,
91    #[allow(dead_code)]
92    sender: Arc<UnboundedSender<Message>>,
93    receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
94}
95
96impl Client {
97    /// Create a new MCP client
98    pub fn new(sender: UnboundedSender<Message>, receiver: UnboundedReceiver<Message>) -> Self {
99        Self {
100            server_capabilities: Arc::new(RwLock::new(None)),
101            request_counter: Arc::new(RwLock::new(0)),
102            sender: Arc::new(sender),
103            receiver: Arc::new(Mutex::new(receiver)),
104        }
105    }
106
107    /// Initialize the client
108    pub async fn initialize(
109        &self,
110        implementation: Implementation,
111        capabilities: Option<ClientCapabilities>,
112    ) -> Result<ServerCapabilities, Error> {
113        let params = serde_json::json!({
114            "clientInfo": implementation,
115            "capabilities": capabilities.unwrap_or_default(),
116            "protocolVersion": crate::LATEST_PROTOCOL_VERSION,
117        });
118        log::debug!("initializing client with capabilities: {}", params);
119        let response = self.request("initialize", Some(params)).await?;
120        let mut caps = ServerCapabilities::default();
121        if let Some(resp_obj) = response.as_object() {
122            if let Some(protocol_version) = resp_obj.get("protocolVersion") {
123                if let Some(v) = protocol_version.as_str() {
124                    if v != crate::LATEST_PROTOCOL_VERSION {
125                        log::warn!(
126                            "mcp server does not support the latest protocol version {}",
127                            crate::LATEST_PROTOCOL_VERSION
128                        );
129                        log::warn!("latest supported protocol version is {v}");
130                    }
131                }
132            }
133            if let Some(server_caps) = resp_obj.get("capabilities") {
134                caps = serde_json::from_value(server_caps.clone())?;
135            }
136        }
137        *self.server_capabilities.write().await = Some(caps.clone());
138        // Send initialized notification
139        self.notify("initialized", None).await?;
140        Ok(caps)
141    }
142
143    /// Send a request to the server and wait for the response.
144    ///
145    /// This method will block until a response is received from the server.
146    /// If the server returns an error, it will be propagated as an `Error`.
147    pub async fn request(
148        &self,
149        method: &str,
150        params: Option<serde_json::Value>,
151    ) -> Result<serde_json::Value, Error> {
152        let mut counter = self.request_counter.write().await;
153        *counter += 1;
154        let id = RequestId::Number(*counter);
155        let request = Request::new(method, params, id.clone());
156        self.sender
157            .send(Message::Request(request))
158            .map_err(|_| Error::Transport("failed to send request message".to_string()))?;
159        // Wait for matching response
160        let mut receiver = self.receiver.lock().await;
161        while let Some(message) = receiver.recv().await {
162            if let Message::Response(response) = message {
163                if response.id == id {
164                    if let Some(error) = response.error {
165                        return Err(Error::protocol(
166                            match error.code {
167                                -32700 => ErrorCode::ParseError,
168                                -32600 => ErrorCode::InvalidRequest,
169                                -32601 => ErrorCode::MethodNotFound,
170                                -32602 => ErrorCode::InvalidParams,
171                                -32603 => ErrorCode::InternalError,
172                                -32002 => ErrorCode::ServerNotInitialized,
173                                -32001 => ErrorCode::UnknownErrorCode,
174                                -32000 => ErrorCode::RequestFailed,
175                                _ => ErrorCode::UnknownErrorCode,
176                            },
177                            &error.message,
178                        ));
179                    }
180                    return response.result.ok_or_else(|| {
181                        Error::protocol(ErrorCode::InternalError, "Response missing result")
182                    });
183                }
184            }
185        }
186
187        Err(Error::protocol(
188            ErrorCode::InternalError,
189            "Connection closed while waiting for response",
190        ))
191    }
192
193    pub async fn subscribe(&self, uri: &str) -> Result<(), Error> {
194        let mut counter = self.request_counter.write().await;
195        *counter += 1;
196        let id = RequestId::Number(*counter);
197
198        let request = Request::new("resources/subscribe", Some(json!({"uri": uri})), id.clone());
199        self.sender.send(Message::Request(request)).map_err(|_| {
200            Error::Transport("failed to send subscribe request message".to_string())
201        })?;
202        Ok(())
203    }
204
205    pub async fn set_log_level(&self, level: LoggingLevel) -> Result<(), Error> {
206        let mut counter = self.request_counter.write().await;
207        *counter += 1;
208        let id = RequestId::Number(*counter);
209        let request = Request::new(
210            "logging/setLevel",
211            Some(json!({"level": level})),
212            id.clone(),
213        );
214        self.sender
215            .send(Message::Request(request))
216            .map_err(|_| Error::Transport("failed to set logging level".to_string()))?;
217        Ok(())
218    }
219
220    /// Send a notification to the server
221    pub async fn notify(
222        &self,
223        method: &str,
224        params: Option<serde_json::Value>,
225    ) -> Result<(), Error> {
226        let path = format!("notifications/{method}");
227        let notification = Notification::new(path, params);
228        self.sender
229            .send(Message::Notification(notification))
230            .map_err(|_| Error::Transport("failed to send notification message".to_string()))
231    }
232
233    /// Get the server capabilities
234    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
235        self.server_capabilities.read().await.clone()
236    }
237
238    /// Close the client connection
239    pub async fn shutdown(&self) -> Result<(), Error> {
240        // Send shutdown request
241        self.request("shutdown", None).await?;
242        // Send exit notification
243        self.notify("exit", None).await?;
244        Ok(())
245    }
246}