rumqttc_core/
websockets.rs1use http::{Response, header::ToStrError};
4
5#[cfg(feature = "websocket")]
6use async_tungstenite::{
7 WebSocketReceiver, WebSocketSender, WebSocketStream,
8 bytes::{ByteReader, ByteWriter},
9 tungstenite::Message,
10};
11#[cfg(feature = "websocket")]
12use futures_io::{AsyncRead as FuturesAsyncRead, AsyncWrite as FuturesAsyncWrite};
13#[cfg(feature = "websocket")]
14use std::{
15 io,
16 pin::Pin,
17 task::{Context, Poll},
18};
19#[cfg(feature = "websocket")]
20use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
21
22#[cfg(feature = "websocket")]
29pub struct WsAdapter<S> {
30 reader: ByteReader<WebSocketReceiver<S>>,
31 writer: ByteWriter<WebSocketSender<S>>,
32}
33
34#[cfg(feature = "websocket")]
35impl<S> WsAdapter<S>
36where
37 S: FuturesAsyncRead + FuturesAsyncWrite + Unpin,
38{
39 pub fn new(ws: WebSocketStream<S>) -> Self {
40 let (sender, receiver) = ws.split();
41 Self {
42 reader: ByteReader::new(receiver),
43 writer: ByteWriter::new(sender),
44 }
45 }
46}
47
48#[cfg(feature = "websocket")]
49impl<S: Unpin> AsyncRead for WsAdapter<S>
50where
51 WebSocketReceiver<S>:
52 futures_util::Stream<Item = Result<Message, async_tungstenite::tungstenite::Error>> + Unpin,
53{
54 fn poll_read(
55 mut self: Pin<&mut Self>,
56 cx: &mut Context<'_>,
57 buf: &mut ReadBuf<'_>,
58 ) -> Poll<io::Result<()>> {
59 tokio::io::AsyncRead::poll_read(Pin::new(&mut self.reader), cx, buf)
60 }
61}
62
63#[cfg(feature = "websocket")]
64impl<S: Unpin> AsyncWrite for WsAdapter<S>
65where
66 WebSocketSender<S>: async_tungstenite::bytes::Sender + Unpin,
67{
68 fn poll_write(
69 mut self: Pin<&mut Self>,
70 cx: &mut Context<'_>,
71 buf: &[u8],
72 ) -> Poll<io::Result<usize>> {
73 tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.writer), cx, buf)
74 }
75
76 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
77 tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.writer), cx)
78 }
79
80 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.writer), cx)
82 }
83}
84
85#[derive(Debug, thiserror::Error)]
86pub enum UrlError {
87 #[error("Invalid protocol specified inside url.")]
88 Protocol,
89 #[error("Couldn't parse host from url.")]
90 Host,
91 #[error("Couldn't parse host url.")]
92 Parse(#[from] http::uri::InvalidUri),
93}
94
95#[derive(Debug, thiserror::Error)]
96pub enum ValidationError {
97 #[error("Websocket response does not contain subprotocol header")]
98 SubprotocolHeaderMissing,
99 #[error("MQTT not in subprotocol header: {0}")]
100 SubprotocolMqttMissing(String),
101 #[error("Subprotocol header couldn't be converted into string representation")]
102 HeaderToStr(#[from] ToStrError),
103}
104
105pub fn validate_response_headers(
106 response: Response<Option<Vec<u8>>>,
107) -> Result<(), ValidationError> {
108 let subprotocol = response
109 .headers()
110 .get("Sec-WebSocket-Protocol")
111 .ok_or(ValidationError::SubprotocolHeaderMissing)?
112 .to_str()?;
113
114 if subprotocol.trim() != "mqtt" {
117 return Err(ValidationError::SubprotocolMqttMissing(
118 subprotocol.to_owned(),
119 ));
120 }
121
122 Ok(())
123}
124
125pub fn split_url(url: &str) -> Result<(String, u16), UrlError> {
126 let uri = url.parse::<http::Uri>()?;
127 let domain = domain(&uri).ok_or(UrlError::Protocol)?;
128 let port = port(&uri).ok_or(UrlError::Host)?;
129 Ok((domain, port))
130}
131
132fn domain(uri: &http::Uri) -> Option<String> {
133 uri.host().map(|host| {
134 let host = if host.starts_with('[') {
140 &host[1..host.len() - 1]
141 } else {
142 host
143 };
144
145 host.to_owned()
146 })
147}
148
149fn port(uri: &http::Uri) -> Option<u16> {
150 uri.port_u16().or_else(|| match uri.scheme_str() {
151 Some("wss") => Some(443),
152 Some("ws") => Some(80),
153 _ => None,
154 })
155}