exc_okx/websocket/transport/protocol/
stream.rs1use crate::error::OkxError;
2use crate::websocket::types::callback::Callback;
3use crate::websocket::types::frames::client::ClientFrame;
4use crate::websocket::types::frames::server::ServerFrame;
5use crate::websocket::types::request::ClientStream;
6use crate::websocket::types::response::Status;
7use crate::websocket::types::response::{ServerStream, StatusKind};
8use atomic_waker::AtomicWaker;
9use futures::channel::mpsc::{self, SendError, UnboundedReceiver, UnboundedSender};
10use futures::SinkExt;
11use futures::{Sink, Stream, StreamExt};
12use pin_project_lite::pin_project;
13use std::collections::hash_map::RandomState;
14use std::collections::{BTreeMap, HashSet};
15use std::pin::Pin;
16use std::sync::Arc;
17use std::task::{Context, Poll};
18use thiserror::Error;
19use tokio::sync::oneshot;
20
21#[derive(Debug, Clone, Copy)]
22enum StreamState {
23 Idle,
24 Open,
25 LocalClosed,
26 RemoteClosed,
27 Closed,
28}
29
30struct StreamContext {
31 sender: UnboundedSender<Result<ServerFrame, OkxError>>,
32 stream: Option<ServerStream>,
33 state: StreamState,
34 tag: Option<String>,
35}
36
37impl StreamContext {
38 fn new(id: usize, cb: Callback) -> Self {
39 let (server_frame_tx, server_frame_rx) = mpsc::unbounded();
40 let stream = ServerStream {
41 id,
42 cb,
43 inner: server_frame_rx.boxed(),
44 };
45 Self {
46 sender: server_frame_tx,
47 stream: Some(stream),
48 state: StreamState::Idle,
49 tag: None,
50 }
51 }
52}
53
54impl Drop for StreamContext {
55 fn drop(&mut self) {
56 let _fut = self.sender.send(Err(OkxError::StreamDropped));
57 }
58}
59
60#[derive(Debug, Error)]
62pub enum StreamingError<E> {
63 #[error(transparent)]
65 Transport(#[from] E),
66
67 #[error(transparent)]
69 Sender(SendError),
70
71 #[error("idle stream missing")]
73 IdleStreamMissing,
74
75 #[error("broken streaming layer")]
77 BlokenStreamingLayer,
78}
79
80pub(super) fn layer<T, E>(
81 transport: T,
82 waker: Arc<AtomicWaker>,
83) -> impl Sink<ClientStream, Error = StreamingError<E>>
84 + Stream<Item = Result<Result<ServerStream, Status>, StreamingError<E>>>
85where
86 E: Send + 'static + std::fmt::Display,
87 T: Send + 'static,
88 T: Sink<ClientFrame, Error = E>,
89 T: Stream<Item = Result<ServerFrame, E>>,
90{
91 let (mut tx, mut rx) = transport.split();
92 let (client_frame_tx, mut client_frame_rx) = mpsc::unbounded::<ClientFrame>();
93 let (sender, mut client_stream_rx) = mpsc::unbounded::<ClientStream>();
94 let (mut server_stream_tx, receiver) = mpsc::unbounded();
95 let mut streams: BTreeMap<usize, StreamContext> = BTreeMap::default();
96 let mut last_server_stream_tx = server_stream_tx.clone();
97 let mut tags = HashSet::<String, RandomState>::new();
98 let worker = async move {
99 loop {
100 tokio::select! {
101 Some(mut client_stream) = client_stream_rx.next() => {
102 let cb = client_stream.cb.take().expect("client stream must contains a callback");
103 let id = client_stream.id;
104 let ctx = StreamContext::new(id, cb);
105 streams.insert(id, ctx);
106 let mut client_frame_tx = client_frame_tx.clone();
107 tokio::spawn(async move {
108 while let Some(mut frame) = client_stream.inner.next().await {
109 frame.stream_id = id;
110 if let Err(err) = client_frame_tx.send(frame).await {
111 error!("streaming client worker; send error id={id} err={err}");
112 break;
113 }
114 }
115 });
116 }
117 Some(client_frame) = client_frame_rx.next() => {
118 let id = client_frame.stream_id;
119 if let Some(ctx) = streams.get_mut(&id) {
120 let is_end_stream = client_frame.is_end_stream();
121 match ctx.state {
122 StreamState::Idle => {
123 if is_end_stream {
124 ctx.state = StreamState::Closed;
125 server_stream_tx.send(Ok(Err(Status { stream_id: id, kind: StatusKind::CloseIdleStream }))).await.map_err(StreamingError::Sender)?;
126 streams.remove(&id);
127 trace!("stream {id}; idle -> closed");
128 continue;
130 } else {
131 if let Some(tag) = client_frame.tag() {
133 if tags.contains(&tag) {
134 server_stream_tx.send(Ok(Err(Status { stream_id: id, kind: StatusKind::AlreadySubscribed(tag) }))).await.map_err(StreamingError::Sender)?;
135 ctx.state = StreamState::Closed;
136 streams.remove(&id);
137 continue;
139 } else {
140 tags.insert(tag.clone());
141 ctx.tag = Some(tag);
142 }
143 }
144 ctx.state = StreamState::Open;
145 trace!("stream {id}; idle -> open");
146 if let Some(stream) = ctx.stream.take() {
147 server_stream_tx.send(Ok(Ok(stream))).await.map_err(StreamingError::Sender)?;
148 } else {
149 return Err(StreamingError::IdleStreamMissing);
150 }
151 }
152 },
153 StreamState::Open => {
154 if is_end_stream {
155 ctx.state = StreamState::LocalClosed;
156 trace!("stream {id}; open -> local-closed");
157 }
158 },
159 StreamState::RemoteClosed => {
160 if is_end_stream {
161 ctx.state = StreamState::Closed;
162 if let Some(tag) = ctx.tag.take() {
163 tags.remove(&tag);
164 }
165 streams.remove(&id);
166 debug!("stream {id} closed abnormally (remote -> local)");
167 trace!("stream {id}; remote-closed -> closed");
168 }
169 }
170 StreamState::LocalClosed | StreamState::Closed => {
171 warn!("streamming worker; trying to send a client frame from a closed or local closed stream: id={id}, ignored");
172 continue;
173 }
174 }
175 } else {
176 warn!("streaming worker; recevied an outdated client frame: {client_frame:?}, ignored");
177 continue;
178 }
179 tx.send(client_frame).await?;
180 }
181 Some(server_frame) = rx.next() => {
182 let frame = server_frame?;
183 trace!("received a server frame: {frame:?}");
184 let id = frame.stream_id;
185 let is_end_stream = frame.is_end_stream();
186 if let Some(ctx) = streams.get_mut(&id) {
187 match ctx.state {
188 StreamState::Idle => {
189 warn!("streaming worker; recevied a server frame from an idle stream: id={id}, ignored");
190 },
191 StreamState::Open => {
192 if is_end_stream {
193 ctx.state = StreamState::RemoteClosed;
194 debug!("streaming worker; received a remote close frame: id={id}");
195 trace!("stream {id}; open -> remote-closed");
196 }
197 let _ = ctx.sender.send(Ok(frame)).await;
198 },
199 StreamState::LocalClosed => {
200 if is_end_stream {
201 ctx.state = StreamState::Closed;
202 let _ = ctx.sender.send(Ok(frame)).await;
203 if let Some(tag) = ctx.tag.take() {
204 tags.remove(&tag);
205 }
206 debug!("stream {id} closed normally (local -> remote)");
207 trace!("stream {id}; local-closed -> closed");
208 streams.remove(&id);
209 } else {
210 let _ = ctx.sender.send(Ok(frame)).await;
211 }
212 },
213 StreamState::RemoteClosed | StreamState::Closed => {
214 warn!("streaming worker; recevied a server frame from a closed or remote closed stream: id={id}, ignored");
215 }
216 }
217 } else {
218 warn!("streaming worker; received an outdated server frame: {frame:?}, ignored");
219 }
220 }
221 else => {
222 break;
223 }
224 }
225 }
226 Result::<(), _>::Err(StreamingError::BlokenStreamingLayer)
227 };
228 let (_cancel, cancel) = oneshot::channel();
229 tokio::spawn(async move {
230 tokio::select! {
231 res = worker => {
232 if let Err(err) = res {
233 error!("streaming worker: {err}");
234 let _ = last_server_stream_tx.send(Err(err)).await;
235 trace!("streaming worker finished");
236 }
237 },
238 _ = cancel => {
239 tracing::trace!("streaming worker cancelled");
240 }
241 }
242 });
243 Streaming {
244 waker,
245 sender,
246 receiver,
247 _cancel,
248 }
249}
250
251pin_project! {
252 struct Streaming<E> {
253 waker: Arc<AtomicWaker>,
254 #[pin]
255 sender: UnboundedSender<ClientStream>,
256 #[pin]
257 receiver: UnboundedReceiver<Result<Result<ServerStream, Status>, StreamingError<E>>>,
258 _cancel: oneshot::Sender<()>,
259 }
260}
261
262impl<E> Sink<ClientStream> for Streaming<E> {
263 type Error = StreamingError<E>;
264
265 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
266 let this = self.project();
267 this.sender.poll_ready(cx).map_err(|err| {
268 this.waker.wake();
269 StreamingError::Sender(err)
270 })
271 }
272
273 fn start_send(self: Pin<&mut Self>, item: ClientStream) -> Result<(), Self::Error> {
274 let this = self.project();
275 this.sender.start_send(item).map_err(|err| {
276 this.waker.wake();
277 StreamingError::Sender(err)
278 })
279 }
280
281 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
282 let this = self.project();
283 this.sender.poll_flush(cx).map_err(|err| {
284 this.waker.wake();
285 StreamingError::Sender(err)
286 })
287 }
288
289 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
290 let this = self.project();
291 this.sender.poll_close(cx).map_err(|err| {
292 this.waker.wake();
293 StreamingError::Sender(err)
294 })
295 }
296}
297
298impl<E> Stream for Streaming<E> {
299 type Item = Result<Result<ServerStream, Status>, StreamingError<E>>;
300
301 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
302 let this = self.project();
303 match this.receiver.poll_next(cx) {
304 Poll::Pending => Poll::Pending,
305 Poll::Ready(None) => {
306 trace!("streaming poll stream; stream end.");
307 Poll::Ready(None)
308 }
309 Poll::Ready(Some(Ok(stream))) => Poll::Ready(Some(Ok(stream))),
310 Poll::Ready(Some(Err(err))) => {
311 trace!("streaming poll stream; stream error.");
312 Poll::Ready(Some(Err(err)))
313 }
314 }
315 }
316
317 fn size_hint(&self) -> (usize, Option<usize>) {
318 self.receiver.size_hint()
319 }
320}