1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
2use super::types::ErrorCode;
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::json;
8use std::pin::Pin;
9
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12use std::{collections::HashMap, sync::Arc};
13use tokio::sync::{oneshot, Mutex};
14
15#[derive(Clone)]
16pub struct Protocol {
17 request_id: Arc<AtomicU64>,
18 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
19 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
20 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
21}
22
23impl Protocol {
24 pub fn builder() -> ProtocolBuilder {
25 ProtocolBuilder::new()
26 }
27
28 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
29 let handlers = self.request_handlers.lock().await;
30 if let Some(handler) = handlers.get(&request.method) {
31 match handler.handle(request.clone()).await {
32 Ok(response) => response,
33 Err(e) => JsonRpcResponse {
34 id: request.id,
35 result: None,
36 error: Some(JsonRpcError {
37 code: ErrorCode::InternalError as i32,
38 message: e.to_string(),
39 data: None,
40 }),
41 ..Default::default()
42 },
43 }
44 } else {
45 JsonRpcResponse {
46 id: request.id,
47 error: Some(JsonRpcError {
48 code: ErrorCode::MethodNotFound as i32,
49 message: format!("Method not found: {}", request.method),
50 data: None,
51 }),
52 ..Default::default()
53 }
54 }
55 }
56
57 pub async fn handle_notification(&self, request: JsonRpcNotification) {
58 let handlers = self.notification_handlers.lock().await;
59 if let Some(handler) = handlers.get(&request.method) {
60 match handler.handle(request.clone()).await {
61 Ok(_) => tracing::info!("Received notification: {:?}", request.method),
62 Err(e) => tracing::error!("Error handling notification: {}", e),
63 }
64 } else {
65 tracing::debug!("No handler for notification: {}", request.method);
66 }
67 }
68
69 pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
70 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
71 let (tx, rx) = oneshot::channel();
72
73 {
74 let mut pending = self.pending_requests.lock().await;
75 pending.insert(id, tx);
76 }
77
78 (id, rx)
79 }
80
81 pub async fn handle_response(&self, response: JsonRpcResponse) {
82 if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
83 let _ = tx.send(response);
84 }
85 }
86
87 pub async fn cancel_response(&self, id: u64) {
88 if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
89 let _ = tx.send(JsonRpcResponse {
90 id,
91 result: None,
92 error: Some(JsonRpcError {
93 code: ErrorCode::RequestTimeout as i32,
94 message: "Request cancelled".to_string(),
95 data: None,
96 }),
97 ..Default::default()
98 });
99 }
100 }
101}
102
103pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
105pub struct RequestOptions {
106 pub timeout: Duration,
107}
108
109impl RequestOptions {
110 pub fn timeout(self, timeout: Duration) -> Self {
111 Self { timeout }
112 }
113}
114
115impl Default for RequestOptions {
116 fn default() -> Self {
117 Self {
118 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
119 }
120 }
121}
122
123#[derive(Clone)]
124pub struct ProtocolBuilder {
125 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
126 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
127}
128
129impl ProtocolBuilder {
130 pub fn new() -> Self {
131 Self {
132 request_handlers: Arc::new(Mutex::new(HashMap::new())),
133 notification_handlers: Arc::new(Mutex::new(HashMap::new())),
134 }
135 }
136
137 pub fn request_handler<Req, Resp>(
139 self,
140 method: &str,
141 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
142 + Send
143 + Sync
144 + 'static,
145 ) -> Self
146 where
147 Req: DeserializeOwned + Send + Sync + 'static,
148 Resp: Serialize + Send + Sync + 'static,
149 {
150 let handler = TypedRequestHandler {
151 handler: Box::new(handler),
152 _phantom: std::marker::PhantomData,
153 };
154
155 if let Ok(mut handlers) = self.request_handlers.try_lock() {
156 handlers.insert(method.to_string(), Box::new(handler));
157 }
158 self
159 }
160
161 pub fn has_request_handler(&self, method: &str) -> bool {
162 self.request_handlers
163 .try_lock()
164 .map(|handlers| handlers.contains_key(method))
165 .unwrap_or(false)
166 }
167
168 pub fn notification_handler<N>(
169 self,
170 method: &str,
171 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
172 + Send
173 + Sync
174 + 'static,
175 ) -> Self
176 where
177 N: DeserializeOwned + Send + Sync + 'static,
178 {
179 let handler = TypedNotificationHandler {
180 handler: Box::new(handler),
181 _phantom: std::marker::PhantomData,
182 };
183
184 if let Ok(mut handlers) = self.notification_handlers.try_lock() {
185 handlers.insert(method.to_string(), Box::new(handler));
186 }
187 self
188 }
189
190 pub fn has_notification_handler(&self, method: &str) -> bool {
191 self.notification_handlers
192 .try_lock()
193 .map(|handlers| handlers.contains_key(method))
194 .unwrap_or(false)
195 }
196
197 pub fn build(self) -> Protocol {
198 Protocol {
199 request_id: Arc::new(AtomicU64::new(0)),
200 pending_requests: Arc::new(Mutex::new(HashMap::new())),
201 request_handlers: self.request_handlers,
202 notification_handlers: self.notification_handlers,
203 }
204 }
205}
206
207#[async_trait]
209trait RequestHandler: Send + Sync {
210 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
211}
212
213#[async_trait]
214trait NotificationHandler: Send + Sync {
215 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
216}
217
218struct TypedRequestHandler<Req, Resp>
220where
221 Req: DeserializeOwned + Send + Sync + 'static,
222 Resp: Serialize + Send + Sync + 'static,
223{
224 handler: Box<
225 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
226 + Send
227 + Sync,
228 >,
229 _phantom: std::marker::PhantomData<(Req, Resp)>,
230}
231
232#[async_trait]
233impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
234where
235 Req: DeserializeOwned + Send + Sync + 'static,
236 Resp: Serialize + Send + Sync + 'static,
237{
238 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
239 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
240 {
241 serde_json::from_value(json!({}))?
242 } else {
243 serde_json::from_value(request.params.unwrap())?
244 };
245 let result = (self.handler)(params).await?;
246 Ok(JsonRpcResponse {
247 id: request.id,
248 result: Some(serde_json::to_value(result)?),
249 error: None,
250 ..Default::default()
251 })
252 }
253}
254
255struct TypedNotificationHandler<N>
256where
257 N: DeserializeOwned + Send + Sync + 'static,
258{
259 handler: Box<
260 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
261 + Send
262 + Sync,
263 >,
264 _phantom: std::marker::PhantomData<N>,
265}
266
267#[async_trait]
268impl<N> NotificationHandler for TypedNotificationHandler<N>
269where
270 N: DeserializeOwned + Send + Sync + 'static,
271{
272 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
273 let params: N =
274 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
275 serde_json::from_value(serde_json::Value::Null)?
276 } else {
277 serde_json::from_value(notification.params.unwrap())?
278 };
279 (self.handler)(params).await
280 }
281}