1use crate::{
2 error::{Error, ErrorCode},
3 protocol::{Notification, Request, RequestId},
4 transport::Message,
5 types::{ClientCapabilities, Implementation, LoggingLevel, LoggingMessage, ServerCapabilities},
6};
7use async_trait::async_trait;
8use serde_json::{json, Value};
9use std::{collections::HashMap, sync::Arc};
10use tokio::sync::{
11 mpsc::{UnboundedReceiver, UnboundedSender},
12 Mutex, RwLock,
13};
14
15#[async_trait]
17pub trait ClientHandler: Send + Sync {
18 async fn shutdown(&self) -> Result<(), Error>;
20
21 async fn handle_request(
23 &self,
24 method: String,
25 params: Option<serde_json::Value>,
26 ) -> Result<serde_json::Value, Error>;
27
28 async fn handle_notification(
30 &self,
31 method: String,
32 params: Option<serde_json::Value>,
33 ) -> Result<(), Error>;
34}
35
36#[derive(Clone, Default)]
37pub struct DefaultClientHandler;
38#[async_trait]
39impl ClientHandler for DefaultClientHandler {
40 async fn handle_request(&self, method: String, _params: Option<Value>) -> Result<Value, Error> {
42 match method.as_str() {
43 "sampling/createMessage" => {
44 log::debug!("Got sampling/createMessage");
45 Ok(json!({}))
46 }
47 _ => Err(Error::Other("unknown method".to_string())),
48 }
49 }
50
51 async fn handle_notification(
53 &self,
54 method: String,
55 params: Option<Value>,
56 ) -> Result<(), Error> {
57 match method.as_str() {
58 "notifications/message" => {
59 if let Some(p) = params {
61 let message: LoggingMessage = serde_json::from_value(p)?;
62 log::log!(message.level.into(), "{}", message.data);
63 }
64 Ok(())
65 }
66 "notifications/resources/updated" => {
67 if let Some(p) = params {
68 let update_params: HashMap<String, Value> = serde_json::from_value(p)?;
69 if let Some(uri_val) = update_params.get("uri") {
70 let uri = uri_val.as_str().ok_or("some file").unwrap();
71 log::debug!("resource updated: {uri}");
72 }
73 }
74 Ok(())
75 }
76 _ => Err(Error::Other("unknown notification".to_string())),
77 }
78 }
79
80 async fn shutdown(&self) -> Result<(), Error> {
81 log::debug!("Client shutting down");
82 Ok(())
83 }
84}
85
86#[derive(Clone, Debug)]
88pub struct Client {
89 server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
90 request_counter: Arc<RwLock<i64>>,
91 #[allow(dead_code)]
92 sender: Arc<UnboundedSender<Message>>,
93 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
94}
95
96impl Client {
97 pub fn new(sender: UnboundedSender<Message>, receiver: UnboundedReceiver<Message>) -> Self {
99 Self {
100 server_capabilities: Arc::new(RwLock::new(None)),
101 request_counter: Arc::new(RwLock::new(0)),
102 sender: Arc::new(sender),
103 receiver: Arc::new(Mutex::new(receiver)),
104 }
105 }
106
107 pub async fn initialize(
109 &self,
110 implementation: Implementation,
111 capabilities: Option<ClientCapabilities>,
112 ) -> Result<ServerCapabilities, Error> {
113 let params = serde_json::json!({
114 "clientInfo": implementation,
115 "capabilities": capabilities.unwrap_or_default(),
116 "protocolVersion": crate::LATEST_PROTOCOL_VERSION,
117 });
118 log::debug!("initializing client with capabilities: {}", params);
119 let response = self.request("initialize", Some(params)).await?;
120 let mut caps = ServerCapabilities::default();
121 if let Some(resp_obj) = response.as_object() {
122 if let Some(protocol_version) = resp_obj.get("protocolVersion") {
123 if let Some(v) = protocol_version.as_str() {
124 if v != crate::LATEST_PROTOCOL_VERSION {
125 log::warn!(
126 "mcp server does not support the latest protocol version {}",
127 crate::LATEST_PROTOCOL_VERSION
128 );
129 log::warn!("latest supported protocol version is {v}");
130 }
131 }
132 }
133 if let Some(server_caps) = resp_obj.get("capabilities") {
134 caps = serde_json::from_value(server_caps.clone())?;
135 }
136 }
137 *self.server_capabilities.write().await = Some(caps.clone());
138 self.notify("initialized", None).await?;
140 Ok(caps)
141 }
142
143 pub async fn request(
148 &self,
149 method: &str,
150 params: Option<serde_json::Value>,
151 ) -> Result<serde_json::Value, Error> {
152 let mut counter = self.request_counter.write().await;
153 *counter += 1;
154 let id = RequestId::Number(*counter);
155 let request = Request::new(method, params, id.clone());
156 self.sender
157 .send(Message::Request(request))
158 .map_err(|_| Error::Transport("failed to send request message".to_string()))?;
159 let mut receiver = self.receiver.lock().await;
161 while let Some(message) = receiver.recv().await {
162 if let Message::Response(response) = message {
163 if response.id == id {
164 if let Some(error) = response.error {
165 return Err(Error::protocol(
166 match error.code {
167 -32700 => ErrorCode::ParseError,
168 -32600 => ErrorCode::InvalidRequest,
169 -32601 => ErrorCode::MethodNotFound,
170 -32602 => ErrorCode::InvalidParams,
171 -32603 => ErrorCode::InternalError,
172 -32002 => ErrorCode::ServerNotInitialized,
173 -32001 => ErrorCode::UnknownErrorCode,
174 -32000 => ErrorCode::RequestFailed,
175 _ => ErrorCode::UnknownErrorCode,
176 },
177 &error.message,
178 ));
179 }
180 return response.result.ok_or_else(|| {
181 Error::protocol(ErrorCode::InternalError, "Response missing result")
182 });
183 }
184 }
185 }
186
187 Err(Error::protocol(
188 ErrorCode::InternalError,
189 "Connection closed while waiting for response",
190 ))
191 }
192
193 pub async fn subscribe(&self, uri: &str) -> Result<(), Error> {
194 let mut counter = self.request_counter.write().await;
195 *counter += 1;
196 let id = RequestId::Number(*counter);
197
198 let request = Request::new("resources/subscribe", Some(json!({"uri": uri})), id.clone());
199 self.sender.send(Message::Request(request)).map_err(|_| {
200 Error::Transport("failed to send subscribe request message".to_string())
201 })?;
202 Ok(())
203 }
204
205 pub async fn set_log_level(&self, level: LoggingLevel) -> Result<(), Error> {
206 let mut counter = self.request_counter.write().await;
207 *counter += 1;
208 let id = RequestId::Number(*counter);
209 let request = Request::new(
210 "logging/setLevel",
211 Some(json!({"level": level})),
212 id.clone(),
213 );
214 self.sender
215 .send(Message::Request(request))
216 .map_err(|_| Error::Transport("failed to set logging level".to_string()))?;
217 Ok(())
218 }
219
220 pub async fn notify(
222 &self,
223 method: &str,
224 params: Option<serde_json::Value>,
225 ) -> Result<(), Error> {
226 let path = format!("notifications/{method}");
227 let notification = Notification::new(path, params);
228 self.sender
229 .send(Message::Notification(notification))
230 .map_err(|_| Error::Transport("failed to send notification message".to_string()))
231 }
232
233 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
235 self.server_capabilities.read().await.clone()
236 }
237
238 pub async fn shutdown(&self) -> Result<(), Error> {
240 self.request("shutdown", None).await?;
242 self.notify("exit", None).await?;
244 Ok(())
245 }
246}