1use std::sync::atomic::{AtomicBool, Ordering};
2
3use futures::StreamExt;
4use serde_json::from_str;
5use tokio::net::TcpStream;
6use tokio_tungstenite::tungstenite::handshake::client::Response;
7use tokio_tungstenite::tungstenite::Message;
8use tokio_tungstenite::WebSocketStream;
9use tokio_tungstenite::{connect_async, MaybeTlsStream};
10use url::Url;
11
12use crate::config::Config;
13use crate::errors::*;
14
15pub static STREAM_ENDPOINT: &str = "stream";
16pub static WS_ENDPOINT: &str = "ws";
17pub static OUTBOUND_ACCOUNT_INFO: &str = "outboundAccountInfo";
18pub static OUTBOUND_ACCOUNT_POSITION: &str = "outboundAccountPosition";
19pub static EXECUTION_REPORT: &str = "executionReport";
20pub static KLINE: &str = "kline";
21pub static AGGREGATED_TRADE: &str = "aggTrade";
22pub static DEPTH_ORDERBOOK: &str = "depthUpdate";
23pub static PARTIAL_ORDERBOOK: &str = "lastUpdateId";
24pub static DAYTICKER: &str = "24hrTicker";
25pub static MARK_PRICE: &str = "markPrice";
26
27pub fn all_ticker_stream() -> &'static str { "!ticker@arr" }
28
29pub fn ticker_stream(symbol: &str) -> String { format!("{symbol}@ticker") }
30
31pub fn agg_trade_stream(symbol: &str) -> String { format!("{symbol}@aggTrade") }
32
33pub fn trade_stream(symbol: &str) -> String { format!("{symbol}@trade") }
34
35pub fn kline_stream(symbol: &str, interval: &str) -> String { format!("{symbol}@kline_{interval}") }
36
37pub fn book_ticker_stream(symbol: &str) -> String { format!("{symbol}@bookTicker") }
38
39pub fn all_book_ticker_stream() -> &'static str { "!bookTicker" }
40
41pub fn all_mini_ticker_stream() -> &'static str { "!miniTicker@arr" }
42
43pub fn mini_ticker_stream(symbol: &str) -> String { format!("{symbol}@miniTicker") }
44
45pub fn mark_price_stream(symbol: &str, update_speed: u8) -> String { format!("{symbol}@markPrice@{update_speed}s") }
50
51pub fn partial_book_depth_stream(symbol: &str, levels: u16, update_speed: u16) -> String {
57 format!("{symbol}@depth{levels}@{update_speed}ms")
58}
59
60pub fn diff_book_depth_stream(symbol: &str, update_speed: u16) -> String { format!("{symbol}@depth@{update_speed}ms") }
65
66fn combined_stream(streams: Vec<String>) -> String { streams.join("/") }
67
68pub struct WebSockets<'a, WE> {
69 pub socket: Option<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response)>,
70 handler: Box<dyn FnMut(WE) -> Result<()> + 'a + Send>,
71 conf: Config,
72}
73
74impl<'a, WE: serde::de::DeserializeOwned> WebSockets<'a, WE> {
75 pub fn new<Callback>(handler: Callback) -> WebSockets<'a, WE>
79 where
80 Callback: FnMut(WE) -> Result<()> + 'a + Send,
81 {
82 Self::new_with_options(handler, Config::default())
83 }
84
85 pub fn new_with_options<Callback>(handler: Callback, conf: Config) -> WebSockets<'a, WE>
89 where
90 Callback: FnMut(WE) -> Result<()> + 'a + Send,
91 {
92 WebSockets {
93 socket: None,
94 handler: Box::new(handler),
95 conf,
96 }
97 }
98
99 pub async fn connect_multiple(&mut self, endpoints: Vec<String>) -> Result<()> {
102 let mut url = Url::parse(&self.conf.ws_endpoint)?;
103 url.path_segments_mut()
104 .map_err(|_| Error::UrlParserError(url::ParseError::RelativeUrlWithoutBase))?
105 .push(STREAM_ENDPOINT);
106 url.set_query(Some(&format!("streams={}", combined_stream(endpoints))));
107
108 self.handle_connect(url).await
109 }
110
111 pub async fn connect(&mut self, endpoint: &str) -> Result<()> {
113 let wss: String = format!("{}/{}/{}", self.conf.ws_endpoint, WS_ENDPOINT, endpoint);
114 let url = Url::parse(&wss)?;
115
116 self.handle_connect(url).await
117 }
118
119 pub async fn connect_futures(&mut self, endpoint: &str) -> Result<()> {
121 let wss: String = format!("{}/{}/{}", self.conf.futures_ws_endpoint, WS_ENDPOINT, endpoint);
122 let url = Url::parse(&wss)?;
123
124 self.handle_connect(url).await
125 }
126
127 async fn handle_connect(&mut self, url: Url) -> Result<()> {
128 match connect_async(url).await {
129 Ok(answer) => {
130 self.socket = Some(answer);
131 Ok(())
132 }
133 Err(e) => Err(Error::Msg(format!("Error during handshake {e}"))),
134 }
135 }
136
137 pub async fn disconnect(&mut self) -> Result<()> {
139 if let Some(ref mut socket) = self.socket {
140 socket.0.close(None).await?;
141 Ok(())
142 } else {
143 Err(Error::Msg("Not able to close the connection".to_string()))
144 }
145 }
146
147 pub fn socket(&self) -> &Option<(WebSocketStream<MaybeTlsStream<TcpStream>>, Response)> { &self.socket }
148
149 pub async fn event_loop(&mut self, running: &AtomicBool) -> Result<()> {
150 while running.load(Ordering::Relaxed) {
151 if let Some((ref mut socket, _)) = self.socket {
152 let message = socket.next().await.unwrap()?;
154
155 match message {
156 Message::Text(msg) => {
157 if msg.is_empty() {
158 return Ok(());
159 }
160 let event: WE = from_str(msg.as_str())?;
161 (self.handler)(event)?;
162 }
163 Message::Ping(_) | Message::Pong(_) | Message::Binary(_) | Message::Frame(_) => {}
164 Message::Close(e) => {
165 return Err(Error::Msg(format!("Disconnected {e:?}")));
166 }
167 }
168 }
169 }
170 Ok(())
171 }
172}