pyth_lazer_client/
ws_connection.rs1use std::hash::{DefaultHasher, Hash, Hasher};
2
3use anyhow::Result;
4use derive_more::From;
5use futures_util::{SinkExt, StreamExt, TryStreamExt};
6use pyth_lazer_protocol::{
7 api::{ErrorResponse, SubscribeRequest, UnsubscribeRequest, WsRequest, WsResponse},
8 binary_update::BinaryWsUpdate,
9};
10use tokio_tungstenite::{connect_async, tungstenite::Message};
11use url::Url;
12
13pub struct PythLazerWSConnection {
21 endpoint: Url,
22 access_token: String,
23 ws_sender: Option<
24 futures_util::stream::SplitSink<
25 tokio_tungstenite::WebSocketStream<
26 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
27 >,
28 Message,
29 >,
30 >,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash, From)]
34pub enum AnyResponse {
35 Json(WsResponse),
36 Binary(BinaryWsUpdate),
37}
38
39impl AnyResponse {
40 pub fn cache_key(&self) -> u64 {
41 let mut hasher = DefaultHasher::new();
42 self.hash(&mut hasher);
43 hasher.finish()
44 }
45}
46impl PythLazerWSConnection {
47 pub fn new(endpoint: Url, access_token: String) -> Result<Self> {
56 Ok(Self {
57 endpoint,
58 access_token,
59 ws_sender: None,
60 })
61 }
62
63 pub async fn start(&mut self) -> Result<impl futures_util::Stream<Item = Result<AnyResponse>>> {
68 let url = self.endpoint.clone();
69 let mut request =
70 tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(url)?;
71
72 request.headers_mut().insert(
73 "Authorization",
74 format!("Bearer {}", self.access_token).parse().unwrap(),
75 );
76
77 let (ws_stream, _) = connect_async(request).await?;
78 let (ws_sender, ws_receiver) = ws_stream.split();
79
80 self.ws_sender = Some(ws_sender);
81 let response_stream =
82 ws_receiver
83 .map_err(anyhow::Error::from)
84 .try_filter_map(|msg| async {
85 let r: Result<Option<AnyResponse>> = match msg {
86 Message::Text(text) => {
87 Ok(Some(serde_json::from_str::<WsResponse>(&text)?.into()))
88 }
89 Message::Binary(data) => {
90 Ok(Some(BinaryWsUpdate::deserialize_slice(&data)?.into()))
91 }
92 Message::Close(_) => Ok(Some(
93 WsResponse::Error(ErrorResponse {
94 error: "WebSocket connection closed".to_string(),
95 })
96 .into(),
97 )),
98 _ => Ok(None),
99 };
100 r
101 });
102
103 Ok(response_stream)
104 }
105
106 pub async fn send_request(&mut self, request: WsRequest) -> Result<()> {
107 if let Some(sender) = &mut self.ws_sender {
108 let msg = serde_json::to_string(&request)?;
109 sender.send(Message::Text(msg)).await?;
110 Ok(())
111 } else {
112 anyhow::bail!("WebSocket connection not started")
113 }
114 }
115
116 pub async fn subscribe(&mut self, request: SubscribeRequest) -> Result<()> {
121 let request = WsRequest::Subscribe(request);
122 self.send_request(request).await
123 }
124
125 pub async fn unsubscribe(&mut self, request: UnsubscribeRequest) -> Result<()> {
130 let request = WsRequest::Unsubscribe(request);
131 self.send_request(request).await
132 }
133
134 pub async fn close(&mut self) -> Result<()> {
136 if let Some(sender) = &mut self.ws_sender {
137 sender.send(Message::Close(None)).await?;
138 self.ws_sender = None;
139 Ok(())
140 } else {
141 anyhow::bail!("WebSocket connection not started")
142 }
143 }
144}