1use super::transport::{
2 JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message,
3 Transport,
4};
5use super::types::ErrorCode;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::pin::Pin;
12use std::sync::atomic::Ordering;
13use std::time::Duration;
14use std::{
15 collections::HashMap,
16 sync::{atomic::AtomicU64, Arc},
17};
18use tokio::sync::oneshot;
19use tokio::sync::Mutex;
20use tokio::time::timeout;
21use tracing::debug;
22
23#[derive(Clone)]
24pub struct Protocol<T: Transport> {
25 transport: Arc<T>,
26
27 request_id: Arc<AtomicU64>,
28 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
29 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
30 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
31}
32
33impl<T: Transport> Protocol<T> {
34 pub fn builder(transport: T) -> ProtocolBuilder<T> {
35 ProtocolBuilder::new(transport)
36 }
37
38 pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
39 let notification = JsonRpcNotification {
40 method: method.to_string(),
41 params,
42 ..Default::default()
43 };
44 let msg = JsonRpcMessage::Notification(notification);
45 self.transport.send(&msg).await?;
46 Ok(())
47 }
48
49 pub async fn request(
50 &self,
51 method: &str,
52 params: Option<serde_json::Value>,
53 options: RequestOptions,
54 ) -> Result<JsonRpcResponse> {
55 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
56
57 let (tx, rx) = oneshot::channel();
59
60 {
62 let mut pending = self.pending_requests.lock().await;
63 pending.insert(id, tx);
64 }
65
66 let msg = JsonRpcMessage::Request(JsonRpcRequest {
68 id,
69 method: method.to_string(),
70 params,
71 ..Default::default()
72 });
73 self.transport.send(&msg).await?;
74
75 match timeout(options.timeout, rx)
77 .await
78 .map_err(|_| anyhow!("Request timed out"))?
79 {
80 Ok(response) => Ok(response),
81 Err(_) => {
82 let mut pending = self.pending_requests.lock().await;
84 pending.remove(&id);
85 Err(anyhow!("Request cancelled"))
86 }
87 }
88 }
89
90 pub async fn listen(&self) -> Result<()> {
91 debug!("Listening for requests");
92 loop {
93 let message: Option<Message> = self.transport.receive().await?;
94
95 if message.is_none() {
97 break;
98 }
99
100 match message.unwrap() {
101 JsonRpcMessage::Request(request) => self.handle_request(request).await?,
102 JsonRpcMessage::Response(response) => {
103 let id = response.id;
104 let mut pending = self.pending_requests.lock().await;
105 if let Some(tx) = pending.remove(&id) {
106 let _ = tx.send(response);
107 }
108 }
109 JsonRpcMessage::Notification(notification) => {
110 let handlers = self.notification_handlers.lock().await;
111 if let Some(handler) = handlers.get(¬ification.method) {
112 handler.handle(notification).await?;
113 }
114 }
115 }
116 }
117 Ok(())
118 }
119
120 async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
121 let handlers = self.request_handlers.lock().await;
122 if let Some(handler) = handlers.get(&request.method) {
123 match handler.handle(request.clone()).await {
124 Ok(response) => {
125 let msg = JsonRpcMessage::Response(response);
126 self.transport.send(&msg).await?;
127 }
128 Err(e) => {
129 let error_response = JsonRpcResponse {
130 id: request.id,
131 result: None,
132 error: Some(JsonRpcError {
133 code: ErrorCode::InternalError as i32,
134 message: e.to_string(),
135 data: None,
136 }),
137 ..Default::default()
138 };
139 let msg = JsonRpcMessage::Response(error_response);
140 self.transport.send(&msg).await?;
141 }
142 }
143 } else {
144 self.transport
145 .send(&JsonRpcMessage::Response(JsonRpcResponse {
146 id: request.id,
147 error: Some(JsonRpcError {
148 code: ErrorCode::MethodNotFound as i32,
149 message: format!("Method not found: {}", request.method),
150 data: None,
151 }),
152 ..Default::default()
153 }))
154 .await?;
155 }
156 Ok(())
157 }
158}
159
160pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
162pub struct RequestOptions {
163 timeout: Duration,
164}
165
166impl RequestOptions {
167 pub fn timeout(self, timeout: Duration) -> Self {
168 Self { timeout }
169 }
170}
171
172impl Default for RequestOptions {
173 fn default() -> Self {
174 Self {
175 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
176 }
177 }
178}
179
180pub struct ProtocolBuilder<T: Transport> {
181 transport: T,
182 request_handlers: HashMap<String, Box<dyn RequestHandler>>,
183 notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
184}
185impl<T: Transport> ProtocolBuilder<T> {
186 pub fn new(transport: T) -> Self {
187 Self {
188 transport,
189 request_handlers: HashMap::new(),
190 notification_handlers: HashMap::new(),
191 }
192 }
193 pub fn request_handler<Req, Resp>(
195 mut self,
196 method: &str,
197 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
198 + Send
199 + Sync
200 + 'static,
201 ) -> Self
202 where
203 Req: DeserializeOwned + Send + Sync + 'static,
204 Resp: Serialize + Send + Sync + 'static,
205 {
206 let handler = TypedRequestHandler {
207 handler: Box::new(handler),
208 _phantom: std::marker::PhantomData,
209 };
210
211 self.request_handlers
212 .insert(method.to_string(), Box::new(handler));
213 self
214 }
215
216 pub fn has_request_handler(&self, method: &str) -> bool {
217 self.request_handlers.contains_key(method)
218 }
219
220 pub fn notification_handler<N>(
221 mut self,
222 method: &str,
223 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
224 + Send
225 + Sync
226 + 'static,
227 ) -> Self
228 where
229 N: DeserializeOwned + Send + Sync + 'static,
230 {
231 self.notification_handlers.insert(
232 method.to_string(),
233 Box::new(TypedNotificationHandler {
234 handler: Box::new(handler),
235 _phantom: std::marker::PhantomData,
236 }),
237 );
238 self
239 }
240
241 pub fn build(self) -> Protocol<T> {
242 Protocol {
243 transport: Arc::new(self.transport),
244 request_handlers: Arc::new(Mutex::new(self.request_handlers)),
245 notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
246 request_id: Arc::new(AtomicU64::new(0)),
247 pending_requests: Arc::new(Mutex::new(HashMap::new())),
248 }
249 }
250}
251
252#[async_trait]
254trait RequestHandler: Send + Sync {
255 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
256}
257
258#[async_trait]
259trait NotificationHandler: Send + Sync {
260 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
261}
262
263struct TypedRequestHandler<Req, Resp>
265where
266 Req: DeserializeOwned + Send + Sync + 'static,
267 Resp: Serialize + Send + Sync + 'static,
268{
269 handler: Box<
270 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
271 + Send
272 + Sync,
273 >,
274 _phantom: std::marker::PhantomData<(Req, Resp)>,
275}
276
277#[async_trait]
278impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
279where
280 Req: DeserializeOwned + Send + Sync + 'static,
281 Resp: Serialize + Send + Sync + 'static,
282{
283 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
284 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
285 {
286 serde_json::from_value(serde_json::Value::Null)?
287 } else {
288 serde_json::from_value(request.params.unwrap())?
289 };
290 let result = (self.handler)(params).await?;
291 Ok(JsonRpcResponse {
292 id: request.id,
293 result: Some(serde_json::to_value(result)?),
294 error: None,
295 ..Default::default()
296 })
297 }
298}
299
300struct TypedNotificationHandler<N>
301where
302 N: DeserializeOwned + Send + Sync + 'static,
303{
304 handler: Box<
305 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
306 + Send
307 + Sync,
308 >,
309 _phantom: std::marker::PhantomData<N>,
310}
311
312#[async_trait]
313impl<N> NotificationHandler for TypedNotificationHandler<N>
314where
315 N: DeserializeOwned + Send + Sync + 'static,
316{
317 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
318 let params: N =
319 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
320 serde_json::from_value(serde_json::Value::Null)?
321 } else {
322 serde_json::from_value(notification.params.unwrap())?
323 };
324 (self.handler)(params).await
325 }
326}