1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
2use super::types::ErrorCode;
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::json;
8use std::pin::Pin;
9
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12use std::{collections::HashMap, sync::Arc};
13use tokio::sync::{oneshot, Mutex};
14
15#[derive(Clone)]
16pub struct Protocol {
17 request_id: Arc<AtomicU64>,
18 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
19 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
20 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
21}
22
23impl Protocol {
24 pub fn builder() -> ProtocolBuilder {
25 ProtocolBuilder::new()
26 }
27
28 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
29 let handlers = self.request_handlers.lock().await;
30 if let Some(handler) = handlers.get(&request.method) {
31 match handler.handle(request.clone()).await {
32 Ok(response) => response,
33 Err(e) => JsonRpcResponse {
34 id: request.id,
35 result: None,
36 error: Some(JsonRpcError {
37 code: ErrorCode::InternalError as i32,
38 message: e.to_string(),
39 data: None,
40 }),
41 ..Default::default()
42 },
43 }
44 } else {
45 JsonRpcResponse {
46 id: request.id,
47 error: Some(JsonRpcError {
48 code: ErrorCode::MethodNotFound as i32,
49 message: format!("Method not found: {}", request.method),
50 data: None,
51 }),
52 ..Default::default()
53 }
54 }
55 }
56
57 pub async fn handle_notification(&self, request: JsonRpcNotification) {
58 let handlers = self.notification_handlers.lock().await;
59 if let Some(handler) = handlers.get(&request.method) {
60 match handler.handle(request.clone()).await {
61 Ok(_) => tracing::info!("Received notification: {:?}", request.method),
62 Err(e) => tracing::error!("Error handling notification: {}", e),
63 }
64 } else {
65 tracing::debug!("No handler for notification: {}", request.method);
66 }
67 }
68
69 pub fn new_message_id(&self) -> u64 {
70 self.request_id.fetch_add(1, Ordering::SeqCst)
71 }
72
73 pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
74 let id = self.new_message_id();
75 let (tx, rx) = oneshot::channel();
76
77 {
78 let mut pending = self.pending_requests.lock().await;
79 pending.insert(id, tx);
80 }
81
82 (id, rx)
83 }
84
85 pub async fn handle_response(&self, response: JsonRpcResponse) {
86 if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
87 let _ = tx.send(response);
88 }
89 }
90
91 pub async fn cancel_response(&self, id: u64) {
92 if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
93 let _ = tx.send(JsonRpcResponse {
94 id,
95 result: None,
96 error: Some(JsonRpcError {
97 code: ErrorCode::RequestTimeout as i32,
98 message: "Request cancelled".to_string(),
99 data: None,
100 }),
101 ..Default::default()
102 });
103 }
104 }
105}
106
107pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
109pub struct RequestOptions {
110 pub timeout: Duration,
111}
112
113impl RequestOptions {
114 pub fn timeout(self, timeout: Duration) -> Self {
115 Self { timeout }
116 }
117}
118
119impl Default for RequestOptions {
120 fn default() -> Self {
121 Self {
122 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
123 }
124 }
125}
126
127#[derive(Clone)]
128pub struct ProtocolBuilder {
129 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
130 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
131}
132
133impl ProtocolBuilder {
134 pub fn new() -> Self {
135 Self {
136 request_handlers: Arc::new(Mutex::new(HashMap::new())),
137 notification_handlers: Arc::new(Mutex::new(HashMap::new())),
138 }
139 }
140
141 pub fn request_handler<Req, Resp>(
143 self,
144 method: &str,
145 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
146 + Send
147 + Sync
148 + 'static,
149 ) -> Self
150 where
151 Req: DeserializeOwned + Send + Sync + 'static,
152 Resp: Serialize + Send + Sync + 'static,
153 {
154 let handler = TypedRequestHandler {
155 handler: Box::new(handler),
156 _phantom: std::marker::PhantomData,
157 };
158
159 if let Ok(mut handlers) = self.request_handlers.try_lock() {
160 handlers.insert(method.to_string(), Box::new(handler));
161 }
162 self
163 }
164
165 pub fn has_request_handler(&self, method: &str) -> bool {
166 self.request_handlers
167 .try_lock()
168 .map(|handlers| handlers.contains_key(method))
169 .unwrap_or(false)
170 }
171
172 pub fn notification_handler<N>(
173 self,
174 method: &str,
175 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
176 + Send
177 + Sync
178 + 'static,
179 ) -> Self
180 where
181 N: DeserializeOwned + Send + Sync + 'static,
182 {
183 let handler = TypedNotificationHandler {
184 handler: Box::new(handler),
185 _phantom: std::marker::PhantomData,
186 };
187
188 if let Ok(mut handlers) = self.notification_handlers.try_lock() {
189 handlers.insert(method.to_string(), Box::new(handler));
190 }
191 self
192 }
193
194 pub fn has_notification_handler(&self, method: &str) -> bool {
195 self.notification_handlers
196 .try_lock()
197 .map(|handlers| handlers.contains_key(method))
198 .unwrap_or(false)
199 }
200
201 pub fn build(self) -> Protocol {
202 Protocol {
203 request_id: Arc::new(AtomicU64::new(0)),
204 pending_requests: Arc::new(Mutex::new(HashMap::new())),
205 request_handlers: self.request_handlers,
206 notification_handlers: self.notification_handlers,
207 }
208 }
209}
210
211#[async_trait]
213trait RequestHandler: Send + Sync {
214 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
215}
216
217#[async_trait]
218trait NotificationHandler: Send + Sync {
219 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
220}
221
222struct TypedRequestHandler<Req, Resp>
224where
225 Req: DeserializeOwned + Send + Sync + 'static,
226 Resp: Serialize + Send + Sync + 'static,
227{
228 handler: Box<
229 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
230 + Send
231 + Sync,
232 >,
233 _phantom: std::marker::PhantomData<(Req, Resp)>,
234}
235
236#[async_trait]
237impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
238where
239 Req: DeserializeOwned + Send + Sync + 'static,
240 Resp: Serialize + Send + Sync + 'static,
241{
242 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
243 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
244 {
245 serde_json::from_value(json!({}))?
246 } else {
247 serde_json::from_value(request.params.unwrap())?
248 };
249 let result = (self.handler)(params).await?;
250 Ok(JsonRpcResponse {
251 id: request.id,
252 result: Some(serde_json::to_value(result)?),
253 error: None,
254 ..Default::default()
255 })
256 }
257}
258
259struct TypedNotificationHandler<N>
260where
261 N: DeserializeOwned + Send + Sync + 'static,
262{
263 handler: Box<
264 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
265 + Send
266 + Sync,
267 >,
268 _phantom: std::marker::PhantomData<N>,
269}
270
271#[async_trait]
272impl<N> NotificationHandler for TypedNotificationHandler<N>
273where
274 N: DeserializeOwned + Send + Sync + 'static,
275{
276 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
277 let params: N =
278 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
279 serde_json::from_value(serde_json::Value::Null)?
280 } else {
281 serde_json::from_value(notification.params.unwrap())?
282 };
283 (self.handler)(params).await
284 }
285}