mcp_sdk_rs/client/
mod.rs

1use crate::{
2    error::{Error, ErrorCode},
3    protocol::{Notification, Request, RequestId, Response},
4    transport::{Message, Transport},
5    types::{ClientCapabilities, Implementation, LoggingLevel, LoggingMessage, ServerCapabilities},
6};
7use async_trait::async_trait;
8use futures::StreamExt;
9use serde_json::{json, Value};
10use std::{collections::HashMap, sync::Arc};
11use tokio::sync::{
12    mpsc::{UnboundedReceiver, UnboundedSender},
13    Mutex, RwLock,
14};
15
16#[derive(Clone)]
17pub struct Session {
18    handler: Option<Arc<dyn ClientHandler>>,
19    transport: Arc<dyn Transport>,
20    receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
21    sender: Arc<UnboundedSender<Message>>,
22}
23impl Session {
24    /// Create a new session
25    pub fn new(
26        transport: Arc<dyn Transport>,
27        sender: UnboundedSender<Message>,
28        receiver: UnboundedReceiver<Message>,
29        handler: Option<Arc<dyn ClientHandler>>,
30    ) -> Self {
31        Self {
32            handler,
33            transport,
34            sender: Arc::new(sender),
35            receiver: Arc::new(Mutex::new(receiver)),
36        }
37    }
38
39    /// Start the session and listen for messages
40    pub async fn start(self) -> Result<(), Error> {
41        let transport = self.transport.clone();
42        let handler = self.handler.unwrap_or(Arc::new(DefaultClientHandler));
43        // listen for messages from the server
44        tokio::spawn(async move {
45            let mut stream = transport.receive();
46            while let Some(result) = stream.next().await {
47                match result {
48                    Ok(message) => match &message {
49                        Message::Request(r) => {
50                            let res = handler
51                                .handle_request(r.method.clone(), r.params.clone())
52                                .await;
53                            if transport
54                                .send(Message::Response(Response::success(
55                                    r.id.clone(),
56                                    Some(res.unwrap()),
57                                )))
58                                .await
59                                .is_err()
60                            {
61                                break;
62                            }
63                        }
64                        Message::Response(_) => {
65                            if self.sender.send(message).is_err() {
66                                break;
67                            }
68                        }
69                        Message::Notification(n) => {
70                            if handler
71                                .handle_notification(n.method.clone(), n.params.clone())
72                                .await
73                                .is_err()
74                            {
75                                break;
76                            }
77                        }
78                    },
79                    Err(_) => break,
80                }
81            }
82        });
83        // listen for requests from the client
84        let rx_clone = self.receiver.clone();
85        let tx_clone = self.transport.clone();
86        tokio::spawn(async move {
87            let mut stream = rx_clone.lock().await;
88            while let Some(message) = stream.recv().await {
89                tx_clone.send(message).await.unwrap();
90            }
91        });
92        Ok(())
93    }
94}
95
96/// Trait for implementing MCP client handlers
97#[async_trait]
98pub trait ClientHandler: Send + Sync {
99    /// Handle shutdown request
100    async fn shutdown(&self) -> Result<(), Error>;
101
102    /// Handle requests
103    async fn handle_request(
104        &self,
105        method: String,
106        params: Option<serde_json::Value>,
107    ) -> Result<serde_json::Value, Error>;
108
109    /// Handle notifications
110    async fn handle_notification(
111        &self,
112        method: String,
113        params: Option<serde_json::Value>,
114    ) -> Result<(), Error>;
115}
116
117#[derive(Clone, Default)]
118pub struct DefaultClientHandler;
119#[async_trait]
120impl ClientHandler for DefaultClientHandler {
121    /// Handle an incoming request
122    async fn handle_request(&self, method: String, _params: Option<Value>) -> Result<Value, Error> {
123        match method.as_str() {
124            "sampling/createMessage" => {
125                log::debug!("Got sampling/createMessage");
126                Ok(json!({}))
127            }
128            _ => Err(Error::Other("unknown method".to_string())),
129        }
130    }
131
132    /// Handle an incoming notification
133    async fn handle_notification(
134        &self,
135        method: String,
136        params: Option<Value>,
137    ) -> Result<(), Error> {
138        match method.as_str() {
139            "notifications/message" => {
140                // handle logging messages
141                if let Some(p) = params {
142                    let message: LoggingMessage = serde_json::from_value(p)?;
143                    log::log!(message.level.into(), "{}", message.data);
144                }
145                Ok(())
146            }
147            "notifications/resources/updated" => {
148                if let Some(p) = params {
149                    let update_params: HashMap<String, Value> = serde_json::from_value(p)?;
150                    if let Some(uri_val) = update_params.get("uri") {
151                        let uri = uri_val.as_str().ok_or("some file").unwrap();
152                        log::debug!("resource updated: {uri}");
153                    }
154                }
155                Ok(())
156            }
157            _ => Err(Error::Other("unknown notification".to_string())),
158        }
159    }
160
161    async fn shutdown(&self) -> Result<(), Error> {
162        log::debug!("Client shutting down");
163        Ok(())
164    }
165}
166
167/// MCP client state
168#[derive(Clone)]
169pub struct Client {
170    server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
171    request_counter: Arc<RwLock<i64>>,
172    #[allow(dead_code)]
173    sender: Arc<UnboundedSender<Message>>,
174    receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
175}
176
177impl Client {
178    /// Create a new MCP client
179    pub fn new(sender: UnboundedSender<Message>, receiver: UnboundedReceiver<Message>) -> Self {
180        Self {
181            server_capabilities: Arc::new(RwLock::new(None)),
182            request_counter: Arc::new(RwLock::new(0)),
183            sender: Arc::new(sender),
184            receiver: Arc::new(Mutex::new(receiver)),
185        }
186    }
187
188    /// Initialize the client
189    pub async fn initialize(
190        &self,
191        implementation: Implementation,
192        capabilities: Option<ClientCapabilities>,
193    ) -> Result<ServerCapabilities, Error> {
194        let params = serde_json::json!({
195            "clientInfo": implementation,
196            "capabilities": capabilities.unwrap_or_default(),
197            "protocolVersion": crate::LATEST_PROTOCOL_VERSION,
198        });
199        log::debug!("initializing client with capabilities: {}", params);
200        let response = self.request("initialize", Some(params)).await?;
201        let mut caps = ServerCapabilities::default();
202        if let Some(resp_obj) = response.as_object() {
203            if let Some(protocol_version) = resp_obj.get("protocolVersion") {
204                if let Some(v) = protocol_version.as_str() {
205                    if v != crate::LATEST_PROTOCOL_VERSION {
206                        log::error!("incorrect protocol version");
207                        self.shutdown().await?;
208                        return Err(Error::Other("incorrect protocol version".to_string()));
209                    }
210                }
211            }
212            if let Some(server_caps) = resp_obj.get("capabilities") {
213                caps = serde_json::from_value(server_caps.clone())?;
214            }
215        }
216        *self.server_capabilities.write().await = Some(caps.clone());
217        // Send initialized notification
218        self.notify("initialized", None).await?;
219        Ok(caps)
220    }
221
222    /// Send a request to the server and wait for the response.
223    ///
224    /// This method will block until a response is received from the server.
225    /// If the server returns an error, it will be propagated as an `Error`.
226    pub async fn request(
227        &self,
228        method: &str,
229        params: Option<serde_json::Value>,
230    ) -> Result<serde_json::Value, Error> {
231        let mut counter = self.request_counter.write().await;
232        *counter += 1;
233        let id = RequestId::Number(*counter);
234
235        let request = Request::new(method, params, id.clone());
236        self.sender
237            .send(Message::Request(request))
238            .map_err(|_| Error::Transport("failed to send request message".to_string()))?;
239        // Wait for matching response
240        let mut receiver = self.receiver.lock().await;
241        while let Some(message) = receiver.recv().await {
242            if let Message::Response(response) = message {
243                if response.id == id {
244                    if let Some(error) = response.error {
245                        return Err(Error::protocol(
246                            match error.code {
247                                -32700 => ErrorCode::ParseError,
248                                -32600 => ErrorCode::InvalidRequest,
249                                -32601 => ErrorCode::MethodNotFound,
250                                -32602 => ErrorCode::InvalidParams,
251                                -32603 => ErrorCode::InternalError,
252                                -32002 => ErrorCode::ServerNotInitialized,
253                                -32001 => ErrorCode::UnknownErrorCode,
254                                -32000 => ErrorCode::RequestFailed,
255                                _ => ErrorCode::UnknownErrorCode,
256                            },
257                            &error.message,
258                        ));
259                    }
260                    return response.result.ok_or_else(|| {
261                        Error::protocol(ErrorCode::InternalError, "Response missing result")
262                    });
263                }
264            }
265        }
266
267        Err(Error::protocol(
268            ErrorCode::InternalError,
269            "Connection closed while waiting for response",
270        ))
271    }
272
273    pub async fn subscribe(&self, uri: &str) -> Result<(), Error> {
274        let mut counter = self.request_counter.write().await;
275        *counter += 1;
276        let id = RequestId::Number(*counter);
277
278        let request = Request::new("resources/subscribe", Some(json!({"uri": uri})), id.clone());
279        self.sender.send(Message::Request(request)).map_err(|_| {
280            Error::Transport("failed to send subscribe request message".to_string())
281        })?;
282        Ok(())
283    }
284
285    pub async fn set_log_level(&self, level: LoggingLevel) -> Result<(), Error> {
286        let mut counter = self.request_counter.write().await;
287        *counter += 1;
288        let id = RequestId::Number(*counter);
289        let request = Request::new(
290            "logging/setLevel",
291            Some(json!({"level": level})),
292            id.clone(),
293        );
294        self.sender
295            .send(Message::Request(request))
296            .map_err(|_| Error::Transport("failed to set logging level".to_string()))?;
297        Ok(())
298    }
299
300    /// Send a notification to the server
301    pub async fn notify(
302        &self,
303        method: &str,
304        params: Option<serde_json::Value>,
305    ) -> Result<(), Error> {
306        let notification = Notification::new(method, params);
307        self.sender
308            .send(Message::Notification(notification))
309            .map_err(|_| Error::Transport("failed to send notification message".to_string()))
310    }
311
312    /// Get the server capabilities
313    pub async fn capabilities(&self) -> Option<ServerCapabilities> {
314        self.server_capabilities.read().await.clone()
315    }
316
317    /// Close the client connection
318    pub async fn shutdown(&self) -> Result<(), Error> {
319        // Send shutdown request
320        self.request("shutdown", None).await?;
321        // Send exit notification
322        self.notify("exit", None).await?;
323        Ok(())
324    }
325}