exc_binance/websocket/protocol/
mod.rs1use std::{
2 collections::HashSet,
3 pin::Pin,
4 sync::Arc,
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use self::{
10 frame::Name,
11 stream::{MultiplexRequest, MultiplexResponse},
12};
13
14use super::response::WsResponse;
15use super::{connect::BinanceWsHost, request::WsRequest};
16use super::{error::WsError, request::RequestKind};
17use exc_core::transport::websocket::WsStream;
18use futures::{future::BoxFuture, FutureExt, Sink, SinkExt, Stream, TryFutureExt, TryStreamExt};
19use tokio_tower::multiplex::{Client as Multiplex, TagStore};
20use tower::Service;
21
22pub mod stream;
24
25pub mod frame;
27
28pub mod keep_alive;
30
31type Req = MultiplexRequest;
32type Resp = MultiplexResponse;
33
34trait Transport: Sink<Req, Error = WsError> + Stream<Item = Result<Resp, WsError>> {}
35
36impl<T> Transport for T
37where
38 T: Sink<Req, Error = WsError>,
39 T: Stream<Item = Result<Resp, WsError>>,
40{
41}
42
43type BoxTransport = Pin<Box<dyn Transport + Send>>;
44type Refresh = BoxFuture<'static, ()>;
45
46pin_project_lite::pin_project! {
47 pub struct Protocol {
49 #[pin]
50 transport: BoxTransport,
51 next_stream_id: usize,
52 }
53}
54
55impl Protocol {
56 fn new(
57 websocket: WsStream,
58 main_stream: HashSet<Name>,
59 keep_alive_timeout: Duration,
60 default_stream_timeout: Duration,
61 refresh: Option<Refresh>,
62 ) -> (Self, Arc<stream::Shared>) {
63 let transport = keep_alive::layer(
64 websocket.sink_map_err(WsError::from).map_err(WsError::from),
65 keep_alive_timeout,
66 );
67 let transport = frame::layer(transport);
68 let (transport, state) =
69 stream::layer(transport, main_stream, default_stream_timeout, refresh);
70 (
71 Self {
72 transport: Box::pin(transport),
73 next_stream_id: 1,
74 },
75 state,
76 )
77 }
78}
79
80impl Sink<Req> for Protocol {
81 type Error = WsError;
82
83 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 self.project().transport.poll_ready(cx)
85 }
86
87 fn start_send(self: Pin<&mut Self>, item: Req) -> Result<(), Self::Error> {
88 self.project().transport.start_send(item)
89 }
90
91 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 self.project().transport.poll_flush(cx)
93 }
94
95 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96 self.project().transport.poll_close(cx)
97 }
98}
99
100impl Stream for Protocol {
101 type Item = Result<Resp, WsError>;
102
103 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
104 self.project().transport.poll_next(cx)
105 }
106
107 fn size_hint(&self) -> (usize, Option<usize>) {
108 self.transport.size_hint()
109 }
110}
111
112impl TagStore<Req, Resp> for Protocol {
113 type Tag = usize;
114
115 fn assign_tag(self: Pin<&mut Self>, r: &mut Req) -> Self::Tag {
116 let this = self.project();
117 let id = *this.next_stream_id;
118 *this.next_stream_id += 1;
119 r.id = id;
120 id
121 }
122
123 fn finish_tag(self: Pin<&mut Self>, r: &Resp) -> Self::Tag {
124 match r {
125 Resp::MainStream(id, _) => *id,
126 Resp::SubStream { id, .. } => *id,
127 }
128 }
129}
130
131impl From<tokio_tower::Error<Protocol, Req>> for WsError {
132 fn from(err: tokio_tower::Error<Protocol, Req>) -> Self {
133 match err {
134 tokio_tower::Error::BrokenTransportSend(err)
135 | tokio_tower::Error::BrokenTransportRecv(Some(err)) => err,
136 err => Self::TokioTower(err.into()),
137 }
138 }
139}
140
141pub struct WsClient {
143 main_stream: HashSet<Name>,
144 endpoint: BinanceWsHost,
145 state: Arc<stream::Shared>,
146 svc: Multiplex<Protocol, WsError, Req>,
147 reconnect: bool,
148}
149
150impl WsClient {
151 pub fn with_websocket(
153 endpoint: BinanceWsHost,
154 websocket: WsStream,
155 main_stream: HashSet<Name>,
156 keep_alive_timeout: Duration,
157 default_stream_timeout: Duration,
158 refresh: Option<Refresh>,
159 ) -> Result<Self, WsError> {
160 let (protocol, state) = Protocol::new(
161 websocket,
162 main_stream.clone(),
163 keep_alive_timeout,
164 default_stream_timeout,
165 refresh,
166 );
167 let shared = state.clone();
168 let svc = Multiplex::with_error_handler(protocol, move |err| {
169 shared.waker.wake();
170 tracing::error!("protocol error: {err}");
171 });
172 Ok(Self {
173 endpoint,
174 main_stream,
175 svc,
176 state,
177 reconnect: false,
178 })
179 }
180
181 fn dispatch(&self, req: WsRequest) -> WsRequest {
182 tracing::trace!(
183 "ws client; dispatching request with endpoint: {:?}",
184 self.endpoint,
185 );
186 match &req.inner {
187 RequestKind::DispatchTrades(trades) => match self.endpoint {
188 BinanceWsHost::EuropeanOptions => {
189 WsRequest::sub_stream(Name::trade(&trades.instrument))
190 }
191 _ => WsRequest::sub_stream(Name::agg_trade(&trades.instrument)),
192 },
193 RequestKind::DispatchBidAsk(bid_ask) => match self.endpoint {
194 BinanceWsHost::EuropeanOptions => {
195 WsRequest::sub_stream(Name::depth(&bid_ask.instrument, "10", "100ms"))
196 }
197 _ => WsRequest::sub_stream(Name::book_ticker(&bid_ask.instrument)),
198 },
199 RequestKind::DispatchSubscribe(name) => {
200 if self.main_stream.contains(name) {
201 WsRequest::main_stream(name.clone())
202 } else {
203 WsRequest::sub_stream(name.clone())
204 }
205 }
206 _ => {
207 tracing::error!("ws client; not a dispatch request");
208 req
209 }
210 }
211 }
212}
213
214impl Service<WsRequest> for WsClient {
215 type Response = WsResponse;
216 type Error = WsError;
217 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
218
219 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220 if self.reconnect {
221 Poll::Ready(Err(WsError::TransportIsBoken))
222 } else {
223 self.state.waker.register(cx.waker());
224 self.svc.poll_ready(cx)
225 }
226 }
227
228 fn call(&mut self, mut req: WsRequest) -> Self::Future {
229 let is_stream = req.stream;
230 let mut dispatched = false;
231 loop {
232 match req.inner {
233 RequestKind::Multiplex(req) => {
234 return self
235 .svc
236 .call(req)
237 .and_then(move |resp| {
238 let resp: WsResponse = resp.into();
239 if is_stream {
240 resp.stream().left_future()
241 } else {
242 futures::future::ready(Ok(resp)).right_future()
243 }
244 })
245 .boxed()
246 }
247 RequestKind::Reconnect => {
248 self.reconnect = true;
249 return futures::future::ready(Ok(WsResponse::Reconnected)).boxed();
250 }
251 _ => {
252 if dispatched {
253 break;
254 }
255 req = self.dispatch(req);
256 dispatched = true;
257 }
258 }
259 }
260 tracing::error!("ws client; failed to dispatch request");
261 futures::future::ready(Err(WsError::TransportIsBoken)).boxed()
262 }
263}