1use std::{collections::HashMap, error::Error, sync::Arc, time::Duration, str::FromStr};
2use futures::{stream::SplitSink, SinkExt};
3use futures_util::StreamExt;
4use tokio_tungstenite::{WebSocketStream, MaybeTlsStream, tungstenite::{Message, self}};
5use tracing;
6use tokio::{sync::oneshot, net::TcpStream, sync::Mutex};
7use http::{Request, Uri};
8
9
10#[derive(Debug)]
11pub enum WsError {
12 Timeout,
13 ConnectionClosed,
14 AlreadyClosed,
15 IoError(std::io::Error),
16 Other(Box<dyn Error + Send + Sync>),
17}
18
19impl From<tungstenite::error::Error> for WsError {
20 fn from(e: tungstenite::error::Error) -> Self {
21 match e {
22 tungstenite::error::Error::ConnectionClosed => WsError::ConnectionClosed,
23 tungstenite::error::Error::AlreadyClosed => WsError::AlreadyClosed,
24 tungstenite::error::Error::Io(e) => WsError::IoError(e),
25 _ => WsError::Other(Box::new(e)),
26 }
27 }
28}
29
30struct WriteReqmap {
31 ws_write: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
32 reqmap: HashMap<u64, oneshot::Sender<String>>,
33}
34
35pub struct WsRouter {
36 write_reqmap: Arc<Mutex<WriteReqmap>>,
37}
38
39impl WsRouter {
40 pub async fn new(node: String) -> Result<WsRouter, Box<dyn Error>> {
41 let url = Uri::from_str(&node)?;
42
43 #[allow(deprecated)]
44 let config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
45 write_buffer_size: 0,
46 max_message_size: None,
47 max_frame_size: None,
48 max_write_buffer_size: usize::MAX,
49 accept_unmasked_frames: false,
50 max_send_queue: None, };
52
53 let (ws, _) = tokio_tungstenite::connect_async_with_config(url, Some(config), false).await?;
54 let (ws_write, ws_read) = ws.split();
55 tracing::info!("Websocket connection to {} established.", node);
56
57 let write_reqmap = Arc::new(Mutex::new(WriteReqmap{
58 ws_write,
59 reqmap: HashMap::new(),
60 }));
61 let write_reqmap_clone = write_reqmap.clone(); let router = WsRouter{
64 write_reqmap: write_reqmap,
65 };
66
67 let read_loop = ws_read.for_each_concurrent(None, move |msg| {
68 let write_reqmap = write_reqmap_clone.clone(); async move {
71 match msg {
72 Ok(msg) => {
73 let msg = msg.into_text().unwrap();
74
75 if msg.is_empty() {
76 return;
77 }
78
79 let resp: Result<serde_json::Value, serde_json::Error> = serde_json::from_str(&msg);
80 let resp = match resp {
81 Ok(resp) => resp,
82 Err(e) => {
83 tracing::error!("Error parsing message: {}", e);
84 return;
85 }
86 };
87
88 let payload_id: u64 = match &resp["id"] {
90 serde_json::Value::Number(number) => {
91 if let Some(payload_id) = number.as_u64() {
92 payload_id
93 } else {
94 tracing::error!("Invalid payload ID format: {}", number);
95 return;
96 }
97 }
98 serde_json::Value::String(s) => match s.parse::<u64>() {
99 Ok(parsed_id) => parsed_id,
100 Err(e) => {
101 tracing::error!("Error parsing payload ID as u64: {}", e);
102 return;
103 }
104 },
105 _ => {
106 tracing::error!("Unexpected payload ID format: {:?}", resp["id"]);
107 return;
108 }
109 };
110
111 if let Some(tx) = write_reqmap.lock().await.reqmap.remove(&payload_id) {
112 let channelres = tx.send(msg);
113 if let Err(e) = channelres {
114 tracing::error!("Error sending message to channel: {}", e);
115 }
116 } else {
117 tracing::warn!("No corresponding sender found for payload_id: {}", payload_id);
118 }
119 },
120 Err(e) => {
121 tracing::error!("Error reading message: {}", e);
122 }
123 }
124 }
125 });
126
127 tokio::spawn(read_loop);
128
129 Ok(router)
130 }
131
132 pub async fn new_with_jwt(node: String, jwt: String) -> Result<WsRouter, Box<dyn Error>> {
133 let url = Uri::from_str(&node)?;
134
135 let request = Request::builder()
136 .method("GET")
137 .uri(&url)
138 .header("Upgrade", "websocket")
139 .header("Connection", "Upgrade")
140 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
141 .header("Sec-WebSocket-Version", "13")
142 .header("Authorization", format!("Bearer {}", jwt))
143 .header("Host", url.host().unwrap())
144 .body(())?;
145
146
147 #[allow(deprecated)]
148 let config = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
149 write_buffer_size: 0,
150 max_message_size: None,
151 max_frame_size: None,
152 max_write_buffer_size: usize::MAX,
153 accept_unmasked_frames: false,
154 max_send_queue: None, };
156
157 let (ws, _) = tokio_tungstenite::connect_async_with_config(request, Some(config), false).await?;
158 let (ws_write, ws_read) = ws.split();
159 tracing::info!("Websocket connection to {} established.", node);
160
161 let write_reqmap = Arc::new(Mutex::new(WriteReqmap{
162 ws_write,
163 reqmap: HashMap::new(),
164 }));
165 let write_reqmap_clone = write_reqmap.clone(); let router = WsRouter{
168 write_reqmap: write_reqmap,
169 };
170
171 let read_loop = ws_read.for_each_concurrent(None, move |msg| {
172 let write_reqmap = write_reqmap_clone.clone(); async move {
175 match msg {
176 Ok(msg) => {
177 let msg = msg.into_text().unwrap();
178
179 if msg.is_empty() {
180 return;
181 }
182
183 let resp: Result<serde_json::Value, serde_json::Error> = serde_json::from_str(&msg);
184 let resp = match resp {
185 Ok(resp) => resp,
186 Err(e) => {
187 tracing::error!("Error parsing message: {}", e);
188 return;
189 }
190 };
191
192 let payload_id: u64 = match &resp["id"] {
194 serde_json::Value::Number(number) => {
195 if let Some(payload_id) = number.as_u64() {
196 payload_id
197 } else {
198 tracing::error!("Invalid payload ID format: {}", number);
199 return;
200 }
201 }
202 serde_json::Value::String(s) => match s.parse::<u64>() {
203 Ok(parsed_id) => parsed_id,
204 Err(e) => {
205 tracing::error!("Error parsing payload ID as u64: {}", e);
206 return;
207 }
208 },
209 _ => {
210 tracing::error!("Unexpected payload ID format: {:?}", resp["id"]);
211 return;
212 }
213 };
214
215 if let Some(tx) = write_reqmap.lock().await.reqmap.remove(&payload_id) {
216 let channelres = tx.send(msg);
217 if let Err(e) = channelres {
218 tracing::error!("Error sending message to channel: {}", e);
219 }
220 } else {
221 tracing::warn!("No corresponding sender found for payload_id: {}", payload_id);
222 }
223 },
224 Err(e) => {
225 tracing::error!("Error reading message: {}", e);
226 }
227 }
228 }
229 });
230
231 tokio::spawn(read_loop);
232
233 Ok(router)
234 }
235
236
237 pub async fn send(&self, req: String, payload_id: u64) -> Result<oneshot::Receiver<String>, WsError> {
239
240 let (tx, rx) = oneshot::channel();
241 let req = Message::Text(req);
242 let mut write_reqmap = self.write_reqmap.lock().await;
243 write_reqmap.reqmap.insert(payload_id, tx);
244 write_reqmap.ws_write.send(req).await?;
246 drop(write_reqmap); Ok(rx)
248 }
249
250 pub async fn make_request(&self, req: String, payload_id: u64) -> Result<String, WsError> {
252 Ok(self.send(req, payload_id).await?.await.map_err(|e| WsError::Other(Box::new(e)))?)
253 }
254
255 pub async fn make_request_timeout(&self, req: String, payload_id: u64, timeout: Duration) -> Result<String, WsError> {
256 let rx = self.send(req, payload_id).await?;
257 tokio::time::timeout(timeout, rx).await.map_err(|_| WsError::Timeout)?.map_err(|e| WsError::Other(Box::new(e)))
258 }
259
260 pub async fn stop(&self) {
261 tracing::debug!("Shutting down websocket connection.");
262 self.write_reqmap.lock().await.ws_write.close().await.unwrap();
264 }
265}