Skip to main content

exc_okx/websocket/transport/protocol/
mod.rs

1use crate::websocket::types::{
2    request::{ClientStream, Request},
3    response::{Response, ServerStream, Status, StatusKind},
4};
5use atomic_waker::AtomicWaker;
6use exc_core::transport::websocket::WsStream;
7use futures::{
8    future::{ready, BoxFuture},
9    FutureExt, Sink, SinkExt, Stream, StreamExt, TryStreamExt,
10};
11use pin_project_lite::pin_project;
12use std::{pin::Pin, sync::Arc};
13use std::{
14    task::{Context, Poll},
15    time::Duration,
16};
17use thiserror::Error;
18use tokio_tower::multiplex::{Client, TagStore};
19use tokio_tungstenite::tungstenite::Message;
20use tower::Service;
21
22mod frame;
23mod message;
24mod ping_pong;
25mod stream;
26
27pub use frame::FrameError;
28pub use message::MessageError;
29pub use ping_pong::PingPongError;
30pub use stream::StreamingError;
31
32type Req = ClientStream;
33type Resp = Result<ServerStream, Status>;
34
35/// Protocol Error.
36#[derive(Debug, Error)]
37pub enum ProtocolError {
38    /// Transport Errors.
39    #[error("transport: {0}")]
40    Transport(#[from] StreamingError<FrameError<MessageError<PingPongError>>>),
41
42    /// Tokio tower error.
43    #[error("tokio-tower: {0}")]
44    TokioTower(anyhow::Error),
45    // /// Subsribed.
46    // #[error("subscribed: {0}")]
47    // Subscribed(Args),
48    /// Reconnect.
49    #[error("reconnect")]
50    Reconnect,
51}
52
53/// Okx websocket transport stream.
54pub trait OkxWsStream:
55    Sink<Req, Error = ProtocolError> + Stream<Item = Result<Resp, ProtocolError>>
56{
57}
58
59impl<S> OkxWsStream for S
60where
61    S: Sink<Req, Error = ProtocolError>,
62    S: Stream<Item = Result<Resp, ProtocolError>>,
63{
64}
65
66type BoxStream = Pin<Box<dyn OkxWsStream + Send>>;
67
68pin_project! {
69    /// Okx websocket transport of v5 api.
70    pub struct Transport {
71        #[pin]
72        inner: BoxStream,
73        stream_id: usize,
74    }
75}
76
77impl Transport {
78    pub(crate) fn new<S, Err>(
79        transport: S,
80        ping_timeout: Duration,
81        waker: Arc<AtomicWaker>,
82    ) -> Transport
83    where
84        S: 'static + Send,
85        Err: 'static,
86        S: Sink<String, Error = Err>,
87        S: Stream<Item = Result<String, Err>>,
88        Err: Into<anyhow::Error>,
89    {
90        let transport = ping_pong::layer(transport, ping_timeout);
91        let transport = message::layer(transport);
92        let transport = frame::layer(transport);
93        let transport = stream::layer(transport, waker);
94        let inner = transport
95            .sink_map_err(ProtocolError::from)
96            .map_err(ProtocolError::from);
97        Self {
98            inner: Box::pin(inner),
99            stream_id: 1,
100        }
101    }
102}
103
104impl Sink<Req> for Transport {
105    type Error = ProtocolError;
106
107    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
108        self.project().inner.poll_ready(cx)
109    }
110
111    fn start_send(self: Pin<&mut Self>, item: Req) -> Result<(), Self::Error> {
112        self.project().inner.start_send(item)
113    }
114
115    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116        self.project().inner.poll_flush(cx)
117    }
118
119    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        self.project().inner.poll_close(cx)
121    }
122}
123
124impl Stream for Transport {
125    type Item = Result<Resp, ProtocolError>;
126
127    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128        self.project().inner.poll_next(cx)
129    }
130
131    fn size_hint(&self) -> (usize, Option<usize>) {
132        self.inner.size_hint()
133    }
134}
135
136impl TagStore<Req, Resp> for Transport {
137    type Tag = usize;
138
139    fn assign_tag(self: Pin<&mut Self>, r: &mut Req) -> Self::Tag {
140        let this = self.project();
141        let id = *this.stream_id;
142        *this.stream_id += 1;
143        r.id = id;
144        id
145    }
146
147    fn finish_tag(self: Pin<&mut Self>, r: &Resp) -> Self::Tag {
148        match r.as_ref() {
149            Ok(s) => s.id,
150            Err(e) => e.stream_id,
151        }
152    }
153}
154
155impl From<tokio_tower::Error<Transport, Req>> for ProtocolError {
156    fn from(err: tokio_tower::Error<Transport, Req>) -> Self {
157        Self::TokioTower(err.into())
158    }
159}
160
161/// Okx websocket api protocol.
162pub struct Protocol {
163    waker: Arc<AtomicWaker>,
164    inner: Client<Transport, ProtocolError, Req>,
165    reconnect: bool,
166}
167
168impl Protocol {
169    pub(crate) async fn init(
170        websocket: WsStream,
171        ping_timeout: Duration,
172    ) -> Result<Self, ProtocolError> {
173        let transport = websocket
174            .with(|msg: String| async move { Ok(Message::Text(msg)) })
175            .filter_map(|msg| async move {
176                match msg {
177                    Ok(msg) => match msg {
178                        Message::Text(text) => Some(Ok(text)),
179                        _ => None,
180                    },
181                    Err(err) => Some(Err(err)),
182                }
183            });
184        let waker = Arc::new(AtomicWaker::default());
185        let transport = Transport::new(transport, ping_timeout, waker.clone());
186        Ok(Self {
187            waker,
188            inner: Client::with_error_handler(transport, |e| {
189                tracing::error!("protocol error: {e}");
190            }),
191            reconnect: false,
192        })
193    }
194}
195
196impl Service<Request> for Protocol {
197    type Response = Response;
198    type Error = ProtocolError;
199    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
200
201    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202        if self.reconnect {
203            Poll::Ready(Err(ProtocolError::Reconnect))
204        } else {
205            // wake up when the transport is dead.
206            self.waker.register(cx.waker());
207            self.inner.poll_ready(cx)
208        }
209    }
210
211    fn call(&mut self, req: Request) -> Self::Future {
212        if req.reconnect {
213            self.reconnect = true;
214            ready(Ok(Response::Reconnected)).boxed()
215        } else {
216            let resp = self.inner.call(req.into_client_stream());
217            async move {
218                let resp = resp.await?;
219                let resp = match resp {
220                    Ok(stream) => {
221                        let mut stream = Box::pin(stream.peekable());
222                        if let Some(frame) = stream.as_mut().peek().await {
223                            trace!("wait header; peeked {frame:?}");
224                            Response::Streaming(stream)
225                        } else {
226                            trace!("wait header; no header");
227                            Response::Error(StatusKind::EmptyResponse)
228                        }
229                    }
230                    Err(err) => Response::Error(err.kind),
231                };
232                Ok(resp)
233            }
234            .boxed()
235        }
236    }
237}