endpoint_libs/libs/ws/
client.rs

1use eyre::{bail, eyre, Context, Result};
2use futures::SinkExt;
3use futures::StreamExt;
4use reqwest::header::HeaderValue;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::net::TcpStream;
8use tokio_tungstenite::connect_async;
9use tokio_tungstenite::tungstenite::client::IntoClientRequest;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::WebSocketStream;
13use tracing::*;
14
15use crate::libs::log::LogLevel;
16use crate::libs::ws::WsLogResponse;
17use crate::libs::ws::WsRequestGeneric;
18use crate::libs::ws::WsResponseGeneric;
19
20use super::WsResponseValue;
21
22pub trait WsRequest: Serialize + DeserializeOwned + Send + Sync + Clone {
23    type Response: WsResponse;
24    const METHOD_ID: u32;
25    const SCHEMA: &'static str;
26    const ROLES: Option<&'static [u32]>;
27}
28pub trait WsResponse: Serialize + DeserializeOwned + Send + Sync + Clone {
29    type Request: WsRequest;
30}
31pub struct WsClient {
32    stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
33    seq: u32,
34}
35impl WsClient {
36    pub async fn new(
37        connect_addr: &str,
38        protocol_header: &str,
39        headers: Option<Vec<(&'static str, &'static str)>>,
40    ) -> Result<Self> {
41        let mut req = <&str as IntoClientRequest>::into_client_request(connect_addr)?;
42        req.headers_mut().insert(
43            "Sec-WebSocket-Protocol",
44            HeaderValue::from_str(protocol_header)?,
45        );
46
47        if let Some(headers) = headers {
48            for header in headers {
49                req.headers_mut()
50                    .insert(header.0, HeaderValue::from_str(header.1)?);
51            }
52        }
53
54        let (ws_stream, _) = connect_async(req)
55            .await
56            .context("Failed to connect to endpoint")?;
57        Ok(Self {
58            stream: ws_stream,
59            seq: 0,
60        })
61    }
62    pub async fn send_req(&mut self, method: u32, params: impl Serialize) -> Result<()> {
63        self.seq += 1;
64        let req = serde_json::to_string(&WsRequestGeneric {
65            method,
66            seq: self.seq,
67            params,
68        })?;
69        debug!("send req: {}", req);
70        self.stream.send(Message::Text(req)).await?;
71        Ok(())
72    }
73    pub async fn recv_raw(&mut self) -> Result<WsResponseValue> {
74        let msg = self
75            .stream
76            .next()
77            .await
78            .ok_or(eyre!("Connection closed"))??;
79        let resp: WsResponseValue = serde_json::from_str(&msg.to_string())?;
80        Ok(resp)
81    }
82    pub async fn recv_resp<T: DeserializeOwned>(&mut self) -> Result<T> {
83        loop {
84            let msg = self
85                .stream
86                .next()
87                .await
88                .ok_or(eyre!("Connection closed"))??;
89            match msg {
90                Message::Text(text) => {
91                    debug!("recv resp: {}", text);
92                    let resp: WsResponseGeneric<T> = serde_json::from_str(&text)?;
93                    match resp {
94                        WsResponseGeneric::Immediate(resp) if resp.seq == self.seq => {
95                            return Ok(resp.params);
96                        }
97                        WsResponseGeneric::Immediate(resp) => {
98                            bail!("Seq mismatch this: {} got: {}", self.seq, resp.seq)
99                        }
100                        WsResponseGeneric::Stream(_) => {
101                            debug!("expect immediate response, got stream")
102                        }
103                        WsResponseGeneric::Forwarded(_) => {
104                            debug!("expect immediate response, got forwarded")
105                        }
106                        WsResponseGeneric::Close => {
107                            bail!("unreachable")
108                        }
109                        WsResponseGeneric::Log(WsLogResponse {
110                            log_id,
111                            level,
112                            message,
113                            ..
114                        }) => match level {
115                            LogLevel::Error => error!(?log_id, "{}", message),
116                            LogLevel::Warn => warn!(?log_id, "{}", message),
117                            LogLevel::Info => info!(?log_id, "{}", message),
118                            LogLevel::Debug => debug!(?log_id, "{}", message),
119                            LogLevel::Trace => trace!(?log_id, "{}", message),
120                            LogLevel::Detail => trace!(?log_id, "{}", message),
121                            LogLevel::Off => {}
122                        },
123                        WsResponseGeneric::Error(err) => {
124                            bail!("Error: {} {:?}", err.code, err.params)
125                        }
126                    }
127                }
128                Message::Close(_) => {
129                    self.stream.close(None).await?;
130                    bail!("Connection closed")
131                }
132                _ => {}
133            }
134        }
135    }
136    pub async fn request<T: WsRequest>(&mut self, params: T) -> Result<T::Response> {
137        self.send_req(T::METHOD_ID, params).await?;
138        self.recv_resp().await
139    }
140    pub async fn close(mut self) -> Result<()> {
141        self.stream.close(None).await?;
142        Ok(())
143    }
144}