endpoint_libs/libs/ws/
client.rs1use eyre::{bail, eyre, Context, Result};
2use futures::SinkExt;
3use futures::StreamExt;
4use reqwest::header::HeaderValue;
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use tokio::net::TcpStream;
8use tokio_tungstenite::connect_async;
9use tokio_tungstenite::tungstenite::client::IntoClientRequest;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_tungstenite::MaybeTlsStream;
12use tokio_tungstenite::WebSocketStream;
13use tracing::*;
14
15use crate::libs::log::LogLevel;
16use crate::libs::ws::WsLogResponse;
17use crate::libs::ws::WsRequestGeneric;
18use crate::libs::ws::WsResponseGeneric;
19
20use super::WsResponseValue;
21
22pub trait WsRequest: Serialize + DeserializeOwned + Send + Sync + Clone {
23 type Response: WsResponse;
24 const METHOD_ID: u32;
25 const SCHEMA: &'static str;
26}
27pub trait WsResponse: Serialize + DeserializeOwned + Send + Sync + Clone {
28 type Request: WsRequest;
29}
30pub struct WsClient {
31 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
32 seq: u32,
33}
34impl WsClient {
35 pub async fn new(connect_addr: &str, protocol_header: &str, headers: Option<Vec<(&'static str, &'static str)>>) -> Result<Self> {
36 let mut req = <&str as IntoClientRequest>::into_client_request(connect_addr)?;
37 req.headers_mut()
38 .insert("Sec-WebSocket-Protocol", HeaderValue::from_str(&protocol_header)?);
39
40 if let Some(headers) = headers {
41 for header in headers {
42 req.headers_mut().insert(header.0, HeaderValue::from_str(header.1)?);
43 }
44 }
45
46 let (ws_stream, _) = connect_async(req).await.context("Failed to connect to endpoint")?;
47 Ok(Self {
48 stream: ws_stream,
49 seq: 0,
50 })
51 }
52 pub async fn send_req(&mut self, method: u32, params: impl Serialize) -> Result<()> {
53 self.seq += 1;
54 let req = serde_json::to_string(&WsRequestGeneric {
55 method,
56 seq: self.seq,
57 params,
58 })?;
59 debug!("send req: {}", req);
60 self.stream.send(Message::Text(req)).await?;
61 Ok(())
62 }
63 pub async fn recv_raw(&mut self) -> Result<WsResponseValue> {
64 let msg = self.stream.next().await.ok_or(eyre!("Connection closed"))??;
65 let resp: WsResponseValue = serde_json::from_str(&msg.to_string())?;
66 Ok(resp)
67 }
68 pub async fn recv_resp<T: DeserializeOwned>(&mut self) -> Result<T> {
69 loop {
70 let msg = self.stream.next().await.ok_or(eyre!("Connection closed"))??;
71 match msg {
72 Message::Text(text) => {
73 debug!("recv resp: {}", text);
74 let resp: WsResponseGeneric<T> = serde_json::from_str(&text)?;
75 match resp {
76 WsResponseGeneric::Immediate(resp) if resp.seq == self.seq => {
77 return Ok(resp.params);
78 }
79 WsResponseGeneric::Immediate(resp) => {
80 bail!("Seq mismatch this: {} got: {}", self.seq, resp.seq)
81 }
82 WsResponseGeneric::Stream(_) => {
83 debug!("expect immediate response, got stream")
84 }
85 WsResponseGeneric::Forwarded(_) => {
86 debug!("expect immediate response, got forwarded")
87 }
88 WsResponseGeneric::Close => {
89 bail!("unreachable")
90 }
91 WsResponseGeneric::Log(WsLogResponse {
92 log_id, level, message, ..
93 }) => match level {
94 LogLevel::Error => error!(?log_id, "{}", message),
95 LogLevel::Warn => warn!(?log_id, "{}", message),
96 LogLevel::Info => info!(?log_id, "{}", message),
97 LogLevel::Debug => debug!(?log_id, "{}", message),
98 LogLevel::Trace => trace!(?log_id, "{}", message),
99 LogLevel::Detail => trace!(?log_id, "{}", message),
100 LogLevel::Off => {}
101 },
102 WsResponseGeneric::Error(err) => {
103 bail!("Error: {} {:?}", err.code, err.params)
104 }
105 }
106 }
107 Message::Close(_) => {
108 self.stream.close(None).await?;
109 bail!("Connection closed")
110 }
111 _ => {}
112 }
113 }
114 }
115 pub async fn request<T: WsRequest>(&mut self, params: T) -> Result<T::Response> {
116 self.send_req(T::METHOD_ID, params).await?;
117 self.recv_resp().await
118 }
119 pub async fn close(mut self) -> Result<()> {
120 self.stream.close(None).await?;
121 Ok(())
122 }
123}