exc_okx/websocket/transport/protocol/
mod.rs1use crate::websocket::types::{
2 request::{ClientStream, Request},
3 response::{Response, ServerStream, Status, StatusKind},
4};
5use atomic_waker::AtomicWaker;
6use exc_core::transport::websocket::WsStream;
7use futures::{
8 future::{ready, BoxFuture},
9 FutureExt, Sink, SinkExt, Stream, StreamExt, TryStreamExt,
10};
11use pin_project_lite::pin_project;
12use std::{pin::Pin, sync::Arc};
13use std::{
14 task::{Context, Poll},
15 time::Duration,
16};
17use thiserror::Error;
18use tokio_tower::multiplex::{Client, TagStore};
19use tokio_tungstenite::tungstenite::Message;
20use tower::Service;
21
22mod frame;
23mod message;
24mod ping_pong;
25mod stream;
26
27pub use frame::FrameError;
28pub use message::MessageError;
29pub use ping_pong::PingPongError;
30pub use stream::StreamingError;
31
32type Req = ClientStream;
33type Resp = Result<ServerStream, Status>;
34
35#[derive(Debug, Error)]
37pub enum ProtocolError {
38 #[error("transport: {0}")]
40 Transport(#[from] StreamingError<FrameError<MessageError<PingPongError>>>),
41
42 #[error("tokio-tower: {0}")]
44 TokioTower(anyhow::Error),
45 #[error("reconnect")]
50 Reconnect,
51}
52
53pub trait OkxWsStream:
55 Sink<Req, Error = ProtocolError> + Stream<Item = Result<Resp, ProtocolError>>
56{
57}
58
59impl<S> OkxWsStream for S
60where
61 S: Sink<Req, Error = ProtocolError>,
62 S: Stream<Item = Result<Resp, ProtocolError>>,
63{
64}
65
66type BoxStream = Pin<Box<dyn OkxWsStream + Send>>;
67
68pin_project! {
69 pub struct Transport {
71 #[pin]
72 inner: BoxStream,
73 stream_id: usize,
74 }
75}
76
77impl Transport {
78 pub(crate) fn new<S, Err>(
79 transport: S,
80 ping_timeout: Duration,
81 waker: Arc<AtomicWaker>,
82 ) -> Transport
83 where
84 S: 'static + Send,
85 Err: 'static,
86 S: Sink<String, Error = Err>,
87 S: Stream<Item = Result<String, Err>>,
88 Err: Into<anyhow::Error>,
89 {
90 let transport = ping_pong::layer(transport, ping_timeout);
91 let transport = message::layer(transport);
92 let transport = frame::layer(transport);
93 let transport = stream::layer(transport, waker);
94 let inner = transport
95 .sink_map_err(ProtocolError::from)
96 .map_err(ProtocolError::from);
97 Self {
98 inner: Box::pin(inner),
99 stream_id: 1,
100 }
101 }
102}
103
104impl Sink<Req> for Transport {
105 type Error = ProtocolError;
106
107 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
108 self.project().inner.poll_ready(cx)
109 }
110
111 fn start_send(self: Pin<&mut Self>, item: Req) -> Result<(), Self::Error> {
112 self.project().inner.start_send(item)
113 }
114
115 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 self.project().inner.poll_flush(cx)
117 }
118
119 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 self.project().inner.poll_close(cx)
121 }
122}
123
124impl Stream for Transport {
125 type Item = Result<Resp, ProtocolError>;
126
127 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128 self.project().inner.poll_next(cx)
129 }
130
131 fn size_hint(&self) -> (usize, Option<usize>) {
132 self.inner.size_hint()
133 }
134}
135
136impl TagStore<Req, Resp> for Transport {
137 type Tag = usize;
138
139 fn assign_tag(self: Pin<&mut Self>, r: &mut Req) -> Self::Tag {
140 let this = self.project();
141 let id = *this.stream_id;
142 *this.stream_id += 1;
143 r.id = id;
144 id
145 }
146
147 fn finish_tag(self: Pin<&mut Self>, r: &Resp) -> Self::Tag {
148 match r.as_ref() {
149 Ok(s) => s.id,
150 Err(e) => e.stream_id,
151 }
152 }
153}
154
155impl From<tokio_tower::Error<Transport, Req>> for ProtocolError {
156 fn from(err: tokio_tower::Error<Transport, Req>) -> Self {
157 Self::TokioTower(err.into())
158 }
159}
160
161pub struct Protocol {
163 waker: Arc<AtomicWaker>,
164 inner: Client<Transport, ProtocolError, Req>,
165 reconnect: bool,
166}
167
168impl Protocol {
169 pub(crate) async fn init(
170 websocket: WsStream,
171 ping_timeout: Duration,
172 ) -> Result<Self, ProtocolError> {
173 let transport = websocket
174 .with(|msg: String| async move { Ok(Message::Text(msg)) })
175 .filter_map(|msg| async move {
176 match msg {
177 Ok(msg) => match msg {
178 Message::Text(text) => Some(Ok(text)),
179 _ => None,
180 },
181 Err(err) => Some(Err(err)),
182 }
183 });
184 let waker = Arc::new(AtomicWaker::default());
185 let transport = Transport::new(transport, ping_timeout, waker.clone());
186 Ok(Self {
187 waker,
188 inner: Client::with_error_handler(transport, |e| {
189 tracing::error!("protocol error: {e}");
190 }),
191 reconnect: false,
192 })
193 }
194}
195
196impl Service<Request> for Protocol {
197 type Response = Response;
198 type Error = ProtocolError;
199 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
200
201 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202 if self.reconnect {
203 Poll::Ready(Err(ProtocolError::Reconnect))
204 } else {
205 self.waker.register(cx.waker());
207 self.inner.poll_ready(cx)
208 }
209 }
210
211 fn call(&mut self, req: Request) -> Self::Future {
212 if req.reconnect {
213 self.reconnect = true;
214 ready(Ok(Response::Reconnected)).boxed()
215 } else {
216 let resp = self.inner.call(req.into_client_stream());
217 async move {
218 let resp = resp.await?;
219 let resp = match resp {
220 Ok(stream) => {
221 let mut stream = Box::pin(stream.peekable());
222 if let Some(frame) = stream.as_mut().peek().await {
223 trace!("wait header; peeked {frame:?}");
224 Response::Streaming(stream)
225 } else {
226 trace!("wait header; no header");
227 Response::Error(StatusKind::EmptyResponse)
228 }
229 }
230 Err(err) => Response::Error(err.kind),
231 };
232 Ok(resp)
233 }
234 .boxed()
235 }
236 }
237}