endpoint_libs/libs/ws/
client.rs

1use eyre::{bail, eyre, Context, Result};
2use futures::SinkExt;
3use futures::StreamExt;
4use reqwest::header::HeaderValue;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::net::TcpStream;
8use tokio_tungstenite::connect_async;
9use tokio_tungstenite::tungstenite::client::IntoClientRequest;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::WebSocketStream;
13use tracing::*;
14
15use crate::libs::log::LogLevel;
16use crate::libs::ws::WsLogResponse;
17use crate::libs::ws::WsRequestGeneric;
18use crate::libs::ws::WsResponseGeneric;
19
20use super::WsResponseValue;
21
22pub trait WsRequest: Serialize + DeserializeOwned + Send + Sync + Clone {
23    type Response: WsResponse;
24    const METHOD_ID: u32;
25    const SCHEMA: &'static str;
26}
27pub trait WsResponse: Serialize + DeserializeOwned + Send + Sync + Clone {
28    type Request: WsRequest;
29}
30pub struct WsClient {
31    stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
32    seq: u32,
33}
34impl WsClient {
35    pub async fn new(connect_addr: &str, protocol_header: &str, headers: Option<Vec<(&'static str, &'static str)>>) -> Result<Self> {
36        let mut req = <&str as IntoClientRequest>::into_client_request(connect_addr)?;
37        req.headers_mut()
38            .insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocol_header)?);
39
40        if let Some(headers) = headers {
41            for header in headers {
42                req.headers_mut().insert(header.0, HeaderValue::from_str(header.1)?);
43            }
44        }
45
46        let (ws_stream, _) = connect_async(req).await.context("Failed to connect to endpoint")?;
47        Ok(Self {
48            stream: ws_stream,
49            seq: 0,
50        })
51    }
52    pub async fn send_req(&mut self, method: u32, params: impl Serialize) -> Result<()> {
53        self.seq += 1;
54        let req = serde_json::to_string(&WsRequestGeneric {
55            method,
56            seq: self.seq,
57            params,
58        })?;
59        debug!("send req: {}", req);
60        self.stream.send(Message::Text(req)).await?;
61        Ok(())
62    }
63    pub async fn recv_raw(&mut self) -> Result<WsResponseValue> {
64        let msg = self.stream.next().await.ok_or(eyre!("Connection closed"))??;
65        let resp: WsResponseValue = serde_json::from_str(&msg.to_string())?;
66        Ok(resp)
67    }
68    pub async fn recv_resp<T: DeserializeOwned>(&mut self) -> Result<T> {
69        loop {
70            let msg = self.stream.next().await.ok_or(eyre!("Connection closed"))??;
71            match msg {
72                Message::Text(text) => {
73                    debug!("recv resp: {}", text);
74                    let resp: WsResponseGeneric<T> = serde_json::from_str(&text)?;
75                    match resp {
76                        WsResponseGeneric::Immediate(resp) if resp.seq == self.seq => {
77                            return Ok(resp.params);
78                        }
79                        WsResponseGeneric::Immediate(resp) => {
80                            bail!("Seq mismatch this: {} got: {}", self.seq, resp.seq)
81                        }
82                        WsResponseGeneric::Stream(_) => {
83                            debug!("expect immediate response, got stream")
84                        }
85                        WsResponseGeneric::Forwarded(_) => {
86                            debug!("expect immediate response, got forwarded")
87                        }
88                        WsResponseGeneric::Close => {
89                            bail!("unreachable")
90                        }
91                        WsResponseGeneric::Log(WsLogResponse {
92                            log_id, level, message, ..
93                        }) => match level {
94                            LogLevel::Error => error!(?log_id, "{}", message),
95                            LogLevel::Warn => warn!(?log_id, "{}", message),
96                            LogLevel::Info => info!(?log_id, "{}", message),
97                            LogLevel::Debug => debug!(?log_id, "{}", message),
98                            LogLevel::Trace => trace!(?log_id, "{}", message),
99                            LogLevel::Detail => trace!(?log_id, "{}", message),
100                            LogLevel::Off => {}
101                        },
102                        WsResponseGeneric::Error(err) => {
103                            bail!("Error: {} {:?}", err.code, err.params)
104                        }
105                    }
106                }
107                Message::Close(_) => {
108                    self.stream.close(None).await?;
109                    bail!("Connection closed")
110                }
111                _ => {}
112            }
113        }
114    }
115    pub async fn request<T: WsRequest>(&mut self, params: T) -> Result<T::Response> {
116        self.send_req(T::METHOD_ID, params).await?;
117        self.recv_resp().await
118    }
119    pub async fn close(mut self) -> Result<()> {
120        self.stream.close(None).await?;
121        Ok(())
122    }
123}