exc_binance/websocket/protocol/
mod.rs

1use std::{
2    collections::HashSet,
3    pin::Pin,
4    sync::Arc,
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use self::{
10    frame::Name,
11    stream::{MultiplexRequest, MultiplexResponse},
12};
13
14use super::response::WsResponse;
15use super::{connect::BinanceWsHost, request::WsRequest};
16use super::{error::WsError, request::RequestKind};
17use exc_core::transport::websocket::WsStream;
18use futures::{future::BoxFuture, FutureExt, Sink, SinkExt, Stream, TryFutureExt, TryStreamExt};
19use tokio_tower::multiplex::{Client as Multiplex, TagStore};
20use tower::Service;
21
22/// Multiplex protocol.
23pub mod stream;
24
25/// Frame protocol.
26pub mod frame;
27
28/// Keep-alive protocol.
29pub mod keep_alive;
30
31type Req = MultiplexRequest;
32type Resp = MultiplexResponse;
33
34trait Transport: Sink<Req, Error = WsError> + Stream<Item = Result<Resp, WsError>> {}
35
36impl<T> Transport for T
37where
38    T: Sink<Req, Error = WsError>,
39    T: Stream<Item = Result<Resp, WsError>>,
40{
41}
42
43type BoxTransport = Pin<Box<dyn Transport + Send>>;
44type Refresh = BoxFuture<'static, ()>;
45
46pin_project_lite::pin_project! {
47    /// Binance websocket protocol.
48    pub struct Protocol {
49        #[pin]
50        transport: BoxTransport,
51        next_stream_id: usize,
52    }
53}
54
55impl Protocol {
56    fn new(
57        websocket: WsStream,
58        main_stream: HashSet<Name>,
59        keep_alive_timeout: Duration,
60        default_stream_timeout: Duration,
61        refresh: Option<Refresh>,
62    ) -> (Self, Arc<stream::Shared>) {
63        let transport = keep_alive::layer(
64            websocket.sink_map_err(WsError::from).map_err(WsError::from),
65            keep_alive_timeout,
66        );
67        let transport = frame::layer(transport);
68        let (transport, state) =
69            stream::layer(transport, main_stream, default_stream_timeout, refresh);
70        (
71            Self {
72                transport: Box::pin(transport),
73                next_stream_id: 1,
74            },
75            state,
76        )
77    }
78}
79
80impl Sink<Req> for Protocol {
81    type Error = WsError;
82
83    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        self.project().transport.poll_ready(cx)
85    }
86
87    fn start_send(self: Pin<&mut Self>, item: Req) -> Result<(), Self::Error> {
88        self.project().transport.start_send(item)
89    }
90
91    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92        self.project().transport.poll_flush(cx)
93    }
94
95    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96        self.project().transport.poll_close(cx)
97    }
98}
99
100impl Stream for Protocol {
101    type Item = Result<Resp, WsError>;
102
103    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
104        self.project().transport.poll_next(cx)
105    }
106
107    fn size_hint(&self) -> (usize, Option<usize>) {
108        self.transport.size_hint()
109    }
110}
111
112impl TagStore<Req, Resp> for Protocol {
113    type Tag = usize;
114
115    fn assign_tag(self: Pin<&mut Self>, r: &mut Req) -> Self::Tag {
116        let this = self.project();
117        let id = *this.next_stream_id;
118        *this.next_stream_id += 1;
119        r.id = id;
120        id
121    }
122
123    fn finish_tag(self: Pin<&mut Self>, r: &Resp) -> Self::Tag {
124        match r {
125            Resp::MainStream(id, _) => *id,
126            Resp::SubStream { id, .. } => *id,
127        }
128    }
129}
130
131impl From<tokio_tower::Error<Protocol, Req>> for WsError {
132    fn from(err: tokio_tower::Error<Protocol, Req>) -> Self {
133        match err {
134            tokio_tower::Error::BrokenTransportSend(err)
135            | tokio_tower::Error::BrokenTransportRecv(Some(err)) => err,
136            err => Self::TokioTower(err.into()),
137        }
138    }
139}
140
141/// Binance websocket service.
142pub struct WsClient {
143    main_stream: HashSet<Name>,
144    endpoint: BinanceWsHost,
145    state: Arc<stream::Shared>,
146    svc: Multiplex<Protocol, WsError, Req>,
147    reconnect: bool,
148}
149
150impl WsClient {
151    /// Create a [`WsClient`] using the given websocket stream.
152    pub fn with_websocket(
153        endpoint: BinanceWsHost,
154        websocket: WsStream,
155        main_stream: HashSet<Name>,
156        keep_alive_timeout: Duration,
157        default_stream_timeout: Duration,
158        refresh: Option<Refresh>,
159    ) -> Result<Self, WsError> {
160        let (protocol, state) = Protocol::new(
161            websocket,
162            main_stream.clone(),
163            keep_alive_timeout,
164            default_stream_timeout,
165            refresh,
166        );
167        let shared = state.clone();
168        let svc = Multiplex::with_error_handler(protocol, move |err| {
169            shared.waker.wake();
170            tracing::error!("protocol error: {err}");
171        });
172        Ok(Self {
173            endpoint,
174            main_stream,
175            svc,
176            state,
177            reconnect: false,
178        })
179    }
180
181    fn dispatch(&self, req: WsRequest) -> WsRequest {
182        tracing::trace!(
183            "ws client; dispatching request with endpoint: {:?}",
184            self.endpoint,
185        );
186        match &req.inner {
187            RequestKind::DispatchTrades(trades) => match self.endpoint {
188                BinanceWsHost::EuropeanOptions => {
189                    WsRequest::sub_stream(Name::trade(&trades.instrument))
190                }
191                _ => WsRequest::sub_stream(Name::agg_trade(&trades.instrument)),
192            },
193            RequestKind::DispatchBidAsk(bid_ask) => match self.endpoint {
194                BinanceWsHost::EuropeanOptions => {
195                    WsRequest::sub_stream(Name::depth(&bid_ask.instrument, "10", "100ms"))
196                }
197                _ => WsRequest::sub_stream(Name::book_ticker(&bid_ask.instrument)),
198            },
199            RequestKind::DispatchSubscribe(name) => {
200                if self.main_stream.contains(name) {
201                    WsRequest::main_stream(name.clone())
202                } else {
203                    WsRequest::sub_stream(name.clone())
204                }
205            }
206            _ => {
207                tracing::error!("ws client; not a dispatch request");
208                req
209            }
210        }
211    }
212}
213
214impl Service<WsRequest> for WsClient {
215    type Response = WsResponse;
216    type Error = WsError;
217    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
218
219    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
220        if self.reconnect {
221            Poll::Ready(Err(WsError::TransportIsBoken))
222        } else {
223            self.state.waker.register(cx.waker());
224            self.svc.poll_ready(cx)
225        }
226    }
227
228    fn call(&mut self, mut req: WsRequest) -> Self::Future {
229        let is_stream = req.stream;
230        let mut dispatched = false;
231        loop {
232            match req.inner {
233                RequestKind::Multiplex(req) => {
234                    return self
235                        .svc
236                        .call(req)
237                        .and_then(move |resp| {
238                            let resp: WsResponse = resp.into();
239                            if is_stream {
240                                resp.stream().left_future()
241                            } else {
242                                futures::future::ready(Ok(resp)).right_future()
243                            }
244                        })
245                        .boxed()
246                }
247                RequestKind::Reconnect => {
248                    self.reconnect = true;
249                    return futures::future::ready(Ok(WsResponse::Reconnected)).boxed();
250                }
251                _ => {
252                    if dispatched {
253                        break;
254                    }
255                    req = self.dispatch(req);
256                    dispatched = true;
257                }
258            }
259        }
260        tracing::error!("ws client; failed to dispatch request");
261        futures::future::ready(Err(WsError::TransportIsBoken)).boxed()
262    }
263}