use crate::error::{MqttError, Result};
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_tungstenite::{
accept_hdr_async, tungstenite::protocol::WebSocketConfig, WebSocketStream,
};
use tracing::{debug, error};
#[derive(Debug, Clone)]
pub struct WebSocketServerConfig {
pub path: String,
pub subprotocol: String,
pub max_frame_size: Option<usize>,
pub max_message_size: Option<usize>,
pub allowed_origins: Option<Vec<String>>,
}
impl Default for WebSocketServerConfig {
fn default() -> Self {
Self {
path: "/mqtt".to_string(),
subprotocol: "mqtt".to_string(),
max_frame_size: Some(16 * 1024 * 1024),
max_message_size: Some(64 * 1024 * 1024),
allowed_origins: None,
}
}
}
impl WebSocketServerConfig {
#[allow(clippy::must_use_candidate)]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.path = path.into();
self
}
#[must_use]
pub fn with_subprotocol(mut self, subprotocol: impl Into<String>) -> Self {
self.subprotocol = subprotocol.into();
self
}
#[must_use]
pub fn with_allowed_origins(mut self, origins: Vec<String>) -> Self {
self.allowed_origins = Some(origins);
self
}
#[must_use]
pub fn build_ws_config(&self) -> Option<WebSocketConfig> {
None
}
}
pub struct WebSocketStreamWrapper<S = TcpStream>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
inner: WebSocketStream<S>,
read_buffer: Vec<u8>,
read_pos: usize,
}
impl<S> WebSocketStreamWrapper<S>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
pub fn new(stream: WebSocketStream<S>) -> Self {
Self {
inner: stream,
read_buffer: Vec::new(),
read_pos: 0,
}
}
pub fn peer_addr(&self) -> Result<SocketAddr>
where
S: std::ops::Deref<Target = TcpStream>,
{
self.inner
.get_ref()
.peer_addr()
.map_err(|e| MqttError::Io(format!("Failed to get peer address: {e}")))
}
}
pub async fn accept_websocket_connection<S>(
stream: S,
config: &WebSocketServerConfig,
peer_addr: SocketAddr,
) -> Result<WebSocketStreamWrapper<S>>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
debug!("Starting WebSocket handshake with {}", peer_addr);
let subprotocol = config.subprotocol.clone();
let path = config.path.clone();
let allowed_origins = config.allowed_origins.clone();
let callback = websocket_handshake_callback(subprotocol, path, allowed_origins);
match accept_hdr_async(stream, callback).await {
Ok(ws_stream) => {
debug!("WebSocket handshake completed with {}", peer_addr);
Ok(WebSocketStreamWrapper::new(ws_stream))
}
Err(e) => {
error!("WebSocket handshake failed with {}: {}", peer_addr, e);
Err(MqttError::ConnectionError(format!(
"WebSocket handshake failed: {e}"
)))
}
}
}
impl<S> AsyncRead for WebSocketStreamWrapper<S>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
use std::task::Poll;
use tokio_tungstenite::tungstenite::Message;
if self.read_pos < self.read_buffer.len() {
let remaining = &self.read_buffer[self.read_pos..];
let to_copy = remaining.len().min(buf.remaining());
buf.put_slice(&remaining[..to_copy]);
self.read_pos += to_copy;
if self.read_pos >= self.read_buffer.len() {
self.read_buffer.clear();
self.read_pos = 0;
}
return Poll::Ready(Ok(()));
}
let mut inner = std::pin::Pin::new(&mut self.inner);
match inner.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(Message::Binary(data)))) => {
self.read_buffer = data.to_vec();
self.read_pos = 0;
let to_copy = self.read_buffer.len().min(buf.remaining());
buf.put_slice(&self.read_buffer[..to_copy]);
self.read_pos = to_copy;
Poll::Ready(Ok(()))
}
Poll::Ready(Some(Ok(Message::Close(_)))) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"WebSocket closed",
))),
Poll::Ready(Some(Ok(Message::Text(_)))) => {
error!("[MQTT-6.0.0-1] received text frame on MQTT WebSocket, closing");
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"MQTT control packets must use binary frames",
)))
}
Poll::Ready(Some(Ok(Message::Ping(_) | Message::Pong(_) | Message::Frame(_)))) => {
cx.waker().wake_by_ref();
Poll::Pending
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
Poll::Ready(None) => Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"WebSocket stream ended",
))),
Poll::Pending => Poll::Pending,
}
}
}
impl<S> AsyncWrite for WebSocketStreamWrapper<S>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
use std::task::Poll;
use tokio_tungstenite::tungstenite::Message;
let message = Message::Binary(buf.to_vec().into());
let mut inner = std::pin::Pin::new(&mut self.inner);
match inner.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => match inner.start_send_unpin(message) {
Ok(()) => Poll::Ready(Ok(buf.len())),
Err(e) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
},
Poll::Ready(Err(e)) => Poll::Ready(Err(std::io::Error::other(e.to_string()))),
Poll::Pending => Poll::Pending,
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let mut inner = std::pin::Pin::new(&mut self.inner);
match inner.poll_flush_unpin(cx) {
std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(())),
std::task::Poll::Ready(Err(e)) => {
std::task::Poll::Ready(Err(std::io::Error::other(e.to_string())))
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
let mut inner = std::pin::Pin::new(&mut self.inner);
match inner.poll_close_unpin(cx) {
std::task::Poll::Ready(Ok(())) => std::task::Poll::Ready(Ok(())),
std::task::Poll::Ready(Err(e)) => {
std::task::Poll::Ready(Err(std::io::Error::other(e.to_string())))
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
}
}
#[allow(clippy::result_large_err)]
fn websocket_handshake_callback(
subprotocol: String,
path: String,
allowed_origins: Option<Vec<String>>,
) -> impl FnOnce(
&tokio_tungstenite::tungstenite::handshake::server::Request,
tokio_tungstenite::tungstenite::handshake::server::Response,
) -> std::result::Result<
tokio_tungstenite::tungstenite::handshake::server::Response,
http::Response<Option<String>>,
> {
move |req, response| {
if req.uri().path() != path {
debug!("WebSocket path mismatch: {} != {}", req.uri().path(), path);
let reject = http::Response::builder()
.status(http::StatusCode::NOT_FOUND)
.body(None)
.expect("building 404 response");
return Err(reject);
}
if let Some(ref origins) = allowed_origins {
let origin = req.headers().get("Origin").and_then(|v| v.to_str().ok());
let allowed = origin.is_some_and(|o| origins.iter().any(|a| a.eq_ignore_ascii_case(o)));
if !allowed {
debug!("WebSocket origin rejected: {:?}", origin);
let reject = http::Response::builder()
.status(http::StatusCode::FORBIDDEN)
.body(None)
.expect("building 403 response");
return Err(reject);
}
}
let has_mqtt_subprotocol = req
.headers()
.get("Sec-WebSocket-Protocol")
.and_then(|v| v.to_str().ok())
.is_some_and(|protocols| {
protocols
.split(',')
.any(|p| p.trim() == subprotocol.as_str())
});
if has_mqtt_subprotocol {
let mut response = response;
response.headers_mut().insert(
"Sec-WebSocket-Protocol",
subprotocol
.parse()
.unwrap_or_else(|_| "mqtt".parse().unwrap()),
);
Ok(response)
} else {
Ok(response)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_websocket_config_default() {
let config = WebSocketServerConfig::default();
assert_eq!(config.path, "/mqtt");
assert_eq!(config.subprotocol, "mqtt");
assert_eq!(config.max_frame_size, Some(16 * 1024 * 1024));
assert_eq!(config.max_message_size, Some(64 * 1024 * 1024));
}
#[test]
fn test_websocket_config_builder() {
let config = WebSocketServerConfig::new()
.with_path("/ws")
.with_subprotocol("mqttv5.0");
assert_eq!(config.path, "/ws");
assert_eq!(config.subprotocol, "mqttv5.0");
}
#[test]
fn test_ws_config_build() {
let config = WebSocketServerConfig::default();
let ws_config = config.build_ws_config();
assert!(ws_config.is_none());
}
}