asteroid_mq_sdk/connection/
ws2.rs

1use asteroid_mq_model::codec::Json;
2use asteroid_mq_model::connection::EdgeNodeConnection;
3use asteroid_mq_model::{
4    codec::{Bincode, Codec},
5    connection::EdgeConnectionError,
6    EdgePayload,
7};
8use futures_util::stream::FusedStream;
9use futures_util::{Sink, Stream};
10use tokio::net::TcpStream;
11use tokio_tungstenite::tungstenite::http::Request;
12use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message};
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
14use tracing::Instrument;
15
16use crate::{ClientNode, ClientNodeError};
17
18use super::auto_reconnect::{ReconnectableConnection, ReconnectableConnectionExt};
19
20pin_project_lite::pin_project! {
21    #[derive(Debug)]
22    pub struct Ws2Client<C = Bincode> {
23        #[pin]
24        inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
25        codec: C,
26        request: Request<()>,
27    }
28}
29
30impl<C> Ws2Client<C>
31where
32    C: Codec,
33{
34    pub fn new(
35        inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
36        request: Request<()>,
37        codec: C,
38    ) -> Self {
39        Self {
40            inner,
41            codec,
42            request,
43        }
44    }
45    pub fn with_codec<C2>(self, codec: C2) -> Ws2Client<C2>
46    where
47        C2: Codec,
48    {
49        Ws2Client {
50            inner: self.inner,
51            request: self.request,
52            codec,
53        }
54    }
55}
56
57impl<C: Codec> Ws2Client<C> {
58    pub async fn create_by_request<R>(request: R, codec: C) -> Result<Self, EdgeConnectionError>
59    where
60        R: IntoClientRequest + Unpin,
61    {
62        let request = request
63            .into_client_request()
64            .map_err(EdgeConnectionError::underlying("ws create_by_request"))?;
65        let (stream, _resp) = tokio_tungstenite::connect_async(request.clone())
66            .await
67            .map_err(EdgeConnectionError::underlying("ws create_by_request"))?;
68        Ok(Self::new(stream, request, codec))
69    }
70}
71
72impl<C> Stream for Ws2Client<C>
73where
74    C: Codec,
75{
76    type Item = Result<EdgePayload, EdgeConnectionError>;
77    fn poll_next(
78        mut self: std::pin::Pin<&mut Self>,
79        cx: &mut std::task::Context<'_>,
80    ) -> std::task::Poll<Option<Self::Item>> {
81        let this = self.as_mut().project();
82        let message = futures_util::ready!(this.inner.poll_next(cx));
83        let message = match message {
84            Some(Ok(message)) => message,
85            Some(Err(e)) => {
86                use tokio_tungstenite::tungstenite::{error::ProtocolError, Error};
87                match e {
88                    Error::ConnectionClosed
89                    | Error::AlreadyClosed
90                    | Error::Protocol(ProtocolError::ResetWithoutClosingHandshake) => {
91                        return std::task::Poll::Ready(None);
92                    }
93                    _ => {
94                        return std::task::Poll::Ready(Some(Err(EdgeConnectionError::underlying(
95                            "ws poll_next",
96                        )(e))));
97                    }
98                }
99            }
100            None => {
101                return std::task::Poll::Ready(None);
102            }
103        };
104        let Message::Binary(payload) = message else {
105            // skip
106            return self.poll_next(cx);
107        };
108
109        let payload = this
110            .codec
111            .decode(&payload)
112            .map_err(EdgeConnectionError::codec("ws poll_next"))?;
113        // tracing::error!(?payload, "[debug] got payload");
114        std::task::Poll::Ready(Some(Ok(payload)))
115    }
116}
117impl<C> Sink<EdgePayload> for Ws2Client<C>
118where
119    C: Codec,
120{
121    type Error = EdgeConnectionError;
122    fn poll_close(
123        self: std::pin::Pin<&mut Self>,
124        cx: &mut std::task::Context<'_>,
125    ) -> std::task::Poll<Result<(), Self::Error>> {
126        let this = self.project();
127        this.inner
128            .poll_close(cx)
129            .map_err(EdgeConnectionError::underlying("ws poll_close"))
130    }
131    fn poll_flush(
132        self: std::pin::Pin<&mut Self>,
133        cx: &mut std::task::Context<'_>,
134    ) -> std::task::Poll<Result<(), Self::Error>> {
135        let this = self.project();
136        this.inner
137            .poll_flush(cx)
138            .map_err(EdgeConnectionError::underlying("ws poll_flush"))
139    }
140    fn poll_ready(
141        self: std::pin::Pin<&mut Self>,
142        cx: &mut std::task::Context<'_>,
143    ) -> std::task::Poll<Result<(), Self::Error>> {
144        let this = self.project();
145        this.inner
146            .poll_ready(cx)
147            .map_err(EdgeConnectionError::underlying("ws poll_ready"))
148    }
149    fn start_send(self: std::pin::Pin<&mut Self>, item: EdgePayload) -> Result<(), Self::Error> {
150        // tracing::warn!(?item, "[debug] ws payload do send");
151        let this = self.project();
152        let payload = this
153            .codec
154            .encode(&item)
155            .map_err(EdgeConnectionError::codec("ws start_send"))?;
156        this.inner
157            .start_send(tokio_tungstenite::tungstenite::Message::Binary(payload))
158            .map_err(EdgeConnectionError::underlying("ws start_send"))?;
159        Ok(())
160    }
161}
162
163impl<C> EdgeNodeConnection for Ws2Client<C> where C: Codec {}
164
165impl ClientNode {
166    pub async fn connect_ws2<R: IntoClientRequest + Unpin>(
167        req: R,
168        codec: impl Codec + Clone,
169    ) -> Result<ClientNode, ClientNodeError> {
170        let client = Ws2Client::create_by_request(req, codec)
171            .await?
172            .auto_reconnect();
173        let node = ClientNode::connect(client).await?;
174        Ok(node)
175    }
176    pub async fn connect_ws2_bincode<R: IntoClientRequest + Unpin>(
177        req: R,
178    ) -> Result<ClientNode, ClientNodeError> {
179        ClientNode::connect_ws2(req, Bincode).await
180    }
181    pub async fn connect_ws2_json<R: IntoClientRequest + Unpin>(
182        req: R,
183    ) -> Result<ClientNode, ClientNodeError> {
184        ClientNode::connect_ws2(req, Json).await
185    }
186}
187
188impl<C: Codec + Clone> ReconnectableConnection for Ws2Client<C> {
189    type ReconnectFuture = std::pin::Pin<
190        Box<dyn std::future::Future<Output = Result<Self, EdgeConnectionError>> + Send>,
191    >;
192    type SleepFuture = tokio::time::Sleep;
193    fn is_closed(&self) -> bool {
194        self.inner.is_terminated()
195    }
196    fn reconnect(&self) -> Self::ReconnectFuture {
197        let request = self.request.clone();
198        let codec = self.codec.clone();
199        let span = tracing::span!(
200            tracing::Level::INFO,
201            "ws2_reconnect",
202            request = ?request.uri(),
203        );
204        Box::pin(
205            async move {
206                tracing::info!("ws2 connection reconnecting");
207                let client = Ws2Client::create_by_request(request, codec).await?;
208                Ok(client)
209            }
210            .instrument(span),
211        )
212    }
213    fn sleep(&self, duration: std::time::Duration) -> Self::SleepFuture {
214        tokio::time::sleep(duration)
215    }
216}