ajsonrpc/
lib.rs

1use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, str::FromStr};
2use futures::{stream::SplitSink, SinkExt};
3use futures_util::StreamExt;
4use tokio_tungstenite::{WebSocketStream, MaybeTlsStream, tungstenite::{Message, self}};
5use tracing;
6use tokio::{sync::oneshot, net::TcpStream, sync::Mutex};
7use http::{Request, Uri};
8
9
10#[derive(Debug)]
11pub enum WsError {
12    Timeout,
13    ConnectionClosed,
14    AlreadyClosed,
15    IoError(std::io::Error),
16    Other(Box<dyn Error + Send + Sync>),
17}
18
19impl From<tungstenite::error::Error> for WsError {
20    fn from(e: tungstenite::error::Error) -> Self {
21        match e {
22            tungstenite::error::Error::ConnectionClosed => WsError::ConnectionClosed,
23            tungstenite::error::Error::AlreadyClosed => WsError::AlreadyClosed,
24            tungstenite::error::Error::Io(e) => WsError::IoError(e),
25            _ => WsError::Other(Box::new(e)),
26        }
27    }
28}
29
30struct WriteReqmap {
31    ws_write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
32    reqmap: HashMap<u64, oneshot::Sender<String>>,
33}
34
35pub struct WsRouter {
36    write_reqmap: Arc<Mutex<WriteReqmap>>,
37}
38
39impl WsRouter {
40    pub async fn new(node: String) -> Result<WsRouter, Box<dyn Error>> {
41        let url = Uri::from_str(&node)?;
42
43        #[allow(deprecated)]
44        let config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
45            write_buffer_size: 0,
46            max_message_size: None,
47            max_frame_size: None,
48            max_write_buffer_size: usize::MAX,
49            accept_unmasked_frames: false,
50            max_send_queue: None,   // deprecated
51        };
52
53        let (ws, _) = tokio_tungstenite::connect_async_with_config(url, Some(config), false).await?;
54        let (ws_write, ws_read) = ws.split();
55        tracing::info!("Websocket connection to {} established.", node);
56
57        let write_reqmap = Arc::new(Mutex::new(WriteReqmap{
58            ws_write,
59            reqmap: HashMap::new(),
60        }));
61        let write_reqmap_clone = write_reqmap.clone(); // Clone for closure
62
63        let router = WsRouter{
64            write_reqmap: write_reqmap,
65        };
66        
67        let read_loop = ws_read.for_each_concurrent(None, move |msg| {
68            let write_reqmap = write_reqmap_clone.clone(); // Clone for closure
69            
70            async move {
71                match msg {
72                    Ok(msg) => {
73                        let msg = msg.into_text().unwrap();
74
75                        if msg.is_empty() {
76                            return;
77                        }
78
79                        let resp: Result<serde_json::Value, serde_json::Error>  = serde_json::from_str(&msg);
80                        let resp = match resp {
81                            Ok(resp) => resp,
82                            Err(e) => {
83                                tracing::error!("Error parsing message: {}", e);
84                                return;
85                            }
86                        };
87
88                        // Payload ID can come in as "1" or just 1
89                        let payload_id: u64 = match &resp["id"] {
90                            serde_json::Value::Number(number) => {
91                                if let Some(payload_id) = number.as_u64() {
92                                    payload_id
93                                } else {
94                                    tracing::error!("Invalid payload ID format: {}", number);
95                                    return;
96                                }
97                            }
98                            serde_json::Value::String(s) => match s.parse::<u64>() {
99                                Ok(parsed_id) => parsed_id,
100                                Err(e) => {
101                                    tracing::error!("Error parsing payload ID as u64: {}", e);
102                                    return;
103                                }
104                            },
105                            _ => {
106                                tracing::error!("Unexpected payload ID format: {:?}", resp["id"]);
107                                return;
108                            }
109                        };
110
111                        if let Some(tx) = write_reqmap.lock().await.reqmap.remove(&payload_id) {
112                            let channelres = tx.send(msg);
113                            if let Err(e) = channelres {
114                                tracing::error!("Error sending message to channel: {}", e);
115                            }
116                        } else {
117                            tracing::warn!("No corresponding sender found for payload_id: {}", payload_id);
118                        }
119                    },
120                    Err(e) => {
121                        tracing::error!("Error reading message: {}", e);
122                    }
123                }
124            }
125        });
126
127        tokio::spawn(read_loop);
128
129        Ok(router)
130    }
131
132     pub async fn new_with_jwt(node: String, jwt: String) -> Result<WsRouter, Box<dyn Error>> {
133        let url = Uri::from_str(&node)?;
134
135        let request = Request::builder()
136            .method("GET")
137            .uri(&url)
138            .header("Upgrade", "websocket")
139            .header("Connection", "Upgrade")
140            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
141            .header("Sec-WebSocket-Version", "13")
142            .header("Authorization", format!("Bearer {}", jwt))
143            .header("Host", url.host().unwrap())
144            .body(())?;
145            
146        
147        #[allow(deprecated)]
148        let config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
149            write_buffer_size: 0,
150            max_message_size: None,
151            max_frame_size: None,
152            max_write_buffer_size: usize::MAX,
153            accept_unmasked_frames: false,
154            max_send_queue: None,   // deprecated
155        };
156
157        let (ws, _) = tokio_tungstenite::connect_async_with_config(request, Some(config), false).await?;
158        let (ws_write, ws_read) = ws.split();
159        tracing::info!("Websocket connection to {} established.", node);
160
161        let write_reqmap = Arc::new(Mutex::new(WriteReqmap{
162            ws_write,
163            reqmap: HashMap::new(),
164        }));
165        let write_reqmap_clone = write_reqmap.clone(); // Clone for closure
166
167        let router = WsRouter{
168            write_reqmap: write_reqmap,
169        };
170        
171        let read_loop = ws_read.for_each_concurrent(None, move |msg| {
172            let write_reqmap = write_reqmap_clone.clone(); // Clone for closure
173            
174            async move {
175                match msg {
176                    Ok(msg) => {
177                        let msg = msg.into_text().unwrap();
178
179                        if msg.is_empty() {
180                            return;
181                        }
182
183                        let resp: Result<serde_json::Value, serde_json::Error>  = serde_json::from_str(&msg);
184                        let resp = match resp {
185                            Ok(resp) => resp,
186                            Err(e) => {
187                                tracing::error!("Error parsing message: {}", e);
188                                return;
189                            }
190                        };
191
192                        // Payload ID can come in as "1" or just 1
193                        let payload_id: u64 = match &resp["id"] {
194                            serde_json::Value::Number(number) => {
195                                if let Some(payload_id) = number.as_u64() {
196                                    payload_id
197                                } else {
198                                    tracing::error!("Invalid payload ID format: {}", number);
199                                    return;
200                                }
201                            }
202                            serde_json::Value::String(s) => match s.parse::<u64>() {
203                                Ok(parsed_id) => parsed_id,
204                                Err(e) => {
205                                    tracing::error!("Error parsing payload ID as u64: {}", e);
206                                    return;
207                                }
208                            },
209                            _ => {
210                                tracing::error!("Unexpected payload ID format: {:?}", resp["id"]);
211                                return;
212                            }
213                        };
214
215                        if let Some(tx) = write_reqmap.lock().await.reqmap.remove(&payload_id) {
216                            let channelres = tx.send(msg);
217                            if let Err(e) = channelres {
218                                tracing::error!("Error sending message to channel: {}", e);
219                            }
220                        } else {
221                            tracing::warn!("No corresponding sender found for payload_id: {}", payload_id);
222                        }
223                    },
224                    Err(e) => {
225                        tracing::error!("Error reading message: {}", e);
226                    }
227                }
228            }
229        });
230
231        tokio::spawn(read_loop);
232
233        Ok(router)
234    }
235
236
237    // make sure that the payload_id is the same id your using in your json rpc request
238    pub async fn send(&self, req: String, payload_id: u64) -> Result<oneshot::Receiver<String>, WsError> {
239
240        let (tx, rx) = oneshot::channel();
241        let req = Message::Text(req);
242        let mut write_reqmap = self.write_reqmap.lock().await;
243        write_reqmap.reqmap.insert(payload_id, tx);
244        // send the req to the node
245        write_reqmap.ws_write.send(req).await?;
246        drop(write_reqmap); // drop here to yield the lock as soon as possible
247        Ok(rx)
248    }
249
250    // makes + waits for response
251    pub async fn make_request(&self, req: String, payload_id: u64) -> Result<String, WsError> {
252        Ok(self.send(req, payload_id).await?.await.map_err(|e| WsError::Other(Box::new(e)))?)
253    }
254
255    pub async fn make_request_timeout(&self, req: String, payload_id: u64, timeout: Duration) -> Result<String, WsError> {
256        let rx = self.send(req, payload_id).await?;
257        tokio::time::timeout(timeout, rx).await.map_err(|_| WsError::Timeout)?.map_err(|e| WsError::Other(Box::new(e)))
258    }
259
260    pub async fn stop(&self) {
261        tracing::debug!("Shutting down websocket connection.");
262        // call close on the websocket connection
263        self.write_reqmap.lock().await.ws_write.close().await.unwrap();
264    }
265}