use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::num::{NonZeroU16, NonZeroU32};
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "quic")]
use crate::quic::QuinnBiStream;
use crate::stream::Dispatcher;
#[cfg(feature = "ws")]
use crate::ws::WsStream;
#[cfg(feature = "tls")]
use crate::TlsCertExtractor;
use crate::{Error, Result};
#[allow(unused_imports)]
use anyhow::{anyhow, Context};
use nonzero_ext::nonzero;
use proxy_protocol::parse;
use proxy_protocol::ProxyHeader;
use proxy_protocol::{version1 as v1, version2 as v2};
#[cfg(feature = "quic")]
use quinn::{crypto::rustls::QuicServerConfig, IdleTimeout};
#[cfg(feature = "tls")]
use rmqtt_codec::cert::CertInfo;
use rmqtt_codec::types::QoS;
#[cfg(not(target_os = "windows"))]
#[cfg(feature = "tls")]
use rustls::crypto::aws_lc_rs as provider;
#[cfg(feature = "tls")]
#[cfg(target_os = "windows")]
use rustls::crypto::ring as provider;
#[cfg(feature = "tls")]
use rustls::pki_types::CertificateDer;
#[cfg(feature = "tls")]
use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
use socket2::{Domain, SockAddr, Socket, Type};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
#[cfg(feature = "tls")]
use tokio_rustls::{server::TlsStream, TlsAcceptor};
#[cfg(feature = "ws")]
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::handshake::server::{ErrorResponse, Request, Response},
};
#[derive(Clone, Debug)]
pub struct Builder {
pub name: String,
pub laddr: SocketAddr,
pub backlog: i32,
pub nodelay: bool,
pub reuseaddr: Option<bool>,
pub reuseport: Option<bool>,
pub max_connections: usize,
pub max_handshaking_limit: usize,
pub max_packet_size: u32,
pub allow_anonymous: bool,
pub min_keepalive: u16,
pub max_keepalive: u16,
pub allow_zero_keepalive: bool,
pub keepalive_backoff: f32,
pub max_inflight: NonZeroU16,
pub handshake_timeout: Duration,
pub send_timeout: Duration,
pub max_mqueue_len: usize,
pub mqueue_rate_limit: (NonZeroU32, Duration),
pub max_clientid_len: usize,
pub max_qos_allowed: QoS,
pub max_topic_levels: usize,
pub session_expiry_interval: Duration,
pub max_session_expiry_interval: Duration,
pub message_retry_interval: Duration,
pub message_expiry_interval: Duration,
pub max_subscriptions: usize,
pub shared_subscription: bool,
pub max_topic_aliases: u16,
pub limit_subscription: bool,
pub delayed_publish: bool,
pub tls_cross_certificate: bool,
pub tls_cert: Option<String>,
pub tls_key: Option<String>,
pub tls_client_ca_certs: Option<String>,
pub proxy_protocol: bool,
pub proxy_protocol_timeout: Duration,
pub cert_cn_as_username: bool,
pub cert_subject_dn_as_username: bool,
pub collect_cert_info: bool,
pub idle_timeout: Duration,
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
impl Builder {
pub fn new() -> Builder {
Builder {
name: Default::default(),
laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
max_connections: 1_000_000,
max_handshaking_limit: 1_000,
max_packet_size: 1024 * 1024,
backlog: 512,
nodelay: false,
reuseaddr: None,
reuseport: None,
allow_anonymous: true,
min_keepalive: 0,
max_keepalive: 65535,
allow_zero_keepalive: true,
keepalive_backoff: 0.75,
max_inflight: nonzero!(16u16),
handshake_timeout: Duration::from_secs(30),
send_timeout: Duration::from_secs(10),
max_mqueue_len: 1000,
mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
max_clientid_len: 65535,
max_qos_allowed: QoS::ExactlyOnce,
max_topic_levels: 0,
session_expiry_interval: Duration::from_secs(2 * 60 * 60),
max_session_expiry_interval: Duration::ZERO,
message_retry_interval: Duration::from_secs(20),
message_expiry_interval: Duration::from_secs(5 * 60),
max_subscriptions: 0,
shared_subscription: true,
max_topic_aliases: 0,
limit_subscription: false,
delayed_publish: false,
tls_cross_certificate: false,
tls_cert: None,
tls_key: None,
tls_client_ca_certs: None,
proxy_protocol: false,
proxy_protocol_timeout: Duration::from_secs(5),
cert_cn_as_username: false,
cert_subject_dn_as_username: false,
collect_cert_info: false,
idle_timeout: Duration::from_secs(90),
}
}
pub fn name<N: Into<String>>(mut self, name: N) -> Self {
self.name = name.into();
self
}
pub fn laddr(mut self, laddr: SocketAddr) -> Self {
self.laddr = laddr;
self
}
pub fn backlog(mut self, backlog: i32) -> Self {
self.backlog = backlog;
self
}
pub fn nodelay(mut self, nodelay: bool) -> Self {
self.nodelay = nodelay;
self
}
pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
self.reuseaddr = reuseaddr;
self
}
pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
self.reuseport = reuseport;
self
}
pub fn max_connections(mut self, max_connections: usize) -> Self {
self.max_connections = max_connections;
self
}
pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
self.max_handshaking_limit = max_handshaking_limit;
self
}
pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
self.max_packet_size = max_packet_size;
self
}
pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
self.allow_anonymous = allow_anonymous;
self
}
pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
self.min_keepalive = min_keepalive;
self
}
pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
self.max_keepalive = max_keepalive;
self
}
pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
self.allow_zero_keepalive = allow_zero_keepalive;
self
}
pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
self.keepalive_backoff = keepalive_backoff;
self
}
pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
self.max_inflight = max_inflight;
self
}
pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
self.handshake_timeout = handshake_timeout;
self
}
pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
self.send_timeout = send_timeout;
self
}
pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
self.max_mqueue_len = max_mqueue_len;
self
}
pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
self.mqueue_rate_limit = (rate_limit, duration);
self
}
pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
self.max_clientid_len = max_clientid_len;
self
}
pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
self.max_qos_allowed = max_qos_allowed;
self
}
pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
self.max_topic_levels = max_topic_levels;
self
}
pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
self.session_expiry_interval = session_expiry_interval;
self
}
pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
self.max_session_expiry_interval = max_session_expiry_interval;
self
}
pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
self.message_retry_interval = message_retry_interval;
self
}
pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
self.message_expiry_interval = message_expiry_interval;
self
}
pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
self.max_subscriptions = max_subscriptions;
self
}
pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
self.shared_subscription = shared_subscription;
self
}
pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
self.max_topic_aliases = max_topic_aliases;
self
}
pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
self.limit_subscription = limit_subscription;
self
}
pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
self.delayed_publish = delayed_publish;
self
}
pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
self.tls_cross_certificate = cross_certificate;
self
}
pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
self.tls_cert = tls_cert.map(|c| c.into());
self
}
pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
self.tls_key = tls_key.map(|c| c.into());
self
}
pub fn tls_client_ca_certs<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
self.tls_client_ca_certs = tls_key.map(|c| c.into());
self
}
pub fn cert_cn_as_username(mut self, cert_cn_as_username: bool) -> Self {
self.cert_cn_as_username = cert_cn_as_username;
self
}
pub fn cert_subject_dn_as_username(mut self, v: bool) -> Self {
self.cert_subject_dn_as_username = v;
self
}
pub fn collect_cert_info(mut self, collect_cert_info: bool) -> Self {
self.collect_cert_info = collect_cert_info;
self
}
pub fn proxy_protocol(mut self, enable_protocol_proxy: bool) -> Self {
self.proxy_protocol = enable_protocol_proxy;
self
}
pub fn proxy_protocol_timeout(mut self, proxy_protocol_timeout: Duration) -> Self {
self.proxy_protocol_timeout = proxy_protocol_timeout;
self
}
pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.idle_timeout = idle_timeout;
self
}
#[allow(unused_variables)]
pub fn bind(self) -> Result<Listener> {
let builder = match self.laddr {
SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
};
builder.set_nonblocking(true)?;
if let Some(reuseaddr) = self.reuseaddr {
builder.set_reuse_address(reuseaddr)?;
}
#[cfg(not(windows))]
if let Some(reuseport) = self.reuseport {
builder.set_reuse_port(reuseport)?;
}
builder.bind(&SockAddr::from(self.laddr))?;
builder.listen(self.backlog)?;
let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
log::info!(
"MQTT Broker Listening on {} {}",
self.name,
tcp_listener.local_addr().unwrap_or(self.laddr)
);
Ok(Listener {
typ: ListenerType::TCP,
cfg: Arc::new(self),
tcp_listener: Some(tcp_listener),
#[cfg(feature = "tls")]
tls_acceptor: None,
#[cfg(feature = "quic")]
quinn_endpoint: None,
})
}
#[allow(unused_variables)]
#[cfg(feature = "quic")]
pub fn bind_quic(self) -> Result<Listener> {
let mut tls_config = self.build_tls_config()?;
tls_config.alpn_protocols = vec![b"mqtt".to_vec(), b"mqttv5".to_vec()];
let server_crypto = QuicServerConfig::try_from(tls_config)?;
let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
transport_config.max_concurrent_uni_streams(0_u8.into());
transport_config.max_idle_timeout(Some(IdleTimeout::try_from(self.idle_timeout)?));
let endpoint = quinn::Endpoint::server(server_config, self.laddr)?;
log::info!("MQTT Broker Listening on {} {}", self.name, endpoint.local_addr().unwrap_or(self.laddr));
Ok(Listener {
typ: ListenerType::QUIC,
cfg: Arc::new(self),
tcp_listener: None,
#[cfg(feature = "tls")]
tls_acceptor: None,
quinn_endpoint: Some(endpoint),
})
}
#[cfg(feature = "tls")]
fn build_tls_config(&self) -> Result<ServerConfig> {
let cert_file = self.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
let key_file = self.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
let cert_chain = read_ca_certs(cert_file).context("Failed to read TLS certificate chain")?;
let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
let client_ca_certs = if let Some(ca_certs_file) = self.tls_client_ca_certs.as_ref() {
read_ca_certs(ca_certs_file).context("Failed to read TLS Client CA certificates")?
} else {
cert_chain.clone()
};
let provider = Arc::new(provider::default_provider());
let client_auth = if self.tls_cross_certificate {
let mut client_auth_roots = RootCertStore::empty();
for root in client_ca_certs {
client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
}
WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
.build()
.map_err(|e| anyhow!(e))?
} else {
WebPkiClientVerifier::no_client_auth()
};
ServerConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| anyhow!(e))?
.with_client_cert_verifier(client_auth)
.with_single_cert(cert_chain, key)
.map_err(|e| anyhow!(format!("Certificate error: {e}")))
}
}
#[derive(Debug, Copy, Clone)]
pub enum ListenerType {
TCP,
#[cfg(feature = "tls")]
TLS,
#[cfg(feature = "ws")]
WS,
#[cfg(feature = "tls")]
#[cfg(feature = "ws")]
WSS,
#[cfg(feature = "quic")]
QUIC,
}
pub struct Listener {
pub typ: ListenerType,
pub cfg: Arc<Builder>,
tcp_listener: Option<TcpListener>,
#[cfg(feature = "tls")]
tls_acceptor: Option<TlsAcceptor>,
#[cfg(feature = "quic")]
quinn_endpoint: Option<quinn::Endpoint>,
}
impl Listener {
pub fn tcp(mut self) -> Result<Self> {
let _err = anyhow!("Protocol downgrade from TLS/WS/WSS/QUIC to TCP is not permitted");
#[cfg(feature = "tls")]
if matches!(self.typ, ListenerType::TLS) {
return Err(_err);
}
#[cfg(feature = "tls")]
#[cfg(feature = "ws")]
if matches!(self.typ, ListenerType::WSS) {
return Err(_err);
}
#[cfg(feature = "ws")]
if matches!(self.typ, ListenerType::WS) {
return Err(_err);
}
#[cfg(feature = "quic")]
if matches!(self.typ, ListenerType::QUIC) {
return Err(_err);
}
self.typ = ListenerType::TCP;
Ok(self)
}
#[cfg(feature = "ws")]
pub fn ws(mut self) -> Result<Self> {
if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
self.typ = ListenerType::WS;
} else {
return Err(anyhow!("Protocol upgrade from TLS/WSS/QUIC to WS is not permitted"));
}
Ok(self)
}
#[cfg(feature = "tls")]
#[cfg(feature = "ws")]
pub fn wss(mut self) -> Result<Self> {
#[cfg(feature = "quic")]
if matches!(self.typ, ListenerType::QUIC) {
return Err(anyhow!("Protocol upgrade from QUIC to WS is not permitted"));
}
if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
self = self.tls()?;
}
self.typ = ListenerType::WSS;
Ok(self)
}
#[cfg(feature = "tls")]
pub fn tls(mut self) -> Result<Listener> {
match self.typ {
#[cfg(feature = "ws")]
ListenerType::WS | ListenerType::WSS => {
return Err(anyhow!("Protocol downgrade from WS/WSS/QUIC to TLS is not permitted"));
}
#[cfg(feature = "quic")]
ListenerType::QUIC => {
return Err(anyhow!("Protocol downgrade from QUIC to TLS is not permitted"));
}
ListenerType::TLS => return Ok(self),
ListenerType::TCP => {}
}
let tls_config = self.cfg.build_tls_config()?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
self.tls_acceptor = Some(acceptor);
self.typ = ListenerType::TLS;
Ok(self)
}
pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
if let Some(tcp_listener) = &self.tcp_listener {
self.accept_tcp(tcp_listener).await
} else {
Err(anyhow!(""))
}
}
async fn accept_tcp(&self, tcp_listener: &TcpListener) -> Result<Acceptor<TcpStream>> {
let (mut socket, mut remote_addr) = tcp_listener.accept().await?;
if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
return Err(Error::from(e));
}
log::debug!("remote_addr: {remote_addr}, proxy_protocol: {}", self.cfg.proxy_protocol);
if self.cfg.proxy_protocol {
let mut buffer = [0u8; u16::MAX as usize];
let read_bytes =
tokio::time::timeout(self.cfg.proxy_protocol_timeout, socket.peek(&mut buffer)).await??;
let len = {
let mut slice = &buffer[..read_bytes];
let header = parse(&mut slice)?;
if let Some((src, _)) = handle_header(header) {
remote_addr = src;
}
read_bytes - slice.len()
};
let _ = socket.read_exact(&mut buffer[..len]).await;
}
Ok(Acceptor {
socket,
remote_addr,
#[cfg(feature = "tls")]
acceptor: self.tls_acceptor.clone(),
cfg: self.cfg.clone(),
typ: self.typ,
})
}
#[cfg(feature = "quic")]
pub async fn accept_quic(&self) -> Result<Acceptor<QuinnBiStream>> {
if let Some(endpoint) = &self.quinn_endpoint {
let incoming =
endpoint.accept().await.ok_or_else(|| anyhow!("No incoming QUIC connection available"))?;
let conn = incoming.await?;
let remote_addr = conn.remote_address();
let (send, recv) = conn.accept_bi().await?;
let socket = QuinnBiStream::new(send, recv);
Ok(Acceptor {
socket,
remote_addr,
#[cfg(feature = "tls")]
acceptor: self.tls_acceptor.clone(),
cfg: self.cfg.clone(),
typ: self.typ,
})
} else {
Err(anyhow!(""))
}
}
pub fn local_addr(&self) -> Result<SocketAddr> {
if let Some(tcp_listener) = &self.tcp_listener {
Ok(tcp_listener.local_addr()?)
} else {
#[cfg(feature = "quic")]
if let Some(endpoint) = &self.quinn_endpoint {
Ok(endpoint.local_addr()?)
} else {
Err(anyhow!("No active listener (neither TCP nor QUIC endpoint is available)"))
}
#[cfg(not(feature = "quic"))]
Err(anyhow!("No active listener"))
}
}
}
pub struct Acceptor<S> {
pub(crate) socket: S,
#[cfg(feature = "tls")]
acceptor: Option<TlsAcceptor>,
pub remote_addr: SocketAddr,
pub cfg: Arc<Builder>,
pub typ: ListenerType,
}
impl<S> Acceptor<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
#[inline]
pub fn tcp(self) -> Result<Dispatcher<S>> {
if matches!(self.typ, ListenerType::TCP) {
Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
} else {
Err(anyhow!("Protocol mismatch: Expected TCP listener"))
}
}
#[cfg(feature = "tls")]
#[inline]
pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
if !matches!(self.typ, ListenerType::TLS) {
return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
}
let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
{
Ok(Ok(tls_s)) => tls_s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
};
let cert_info = Self::get_extract_cert_info(
&tls_s,
self.cfg.cert_cn_as_username,
self.cfg.cert_subject_dn_as_username,
self.cfg.collect_cert_info,
);
Ok(Dispatcher::new(tls_s, self.remote_addr, cert_info, self.cfg))
}
#[cfg(feature = "ws")]
#[inline]
pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
if !matches!(self.typ, ListenerType::WS) {
return Err(anyhow!("Protocol mismatch: Expected WS listener"));
}
match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
.await
{
Ok(Ok(ws_stream)) => {
Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, None, self.cfg.clone()))
}
Ok(Err(e)) => Err(e.into()),
Err(_) => Err(crate::MqttError::ReadTimeout.into()),
}
}
#[cfg(feature = "tls")]
#[cfg(feature = "ws")]
#[inline]
pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
if !matches!(self.typ, ListenerType::WSS) {
return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
}
let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
{
Ok(Ok(tls_s)) => tls_s,
Ok(Err(e)) => return Err(e.into()),
Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
};
let cert_info = Self::get_extract_cert_info(
&tls_s,
self.cfg.cert_cn_as_username,
self.cfg.cert_subject_dn_as_username,
self.cfg.collect_cert_info,
);
match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
Ok(Ok(ws_stream)) => {
Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, cert_info, self.cfg.clone()))
}
Ok(Err(e)) => Err(e.into()),
Err(_) => Err(crate::MqttError::ReadTimeout.into()),
}
}
#[cfg(feature = "quic")]
#[inline]
pub async fn quic(self) -> Result<Dispatcher<S>> {
if !matches!(self.typ, ListenerType::QUIC) {
return Err(anyhow!("Protocol mismatch: Expected QUIC listener"));
}
Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
}
#[inline]
#[cfg(feature = "tls")]
fn get_extract_cert_info<C: TlsCertExtractor>(
io: &C,
cert_cn_as_username: bool,
cert_subject_dn_as_username: bool,
collect_cert_info: bool,
) -> Option<CertInfo> {
if cert_cn_as_username || cert_subject_dn_as_username || collect_cert_info {
let cert_info: Option<CertInfo> = io.extract_cert_info();
if let Some(ref cert) = cert_info {
log::debug!("Client certificate: {cert}");
log::debug!("CN: {:?}, Org: {:?}", cert.common_name, cert.organization);
}
cert_info
} else {
None
}
}
}
#[allow(clippy::result_large_err)]
#[cfg(feature = "ws")]
fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
let mqtt_protocol = req
.headers()
.get("Sec-WebSocket-Protocol")
.ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?
.to_str()
.map_err(|_| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
match mqtt_protocol {
"mqtt" => {
response.headers_mut().append(
"Sec-WebSocket-Protocol",
"mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
);
}
"mqttv3.1" => {
response.headers_mut().append(
"Sec-WebSocket-Protocol",
"mqttv3.1".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
);
}
_ => {
return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
}
}
Ok(response)
}
fn handle_header(header: ProxyHeader) -> Option<(SocketAddr, SocketAddr)> {
use ProxyHeader::{Version1, Version2};
match header {
Version1 { addresses } => handle_header_v1(addresses),
Version2 { command, transport_protocol, addresses } => {
handle_header_v2(command, transport_protocol, addresses)
}
_ => {
log::info!("[tcp]accept proxy-protocol-v?");
None
}
}
}
fn handle_header_v1(addr: v1::ProxyAddresses) -> Option<(SocketAddr, SocketAddr)> {
use v1::ProxyAddresses::*;
match addr {
Unknown => {
log::debug!("[tcp]accept proxy-protocol-v1: unknown");
None
}
Ipv4 { source, destination } => {
log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
}
Ipv6 { source, destination } => {
log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
}
}
}
fn handle_header_v2(
cmd: v2::ProxyCommand,
proto: v2::ProxyTransportProtocol,
addr: v2::ProxyAddresses,
) -> Option<(SocketAddr, SocketAddr)> {
use v2::ProxyAddresses as Address;
use v2::ProxyCommand as Command;
use v2::ProxyTransportProtocol as Protocol;
if let Command::Local = cmd {
log::debug!("[tcp]accept proxy-protocol-v2: command = LOCAL, ignore");
return None;
}
match proto {
Protocol::Stream => {}
Protocol::Unspec => {
log::debug!("[tcp]accept proxy-protocol-v2: protocol = UNSPEC, ignore");
return None;
}
Protocol::Datagram => {
log::debug!("[tcp]accept proxy-protocol-v2: protocol = DGRAM, ignore");
return None;
}
}
match addr {
Address::Ipv4 { source, destination } => {
log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
}
Address::Ipv6 { source, destination } => {
log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
}
Address::Unspec => {
log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNSPEC, ignore");
None
}
Address::Unix { .. } => {
log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNIX, ignore");
None
}
}
}
#[cfg(feature = "tls")]
fn read_ca_certs(cert_file: &str) -> Result<Vec<CertificateDer<'static>>> {
CertificateDer::pem_file_iter(cert_file)
.map_err(|e| anyhow!(e))?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| anyhow!(e))
}