1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
17use super::types::ErrorCode;
18use anyhow::Result;
19use async_trait::async_trait;
20use serde::de::DeserializeOwned;
21use serde::Serialize;
22use serde_json::json;
23use std::pin::Pin;
24
25use std::sync::atomic::{AtomicU64, Ordering};
26use std::time::Duration;
27use std::{collections::HashMap, sync::Arc};
28use tokio::sync::{oneshot, Mutex};
29
30#[derive(Clone)]
36pub struct Protocol {
37 request_id: Arc<AtomicU64>,
38 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
39 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
40 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
41}
42
43impl Protocol {
44 pub fn builder() -> ProtocolBuilder {
50 ProtocolBuilder::new()
51 }
52
53 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
66 let handlers = self.request_handlers.lock().await;
67 if let Some(handler) = handlers.get(&request.method) {
68 match handler.handle(request.clone()).await {
69 Ok(response) => response,
70 Err(e) => JsonRpcResponse {
71 id: request.id,
72 result: None,
73 error: Some(JsonRpcError {
74 code: ErrorCode::InternalError as i32,
75 message: e.to_string(),
76 data: None,
77 }),
78 ..Default::default()
79 },
80 }
81 } else {
82 JsonRpcResponse {
83 id: request.id,
84 error: Some(JsonRpcError {
85 code: ErrorCode::MethodNotFound as i32,
86 message: format!("Method not found: {}", request.method),
87 data: None,
88 }),
89 ..Default::default()
90 }
91 }
92 }
93
94 pub async fn handle_notification(&self, request: JsonRpcNotification) {
103 let handlers = self.notification_handlers.lock().await;
104 if let Some(handler) = handlers.get(&request.method) {
105 match handler.handle(request.clone()).await {
106 Ok(_) => tracing::info!("Received notification: {:?}", request.method),
107 Err(e) => tracing::error!("Error handling notification: {}", e),
108 }
109 } else {
110 tracing::debug!("No handler for notification: {}", request.method);
111 }
112 }
113
114 pub fn new_message_id(&self) -> u64 {
120 self.request_id.fetch_add(1, Ordering::SeqCst)
121 }
122
123 pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
129 let id = self.new_message_id();
130 let (tx, rx) = oneshot::channel();
131
132 {
133 let mut pending = self.pending_requests.lock().await;
134 pending.insert(id, tx);
135 }
136
137 (id, rx)
138 }
139
140 pub async fn handle_response(&self, response: JsonRpcResponse) {
149 if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
150 let _ = tx.send(response);
151 }
152 }
153
154 pub async fn cancel_response(&self, id: u64) {
160 if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
161 let _ = tx.send(JsonRpcResponse {
162 id,
163 result: None,
164 error: Some(JsonRpcError {
165 code: ErrorCode::RequestTimeout as i32,
166 message: "Request cancelled".to_string(),
167 data: None,
168 }),
169 ..Default::default()
170 });
171 }
172 }
173}
174
175pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
177
178pub struct RequestOptions {
183 pub timeout: Duration,
185}
186
187impl RequestOptions {
188 pub fn timeout(self, timeout: Duration) -> Self {
198 Self { timeout }
199 }
200}
201
202impl Default for RequestOptions {
203 fn default() -> Self {
204 Self {
205 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
206 }
207 }
208}
209
210#[derive(Clone)]
215pub struct ProtocolBuilder {
216 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
217 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
218}
219
220impl ProtocolBuilder {
221 pub fn new() -> Self {
227 Self {
228 request_handlers: Arc::new(Mutex::new(HashMap::new())),
229 notification_handlers: Arc::new(Mutex::new(HashMap::new())),
230 }
231 }
232
233 pub fn request_handler<Req, Resp>(
244 self,
245 method: &str,
246 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
247 + Send
248 + Sync
249 + 'static,
250 ) -> Self
251 where
252 Req: DeserializeOwned + Send + Sync + 'static,
253 Resp: Serialize + Send + Sync + 'static,
254 {
255 let handler = TypedRequestHandler {
256 handler: Box::new(handler),
257 _phantom: std::marker::PhantomData,
258 };
259
260 if let Ok(mut handlers) = self.request_handlers.try_lock() {
261 handlers.insert(method.to_string(), Box::new(handler));
262 }
263 self
264 }
265
266 pub fn has_request_handler(&self, method: &str) -> bool {
276 self.request_handlers
277 .try_lock()
278 .map(|handlers| handlers.contains_key(method))
279 .unwrap_or(false)
280 }
281
282 pub fn notification_handler<N>(
293 self,
294 method: &str,
295 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
296 + Send
297 + Sync
298 + 'static,
299 ) -> Self
300 where
301 N: DeserializeOwned + Send + Sync + 'static,
302 {
303 let handler = TypedNotificationHandler {
304 handler: Box::new(handler),
305 _phantom: std::marker::PhantomData,
306 };
307
308 if let Ok(mut handlers) = self.notification_handlers.try_lock() {
309 handlers.insert(method.to_string(), Box::new(handler));
310 }
311 self
312 }
313
314 pub fn has_notification_handler(&self, method: &str) -> bool {
324 self.notification_handlers
325 .try_lock()
326 .map(|handlers| handlers.contains_key(method))
327 .unwrap_or(false)
328 }
329
330 pub fn build(self) -> Protocol {
336 Protocol {
337 request_id: Arc::new(AtomicU64::new(0)),
338 pending_requests: Arc::new(Mutex::new(HashMap::new())),
339 request_handlers: self.request_handlers,
340 notification_handlers: self.notification_handlers,
341 }
342 }
343}
344
345#[async_trait]
350trait RequestHandler: Send + Sync {
351 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
361}
362
363#[async_trait]
367trait NotificationHandler: Send + Sync {
368 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
378}
379
380struct TypedRequestHandler<Req, Resp>
385where
386 Req: DeserializeOwned + Send + Sync + 'static,
387 Resp: Serialize + Send + Sync + 'static,
388{
389 handler: Box<
390 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
391 + Send
392 + Sync,
393 >,
394 _phantom: std::marker::PhantomData<(Req, Resp)>,
395}
396
397#[async_trait]
398impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
399where
400 Req: DeserializeOwned + Send + Sync + 'static,
401 Resp: Serialize + Send + Sync + 'static,
402{
403 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
404 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
405 {
406 serde_json::from_value(json!({}))?
407 } else {
408 serde_json::from_value(request.params.unwrap())?
409 };
410 let result = (self.handler)(params).await?;
411 Ok(JsonRpcResponse {
412 id: request.id,
413 result: Some(serde_json::to_value(result)?),
414 error: None,
415 ..Default::default()
416 })
417 }
418}
419
420struct TypedNotificationHandler<N>
425where
426 N: DeserializeOwned + Send + Sync + 'static,
427{
428 handler: Box<
429 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
430 + Send
431 + Sync,
432 >,
433 _phantom: std::marker::PhantomData<N>,
434}
435
436#[async_trait]
437impl<N> NotificationHandler for TypedNotificationHandler<N>
438where
439 N: DeserializeOwned + Send + Sync + 'static,
440{
441 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
442 let params: N =
443 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
444 serde_json::from_value(serde_json::Value::Null)?
445 } else {
446 serde_json::from_value(notification.params.unwrap())?
447 };
448 (self.handler)(params).await
449 }
450}