1use super::transport::{
2 JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message,
3 Transport,
4};
5use super::types::ErrorCode;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::pin::Pin;
12use std::sync::atomic::Ordering;
13use std::time::Duration;
14use std::{
15 collections::HashMap,
16 sync::{atomic::AtomicU64, Arc},
17};
18use tokio::sync::oneshot;
19use tokio::sync::Mutex;
20use tokio::time::timeout;
21use tracing::debug;
22
23#[derive(Clone)]
24pub struct Protocol<T: Transport> {
25 transport: Arc<T>,
26
27 request_id: Arc<AtomicU64>,
28 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
29 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
30 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
31}
32
33impl<T: Transport> Protocol<T> {
34 pub fn builder(transport: T) -> ProtocolBuilder<T> {
35 ProtocolBuilder::new(transport)
36 }
37
38 pub async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
39 let notification = JsonRpcNotification {
40 method: method.to_string(),
41 params,
42 ..Default::default()
43 };
44 let msg = JsonRpcMessage::Notification(notification);
45 self.transport.send(&msg).await?;
46 Ok(())
47 }
48
49 pub async fn request(
50 &self,
51 method: &str,
52 params: Option<serde_json::Value>,
53 options: RequestOptions,
54 ) -> Result<JsonRpcResponse> {
55 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
56
57 let (tx, rx) = oneshot::channel();
59
60 {
62 let mut pending = self.pending_requests.lock().await;
63 pending.insert(id, tx);
64 }
65
66 let msg = JsonRpcMessage::Request(JsonRpcRequest {
68 id,
69 method: method.to_string(),
70 params,
71 ..Default::default()
72 });
73
74 self.transport.send(&msg).await?;
75
76 match timeout(options.timeout, rx)
78 .await
79 .map_err(|_| anyhow!("Request timed out"))?
80 {
81 Ok(response) => Ok(response),
82 Err(_) => {
83 let mut pending = self.pending_requests.lock().await;
85 pending.remove(&id);
86 Err(anyhow!("Request cancelled"))
87 }
88 }
89 }
90
91 pub async fn listen(&self) -> Result<()> {
92 debug!("Listening for requests");
93 loop {
94 let message: Option<Message> = self.transport.receive().await?;
95
96 if message.is_none() {
98 break;
99 }
100
101 match message.unwrap() {
102 JsonRpcMessage::Request(request) => self.handle_request(request).await?,
103 JsonRpcMessage::Response(response) => {
104 let id = response.id;
105 let mut pending = self.pending_requests.lock().await;
106 if let Some(tx) = pending.remove(&id) {
107 let _ = tx.send(response);
108 }
109 }
110 JsonRpcMessage::Notification(notification) => {
111 let handlers = self.notification_handlers.lock().await;
112 if let Some(handler) = handlers.get(¬ification.method) {
113 handler.handle(notification).await?;
114 }
115 }
116 }
117 }
118 Ok(())
119 }
120
121 async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
122 let handlers = self.request_handlers.lock().await;
123 if let Some(handler) = handlers.get(&request.method) {
124 match handler.handle(request.clone()).await {
125 Ok(response) => {
126 let msg = JsonRpcMessage::Response(response);
127 self.transport.send(&msg).await?;
128 }
129 Err(e) => {
130 let error_response = JsonRpcResponse {
131 id: request.id,
132 result: None,
133 error: Some(JsonRpcError {
134 code: ErrorCode::InternalError as i32,
135 message: e.to_string(),
136 data: None,
137 }),
138 ..Default::default()
139 };
140 let msg = JsonRpcMessage::Response(error_response);
141 self.transport.send(&msg).await?;
142 }
143 }
144 } else {
145 self.transport
146 .send(&JsonRpcMessage::Response(JsonRpcResponse {
147 id: request.id,
148 error: Some(JsonRpcError {
149 code: ErrorCode::MethodNotFound as i32,
150 message: format!("Method not found: {}", request.method),
151 data: None,
152 }),
153 ..Default::default()
154 }))
155 .await?;
156 }
157 Ok(())
158 }
159}
160
161pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
163pub struct RequestOptions {
164 timeout: Duration,
165}
166
167impl RequestOptions {
168 pub fn timeout(self, timeout: Duration) -> Self {
169 Self { timeout }
170 }
171}
172
173impl Default for RequestOptions {
174 fn default() -> Self {
175 Self {
176 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
177 }
178 }
179}
180
181pub struct ProtocolBuilder<T: Transport> {
182 transport: T,
183 request_handlers: HashMap<String, Box<dyn RequestHandler>>,
184 notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
185}
186impl<T: Transport> ProtocolBuilder<T> {
187 pub fn new(transport: T) -> Self {
188 Self {
189 transport,
190 request_handlers: HashMap::new(),
191 notification_handlers: HashMap::new(),
192 }
193 }
194 pub fn request_handler<Req, Resp>(
196 mut self,
197 method: &str,
198 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
199 + Send
200 + Sync
201 + 'static,
202 ) -> Self
203 where
204 Req: DeserializeOwned + Send + Sync + 'static,
205 Resp: Serialize + Send + Sync + 'static,
206 {
207 let handler = TypedRequestHandler {
208 handler: Box::new(handler),
209 _phantom: std::marker::PhantomData,
210 };
211
212 self.request_handlers
213 .insert(method.to_string(), Box::new(handler));
214 self
215 }
216
217 pub fn has_request_handler(&self, method: &str) -> bool {
218 self.request_handlers.contains_key(method)
219 }
220
221 pub fn notification_handler<N>(
222 mut self,
223 method: &str,
224 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
225 + Send
226 + Sync
227 + 'static,
228 ) -> Self
229 where
230 N: DeserializeOwned + Send + Sync + 'static,
231 {
232 self.notification_handlers.insert(
233 method.to_string(),
234 Box::new(TypedNotificationHandler {
235 handler: Box::new(handler),
236 _phantom: std::marker::PhantomData,
237 }),
238 );
239 self
240 }
241
242 pub fn build(self) -> Protocol<T> {
243 Protocol {
244 transport: Arc::new(self.transport),
245 request_handlers: Arc::new(Mutex::new(self.request_handlers)),
246 notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
247 request_id: Arc::new(AtomicU64::new(0)),
248 pending_requests: Arc::new(Mutex::new(HashMap::new())),
249 }
250 }
251}
252
253#[async_trait]
255trait RequestHandler: Send + Sync {
256 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
257}
258
259#[async_trait]
260trait NotificationHandler: Send + Sync {
261 async fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
262}
263
264struct TypedRequestHandler<Req, Resp>
266where
267 Req: DeserializeOwned + Send + Sync + 'static,
268 Resp: Serialize + Send + Sync + 'static,
269{
270 handler: Box<
271 dyn Fn(Req) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
272 + Send
273 + Sync,
274 >,
275 _phantom: std::marker::PhantomData<(Req, Resp)>,
276}
277
278#[async_trait]
279impl<Req, Resp> RequestHandler for TypedRequestHandler<Req, Resp>
280where
281 Req: DeserializeOwned + Send + Sync + 'static,
282 Resp: Serialize + Send + Sync + 'static,
283{
284 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
285 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
286 {
287 serde_json::from_value(serde_json::Value::Null)?
288 } else {
289 serde_json::from_value(request.params.unwrap())?
290 };
291 let result = (self.handler)(params).await?;
292 Ok(JsonRpcResponse {
293 id: request.id,
294 result: Some(serde_json::to_value(result)?),
295 error: None,
296 ..Default::default()
297 })
298 }
299}
300
301struct TypedNotificationHandler<N>
302where
303 N: DeserializeOwned + Send + Sync + 'static,
304{
305 handler: Box<
306 dyn Fn(N) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
307 + Send
308 + Sync,
309 >,
310 _phantom: std::marker::PhantomData<N>,
311}
312
313#[async_trait]
314impl<N> NotificationHandler for TypedNotificationHandler<N>
315where
316 N: DeserializeOwned + Send + Sync + 'static,
317{
318 async fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
319 let params: N =
320 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
321 serde_json::from_value(serde_json::Value::Null)?
322 } else {
323 serde_json::from_value(notification.params.unwrap())?
324 };
325 (self.handler)(params).await
326 }
327}