Skip to main content

rumqttc_core/
websockets.rs

1//! WebSocket URL/handshake helpers and the `WsAdapter` transport bridge.
2
3use http::{Response, header::ToStrError};
4
5#[cfg(feature = "websocket")]
6use async_tungstenite::{
7    WebSocketReceiver, WebSocketSender, WebSocketStream,
8    bytes::{ByteReader, ByteWriter},
9    tungstenite::Message,
10};
11#[cfg(feature = "websocket")]
12use futures_io::{AsyncRead as FuturesAsyncRead, AsyncWrite as FuturesAsyncWrite};
13#[cfg(feature = "websocket")]
14use std::{
15    io,
16    pin::Pin,
17    task::{Context, Poll},
18};
19#[cfg(feature = "websocket")]
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21
22/// Bridges `async-tungstenite`'s split WebSocket halves (`ByteReader` + `ByteWriter`)
23/// into a unified `AsyncRead + AsyncWrite` object compatible with `rumqttc`'s `Network`
24/// abstraction.
25///
26/// This replaces the former `ws_stream_tungstenite::WsStream` wrapper and eliminates
27/// the `ws_stream_tungstenite` dependency from the `websocket` feature.
28#[cfg(feature = "websocket")]
29pub struct WsAdapter<S> {
30    reader: ByteReader<WebSocketReceiver<S>>,
31    writer: ByteWriter<WebSocketSender<S>>,
32}
33
34#[cfg(feature = "websocket")]
35impl<S> WsAdapter<S>
36where
37    S: FuturesAsyncRead + FuturesAsyncWrite + Unpin,
38{
39    pub fn new(ws: WebSocketStream<S>) -> Self {
40        let (sender, receiver) = ws.split();
41        Self {
42            reader: ByteReader::new(receiver),
43            writer: ByteWriter::new(sender),
44        }
45    }
46}
47
48#[cfg(feature = "websocket")]
49impl<S: Unpin> AsyncRead for WsAdapter<S>
50where
51    WebSocketReceiver<S>:
52        futures_util::Stream<Item = Result<Message, async_tungstenite::tungstenite::Error>> + Unpin,
53{
54    fn poll_read(
55        mut self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57        buf: &mut ReadBuf<'_>,
58    ) -> Poll<io::Result<()>> {
59        tokio::io::AsyncRead::poll_read(Pin::new(&mut self.reader), cx, buf)
60    }
61}
62
63#[cfg(feature = "websocket")]
64impl<S: Unpin> AsyncWrite for WsAdapter<S>
65where
66    WebSocketSender<S>: async_tungstenite::bytes::Sender + Unpin,
67{
68    fn poll_write(
69        mut self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &[u8],
72    ) -> Poll<io::Result<usize>> {
73        tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf)
74    }
75
76    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
77        tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx)
78    }
79
80    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81        tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.writer), cx)
82    }
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum UrlError {
87    #[error("Invalid protocol specified inside url.")]
88    Protocol,
89    #[error("Couldn't parse host from url.")]
90    Host,
91    #[error("Couldn't parse host url.")]
92    Parse(#[from] http::uri::InvalidUri),
93}
94
95#[derive(Debug, thiserror::Error)]
96pub enum ValidationError {
97    #[error("Websocket response does not contain subprotocol header")]
98    SubprotocolHeaderMissing,
99    #[error("MQTT not in subprotocol header: {0}")]
100    SubprotocolMqttMissing(String),
101    #[error("Subprotocol header couldn't be converted into string representation")]
102    HeaderToStr(#[from] ToStrError),
103}
104
105pub fn validate_response_headers(
106    response: Response<Option<Vec<u8>>>,
107) -> Result<(), ValidationError> {
108    let subprotocol = response
109        .headers()
110        .get("Sec-WebSocket-Protocol")
111        .ok_or(ValidationError::SubprotocolHeaderMissing)?
112        .to_str()?;
113
114    // Server must respond with Sec-WebSocket-Protocol header value of "mqtt"
115    // https://http.dev/ws#sec-websocket-protocol
116    if subprotocol.trim() != "mqtt" {
117        return Err(ValidationError::SubprotocolMqttMissing(
118            subprotocol.to_owned(),
119        ));
120    }
121
122    Ok(())
123}
124
125pub fn split_url(url: &str) -> Result<(String, u16), UrlError> {
126    let uri = url.parse::<http::Uri>()?;
127    let domain = domain(&uri).ok_or(UrlError::Protocol)?;
128    let port = port(&uri).ok_or(UrlError::Host)?;
129    Ok((domain, port))
130}
131
132fn domain(uri: &http::Uri) -> Option<String> {
133    uri.host().map(|host| {
134        // If host is an IPv6 address, it might be surrounded by brackets. These brackets are
135        // *not* part of a valid IP, so they must be stripped out.
136        //
137        // The URI from the request is guaranteed to be valid, so we don't need a separate
138        // check for the closing bracket.
139        let host = if host.starts_with('[') {
140            &host[1..host.len() - 1]
141        } else {
142            host
143        };
144
145        host.to_owned()
146    })
147}
148
149fn port(uri: &http::Uri) -> Option<u16> {
150    uri.port_u16().or_else(|| match uri.scheme_str() {
151        Some("wss") => Some(443),
152        Some("ws") => Some(80),
153        _ => None,
154    })
155}