exc_okx/websocket/transport/protocol/
stream.rs

1use crate::error::OkxError;
2use crate::websocket::types::callback::Callback;
3use crate::websocket::types::frames::client::ClientFrame;
4use crate::websocket::types::frames::server::ServerFrame;
5use crate::websocket::types::request::ClientStream;
6use crate::websocket::types::response::Status;
7use crate::websocket::types::response::{ServerStream, StatusKind};
8use atomic_waker::AtomicWaker;
9use futures::channel::mpsc::{self, SendError, UnboundedReceiver, UnboundedSender};
10use futures::SinkExt;
11use futures::{Sink, Stream, StreamExt};
12use pin_project_lite::pin_project;
13use std::collections::hash_map::RandomState;
14use std::collections::{BTreeMap, HashSet};
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use thiserror::Error;
19use tokio::sync::oneshot;
20
21#[derive(Debug, Clone, Copy)]
22enum StreamState {
23    Idle,
24    Open,
25    LocalClosed,
26    RemoteClosed,
27    Closed,
28}
29
30struct StreamContext {
31    sender: UnboundedSender<Result<ServerFrame, OkxError>>,
32    stream: Option<ServerStream>,
33    state: StreamState,
34    tag: Option<String>,
35}
36
37impl StreamContext {
38    fn new(id: usize, cb: Callback) -> Self {
39        let (server_frame_tx, server_frame_rx) = mpsc::unbounded();
40        let stream = ServerStream {
41            id,
42            cb,
43            inner: server_frame_rx.boxed(),
44        };
45        Self {
46            sender: server_frame_tx,
47            stream: Some(stream),
48            state: StreamState::Idle,
49            tag: None,
50        }
51    }
52}
53
54impl Drop for StreamContext {
55    fn drop(&mut self) {
56        let _fut = self.sender.send(Err(OkxError::StreamDropped));
57    }
58}
59
60/// Stream layer errors.
61#[derive(Debug, Error)]
62pub enum StreamingError<E> {
63    /// Transport error.
64    #[error(transparent)]
65    Transport(#[from] E),
66
67    /// Sender error.
68    #[error(transparent)]
69    Sender(SendError),
70
71    /// Idle stream missing.
72    #[error("idle stream missing")]
73    IdleStreamMissing,
74
75    /// Borken streaming layer.
76    #[error("broken streaming layer")]
77    BlokenStreamingLayer,
78}
79
80pub(super) fn layer<T, E>(
81    transport: T,
82    waker: Arc<AtomicWaker>,
83) -> impl Sink<ClientStream, Error = StreamingError<E>>
84       + Stream<Item = Result<Result<ServerStream, Status>, StreamingError<E>>>
85where
86    E: Send + 'static + std::fmt::Display,
87    T: Send + 'static,
88    T: Sink<ClientFrame, Error = E>,
89    T: Stream<Item = Result<ServerFrame, E>>,
90{
91    let (mut tx, mut rx) = transport.split();
92    let (client_frame_tx, mut client_frame_rx) = mpsc::unbounded::<ClientFrame>();
93    let (sender, mut client_stream_rx) = mpsc::unbounded::<ClientStream>();
94    let (mut server_stream_tx, receiver) = mpsc::unbounded();
95    let mut streams: BTreeMap<usize, StreamContext> = BTreeMap::default();
96    let mut last_server_stream_tx = server_stream_tx.clone();
97    let mut tags = HashSet::<String, RandomState>::new();
98    let worker = async move {
99        loop {
100            tokio::select! {
101                Some(mut client_stream) = client_stream_rx.next() => {
102                    let cb = client_stream.cb.take().expect("client stream must contains a callback");
103                    let id = client_stream.id;
104                    let ctx = StreamContext::new(id, cb);
105                    streams.insert(id, ctx);
106                    let mut client_frame_tx = client_frame_tx.clone();
107                    tokio::spawn(async move {
108                        while let Some(mut frame) = client_stream.inner.next().await {
109                            frame.stream_id = id;
110                            if let Err(err) = client_frame_tx.send(frame).await {
111                                error!("streaming client worker; send error id={id} err={err}");
112                                break;
113                            }
114                        }
115                    });
116                }
117                Some(client_frame) = client_frame_rx.next() => {
118                    let id = client_frame.stream_id;
119                    if let Some(ctx) = streams.get_mut(&id) {
120                        let is_end_stream = client_frame.is_end_stream();
121                        match ctx.state {
122                            StreamState::Idle => {
123                                if is_end_stream {
124                                    ctx.state = StreamState::Closed;
125                                    server_stream_tx.send(Ok(Err(Status { stream_id: id, kind: StatusKind::CloseIdleStream }))).await.map_err(StreamingError::Sender)?;
126                                    streams.remove(&id);
127                    trace!("stream {id}; idle -> closed");
128                                    // client frame is ignored
129                                    continue;
130                                } else {
131                                    // the first client frame is considered to be a stream header, so we need to check the tag.
132                                    if let Some(tag) = client_frame.tag() {
133                                        if tags.contains(&tag) {
134                                            server_stream_tx.send(Ok(Err(Status { stream_id: id, kind: StatusKind::AlreadySubscribed(tag) }))).await.map_err(StreamingError::Sender)?;
135                                            ctx.state = StreamState::Closed;
136                                            streams.remove(&id);
137                                            // client frame is ignored
138                                            continue;
139                                        } else {
140                                            tags.insert(tag.clone());
141                                            ctx.tag = Some(tag);
142                                        }
143                                    }
144                                    ctx.state = StreamState::Open;
145                    trace!("stream {id}; idle -> open");
146                                    if let Some(stream) = ctx.stream.take() {
147                                        server_stream_tx.send(Ok(Ok(stream))).await.map_err(StreamingError::Sender)?;
148                                    } else {
149                                        return Err(StreamingError::IdleStreamMissing);
150                                    }
151                                }
152                            },
153                            StreamState::Open => {
154                                if is_end_stream {
155                                    ctx.state = StreamState::LocalClosed;
156                    trace!("stream {id}; open -> local-closed");
157                                }
158                            },
159                            StreamState::RemoteClosed => {
160                                if is_end_stream {
161                                    ctx.state = StreamState::Closed;
162                                    if let Some(tag) = ctx.tag.take() {
163                                        tags.remove(&tag);
164                                    }
165                                    streams.remove(&id);
166                                    debug!("stream {id} closed abnormally (remote -> local)");
167                    trace!("stream {id}; remote-closed -> closed");
168                                }
169                            }
170                            StreamState::LocalClosed | StreamState::Closed => {
171                                warn!("streamming worker; trying to send a client frame from a closed or local closed stream: id={id}, ignored");
172                                continue;
173                            }
174                        }
175                    } else {
176                        warn!("streaming worker; recevied an outdated client frame: {client_frame:?}, ignored");
177                        continue;
178                    }
179                    tx.send(client_frame).await?;
180                }
181                Some(server_frame) = rx.next() => {
182                    let frame = server_frame?;
183            trace!("received a server frame: {frame:?}");
184                    let id = frame.stream_id;
185                    let is_end_stream = frame.is_end_stream();
186                    if let Some(ctx) = streams.get_mut(&id) {
187                        match ctx.state {
188                            StreamState::Idle => {
189                                warn!("streaming worker; recevied a server frame from an idle stream: id={id}, ignored");
190                            },
191                            StreamState::Open => {
192                                if is_end_stream {
193                                    ctx.state = StreamState::RemoteClosed;
194                                    debug!("streaming worker; received a remote close frame: id={id}");
195                    trace!("stream {id}; open -> remote-closed");
196                                }
197                                let _ = ctx.sender.send(Ok(frame)).await;
198                            },
199                            StreamState::LocalClosed => {
200                                if is_end_stream {
201                                    ctx.state = StreamState::Closed;
202                                    let _ = ctx.sender.send(Ok(frame)).await;
203                                    if let Some(tag) = ctx.tag.take() {
204                                        tags.remove(&tag);
205                                    }
206                                    debug!("stream {id} closed normally (local -> remote)");
207                    trace!("stream {id}; local-closed -> closed");
208                                    streams.remove(&id);
209                                } else {
210                                    let _ = ctx.sender.send(Ok(frame)).await;
211                                }
212                            },
213                            StreamState::RemoteClosed | StreamState::Closed => {
214                                warn!("streaming worker; recevied a server frame from a closed or remote closed stream: id={id}, ignored");
215                            }
216                        }
217                    } else {
218                        warn!("streaming worker; received an outdated server frame: {frame:?}, ignored");
219                    }
220                }
221                else => {
222                    break;
223                }
224            }
225        }
226        Result::<(), _>::Err(StreamingError::BlokenStreamingLayer)
227    };
228    let (_cancel, cancel) = oneshot::channel();
229    tokio::spawn(async move {
230        tokio::select! {
231            res = worker => {
232                if let Err(err) = res {
233                    error!("streaming worker: {err}");
234                    let _ = last_server_stream_tx.send(Err(err)).await;
235                    trace!("streaming worker finished");
236                }
237            },
238            _ = cancel => {
239                tracing::trace!("streaming worker cancelled");
240            }
241        }
242    });
243    Streaming {
244        waker,
245        sender,
246        receiver,
247        _cancel,
248    }
249}
250
251pin_project! {
252    struct Streaming<E> {
253        waker: Arc<AtomicWaker>,
254        #[pin]
255        sender: UnboundedSender<ClientStream>,
256        #[pin]
257        receiver: UnboundedReceiver<Result<Result<ServerStream, Status>, StreamingError<E>>>,
258        _cancel: oneshot::Sender<()>,
259    }
260}
261
262impl<E> Sink<ClientStream> for Streaming<E> {
263    type Error = StreamingError<E>;
264
265    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266        let this = self.project();
267        this.sender.poll_ready(cx).map_err(|err| {
268            this.waker.wake();
269            StreamingError::Sender(err)
270        })
271    }
272
273    fn start_send(self: Pin<&mut Self>, item: ClientStream) -> Result<(), Self::Error> {
274        let this = self.project();
275        this.sender.start_send(item).map_err(|err| {
276            this.waker.wake();
277            StreamingError::Sender(err)
278        })
279    }
280
281    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
282        let this = self.project();
283        this.sender.poll_flush(cx).map_err(|err| {
284            this.waker.wake();
285            StreamingError::Sender(err)
286        })
287    }
288
289    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
290        let this = self.project();
291        this.sender.poll_close(cx).map_err(|err| {
292            this.waker.wake();
293            StreamingError::Sender(err)
294        })
295    }
296}
297
298impl<E> Stream for Streaming<E> {
299    type Item = Result<Result<ServerStream, Status>, StreamingError<E>>;
300
301    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302        let this = self.project();
303        match this.receiver.poll_next(cx) {
304            Poll::Pending => Poll::Pending,
305            Poll::Ready(None) => {
306                trace!("streaming poll stream; stream end.");
307                Poll::Ready(None)
308            }
309            Poll::Ready(Some(Ok(stream))) => Poll::Ready(Some(Ok(stream))),
310            Poll::Ready(Some(Err(err))) => {
311                trace!("streaming poll stream; stream error.");
312                Poll::Ready(Some(Err(err)))
313            }
314        }
315    }
316
317    fn size_hint(&self) -> (usize, Option<usize>) {
318        self.receiver.size_hint()
319    }
320}