1use super::transport::{
2 JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message,
3 Transport,
4};
5use super::types::ErrorCode;
6use anyhow::anyhow;
7use anyhow::Result;
8use async_trait::async_trait;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use std::sync::atomic::Ordering;
12use std::time::Duration;
13use std::{
14 collections::HashMap,
15 sync::{atomic::AtomicU64, Arc},
16};
17use tokio::sync::oneshot;
18use tokio::sync::Mutex;
19use tokio::time::timeout;
20use tracing::debug;
21
22#[derive(Clone)]
23pub struct Protocol<T: Transport> {
24 transport: Arc<T>,
25
26 request_id: Arc<AtomicU64>,
27 pending_requests: Arc<Mutex<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28 request_handlers: Arc<Mutex<HashMap<String, Box<dyn RequestHandler>>>>,
29 notification_handlers: Arc<Mutex<HashMap<String, Box<dyn NotificationHandler>>>>,
30}
31
32impl<T: Transport> Protocol<T> {
33 pub fn builder(transport: T) -> ProtocolBuilder<T> {
34 ProtocolBuilder::new(transport)
35 }
36
37 pub fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
38 let notification = JsonRpcNotification {
39 method: method.to_string(),
40 params,
41 ..Default::default()
42 };
43 let msg = JsonRpcMessage::Notification(notification);
44 self.transport.send(&msg)?;
45 Ok(())
46 }
47
48 pub async fn request(
49 &self,
50 method: &str,
51 params: Option<serde_json::Value>,
52 options: RequestOptions,
53 ) -> Result<JsonRpcResponse> {
54 let id = self.request_id.fetch_add(1, Ordering::SeqCst);
55
56 let (tx, rx) = oneshot::channel();
58
59 {
61 let mut pending = self.pending_requests.lock().await;
62 pending.insert(id, tx);
63 }
64
65 let msg = JsonRpcMessage::Request(JsonRpcRequest {
67 id,
68 method: method.to_string(),
69 params,
70 ..Default::default()
71 });
72 self.transport.send(&msg)?;
73
74 match timeout(options.timeout, rx)
76 .await
77 .map_err(|_| anyhow!("Request timed out"))?
78 {
79 Ok(response) => Ok(response),
80 Err(_) => {
81 let mut pending = self.pending_requests.lock().await;
83 pending.remove(&id);
84 Err(anyhow!("Request cancelled"))
85 }
86 }
87 }
88
89 pub async fn listen(&self) -> Result<()> {
90 debug!("Listening for requests");
91 loop {
92 let message: Message = self.transport.receive()?;
93 match message {
94 JsonRpcMessage::Request(request) => self.handle_request(request).await?,
95 JsonRpcMessage::Response(response) => {
96 let id = response.id;
98 let mut pending = self.pending_requests.lock().await;
99 if let Some(tx) = pending.remove(&id) {
100 let _ = tx.send(response);
102 }
103 }
104 JsonRpcMessage::Notification(json_rpc_notification) => {
105 let handlers = self.notification_handlers.lock().await;
106 if let Some(handler) = handlers.get(&json_rpc_notification.method) {
107 handler.handle(json_rpc_notification)?;
108 }
109 }
110 }
111 }
112 }
113
114 async fn handle_request(&self, request: JsonRpcRequest) -> Result<()> {
115 let handlers = self.request_handlers.lock().await;
116 if let Some(handler) = handlers.get(&request.method) {
117 match handler.handle(request.clone()).await {
118 Ok(response) => {
119 let msg = JsonRpcMessage::Response(response);
120 self.transport.send(&msg)?;
121 }
122 Err(e) => {
123 let error_response = JsonRpcResponse {
124 id: request.id,
125 result: None,
126 error: Some(JsonRpcError {
127 code: ErrorCode::InternalError as i32,
128 message: e.to_string(),
129 data: None,
130 }),
131 ..Default::default()
132 };
133 let msg = JsonRpcMessage::Response(error_response);
134 self.transport.send(&msg)?;
135 }
136 }
137 } else {
138 self.transport
139 .send(&JsonRpcMessage::Response(JsonRpcResponse {
140 id: request.id,
141 error: Some(JsonRpcError {
142 code: ErrorCode::MethodNotFound as i32,
143 message: format!("Method not found: {}", request.method),
144 data: None,
145 }),
146 ..Default::default()
147 }))?;
148 }
149 Ok(())
150 }
151}
152
153pub const DEFAULT_REQUEST_TIMEOUT_MSEC: u64 = 60000;
155pub struct RequestOptions {
156 timeout: Duration,
157}
158
159impl RequestOptions {
160 pub fn timeout(self, timeout: Duration) -> Self {
161 Self { timeout }
162 }
163}
164
165impl Default for RequestOptions {
166 fn default() -> Self {
167 Self {
168 timeout: Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MSEC),
169 }
170 }
171}
172
173pub struct ProtocolBuilder<T: Transport> {
174 transport: T,
175 request_handlers: HashMap<String, Box<dyn RequestHandler>>,
176 notification_handlers: HashMap<String, Box<dyn NotificationHandler>>,
177}
178impl<T: Transport> ProtocolBuilder<T> {
179 pub fn new(transport: T) -> Self {
180 Self {
181 transport,
182 request_handlers: HashMap::new(),
183 notification_handlers: HashMap::new(),
184 }
185 }
186 pub fn request_handler<Req, Resp>(
188 mut self,
189 method: &str,
190 handler: impl Fn(Req) -> Result<Resp> + Send + Sync + 'static,
191 ) -> Self
192 where
193 Req: DeserializeOwned + Send + Sync + 'static,
194 Resp: Serialize + Send + Sync + 'static,
195 {
196 let handler = TypedRequestHandler {
197 handler,
198 _phantom: std::marker::PhantomData,
199 };
200
201 self.request_handlers
202 .insert(method.to_string(), Box::new(handler));
203 self
204 }
205
206 pub fn has_request_handler(&self, method: &str) -> bool {
207 self.request_handlers.contains_key(method)
208 }
209
210 pub fn notification_handler<N>(
211 mut self,
212 method: &str,
213 handler: impl Fn(N) -> Result<()> + Send + Sync + 'static,
214 ) -> Self
215 where
216 N: DeserializeOwned + Send + Sync + 'static,
217 {
218 self.notification_handlers.insert(
219 method.to_string(),
220 Box::new(TypedNotificationHandler {
221 handler,
222 _phantom: std::marker::PhantomData,
223 }),
224 );
225 self
226 }
227
228 pub fn build(self) -> Protocol<T> {
229 Protocol {
230 transport: Arc::new(self.transport),
231 request_handlers: Arc::new(Mutex::new(self.request_handlers)),
232 notification_handlers: Arc::new(Mutex::new(self.notification_handlers)),
233 request_id: Arc::new(AtomicU64::new(0)),
234 pending_requests: Arc::new(Mutex::new(HashMap::new())),
235 }
236 }
237}
238
239#[async_trait]
241trait RequestHandler: Send + Sync {
242 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse>;
243}
244
245trait NotificationHandler: Send + Sync {
246 fn handle(&self, notification: JsonRpcNotification) -> Result<()>;
247}
248
249struct TypedRequestHandler<Req, Resp, F>
251where
252 Req: DeserializeOwned + Send + Sync + 'static,
253 Resp: Serialize + Send + Sync + 'static,
254 F: Fn(Req) -> Result<Resp> + Send + Sync + 'static,
255{
256 handler: F,
257 _phantom: std::marker::PhantomData<(Req, Resp)>,
258}
259
260#[async_trait]
261impl<Req, Resp, F> RequestHandler for TypedRequestHandler<Req, Resp, F>
262where
263 Req: DeserializeOwned + Send + Sync + 'static,
264 Resp: Serialize + Send + Sync + 'static,
265 F: Fn(Req) -> Result<Resp> + Send + Sync + 'static,
266{
267 async fn handle(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
268 let params: Req = if request.params.is_none() || request.params.as_ref().unwrap().is_null()
270 {
271 serde_json::from_value(serde_json::Value::Null)?
272 } else {
273 serde_json::from_value(request.params.unwrap())?
274 };
275 let result = (self.handler)(params)?;
276 Ok(JsonRpcResponse {
277 id: request.id,
278 result: Some(serde_json::to_value(result)?),
279 error: None,
280 ..Default::default()
281 })
282 }
283}
284pub struct TypedNotificationHandler<N, F>
285where
286 N: DeserializeOwned + Send + Sync + 'static,
287 F: Fn(N) -> Result<()> + Send + Sync + 'static,
288{
289 handler: F,
290 _phantom: std::marker::PhantomData<N>,
291}
292
293#[async_trait]
294impl<N, F> NotificationHandler for TypedNotificationHandler<N, F>
295where
296 N: DeserializeOwned + Send + Sync + 'static,
297 F: Fn(N) -> Result<()> + Send + Sync + 'static,
298{
299 fn handle(&self, notification: JsonRpcNotification) -> Result<()> {
300 let params: N =
301 if notification.params.is_none() || notification.params.as_ref().unwrap().is_null() {
302 serde_json::from_value(serde_json::Value::Null)?
303 } else {
304 serde_json::from_value(notification.params.unwrap())?
305 };
306 (self.handler)(params)
307 }
308}