1use super::transport::{
2 JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Transport,
3};
4use super::types::ErrorCode;
5use anyhow::anyhow;
6use anyhow::Result;
7use async_trait::async_trait;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use std::pin::Pin;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13use std::{
14 collections::HashMap,
15 sync::{atomic::AtomicU64, Arc},
16};
17use tokio::sync::oneshot;
18use tokio::sync::Mutex;
19use tokio::time::timeout;
20use tracing::debug;
21
22#[derive(Clone)]
23pub struct Protocol<T: Transport> {
24 transport: Arc<T>,
25
26 request_id: Arc<AtomicU64>,
27 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
29 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
30}
31
32impl<T: Transport> Protocol<T> {
33 pub fn builder(transport: T) -> ProtocolBuilder<T> {
34 ProtocolBuilder::new(transport)
35 }
36
37 pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
38 let notification = JsonRpcNotification {
39 method: method.to_string(),
40 params,
41 ..Default::default()
42 };
43 let msg = JsonRpcMessage::Notification(notification);
44 self.transport.send(&msg).await?;
45 Ok(())
46 }
47
48 pub async fn request(
49 &self,
50 method: &str,
51 params: Option<serde_json::Value>,
52 options: RequestOptions,
53 ) -> Result<JsonRpcResponse> {
54 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
55
56 let (tx, rx) = oneshot::channel();
58
59 {
61 let mut pending = self.pending_requests.lock().await;
62 pending.insert(id, tx);
63 }
64
65 let msg = JsonRpcMessage::Request(JsonRpcRequest {
67 id,
68 method: method.to_string(),
69 params,
70 ..Default::default()
71 });
72 self.transport.send(&msg).await?;
73
74 match timeout(options.timeout, rx)
76 .await
77 .map_err(|_| anyhow!("Request timed out"))?
78 {
79 Ok(response) => Ok(response),
80 Err(_) => {
81 let mut pending = self.pending_requests.lock().await;
83 pending.remove(&id);
84 Err(anyhow!("Request cancelled"))
85 }
86 }
87 }
88
89 pub async fn listen(&self) -> Result<()> {
90 debug!("Listening for requests");
91 loop {
92 let message = self.transport.receive().await;
93
94 let message = match message {
95 Ok(msg) => msg,
96 Err(e) => {
97 tracing::error!("Failed to parse message: {:?}", e);
98 continue;
99 }
100 };
101
102 if message.is_none() {
104 break;
105 }
106
107 match message.unwrap() {
108 JsonRpcMessage::Request(request) => self.handle_request(request).await?,
109 JsonRpcMessage::Response(response) => {
110 let id = response.id;
111 let mut pending = self.pending_requests.lock().await;
112 if let Some(tx) = pending.remove(&id) {
113 let _ = tx.send(response);
114 }
115 }
116 JsonRpcMessage::Notification(notification) => {
117 let handlers = self.notification_handlers.lock().await;
118 if let Some(handler) = handlers.get(¬ification.method) {
119 handler.handle(notification).await?;
120 }
121 }
122 }
123 }
124 Ok(())
125 }
126
127 async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
128 let handlers = self.request_handlers.lock().await;
129 if let Some(handler) = handlers.get(&request.method) {
130 match handler.handle(request.clone()).await {
131 Ok(response) => {
132 let msg = JsonRpcMessage::Response(response);
133 self.transport.send(&msg).await?;
134 }
135 Err(e) => {
136 let error_response = JsonRpcResponse {
137 id: request.id,
138 result: None,
139 error: Some(JsonRpcError {
140 code: ErrorCode::InternalError as i32,
141 message: e.to_string(),
142 data: None,
143 }),
144 ..Default::default()
145 };
146 let msg = JsonRpcMessage::Response(error_response);
147 self.transport.send(&msg).await?;
148 }
149 }
150 } else {
151 self.transport
152 .send(&JsonRpcMessage::Response(JsonRpcResponse {
153 id: request.id,
154 error: Some(JsonRpcError {
155 code: ErrorCode::MethodNotFound as i32,
156 message: format!("Method not found: {}", request.method),
157 data: None,
158 }),
159 ..Default::default()
160 }))
161 .await?;
162 }
163 Ok(())
164 }
165}
166
167pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
169pub struct RequestOptions {
170 timeout: Duration,
171}
172
173impl RequestOptions {
174 pub fn timeout(self, timeout: Duration) -> Self {
175 Self { timeout }
176 }
177}
178
179impl Default for RequestOptions {
180 fn default() -> Self {
181 Self {
182 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
183 }
184 }
185}
186
187pub struct ProtocolBuilder<T: Transport> {
188 transport: T,
189 request_handlers: HashMap<String, Box<dyn RequestHandler>>,
190 notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
191}
192impl<T: Transport> ProtocolBuilder<T> {
193 pub fn new(transport: T) -> Self {
194 Self {
195 transport,
196 request_handlers: HashMap::new(),
197 notification_handlers: HashMap::new(),
198 }
199 }
200 pub fn request_handler<Req, Resp>(
202 mut self,
203 method: &str,
204 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
205 + Send
206 + Sync
207 + 'static,
208 ) -> Self
209 where
210 Req: DeserializeOwned + Send + Sync + 'static,
211 Resp: Serialize + Send + Sync + 'static,
212 {
213 let handler = TypedRequestHandler {
214 handler: Box::new(handler),
215 _phantom: std::marker::PhantomData,
216 };
217
218 self.request_handlers
219 .insert(method.to_string(), Box::new(handler));
220 self
221 }
222
223 pub fn has_request_handler(&self, method: &str) -> bool {
224 self.request_handlers.contains_key(method)
225 }
226
227 pub fn notification_handler<N>(
228 mut self,
229 method: &str,
230 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
231 + Send
232 + Sync
233 + 'static,
234 ) -> Self
235 where
236 N: DeserializeOwned + Send + Sync + 'static,
237 {
238 self.notification_handlers.insert(
239 method.to_string(),
240 Box::new(TypedNotificationHandler {
241 handler: Box::new(handler),
242 _phantom: std::marker::PhantomData,
243 }),
244 );
245 self
246 }
247
248 pub fn build(self) -> Protocol<T> {
249 Protocol {
250 transport: Arc::new(self.transport),
251 request_handlers: Arc::new(Mutex::new(self.request_handlers)),
252 notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
253 request_id: Arc::new(AtomicU64::new(0)),
254 pending_requests: Arc::new(Mutex::new(HashMap::new())),
255 }
256 }
257}
258
259#[async_trait]
261trait RequestHandler: Send + Sync {
262 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
263}
264
265#[async_trait]
266trait NotificationHandler: Send + Sync {
267 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
268}
269
270struct TypedRequestHandler<Req, Resp>
272where
273 Req: DeserializeOwned + Send + Sync + 'static,
274 Resp: Serialize + Send + Sync + 'static,
275{
276 handler: Box<
277 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
278 + Send
279 + Sync,
280 >,
281 _phantom: std::marker::PhantomData<(Req, Resp)>,
282}
283
284#[async_trait]
285impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
286where
287 Req: DeserializeOwned + Send + Sync + 'static,
288 Resp: Serialize + Send + Sync + 'static,
289{
290 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
291 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
292 {
293 serde_json::from_value(serde_json::Value::Null)?
294 } else {
295 serde_json::from_value(request.params.unwrap())?
296 };
297 let result = (self.handler)(params).await?;
298 Ok(JsonRpcResponse {
299 id: request.id,
300 result: Some(serde_json::to_value(result)?),
301 error: None,
302 ..Default::default()
303 })
304 }
305}
306
307struct TypedNotificationHandler<N>
308where
309 N: DeserializeOwned + Send + Sync + 'static,
310{
311 handler: Box<
312 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
313 + Send
314 + Sync,
315 >,
316 _phantom: std::marker::PhantomData<N>,
317}
318
319#[async_trait]
320impl<N> NotificationHandler for TypedNotificationHandler<N>
321where
322 N: DeserializeOwned + Send + Sync + 'static,
323{
324 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
325 let params: N =
326 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
327 serde_json::from_value(serde_json::Value::Null)?
328 } else {
329 match ¬ification.params {
330 Some(params) => {
331 let res = serde_json::from_value(params.clone());
332 match res {
333 Ok(r) => r,
334 Err(e) => {
335 tracing::warn!(
336 "Failed to parse notification params: {:?}. Params: {:?}",
337 e,
338 notification.params
339 );
340 serde_json::from_value(serde_json::Value::Null)?
341 }
342 }
343 }
344 None => serde_json::from_value(serde_json::Value::Null)?,
345 }
346 };
347 (self.handler)(params).await
348 }
349}