cloudpub_common/transport/
websocket.rs

1use core::result::Result;
2use std::io::{Error, ErrorKind};
3#[cfg(unix)]
4use std::os::fd::AsRawFd;
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8use super::{
9    AddrMaybeCached, Listener, NamedSocketAddr, ProtobufStream, SocketAddr, SocketOpts, Stream,
10    TcpTransport, Transport,
11};
12use crate::config::TransportConfig;
13use crate::constants::MESSAGE_TIMEOUT_SECS;
14use anyhow::{anyhow, Context as _};
15use async_trait::async_trait;
16use bytes::{Bytes, BytesMut};
17use futures_core::stream::Stream as AsyncStream;
18use std::time::Duration;
19use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
20use tokio::time::timeout;
21
22use crate::utils::trace_message;
23use dashmap::DashMap;
24#[cfg(unix)]
25use std::os::fd::RawFd;
26use std::sync::Arc;
27use tokio_tungstenite::tungstenite::client::IntoClientRequest;
28use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
29use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
30use tokio_tungstenite::{accept_hdr_async_with_config, client_async_with_config, WebSocketStream};
31use tokio_util::io::StreamReader;
32use tracing::{debug, error, trace};
33
34use futures_util::sink::{Sink, SinkExt};
35use futures_util::stream::StreamExt;
36
37#[cfg(feature = "rustls")]
38use super::tls::{get_stream, TlsStream, TlsTransport};
39
40use crate::protocol::message::Message as ProtocolMessage;
41use crate::protocol::ProstMessage;
42
43#[derive(Debug)]
44enum TransportStream {
45    Insecure(Stream),
46    #[cfg(feature = "rustls")]
47    Secure(Box<TlsStream<Stream>>),
48}
49
50impl TransportStream {
51    fn get_tcpstream(&self) -> &Stream {
52        match self {
53            TransportStream::Insecure(s) => s,
54            #[cfg(feature = "rustls")]
55            TransportStream::Secure(s) => get_stream(s.as_ref()),
56        }
57    }
58}
59
60impl AsyncRead for TransportStream {
61    fn poll_read(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        buf: &mut ReadBuf<'_>,
65    ) -> Poll<std::io::Result<()>> {
66        match self.get_mut() {
67            TransportStream::Insecure(s) => Pin::new(s).poll_read(cx, buf),
68            #[cfg(feature = "rustls")]
69            TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
70        }
71    }
72}
73
74impl AsyncWrite for TransportStream {
75    fn poll_write(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        buf: &[u8],
79    ) -> Poll<Result<usize, std::io::Error>> {
80        match self.get_mut() {
81            TransportStream::Insecure(s) => Pin::new(s).poll_write(cx, buf),
82            #[cfg(feature = "rustls")]
83            TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
84        }
85    }
86
87    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
88        match self.get_mut() {
89            TransportStream::Insecure(s) => Pin::new(s).poll_flush(cx),
90            #[cfg(feature = "rustls")]
91            TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_flush(cx),
92        }
93    }
94
95    fn poll_shutdown(
96        self: Pin<&mut Self>,
97        cx: &mut Context<'_>,
98    ) -> Poll<Result<(), std::io::Error>> {
99        match self.get_mut() {
100            TransportStream::Insecure(s) => Pin::new(s).poll_shutdown(cx),
101            #[cfg(feature = "rustls")]
102            TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
103        }
104    }
105}
106
107#[derive(Debug)]
108struct StreamWrapper {
109    inner: WebSocketStream<TransportStream>,
110}
111
112#[async_trait]
113impl ProtobufStream for StreamWrapper {
114    async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
115        let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
116        let result = timeout(timeout_duration, self.inner.next()).await;
117
118        match result {
119            Err(_) => Err(anyhow!(
120                "Timeout reading message after {} seconds",
121                MESSAGE_TIMEOUT_SECS
122            )),
123            Ok(msg_result) => match msg_result {
124                Some(Ok(Message::Binary(b))) => {
125                    let msg = crate::protocol::Message::decode(b.as_ref())
126                        .context("Failed to decode protobuf message")?;
127                    let msg = msg
128                        .message
129                        .context("Message field is missing in the protobuf message")?;
130                    trace_message("Recv", &msg);
131                    Ok(Some(msg))
132                }
133                Some(Ok(Message::Close(_))) => {
134                    debug!("WebSocket connection closed");
135                    Ok(None)
136                }
137                Some(Ok(Message::Ping(data))) => {
138                    debug!("Received ping, sending pong");
139                    self.inner
140                        .send(Message::Pong(data))
141                        .await
142                        .context("Failed to send pong")?;
143                    Ok(None)
144                }
145                Some(Ok(Message::Pong(_))) => {
146                    debug!("Received pong");
147                    Ok(None)
148                }
149                Some(Ok(Message::Text(_))) => {
150                    error!("Received unexpected text message");
151                    Err(anyhow!("Unexpected text message received"))
152                }
153                Some(Ok(m)) => {
154                    error!("Received unexpected  message: {:?}", m);
155                    Err(anyhow!("Unexpected  message received"))
156                }
157                None => Ok(None),
158                Some(Err(e)) => Err(anyhow!("WebSocket error: {}", e)),
159            },
160        }
161    }
162
163    async fn send_message(
164        &mut self,
165        msg: &crate::protocol::message::Message,
166    ) -> anyhow::Result<()> {
167        trace_message("Send", msg);
168        let mut buf = BytesMut::new();
169        let msg = crate::protocol::Message {
170            message: Some(msg.clone()),
171        };
172        msg.encode(&mut buf)
173            .context("Failed to encode protobuf message")?;
174
175        let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
176        timeout(
177            timeout_duration,
178            self.inner.send(Message::Binary(buf.into())),
179        )
180        .await
181        .map_err(|_| {
182            anyhow::anyhow!(
183                "Timeout writing message after {} seconds",
184                MESSAGE_TIMEOUT_SECS
185            )
186        })?
187        .context("Failed to send WebSocket message")?;
188        Ok(())
189    }
190
191    async fn close(&mut self) -> anyhow::Result<()> {
192        debug!("Closing WebSocket connection");
193        self.inner
194            .close(None)
195            .await
196            .context("Failed to close WebSocket stream")
197    }
198}
199
200impl AsyncStream for StreamWrapper {
201    type Item = Result<Bytes, Error>;
202
203    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204        match Pin::new(&mut self.get_mut().inner).poll_next(cx) {
205            Poll::Pending => Poll::Pending,
206            Poll::Ready(None) => Poll::Ready(None),
207            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::other(err)))),
208            Poll::Ready(Some(Ok(res))) => {
209                if let Message::Binary(b) = res {
210                    Poll::Ready(Some(Ok(b)))
211                } else {
212                    Poll::Ready(Some(Err(Error::new(
213                        ErrorKind::InvalidData,
214                        "unexpected frame",
215                    ))))
216                }
217            }
218        }
219    }
220
221    fn size_hint(&self) -> (usize, Option<usize>) {
222        self.inner.size_hint()
223    }
224}
225
226#[derive(Debug)]
227pub struct WebsocketStream {
228    inner: StreamReader<StreamWrapper, Bytes>,
229}
230
231impl AsyncRead for WebsocketStream {
232    fn poll_read(
233        self: Pin<&mut Self>,
234        cx: &mut Context<'_>,
235        buf: &mut ReadBuf<'_>,
236    ) -> Poll<std::io::Result<()>> {
237        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
238    }
239}
240
241impl AsyncBufRead for WebsocketStream {
242    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
243        Pin::new(&mut self.get_mut().inner).poll_fill_buf(cx)
244    }
245
246    fn consume(self: Pin<&mut Self>, amt: usize) {
247        Pin::new(&mut self.get_mut().inner).consume(amt)
248    }
249}
250
251impl AsyncWrite for WebsocketStream {
252    fn poll_write(
253        self: Pin<&mut Self>,
254        cx: &mut Context<'_>,
255        buf: &[u8],
256    ) -> Poll<Result<usize, std::io::Error>> {
257        let sw = self.get_mut().inner.get_mut();
258        ready!(Pin::new(&mut sw.inner).poll_ready(cx).map_err(Error::other))?;
259
260        let bbuf = BytesMut::from(buf);
261
262        match Pin::new(&mut sw.inner).start_send(Message::Binary(bbuf.into())) {
263            Ok(()) => Poll::Ready(Ok(buf.len())),
264            Err(e) => Poll::Ready(Err(Error::other(e))),
265        }
266    }
267
268    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
269        Pin::new(&mut self.get_mut().inner.get_mut().inner)
270            .poll_flush(cx)
271            .map_err(Error::other)
272    }
273
274    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
275        Pin::new(&mut self.get_mut().inner.get_mut().inner)
276            .poll_close(cx)
277            .map_err(Error::other)
278    }
279}
280
281#[async_trait]
282impl ProtobufStream for WebsocketStream {
283    async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
284        self.inner.get_mut().recv_message().await
285    }
286
287    async fn send_message(&mut self, msg: &ProtocolMessage) -> anyhow::Result<()> {
288        self.inner.get_mut().send_message(msg).await
289    }
290
291    async fn close(&mut self) -> anyhow::Result<()> {
292        debug!("Closing WebSocket stream");
293        self.inner.get_mut().close().await
294    }
295}
296
297#[derive(Debug)]
298enum SubTransport {
299    #[cfg(feature = "rustls")]
300    Secure(TlsTransport),
301    Insecure(TcpTransport),
302}
303
304#[derive(Debug)]
305pub struct WebsocketTransport {
306    sub: SubTransport,
307    conf: WebSocketConfig,
308    headers: Arc<DashMap<String, String>>,
309}
310
311#[async_trait]
312impl Transport for WebsocketTransport {
313    type Acceptor = Listener;
314    type RawStream = Stream;
315    type Stream = WebsocketStream;
316
317    fn new(config: &TransportConfig) -> anyhow::Result<Self> {
318        let wsconfig = config
319            .websocket
320            .as_ref()
321            .ok_or_else(|| anyhow!("Missing websocket config"))?;
322
323        let conf = WebSocketConfig::default().write_buffer_size(0);
324
325        let sub = match wsconfig.tls {
326            #[cfg(feature = "rustls")]
327            true => SubTransport::Secure(TlsTransport::new(config)?),
328            #[cfg(not(feature = "rustls"))]
329            true => unreachable!("TLS support not enabled"),
330            false => SubTransport::Insecure(TcpTransport::new(config)?),
331        };
332        let headers = Arc::new(DashMap::new());
333        Ok(WebsocketTransport { sub, conf, headers })
334    }
335
336    fn hint(conn: &Self::Stream, opt: SocketOpts) {
337        opt.apply(conn.inner.get_ref().inner.get_ref().get_tcpstream())
338    }
339
340    #[cfg(unix)]
341    fn as_raw_fd(conn: &Self::Stream) -> RawFd {
342        match conn.inner.get_ref().inner.get_ref().get_tcpstream() {
343            Stream::Tcp(tcp_stream) => tcp_stream.as_raw_fd(),
344            Stream::Unix(unix_stream) => unix_stream.as_raw_fd(),
345        }
346    }
347
348    async fn bind(&self, addr: NamedSocketAddr) -> anyhow::Result<Self::Acceptor> {
349        Listener::bind(&addr).await.map_err(Into::into)
350    }
351
352    async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
353        let (s, addr) = match &self.sub {
354            SubTransport::Insecure(t) => t.accept(a).await?,
355            #[cfg(feature = "rustls")]
356            SubTransport::Secure(t) => t.accept(a).await?,
357        };
358        Ok((s, addr))
359    }
360
361    async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
362        let tsream = match &self.sub {
363            SubTransport::Insecure(t) => {
364                TransportStream::Insecure(t.handshake(conn).await?.into_stream())
365            }
366            #[cfg(feature = "rustls")]
367            SubTransport::Secure(t) => TransportStream::Secure(Box::new(t.handshake(conn).await?)),
368        };
369
370        let headers = self.headers.clone();
371
372        let callback = move |req: &Request, res: Response| {
373            for ref header in req.headers() {
374                trace!("WS headers: {:?}", header);
375                headers.insert(
376                    header.0.to_string(),
377                    header.1.to_str().unwrap_or_default().to_string(),
378                );
379            }
380            Ok(res)
381        };
382
383        let wsstream = accept_hdr_async_with_config(tsream, callback, Some(self.conf)).await?;
384
385        let tun = WebsocketStream {
386            inner: StreamReader::new(StreamWrapper { inner: wsstream }),
387        };
388        Ok(tun)
389    }
390
391    async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
392        let u = format!("wss://{}/endpoint/v3", &addr.addr.as_str());
393        let tstream = match &self.sub {
394            SubTransport::Insecure(t) => {
395                TransportStream::Insecure(t.connect(addr).await?.into_stream())
396            }
397            #[cfg(feature = "rustls")]
398            SubTransport::Secure(t) => TransportStream::Secure(Box::new(t.connect(addr).await?)),
399        };
400        debug!("Connecting to {}", u);
401        let (wsstream, _) = client_async_with_config(
402            u.clone()
403                .into_client_request()
404                .context("Failed to create client request")?,
405            tstream,
406            Some(self.conf),
407        )
408        .await
409        .with_context(|| format!("Failed to connect to {}", u))?;
410
411        debug!("Connected");
412
413        let tun = WebsocketStream {
414            inner: StreamReader::new(StreamWrapper { inner: wsstream }),
415        };
416        Ok(tun)
417    }
418
419    fn get_header(&self, name: &str) -> Option<String> {
420        self.headers.get(&name.to_lowercase()).map(|v| v.clone())
421    }
422}