cloudpub_common/transport/
websocket.rs1use core::result::Result;
2use std::io::{Error, ErrorKind};
3#[cfg(unix)]
4use std::os::fd::AsRawFd;
5use std::pin::Pin;
6use std::task::{ready, Context, Poll};
7
8use super::{
9 AddrMaybeCached, Listener, NamedSocketAddr, ProtobufStream, SocketAddr, SocketOpts, Stream,
10 TcpTransport, Transport,
11};
12use crate::config::TransportConfig;
13use crate::constants::MESSAGE_TIMEOUT_SECS;
14use anyhow::{anyhow, Context as _};
15use async_trait::async_trait;
16use bytes::{Bytes, BytesMut};
17use futures_core::stream::Stream as AsyncStream;
18use std::time::Duration;
19use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
20use tokio::time::timeout;
21
22use crate::utils::trace_message;
23use dashmap::DashMap;
24#[cfg(unix)]
25use std::os::fd::RawFd;
26use std::sync::Arc;
27use tokio_tungstenite::tungstenite::client::IntoClientRequest;
28use tokio_tungstenite::tungstenite::handshake::server::{Request, Response};
29use tokio_tungstenite::tungstenite::protocol::{Message, WebSocketConfig};
30use tokio_tungstenite::{accept_hdr_async_with_config, client_async_with_config, WebSocketStream};
31use tokio_util::io::StreamReader;
32use tracing::{debug, error, trace};
33
34use futures_util::sink::{Sink, SinkExt};
35use futures_util::stream::StreamExt;
36
37#[cfg(feature = "rustls")]
38use super::tls::{get_stream, TlsStream, TlsTransport};
39
40use crate::protocol::message::Message as ProtocolMessage;
41use crate::protocol::ProstMessage;
42
43#[derive(Debug)]
44enum TransportStream {
45 Insecure(Stream),
46 #[cfg(feature = "rustls")]
47 Secure(Box<TlsStream<Stream>>),
48}
49
50impl TransportStream {
51 fn get_tcpstream(&self) -> &Stream {
52 match self {
53 TransportStream::Insecure(s) => s,
54 #[cfg(feature = "rustls")]
55 TransportStream::Secure(s) => get_stream(s.as_ref()),
56 }
57 }
58}
59
60impl AsyncRead for TransportStream {
61 fn poll_read(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &mut ReadBuf<'_>,
65 ) -> Poll<std::io::Result<()>> {
66 match self.get_mut() {
67 TransportStream::Insecure(s) => Pin::new(s).poll_read(cx, buf),
68 #[cfg(feature = "rustls")]
69 TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
70 }
71 }
72}
73
74impl AsyncWrite for TransportStream {
75 fn poll_write(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 buf: &[u8],
79 ) -> Poll<Result<usize, std::io::Error>> {
80 match self.get_mut() {
81 TransportStream::Insecure(s) => Pin::new(s).poll_write(cx, buf),
82 #[cfg(feature = "rustls")]
83 TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
84 }
85 }
86
87 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
88 match self.get_mut() {
89 TransportStream::Insecure(s) => Pin::new(s).poll_flush(cx),
90 #[cfg(feature = "rustls")]
91 TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_flush(cx),
92 }
93 }
94
95 fn poll_shutdown(
96 self: Pin<&mut Self>,
97 cx: &mut Context<'_>,
98 ) -> Poll<Result<(), std::io::Error>> {
99 match self.get_mut() {
100 TransportStream::Insecure(s) => Pin::new(s).poll_shutdown(cx),
101 #[cfg(feature = "rustls")]
102 TransportStream::Secure(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
103 }
104 }
105}
106
107#[derive(Debug)]
108struct StreamWrapper {
109 inner: WebSocketStream<TransportStream>,
110}
111
112#[async_trait]
113impl ProtobufStream for StreamWrapper {
114 async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
115 let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
116 let result = timeout(timeout_duration, self.inner.next()).await;
117
118 match result {
119 Err(_) => Err(anyhow!(
120 "Timeout reading message after {} seconds",
121 MESSAGE_TIMEOUT_SECS
122 )),
123 Ok(msg_result) => match msg_result {
124 Some(Ok(Message::Binary(b))) => {
125 let msg = crate::protocol::Message::decode(b.as_ref())
126 .context("Failed to decode protobuf message")?;
127 let msg = msg
128 .message
129 .context("Message field is missing in the protobuf message")?;
130 trace_message("Recv", &msg);
131 Ok(Some(msg))
132 }
133 Some(Ok(Message::Close(_))) => {
134 debug!("WebSocket connection closed");
135 Ok(None)
136 }
137 Some(Ok(Message::Ping(data))) => {
138 debug!("Received ping, sending pong");
139 self.inner
140 .send(Message::Pong(data))
141 .await
142 .context("Failed to send pong")?;
143 Ok(None)
144 }
145 Some(Ok(Message::Pong(_))) => {
146 debug!("Received pong");
147 Ok(None)
148 }
149 Some(Ok(Message::Text(_))) => {
150 error!("Received unexpected text message");
151 Err(anyhow!("Unexpected text message received"))
152 }
153 Some(Ok(m)) => {
154 error!("Received unexpected message: {:?}", m);
155 Err(anyhow!("Unexpected message received"))
156 }
157 None => Ok(None),
158 Some(Err(e)) => Err(anyhow!("WebSocket error: {}", e)),
159 },
160 }
161 }
162
163 async fn send_message(
164 &mut self,
165 msg: &crate::protocol::message::Message,
166 ) -> anyhow::Result<()> {
167 trace_message("Send", msg);
168 let mut buf = BytesMut::new();
169 let msg = crate::protocol::Message {
170 message: Some(msg.clone()),
171 };
172 msg.encode(&mut buf)
173 .context("Failed to encode protobuf message")?;
174
175 let timeout_duration = Duration::from_secs(MESSAGE_TIMEOUT_SECS);
176 timeout(
177 timeout_duration,
178 self.inner.send(Message::Binary(buf.into())),
179 )
180 .await
181 .map_err(|_| {
182 anyhow::anyhow!(
183 "Timeout writing message after {} seconds",
184 MESSAGE_TIMEOUT_SECS
185 )
186 })?
187 .context("Failed to send WebSocket message")?;
188 Ok(())
189 }
190
191 async fn close(&mut self) -> anyhow::Result<()> {
192 debug!("Closing WebSocket connection");
193 self.inner
194 .close(None)
195 .await
196 .context("Failed to close WebSocket stream")
197 }
198}
199
200impl AsyncStream for StreamWrapper {
201 type Item = Result<Bytes, Error>;
202
203 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
204 match Pin::new(&mut self.get_mut().inner).poll_next(cx) {
205 Poll::Pending => Poll::Pending,
206 Poll::Ready(None) => Poll::Ready(None),
207 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(Error::other(err)))),
208 Poll::Ready(Some(Ok(res))) => {
209 if let Message::Binary(b) = res {
210 Poll::Ready(Some(Ok(b)))
211 } else {
212 Poll::Ready(Some(Err(Error::new(
213 ErrorKind::InvalidData,
214 "unexpected frame",
215 ))))
216 }
217 }
218 }
219 }
220
221 fn size_hint(&self) -> (usize, Option<usize>) {
222 self.inner.size_hint()
223 }
224}
225
226#[derive(Debug)]
227pub struct WebsocketStream {
228 inner: StreamReader<StreamWrapper, Bytes>,
229}
230
231impl AsyncRead for WebsocketStream {
232 fn poll_read(
233 self: Pin<&mut Self>,
234 cx: &mut Context<'_>,
235 buf: &mut ReadBuf<'_>,
236 ) -> Poll<std::io::Result<()>> {
237 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
238 }
239}
240
241impl AsyncBufRead for WebsocketStream {
242 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
243 Pin::new(&mut self.get_mut().inner).poll_fill_buf(cx)
244 }
245
246 fn consume(self: Pin<&mut Self>, amt: usize) {
247 Pin::new(&mut self.get_mut().inner).consume(amt)
248 }
249}
250
251impl AsyncWrite for WebsocketStream {
252 fn poll_write(
253 self: Pin<&mut Self>,
254 cx: &mut Context<'_>,
255 buf: &[u8],
256 ) -> Poll<Result<usize, std::io::Error>> {
257 let sw = self.get_mut().inner.get_mut();
258 ready!(Pin::new(&mut sw.inner).poll_ready(cx).map_err(Error::other))?;
259
260 let bbuf = BytesMut::from(buf);
261
262 match Pin::new(&mut sw.inner).start_send(Message::Binary(bbuf.into())) {
263 Ok(()) => Poll::Ready(Ok(buf.len())),
264 Err(e) => Poll::Ready(Err(Error::other(e))),
265 }
266 }
267
268 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
269 Pin::new(&mut self.get_mut().inner.get_mut().inner)
270 .poll_flush(cx)
271 .map_err(Error::other)
272 }
273
274 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
275 Pin::new(&mut self.get_mut().inner.get_mut().inner)
276 .poll_close(cx)
277 .map_err(Error::other)
278 }
279}
280
281#[async_trait]
282impl ProtobufStream for WebsocketStream {
283 async fn recv_message(&mut self) -> anyhow::Result<Option<ProtocolMessage>> {
284 self.inner.get_mut().recv_message().await
285 }
286
287 async fn send_message(&mut self, msg: &ProtocolMessage) -> anyhow::Result<()> {
288 self.inner.get_mut().send_message(msg).await
289 }
290
291 async fn close(&mut self) -> anyhow::Result<()> {
292 debug!("Closing WebSocket stream");
293 self.inner.get_mut().close().await
294 }
295}
296
297#[derive(Debug)]
298enum SubTransport {
299 #[cfg(feature = "rustls")]
300 Secure(TlsTransport),
301 Insecure(TcpTransport),
302}
303
304#[derive(Debug)]
305pub struct WebsocketTransport {
306 sub: SubTransport,
307 conf: WebSocketConfig,
308 headers: Arc<DashMap<String, String>>,
309}
310
311#[async_trait]
312impl Transport for WebsocketTransport {
313 type Acceptor = Listener;
314 type RawStream = Stream;
315 type Stream = WebsocketStream;
316
317 fn new(config: &TransportConfig) -> anyhow::Result<Self> {
318 let wsconfig = config
319 .websocket
320 .as_ref()
321 .ok_or_else(|| anyhow!("Missing websocket config"))?;
322
323 let conf = WebSocketConfig::default().write_buffer_size(0);
324
325 let sub = match wsconfig.tls {
326 #[cfg(feature = "rustls")]
327 true => SubTransport::Secure(TlsTransport::new(config)?),
328 #[cfg(not(feature = "rustls"))]
329 true => unreachable!("TLS support not enabled"),
330 false => SubTransport::Insecure(TcpTransport::new(config)?),
331 };
332 let headers = Arc::new(DashMap::new());
333 Ok(WebsocketTransport { sub, conf, headers })
334 }
335
336 fn hint(conn: &Self::Stream, opt: SocketOpts) {
337 opt.apply(conn.inner.get_ref().inner.get_ref().get_tcpstream())
338 }
339
340 #[cfg(unix)]
341 fn as_raw_fd(conn: &Self::Stream) -> RawFd {
342 match conn.inner.get_ref().inner.get_ref().get_tcpstream() {
343 Stream::Tcp(tcp_stream) => tcp_stream.as_raw_fd(),
344 Stream::Unix(unix_stream) => unix_stream.as_raw_fd(),
345 }
346 }
347
348 async fn bind(&self, addr: NamedSocketAddr) -> anyhow::Result<Self::Acceptor> {
349 Listener::bind(&addr).await.map_err(Into::into)
350 }
351
352 async fn accept(&self, a: &Self::Acceptor) -> anyhow::Result<(Self::RawStream, SocketAddr)> {
353 let (s, addr) = match &self.sub {
354 SubTransport::Insecure(t) => t.accept(a).await?,
355 #[cfg(feature = "rustls")]
356 SubTransport::Secure(t) => t.accept(a).await?,
357 };
358 Ok((s, addr))
359 }
360
361 async fn handshake(&self, conn: Self::RawStream) -> anyhow::Result<Self::Stream> {
362 let tsream = match &self.sub {
363 SubTransport::Insecure(t) => {
364 TransportStream::Insecure(t.handshake(conn).await?.into_stream())
365 }
366 #[cfg(feature = "rustls")]
367 SubTransport::Secure(t) => TransportStream::Secure(Box::new(t.handshake(conn).await?)),
368 };
369
370 let headers = self.headers.clone();
371
372 let callback = move |req: &Request, res: Response| {
373 for ref header in req.headers() {
374 trace!("WS headers: {:?}", header);
375 headers.insert(
376 header.0.to_string(),
377 header.1.to_str().unwrap_or_default().to_string(),
378 );
379 }
380 Ok(res)
381 };
382
383 let wsstream = accept_hdr_async_with_config(tsream, callback, Some(self.conf)).await?;
384
385 let tun = WebsocketStream {
386 inner: StreamReader::new(StreamWrapper { inner: wsstream }),
387 };
388 Ok(tun)
389 }
390
391 async fn connect(&self, addr: &AddrMaybeCached) -> anyhow::Result<Self::Stream> {
392 let u = format!("wss://{}/endpoint/v3", &addr.addr.as_str());
393 let tstream = match &self.sub {
394 SubTransport::Insecure(t) => {
395 TransportStream::Insecure(t.connect(addr).await?.into_stream())
396 }
397 #[cfg(feature = "rustls")]
398 SubTransport::Secure(t) => TransportStream::Secure(Box::new(t.connect(addr).await?)),
399 };
400 debug!("Connecting to {}", u);
401 let (wsstream, _) = client_async_with_config(
402 u.clone()
403 .into_client_request()
404 .context("Failed to create client request")?,
405 tstream,
406 Some(self.conf),
407 )
408 .await
409 .with_context(|| format!("Failed to connect to {}", u))?;
410
411 debug!("Connected");
412
413 let tun = WebsocketStream {
414 inner: StreamReader::new(StreamWrapper { inner: wsstream }),
415 };
416 Ok(tun)
417 }
418
419 fn get_header(&self, name: &str) -> Option<String> {
420 self.headers.get(&name.to_lowercase()).map(|v| v.clone())
421 }
422}