1use crate::{
2 error::{Error, ErrorCode},
3 protocol::{Notification, Request, RequestId, Response},
4 transport::{Message, Transport},
5 types::{ClientCapabilities, Implementation, LoggingLevel, LoggingMessage, ServerCapabilities},
6};
7use async_trait::async_trait;
8use futures::StreamExt;
9use serde_json::{json, Value};
10use std::{collections::HashMap, sync::Arc};
11use tokio::sync::{
12 mpsc::{UnboundedReceiver, UnboundedSender},
13 Mutex, RwLock,
14};
15
16#[derive(Clone)]
17pub struct Session {
18 handler: Option<Arc<dyn ClientHandler>>,
19 transport: Arc<dyn Transport>,
20 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
21 sender: Arc<UnboundedSender<Message>>,
22}
23impl Session {
24 pub fn new(
26 transport: Arc<dyn Transport>,
27 sender: UnboundedSender<Message>,
28 receiver: UnboundedReceiver<Message>,
29 handler: Option<Arc<dyn ClientHandler>>,
30 ) -> Self {
31 Self {
32 handler,
33 transport,
34 sender: Arc::new(sender),
35 receiver: Arc::new(Mutex::new(receiver)),
36 }
37 }
38
39 pub async fn start(self) -> Result<(), Error> {
41 let transport = self.transport.clone();
42 let handler = self.handler.unwrap_or(Arc::new(DefaultClientHandler));
43 tokio::spawn(async move {
45 let mut stream = transport.receive();
46 while let Some(result) = stream.next().await {
47 match result {
48 Ok(message) => match &message {
49 Message::Request(r) => {
50 let res = handler
51 .handle_request(r.method.clone(), r.params.clone())
52 .await;
53 if transport
54 .send(Message::Response(Response::success(
55 r.id.clone(),
56 Some(res.unwrap()),
57 )))
58 .await
59 .is_err()
60 {
61 break;
62 }
63 }
64 Message::Response(_) => {
65 if self.sender.send(message).is_err() {
66 break;
67 }
68 }
69 Message::Notification(n) => {
70 if handler
71 .handle_notification(n.method.clone(), n.params.clone())
72 .await
73 .is_err()
74 {
75 break;
76 }
77 }
78 },
79 Err(_) => break,
80 }
81 }
82 });
83 let rx_clone = self.receiver.clone();
85 let tx_clone = self.transport.clone();
86 tokio::spawn(async move {
87 let mut stream = rx_clone.lock().await;
88 while let Some(message) = stream.recv().await {
89 tx_clone.send(message).await.unwrap();
90 }
91 });
92 Ok(())
93 }
94}
95
96#[async_trait]
98pub trait ClientHandler: Send + Sync {
99 async fn shutdown(&self) -> Result<(), Error>;
101
102 async fn handle_request(
104 &self,
105 method: String,
106 params: Option<serde_json::Value>,
107 ) -> Result<serde_json::Value, Error>;
108
109 async fn handle_notification(
111 &self,
112 method: String,
113 params: Option<serde_json::Value>,
114 ) -> Result<(), Error>;
115}
116
117#[derive(Clone, Default)]
118pub struct DefaultClientHandler;
119#[async_trait]
120impl ClientHandler for DefaultClientHandler {
121 async fn handle_request(&self, method: String, _params: Option<Value>) -> Result<Value, Error> {
123 match method.as_str() {
124 "sampling/createMessage" => {
125 log::debug!("Got sampling/createMessage");
126 Ok(json!({}))
127 }
128 _ => Err(Error::Other("unknown method".to_string())),
129 }
130 }
131
132 async fn handle_notification(
134 &self,
135 method: String,
136 params: Option<Value>,
137 ) -> Result<(), Error> {
138 match method.as_str() {
139 "notifications/message" => {
140 if let Some(p) = params {
142 let message: LoggingMessage = serde_json::from_value(p)?;
143 log::log!(message.level.into(), "{}", message.data);
144 }
145 Ok(())
146 }
147 "notifications/resources/updated" => {
148 if let Some(p) = params {
149 let update_params: HashMap<String, Value> = serde_json::from_value(p)?;
150 if let Some(uri_val) = update_params.get("uri") {
151 let uri = uri_val.as_str().ok_or("some file").unwrap();
152 log::debug!("resource updated: {uri}");
153 }
154 }
155 Ok(())
156 }
157 _ => Err(Error::Other("unknown notification".to_string())),
158 }
159 }
160
161 async fn shutdown(&self) -> Result<(), Error> {
162 log::debug!("Client shutting down");
163 Ok(())
164 }
165}
166
167#[derive(Clone)]
169pub struct Client {
170 server_capabilities: Arc<RwLock<Option<ServerCapabilities>>>,
171 request_counter: Arc<RwLock<i64>>,
172 #[allow(dead_code)]
173 sender: Arc<UnboundedSender<Message>>,
174 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
175}
176
177impl Client {
178 pub fn new(sender: UnboundedSender<Message>, receiver: UnboundedReceiver<Message>) -> Self {
180 Self {
181 server_capabilities: Arc::new(RwLock::new(None)),
182 request_counter: Arc::new(RwLock::new(0)),
183 sender: Arc::new(sender),
184 receiver: Arc::new(Mutex::new(receiver)),
185 }
186 }
187
188 pub async fn initialize(
190 &self,
191 implementation: Implementation,
192 capabilities: Option<ClientCapabilities>,
193 ) -> Result<ServerCapabilities, Error> {
194 let params = serde_json::json!({
195 "clientInfo": implementation,
196 "capabilities": capabilities.unwrap_or_default(),
197 "protocolVersion": crate::LATEST_PROTOCOL_VERSION,
198 });
199 log::debug!("initializing client with capabilities: {}", params);
200 let response = self.request("initialize", Some(params)).await?;
201 let mut caps = ServerCapabilities::default();
202 if let Some(resp_obj) = response.as_object() {
203 if let Some(protocol_version) = resp_obj.get("protocolVersion") {
204 if let Some(v) = protocol_version.as_str() {
205 if v != crate::LATEST_PROTOCOL_VERSION {
206 log::error!("incorrect protocol version");
207 self.shutdown().await?;
208 return Err(Error::Other("incorrect protocol version".to_string()));
209 }
210 }
211 }
212 if let Some(server_caps) = resp_obj.get("capabilities") {
213 caps = serde_json::from_value(server_caps.clone())?;
214 }
215 }
216 *self.server_capabilities.write().await = Some(caps.clone());
217 self.notify("initialized", None).await?;
219 Ok(caps)
220 }
221
222 pub async fn request(
227 &self,
228 method: &str,
229 params: Option<serde_json::Value>,
230 ) -> Result<serde_json::Value, Error> {
231 let mut counter = self.request_counter.write().await;
232 *counter += 1;
233 let id = RequestId::Number(*counter);
234
235 let request = Request::new(method, params, id.clone());
236 self.sender
237 .send(Message::Request(request))
238 .map_err(|_| Error::Transport("failed to send request message".to_string()))?;
239 let mut receiver = self.receiver.lock().await;
241 while let Some(message) = receiver.recv().await {
242 if let Message::Response(response) = message {
243 if response.id == id {
244 if let Some(error) = response.error {
245 return Err(Error::protocol(
246 match error.code {
247 -32700 => ErrorCode::ParseError,
248 -32600 => ErrorCode::InvalidRequest,
249 -32601 => ErrorCode::MethodNotFound,
250 -32602 => ErrorCode::InvalidParams,
251 -32603 => ErrorCode::InternalError,
252 -32002 => ErrorCode::ServerNotInitialized,
253 -32001 => ErrorCode::UnknownErrorCode,
254 -32000 => ErrorCode::RequestFailed,
255 _ => ErrorCode::UnknownErrorCode,
256 },
257 &error.message,
258 ));
259 }
260 return response.result.ok_or_else(|| {
261 Error::protocol(ErrorCode::InternalError, "Response missing result")
262 });
263 }
264 }
265 }
266
267 Err(Error::protocol(
268 ErrorCode::InternalError,
269 "Connection closed while waiting for response",
270 ))
271 }
272
273 pub async fn subscribe(&self, uri: &str) -> Result<(), Error> {
274 let mut counter = self.request_counter.write().await;
275 *counter += 1;
276 let id = RequestId::Number(*counter);
277
278 let request = Request::new("resources/subscribe", Some(json!({"uri": uri})), id.clone());
279 self.sender.send(Message::Request(request)).map_err(|_| {
280 Error::Transport("failed to send subscribe request message".to_string())
281 })?;
282 Ok(())
283 }
284
285 pub async fn set_log_level(&self, level: LoggingLevel) -> Result<(), Error> {
286 let mut counter = self.request_counter.write().await;
287 *counter += 1;
288 let id = RequestId::Number(*counter);
289 let request = Request::new(
290 "logging/setLevel",
291 Some(json!({"level": level})),
292 id.clone(),
293 );
294 self.sender
295 .send(Message::Request(request))
296 .map_err(|_| Error::Transport("failed to set logging level".to_string()))?;
297 Ok(())
298 }
299
300 pub async fn notify(
302 &self,
303 method: &str,
304 params: Option<serde_json::Value>,
305 ) -> Result<(), Error> {
306 let notification = Notification::new(method, params);
307 self.sender
308 .send(Message::Notification(notification))
309 .map_err(|_| Error::Transport("failed to send notification message".to_string()))
310 }
311
312 pub async fn capabilities(&self) -> Option<ServerCapabilities> {
314 self.server_capabilities.read().await.clone()
315 }
316
317 pub async fn shutdown(&self) -> Result<(), Error> {
319 self.request("shutdown", None).await?;
321 self.notify("exit", None).await?;
323 Ok(())
324 }
325}