1use super::transport::{JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
2use super::types::ErrorCode;
3use anyhow::Result;
4use async_trait::async_trait;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use std::pin::Pin;
8
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::time::Duration;
11use std::{collections::HashMap, sync::Arc};
12use tokio::sync::{oneshot, Mutex};
13
14#[derive(Clone)]
15pub struct Protocol {
16 request_id: Arc<AtomicU64>,
17 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
18 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
19 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
20}
21
22impl Protocol {
23 pub fn builder() -> ProtocolBuilder {
24 ProtocolBuilder::new()
25 }
26
27 pub async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
28 let handlers = self.request_handlers.lock().await;
29 if let Some(handler) = handlers.get(&request.method) {
30 match handler.handle(request.clone()).await {
31 Ok(response) => response,
32 Err(e) => JsonRpcResponse {
33 id: request.id,
34 result: None,
35 error: Some(JsonRpcError {
36 code: ErrorCode::InternalError as i32,
37 message: e.to_string(),
38 data: None,
39 }),
40 ..Default::default()
41 },
42 }
43 } else {
44 JsonRpcResponse {
45 id: request.id,
46 error: Some(JsonRpcError {
47 code: ErrorCode::MethodNotFound as i32,
48 message: format!("Method not found: {}", request.method),
49 data: None,
50 }),
51 ..Default::default()
52 }
53 }
54 }
55
56 pub async fn handle_notification(&self, request: JsonRpcNotification) {
57 let handlers = self.notification_handlers.lock().await;
58 if let Some(handler) = handlers.get(&request.method) {
59 match handler.handle(request.clone()).await {
60 Ok(_) => tracing::info!("Received notification: {:?}", request.method),
61 Err(e) => tracing::error!("Error handling notification: {}", e),
62 }
63 } else {
64 tracing::debug!("No handler for notification: {}", request.method);
65 }
66 }
67
68 pub async fn create_request(&self) -> (u64, oneshot::Receiver<JsonRpcResponse>) {
69 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
70 let (tx, rx) = oneshot::channel();
71
72 {
73 let mut pending = self.pending_requests.lock().await;
74 pending.insert(id, tx);
75 }
76
77 (id, rx)
78 }
79
80 pub async fn handle_response(&self, response: JsonRpcResponse) {
81 if let Some(tx) = self.pending_requests.lock().await.remove(&response.id) {
82 let _ = tx.send(response);
83 }
84 }
85
86 pub async fn cancel_response(&self, id: u64) {
87 if let Some(tx) = self.pending_requests.lock().await.remove(&id) {
88 let _ = tx.send(JsonRpcResponse {
89 id,
90 result: None,
91 error: Some(JsonRpcError {
92 code: ErrorCode::RequestTimeout as i32,
93 message: "Request cancelled".to_string(),
94 data: None,
95 }),
96 ..Default::default()
97 });
98 }
99 }
100}
101
102pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
104pub struct RequestOptions {
105 pub timeout: Duration,
106}
107
108impl RequestOptions {
109 pub fn timeout(self, timeout: Duration) -> Self {
110 Self { timeout }
111 }
112}
113
114impl Default for RequestOptions {
115 fn default() -> Self {
116 Self {
117 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
118 }
119 }
120}
121
122#[derive(Clone)]
123pub struct ProtocolBuilder {
124 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
125 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
126}
127
128impl ProtocolBuilder {
129 pub fn new() -> Self {
130 Self {
131 request_handlers: Arc::new(Mutex::new(HashMap::new())),
132 notification_handlers: Arc::new(Mutex::new(HashMap::new())),
133 }
134 }
135
136 pub fn request_handler<Req, Resp>(
138 self,
139 method: &str,
140 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
141 + Send
142 + Sync
143 + 'static,
144 ) -> Self
145 where
146 Req: DeserializeOwned + Send + Sync + 'static,
147 Resp: Serialize + Send + Sync + 'static,
148 {
149 let handler = TypedRequestHandler {
150 handler: Box::new(handler),
151 _phantom: std::marker::PhantomData,
152 };
153
154 if let Ok(mut handlers) = self.request_handlers.try_lock() {
155 handlers.insert(method.to_string(), Box::new(handler));
156 }
157 self
158 }
159
160 pub fn has_request_handler(&self, method: &str) -> bool {
161 self.request_handlers
162 .try_lock()
163 .map(|handlers| handlers.contains_key(method))
164 .unwrap_or(false)
165 }
166
167 pub fn notification_handler<N>(
168 self,
169 method: &str,
170 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
171 + Send
172 + Sync
173 + 'static,
174 ) -> Self
175 where
176 N: DeserializeOwned + Send + Sync + 'static,
177 {
178 let handler = TypedNotificationHandler {
179 handler: Box::new(handler),
180 _phantom: std::marker::PhantomData,
181 };
182
183 if let Ok(mut handlers) = self.notification_handlers.try_lock() {
184 handlers.insert(method.to_string(), Box::new(handler));
185 }
186 self
187 }
188
189 pub fn has_notification_handler(&self, method: &str) -> bool {
190 self.notification_handlers
191 .try_lock()
192 .map(|handlers| handlers.contains_key(method))
193 .unwrap_or(false)
194 }
195
196 pub fn build(self) -> Protocol {
197 Protocol {
198 request_id: Arc::new(AtomicU64::new(0)),
199 pending_requests: Arc::new(Mutex::new(HashMap::new())),
200 request_handlers: self.request_handlers,
201 notification_handlers: self.notification_handlers,
202 }
203 }
204}
205
206#[async_trait]
208trait RequestHandler: Send + Sync {
209 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
210}
211
212#[async_trait]
213trait NotificationHandler: Send + Sync {
214 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
215}
216
217struct TypedRequestHandler<Req, Resp>
219where
220 Req: DeserializeOwned + Send + Sync + 'static,
221 Resp: Serialize + Send + Sync + 'static,
222{
223 handler: Box<
224 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
225 + Send
226 + Sync,
227 >,
228 _phantom: std::marker::PhantomData<(Req, Resp)>,
229}
230
231#[async_trait]
232impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
233where
234 Req: DeserializeOwned + Send + Sync + 'static,
235 Resp: Serialize + Send + Sync + 'static,
236{
237 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
238 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
239 {
240 serde_json::from_value(serde_json::Value::Null)?
241 } else {
242 serde_json::from_value(request.params.unwrap())?
243 };
244 let result = (self.handler)(params).await?;
245 Ok(JsonRpcResponse {
246 id: request.id,
247 result: Some(serde_json::to_value(result)?),
248 error: None,
249 ..Default::default()
250 })
251 }
252}
253
254struct TypedNotificationHandler<N>
255where
256 N: DeserializeOwned + Send + Sync + 'static,
257{
258 handler: Box<
259 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
260 + Send
261 + Sync,
262 >,
263 _phantom: std::marker::PhantomData<N>,
264}
265
266#[async_trait]
267impl<N> NotificationHandler for TypedNotificationHandler<N>
268where
269 N: DeserializeOwned + Send + Sync + 'static,
270{
271 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
272 let params: N =
273 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
274 serde_json::from_value(serde_json::Value::Null)?
275 } else {
276 serde_json::from_value(notification.params.unwrap())?
277 };
278 (self.handler)(params).await
279 }
280}