asteroid_mq_sdk/connection/
ws2.rs1use asteroid_mq_model::codec::Json;
2use asteroid_mq_model::connection::EdgeNodeConnection;
3use asteroid_mq_model::{
4 codec::{Bincode, Codec},
5 connection::EdgeConnectionError,
6 EdgePayload,
7};
8use futures_util::stream::FusedStream;
9use futures_util::{Sink, Stream};
10use tokio::net::TcpStream;
11use tokio_tungstenite::tungstenite::http::Request;
12use tokio_tungstenite::tungstenite::{client::IntoClientRequest, Message};
13use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
14use tracing::Instrument;
15
16use crate::{ClientNode, ClientNodeError};
17
18use super::auto_reconnect::{ReconnectableConnection, ReconnectableConnectionExt};
19
20pin_project_lite::pin_project! {
21 #[derive(Debug)]
22 pub struct Ws2Client<C = Bincode> {
23 #[pin]
24 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
25 codec: C,
26 request: Request<()>,
27 }
28}
29
30impl<C> Ws2Client<C>
31where
32 C: Codec,
33{
34 pub fn new(
35 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
36 request: Request<()>,
37 codec: C,
38 ) -> Self {
39 Self {
40 inner,
41 codec,
42 request,
43 }
44 }
45 pub fn with_codec<C2>(self, codec: C2) -> Ws2Client<C2>
46 where
47 C2: Codec,
48 {
49 Ws2Client {
50 inner: self.inner,
51 request: self.request,
52 codec,
53 }
54 }
55}
56
57impl<C: Codec> Ws2Client<C> {
58 pub async fn create_by_request<R>(request: R, codec: C) -> Result<Self, EdgeConnectionError>
59 where
60 R: IntoClientRequest + Unpin,
61 {
62 let request = request
63 .into_client_request()
64 .map_err(EdgeConnectionError::underlying("ws create_by_request"))?;
65 let (stream, _resp) = tokio_tungstenite::connect_async(request.clone())
66 .await
67 .map_err(EdgeConnectionError::underlying("ws create_by_request"))?;
68 Ok(Self::new(stream, request, codec))
69 }
70}
71
72impl<C> Stream for Ws2Client<C>
73where
74 C: Codec,
75{
76 type Item = Result<EdgePayload, EdgeConnectionError>;
77 fn poll_next(
78 mut self: std::pin::Pin<&mut Self>,
79 cx: &mut std::task::Context<'_>,
80 ) -> std::task::Poll<Option<Self::Item>> {
81 let this = self.as_mut().project();
82 let message = futures_util::ready!(this.inner.poll_next(cx));
83 let message = match message {
84 Some(Ok(message)) => message,
85 Some(Err(e)) => {
86 use tokio_tungstenite::tungstenite::{error::ProtocolError, Error};
87 match e {
88 Error::ConnectionClosed
89 | Error::AlreadyClosed
90 | Error::Protocol(ProtocolError::ResetWithoutClosingHandshake) => {
91 return std::task::Poll::Ready(None);
92 }
93 _ => {
94 return std::task::Poll::Ready(Some(Err(EdgeConnectionError::underlying(
95 "ws poll_next",
96 )(e))));
97 }
98 }
99 }
100 None => {
101 return std::task::Poll::Ready(None);
102 }
103 };
104 let Message::Binary(payload) = message else {
105 return self.poll_next(cx);
107 };
108
109 let payload = this
110 .codec
111 .decode(&payload)
112 .map_err(EdgeConnectionError::codec("ws poll_next"))?;
113 std::task::Poll::Ready(Some(Ok(payload)))
115 }
116}
117impl<C> Sink<EdgePayload> for Ws2Client<C>
118where
119 C: Codec,
120{
121 type Error = EdgeConnectionError;
122 fn poll_close(
123 self: std::pin::Pin<&mut Self>,
124 cx: &mut std::task::Context<'_>,
125 ) -> std::task::Poll<Result<(), Self::Error>> {
126 let this = self.project();
127 this.inner
128 .poll_close(cx)
129 .map_err(EdgeConnectionError::underlying("ws poll_close"))
130 }
131 fn poll_flush(
132 self: std::pin::Pin<&mut Self>,
133 cx: &mut std::task::Context<'_>,
134 ) -> std::task::Poll<Result<(), Self::Error>> {
135 let this = self.project();
136 this.inner
137 .poll_flush(cx)
138 .map_err(EdgeConnectionError::underlying("ws poll_flush"))
139 }
140 fn poll_ready(
141 self: std::pin::Pin<&mut Self>,
142 cx: &mut std::task::Context<'_>,
143 ) -> std::task::Poll<Result<(), Self::Error>> {
144 let this = self.project();
145 this.inner
146 .poll_ready(cx)
147 .map_err(EdgeConnectionError::underlying("ws poll_ready"))
148 }
149 fn start_send(self: std::pin::Pin<&mut Self>, item: EdgePayload) -> Result<(), Self::Error> {
150 let this = self.project();
152 let payload = this
153 .codec
154 .encode(&item)
155 .map_err(EdgeConnectionError::codec("ws start_send"))?;
156 this.inner
157 .start_send(tokio_tungstenite::tungstenite::Message::Binary(payload))
158 .map_err(EdgeConnectionError::underlying("ws start_send"))?;
159 Ok(())
160 }
161}
162
163impl<C> EdgeNodeConnection for Ws2Client<C> where C: Codec {}
164
165impl ClientNode {
166 pub async fn connect_ws2<R: IntoClientRequest + Unpin>(
167 req: R,
168 codec: impl Codec + Clone,
169 ) -> Result<ClientNode, ClientNodeError> {
170 let client = Ws2Client::create_by_request(req, codec)
171 .await?
172 .auto_reconnect();
173 let node = ClientNode::connect(client).await?;
174 Ok(node)
175 }
176 pub async fn connect_ws2_bincode<R: IntoClientRequest + Unpin>(
177 req: R,
178 ) -> Result<ClientNode, ClientNodeError> {
179 ClientNode::connect_ws2(req, Bincode).await
180 }
181 pub async fn connect_ws2_json<R: IntoClientRequest + Unpin>(
182 req: R,
183 ) -> Result<ClientNode, ClientNodeError> {
184 ClientNode::connect_ws2(req, Json).await
185 }
186}
187
188impl<C: Codec + Clone> ReconnectableConnection for Ws2Client<C> {
189 type ReconnectFuture = std::pin::Pin<
190 Box<dyn std::future::Future<Output = Result<Self, EdgeConnectionError>> + Send>,
191 >;
192 type SleepFuture = tokio::time::Sleep;
193 fn is_closed(&self) -> bool {
194 self.inner.is_terminated()
195 }
196 fn reconnect(&self) -> Self::ReconnectFuture {
197 let request = self.request.clone();
198 let codec = self.codec.clone();
199 let span = tracing::span!(
200 tracing::Level::INFO,
201 "ws2_reconnect",
202 request = ?request.uri(),
203 );
204 Box::pin(
205 async move {
206 tracing::info!("ws2 connection reconnecting");
207 let client = Ws2Client::create_by_request(request, codec).await?;
208 Ok(client)
209 }
210 .instrument(span),
211 )
212 }
213 fn sleep(&self, duration: std::time::Duration) -> Self::SleepFuture {
214 tokio::time::sleep(duration)
215 }
216}