endpoint_libs/libs/ws/
client.rs1use eyre::{bail, eyre, Context, Result};
2use futures::SinkExt;
3use futures::StreamExt;
4use reqwest::header::HeaderValue;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::net::TcpStream;
8use tokio_tungstenite::connect_async;
9use tokio_tungstenite::tungstenite::client::IntoClientRequest;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::WebSocketStream;
13use tracing::*;
14
15use crate::libs::log::LogLevel;
16use crate::libs::ws::WsLogResponse;
17use crate::libs::ws::WsRequestGeneric;
18use crate::libs::ws::WsResponseGeneric;
19
20use super::WsResponseValue;
21
22pub trait WsRequest: Serialize + DeserializeOwned + Send + Sync + Clone {
23 type Response: WsResponse;
24 const METHOD_ID: u32;
25 const SCHEMA: &'static str;
26 const ROLES: Option<&'static [u32]>;
27}
28pub trait WsResponse: Serialize + DeserializeOwned + Send + Sync + Clone {
29 type Request: WsRequest;
30}
31pub struct WsClient {
32 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
33 seq: u32,
34}
35impl WsClient {
36 pub async fn new(
37 connect_addr: &str,
38 protocol_header: &str,
39 headers: Option<Vec<(&'static str, &'static str)>>,
40 ) -> Result<Self> {
41 let mut req = <&str as IntoClientRequest>::into_client_request(connect_addr)?;
42 req.headers_mut().insert(
43 "Sec-WebSocket-Protocol",
44 HeaderValue::from_str(protocol_header)?,
45 );
46
47 if let Some(headers) = headers {
48 for header in headers {
49 req.headers_mut()
50 .insert(header.0, HeaderValue::from_str(header.1)?);
51 }
52 }
53
54 let (ws_stream, _) = connect_async(req)
55 .await
56 .context("Failed to connect to endpoint")?;
57 Ok(Self {
58 stream: ws_stream,
59 seq: 0,
60 })
61 }
62 pub async fn send_req(&mut self, method: u32, params: impl Serialize) -> Result<()> {
63 self.seq += 1;
64 let req = serde_json::to_string(&WsRequestGeneric {
65 method,
66 seq: self.seq,
67 params,
68 })?;
69 debug!("send req: {}", req);
70 self.stream.send(Message::Text(req)).await?;
71 Ok(())
72 }
73 pub async fn recv_raw(&mut self) -> Result<WsResponseValue> {
74 let msg = self
75 .stream
76 .next()
77 .await
78 .ok_or(eyre!("Connection closed"))??;
79 let resp: WsResponseValue = serde_json::from_str(&msg.to_string())?;
80 Ok(resp)
81 }
82 pub async fn recv_resp<T: DeserializeOwned>(&mut self) -> Result<T> {
83 loop {
84 let msg = self
85 .stream
86 .next()
87 .await
88 .ok_or(eyre!("Connection closed"))??;
89 match msg {
90 Message::Text(text) => {
91 debug!("recv resp: {}", text);
92 let resp: WsResponseGeneric<T> = serde_json::from_str(&text)?;
93 match resp {
94 WsResponseGeneric::Immediate(resp) if resp.seq == self.seq => {
95 return Ok(resp.params);
96 }
97 WsResponseGeneric::Immediate(resp) => {
98 bail!("Seq mismatch this: {} got: {}", self.seq, resp.seq)
99 }
100 WsResponseGeneric::Stream(_) => {
101 debug!("expect immediate response, got stream")
102 }
103 WsResponseGeneric::Forwarded(_) => {
104 debug!("expect immediate response, got forwarded")
105 }
106 WsResponseGeneric::Close => {
107 bail!("unreachable")
108 }
109 WsResponseGeneric::Log(WsLogResponse {
110 log_id,
111 level,
112 message,
113 ..
114 }) => match level {
115 LogLevel::Error => error!(?log_id, "{}", message),
116 LogLevel::Warn => warn!(?log_id, "{}", message),
117 LogLevel::Info => info!(?log_id, "{}", message),
118 LogLevel::Debug => debug!(?log_id, "{}", message),
119 LogLevel::Trace => trace!(?log_id, "{}", message),
120 LogLevel::Detail => trace!(?log_id, "{}", message),
121 LogLevel::Off => {}
122 },
123 WsResponseGeneric::Error(err) => {
124 bail!("Error: {} {:?}", err.code, err.params)
125 }
126 }
127 }
128 Message::Close(_) => {
129 self.stream.close(None).await?;
130 bail!("Connection closed")
131 }
132 _ => {}
133 }
134 }
135 }
136 pub async fn request<T: WsRequest>(&mut self, params: T) -> Result<T::Response> {
137 self.send_req(T::METHOD_ID, params).await?;
138 self.recv_resp().await
139 }
140 pub async fn close(mut self) -> Result<()> {
141 self.stream.close(None).await?;
142 Ok(())
143 }
144}