rmqtt_net/
builder.rs

1//! # MQTT Server Implementation
2//!
3//! ## Overall Example
4//!
5//! ```rust,no_run
6//! use std::net::{Ipv4Addr, SocketAddr};
7//! use std::time::Duration;
8//!
9//! #[tokio::main]
10//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
11//!     // Create server configuration
12//!     let builder = rmqtt_net::Builder::new()
13//!         .name("MyMQTTBroker")
14//!         .laddr(SocketAddr::from((Ipv4Addr::LOCALHOST, 1883)))
15//!         .max_connections(5000);
16//!
17//!     // Bind TCP listener
18//!     let listener = builder.bind()?;
19//!
20//!     // Accept and handle connections
21//!     loop {
22//!         let acceptor = listener.accept().await?;
23//!         tokio::spawn(async move {
24//!             let dispatcher = acceptor.tcp().unwrap();
25//!             // Handle MQTT protocol...
26//!         });
27//!     }
28//!     Ok(())
29//! }
30//! ```
31
32use proxy_protocol::parse;
33use proxy_protocol::ProxyHeader;
34use proxy_protocol::{version1 as v1, version2 as v2};
35use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
36use std::num::{NonZeroU16, NonZeroU32};
37use std::sync::Arc;
38use std::time::Duration;
39
40use anyhow::anyhow;
41use nonzero_ext::nonzero;
42use rmqtt_codec::types::QoS;
43#[cfg(not(target_os = "windows"))]
44#[cfg(feature = "tls")]
45use rustls::crypto::aws_lc_rs as provider;
46#[cfg(feature = "tls")]
47#[cfg(target_os = "windows")]
48use rustls::crypto::ring as provider;
49#[cfg(feature = "tls")]
50use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
51use socket2::{Domain, SockAddr, Socket, Type};
52use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
53use tokio::net::{TcpListener, TcpStream};
54#[cfg(feature = "tls")]
55use tokio_rustls::{server::TlsStream, TlsAcceptor};
56#[cfg(feature = "ws")]
57use tokio_tungstenite::{
58    accept_hdr_async,
59    tungstenite::handshake::server::{ErrorResponse, Request, Response},
60};
61
62use crate::stream::Dispatcher;
63#[cfg(feature = "ws")]
64use crate::ws::WsStream;
65use crate::{Error, Result};
66
67/// Configuration builder for MQTT server instances
68#[derive(Clone, Debug)]
69pub struct Builder {
70    /// Server identifier for logging and monitoring
71    pub name: String,
72    /// Network address to listen on
73    pub laddr: SocketAddr,
74    /// Maximum number of pending connections in the accept queue
75    pub backlog: i32,
76    /// Enable TCP_NODELAY option for lower latency
77    pub nodelay: bool,
78    /// Set SO_REUSEADDR socket option
79    pub reuseaddr: Option<bool>,
80    /// Set SO_REUSEPORT socket option
81    pub reuseport: Option<bool>,
82    /// Maximum concurrent active connections
83    pub max_connections: usize,
84    /// Maximum simultaneous handshakes during connection setup
85    pub max_handshaking_limit: usize,
86    /// Maximum allowed MQTT packet size in bytes (0 = unlimited)
87    pub max_packet_size: u32,
88
89    /// Allow unauthenticated client connections
90    pub allow_anonymous: bool,
91    /// Minimum acceptable keepalive value in seconds
92    pub min_keepalive: u16,
93    /// Maximum acceptable keepalive value in seconds
94    pub max_keepalive: u16,
95    /// Allow clients to disable keepalive mechanism
96    pub allow_zero_keepalive: bool,
97    /// Multiplier for calculating actual keepalive timeout
98    pub keepalive_backoff: f32,
99    /// Window size for unacknowledged QoS 1/2 messages
100    pub max_inflight: NonZeroU16,
101    /// Timeout for completing connection handshake
102    pub handshake_timeout: Duration,
103    /// Network I/O timeout for sending operations
104    pub send_timeout: Duration,
105    /// Maximum messages queued per client
106    pub max_mqueue_len: usize,
107    /// Rate limiting for message delivery (messages per duration)
108    pub mqueue_rate_limit: (NonZeroU32, Duration),
109    /// Maximum length of client identifiers
110    pub max_clientid_len: usize,
111    /// Highest QoS level permitted for publishing
112    pub max_qos_allowed: QoS,
113    /// Maximum depth for topic hierarchy (0 = unlimited)
114    pub max_topic_levels: usize,
115    /// Duration before inactive sessions expire
116    pub session_expiry_interval: Duration,
117    /// The upper limit for how long a session can remain valid before it must expire,
118    /// regardless of the client's requested session expiry interval. (0 = unlimited)
119    pub max_session_expiry_interval: Duration,
120    /// Retry interval for unacknowledged messages
121    pub message_retry_interval: Duration,
122    /// Time-to-live for undelivered messages
123    pub message_expiry_interval: Duration,
124    /// Maximum subscriptions per client (0 = unlimited)
125    pub max_subscriptions: usize,
126    /// Enable shared subscription support
127    pub shared_subscription: bool,
128    /// Maximum topic aliases (MQTTv5 feature)
129    pub max_topic_aliases: u16,
130    /// Enable subscription count limiting
131    pub limit_subscription: bool,
132    /// Enable future-dated message publishing
133    pub delayed_publish: bool,
134
135    /// Enable mutual TLS authentication
136    pub tls_cross_certificate: bool,
137    /// Path to TLS certificate chain
138    pub tls_cert: Option<String>,
139    /// Path to TLS private key
140    pub tls_key: Option<String>,
141    /// Enable Proxy Protocol
142    pub proxy_protocol: bool,
143    /// Proxy Protocol timeout
144    pub proxy_protocol_timeout: Duration,
145}
146
147impl Default for Builder {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// # Examples
154/// ```
155/// use std::net::SocketAddr;
156/// use rmqtt_net::Builder;
157///
158/// let builder = Builder::new()
159///     .name("EdgeBroker")
160///     .laddr("127.0.0.1:1883".parse().unwrap())
161///     .max_connections(10_000);
162/// ```
163impl Builder {
164    /// Creates a new builder with default configuration values
165    pub fn new() -> Builder {
166        Builder {
167            name: Default::default(),
168            laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
169            max_connections: 1_000_000,
170            max_handshaking_limit: 1_000,
171            max_packet_size: 1024 * 1024,
172            backlog: 512,
173            nodelay: false,
174            reuseaddr: None,
175            reuseport: None,
176
177            allow_anonymous: true,
178            min_keepalive: 0,
179            max_keepalive: 65535,
180            allow_zero_keepalive: true,
181            keepalive_backoff: 0.75,
182            max_inflight: nonzero!(16u16),
183            handshake_timeout: Duration::from_secs(30),
184            send_timeout: Duration::from_secs(10),
185            max_mqueue_len: 1000,
186
187            mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
188            max_clientid_len: 65535,
189            max_qos_allowed: QoS::ExactlyOnce,
190            max_topic_levels: 0,
191            session_expiry_interval: Duration::from_secs(2 * 60 * 60),
192            max_session_expiry_interval: Duration::ZERO,
193            message_retry_interval: Duration::from_secs(20),
194            message_expiry_interval: Duration::from_secs(5 * 60),
195            max_subscriptions: 0,
196            shared_subscription: true,
197            max_topic_aliases: 0,
198
199            limit_subscription: false,
200            delayed_publish: false,
201
202            tls_cross_certificate: false,
203            tls_cert: None,
204            tls_key: None,
205            proxy_protocol: false,
206            proxy_protocol_timeout: Duration::from_secs(5),
207        }
208    }
209
210    /// Sets the server name identifier
211    pub fn name<N: Into<String>>(mut self, name: N) -> Self {
212        self.name = name.into();
213        self
214    }
215
216    /// Configures the network listen address
217    pub fn laddr(mut self, laddr: SocketAddr) -> Self {
218        self.laddr = laddr;
219        self
220    }
221
222    /// Sets the TCP backlog size
223    pub fn backlog(mut self, backlog: i32) -> Self {
224        self.backlog = backlog;
225        self
226    }
227
228    /// Enables/disables TCP_NODELAY option
229    pub fn nodelay(mut self, nodelay: bool) -> Self {
230        self.nodelay = nodelay;
231        self
232    }
233
234    /// Configures SO_REUSEADDR socket option
235    pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
236        self.reuseaddr = reuseaddr;
237        self
238    }
239
240    /// Configures SO_REUSEPORT socket option
241    pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
242        self.reuseport = reuseport;
243        self
244    }
245
246    /// Sets maximum concurrent connections
247    pub fn max_connections(mut self, max_connections: usize) -> Self {
248        self.max_connections = max_connections;
249        self
250    }
251
252    /// Sets maximum concurrent handshakes
253    pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
254        self.max_handshaking_limit = max_handshaking_limit;
255        self
256    }
257
258    /// Configures maximum MQTT packet size
259    pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
260        self.max_packet_size = max_packet_size;
261        self
262    }
263
264    /// Enables anonymous client access
265    pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
266        self.allow_anonymous = allow_anonymous;
267        self
268    }
269
270    /// Sets minimum acceptable keepalive value
271    pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
272        self.min_keepalive = min_keepalive;
273        self
274    }
275
276    /// Sets maximum acceptable keepalive value
277    pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
278        self.max_keepalive = max_keepalive;
279        self
280    }
281
282    /// Allows clients to disable keepalive
283    pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
284        self.allow_zero_keepalive = allow_zero_keepalive;
285        self
286    }
287
288    /// Configures keepalive backoff multiplier
289    pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
290        self.keepalive_backoff = keepalive_backoff;
291        self
292    }
293
294    /// Sets inflight message window size
295    pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
296        self.max_inflight = max_inflight;
297        self
298    }
299
300    /// Configures handshake timeout duration
301    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
302        self.handshake_timeout = handshake_timeout;
303        self
304    }
305
306    /// Sets network send timeout duration
307    pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
308        self.send_timeout = send_timeout;
309        self
310    }
311
312    /// Configures maximum message queue length
313    pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
314        self.max_mqueue_len = max_mqueue_len;
315        self
316    }
317
318    /// Sets message rate limiting parameters
319    pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
320        self.mqueue_rate_limit = (rate_limit, duration);
321        self
322    }
323
324    /// Sets maximum client ID length
325    pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
326        self.max_clientid_len = max_clientid_len;
327        self
328    }
329
330    /// Configures maximum allowed QoS level
331    pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
332        self.max_qos_allowed = max_qos_allowed;
333        self
334    }
335
336    /// Sets maximum topic hierarchy depth
337    pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
338        self.max_topic_levels = max_topic_levels;
339        self
340    }
341
342    /// Configures session expiration interval
343    pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
344        self.session_expiry_interval = session_expiry_interval;
345        self
346    }
347
348    /// Configures max session expiration interval
349    pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
350        self.max_session_expiry_interval = max_session_expiry_interval;
351        self
352    }
353
354    /// Sets message retry interval for QoS 1/2
355    pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
356        self.message_retry_interval = message_retry_interval;
357        self
358    }
359
360    /// Configures message expiration time
361    pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
362        self.message_expiry_interval = message_expiry_interval;
363        self
364    }
365
366    /// Sets maximum subscriptions per client
367    pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
368        self.max_subscriptions = max_subscriptions;
369        self
370    }
371
372    /// Enables shared subscription support
373    pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
374        self.shared_subscription = shared_subscription;
375        self
376    }
377
378    /// Configures maximum topic aliases (MQTTv5)
379    pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
380        self.max_topic_aliases = max_topic_aliases;
381        self
382    }
383
384    /// Enables subscription count limiting
385    pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
386        self.limit_subscription = limit_subscription;
387        self
388    }
389
390    /// Enables delayed message publishing
391    pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
392        self.delayed_publish = delayed_publish;
393        self
394    }
395
396    /// Enables mutual TLS authentication
397    pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
398        self.tls_cross_certificate = cross_certificate;
399        self
400    }
401
402    /// Sets path to TLS certificate chain
403    pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
404        self.tls_cert = tls_cert.map(|c| c.into());
405        self
406    }
407
408    /// Sets path to TLS private key
409    pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
410        self.tls_key = tls_key.map(|c| c.into());
411        self
412    }
413
414    /// Enable proxy protocol parse
415    pub fn proxy_protocol(mut self, enable_protocol_proxy: bool) -> Self {
416        self.proxy_protocol = enable_protocol_proxy;
417        self
418    }
419
420    /// Sets proxy protocol timeout
421    pub fn proxy_protocol_timeout(mut self, proxy_protocol_timeout: Duration) -> Self {
422        self.proxy_protocol_timeout = proxy_protocol_timeout;
423        self
424    }
425
426    /// Binds the server to the configured address
427    #[allow(unused_variables)]
428    pub fn bind(self) -> Result<Listener> {
429        let builder = match self.laddr {
430            SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
431            SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
432        };
433
434        builder.set_linger(Some(Duration::from_secs(10)))?;
435
436        builder.set_nonblocking(true)?;
437
438        if let Some(reuseaddr) = self.reuseaddr {
439            builder.set_reuse_address(reuseaddr)?;
440        }
441
442        #[cfg(not(windows))]
443        if let Some(reuseport) = self.reuseport {
444            builder.set_reuse_port(reuseport)?;
445        }
446
447        builder.bind(&SockAddr::from(self.laddr))?;
448        builder.listen(self.backlog)?;
449        let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
450
451        log::info!(
452            "MQTT Broker Listening on {} {}",
453            self.name,
454            tcp_listener.local_addr().unwrap_or(self.laddr)
455        );
456        Ok(Listener {
457            typ: ListenerType::TCP,
458            cfg: Arc::new(self),
459            tcp_listener,
460            #[cfg(feature = "tls")]
461            tls_acceptor: None,
462        })
463    }
464}
465
466/// Protocol variants for network listeners
467#[derive(Debug, Copy, Clone)]
468pub enum ListenerType {
469    /// Plain TCP listener
470    TCP,
471    #[cfg(feature = "tls")]
472    /// TLS-secured TCP listener
473    TLS,
474    #[cfg(feature = "ws")]
475    /// WebSocket listener
476    WS,
477    #[cfg(feature = "tls")]
478    #[cfg(feature = "ws")]
479    /// TLS-secured WebSocket listener
480    WSS,
481}
482
483/// Network listener for accepting client connections
484pub struct Listener {
485    /// Active listener protocol type
486    pub typ: ListenerType,
487    /// Shared server configuration
488    pub cfg: Arc<Builder>,
489    tcp_listener: TcpListener,
490    #[cfg(feature = "tls")]
491    tls_acceptor: Option<TlsAcceptor>,
492}
493
494/// # Examples
495/// ```
496/// # use rmqtt_net::{Builder, Listener};
497/// # fn setup() -> Result<(), Box<dyn std::error::Error>> {
498/// let builder = Builder::new();
499/// let listener = builder.bind()?;
500/// # Ok(())
501/// # }
502/// ```
503impl Listener {
504    /// Converts listener to plain TCP mode
505    pub fn tcp(mut self) -> Result<Self> {
506        let _err = anyhow!("Protocol downgrade from TLS/WS/WSS to TCP is not permitted");
507        #[cfg(feature = "tls")]
508        if matches!(self.typ, ListenerType::TLS) {
509            return Err(_err);
510        }
511        #[cfg(feature = "tls")]
512        #[cfg(feature = "ws")]
513        if matches!(self.typ, ListenerType::WSS) {
514            return Err(_err);
515        }
516        #[cfg(feature = "ws")]
517        if matches!(self.typ, ListenerType::WS) {
518            return Err(_err);
519        }
520        self.typ = ListenerType::TCP;
521        Ok(self)
522    }
523
524    #[cfg(feature = "ws")]
525    /// Upgrades listener to WebSocket protocol
526    pub fn ws(mut self) -> Result<Self> {
527        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
528            self.typ = ListenerType::WS;
529        } else {
530            return Err(anyhow!("Protocol upgrade from TLS/WSS to WS is not permitted"));
531        }
532        Ok(self)
533    }
534
535    #[cfg(feature = "tls")]
536    #[cfg(feature = "ws")]
537    /// Upgrades listener to secure WebSocket (WSS)
538    pub fn wss(mut self) -> Result<Self> {
539        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
540            self = self.tls()?;
541        }
542        self.typ = ListenerType::WSS;
543        Ok(self)
544    }
545
546    #[cfg(feature = "tls")]
547    /// Upgrades listener to TLS-secured TCP
548    pub fn tls(mut self) -> Result<Listener> {
549        match self.typ {
550            #[cfg(feature = "ws")]
551            ListenerType::WS | ListenerType::WSS => {
552                return Err(anyhow!("Protocol downgrade from WS/WSS to TLS is not permitted"));
553            }
554            ListenerType::TLS => return Ok(self),
555            ListenerType::TCP => {}
556        }
557
558        let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
559        let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
560
561        let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
562            .map_err(|e| anyhow!(e))?
563            .collect::<std::result::Result<Vec<_>, _>>()
564            .map_err(|e| anyhow!(e))?;
565        let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
566
567        let provider = Arc::new(provider::default_provider());
568        let client_auth = if self.cfg.tls_cross_certificate {
569            let root_chain = cert_chain.clone();
570            let mut client_auth_roots = RootCertStore::empty();
571            for root in root_chain {
572                client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
573            }
574            WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
575                .build()
576                .map_err(|e| anyhow!(e))?
577        } else {
578            WebPkiClientVerifier::no_client_auth()
579        };
580
581        let tls_config = ServerConfig::builder_with_provider(provider)
582            .with_safe_default_protocol_versions()
583            .map_err(|e| anyhow!(e))?
584            .with_client_cert_verifier(client_auth)
585            .with_single_cert(cert_chain, key)
586            .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
587
588        let acceptor = TlsAcceptor::from(Arc::new(tls_config));
589        self.tls_acceptor = Some(acceptor);
590        self.typ = ListenerType::TLS;
591        Ok(self)
592    }
593
594    /// Accepts incoming client connections
595    pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
596        let (mut socket, mut remote_addr) = self.tcp_listener.accept().await?;
597        if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
598            return Err(Error::from(e));
599        }
600        if self.cfg.proxy_protocol {
601            let mut buffer = [0u8; u16::MAX as usize];
602            let read_bytes =
603                tokio::time::timeout(self.cfg.proxy_protocol_timeout, socket.peek(&mut buffer)).await??;
604            let len = {
605                let mut slice = &buffer[..read_bytes];
606                let header = parse(&mut slice)?;
607                if let Some((src, _)) = handle_header(header) {
608                    remote_addr = src;
609                }
610                read_bytes - slice.len()
611            };
612            // skip proxy protocol data
613            let _ = socket.read_exact(&mut buffer[..len]).await;
614        }
615        Ok(Acceptor {
616            socket,
617            remote_addr,
618            #[cfg(feature = "tls")]
619            acceptor: self.tls_acceptor.clone(),
620            cfg: self.cfg.clone(),
621            typ: self.typ,
622        })
623    }
624
625    pub fn local_addr(&self) -> Result<SocketAddr> {
626        Ok(self.tcp_listener.local_addr()?)
627    }
628}
629
630/// Connection handler for processing client streams
631pub struct Acceptor<S> {
632    /// Underlying network transport
633    pub(crate) socket: S,
634    #[cfg(feature = "tls")]
635    acceptor: Option<TlsAcceptor>,
636    /// Remote client address
637    pub remote_addr: SocketAddr,
638    /// Shared server configuration
639    pub cfg: Arc<Builder>,
640    /// Active protocol type
641    pub typ: ListenerType,
642}
643
644impl<S> Acceptor<S>
645where
646    S: AsyncRead + AsyncWrite + Unpin,
647{
648    /// Creates TCP protocol dispatcher
649    #[inline]
650    pub fn tcp(self) -> Result<Dispatcher<S>> {
651        if matches!(self.typ, ListenerType::TCP) {
652            Ok(Dispatcher::new(self.socket, self.remote_addr, self.cfg))
653        } else {
654            Err(anyhow!("Protocol mismatch: Expected TCP listener"))
655        }
656    }
657
658    #[cfg(feature = "tls")]
659    /// Performs TLS handshake and creates secure dispatcher
660    #[inline]
661    pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
662        if !matches!(self.typ, ListenerType::TLS) {
663            return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
664        }
665
666        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
667        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
668        {
669            Ok(Ok(tls_s)) => tls_s,
670            Ok(Err(e)) => return Err(e.into()),
671            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
672        };
673        Ok(Dispatcher::new(tls_s, self.remote_addr, self.cfg))
674    }
675
676    #[cfg(feature = "ws")]
677    /// Performs WebSocket upgrade and creates WS dispatcher
678    #[inline]
679    pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
680        if !matches!(self.typ, ListenerType::WS) {
681            return Err(anyhow!("Protocol mismatch: Expected WS listener"));
682        }
683
684        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
685            .await
686        {
687            Ok(Ok(ws_stream)) => {
688                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, self.cfg.clone()))
689            }
690            Ok(Err(e)) => Err(e.into()),
691            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
692        }
693    }
694
695    #[cfg(feature = "tls")]
696    #[cfg(feature = "ws")]
697    /// Performs TLS handshake and WebSocket upgrade
698    #[inline]
699    pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
700        if !matches!(self.typ, ListenerType::WSS) {
701            return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
702        }
703
704        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
705        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
706        {
707            Ok(Ok(tls_s)) => tls_s,
708            Ok(Err(e)) => return Err(e.into()),
709            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
710        };
711        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
712            Ok(Ok(ws_stream)) => {
713                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, self.cfg.clone()))
714            }
715            Ok(Err(e)) => Err(e.into()),
716            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
717        }
718    }
719}
720
721#[allow(clippy::result_large_err)]
722#[cfg(feature = "ws")]
723/// Validates WebSocket handshake requests for MQTT protocol
724fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
725    const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
726    let mqtt_protocol = req
727        .headers()
728        .get("Sec-WebSocket-Protocol")
729        .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
730    if mqtt_protocol != "mqtt" {
731        return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
732    }
733    response.headers_mut().append(
734        "Sec-WebSocket-Protocol",
735        "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
736    );
737    Ok(response)
738}
739
740// from https://github.com/zhboner/realm/blob/master/realm_core/src/tcp/proxy.rs
741fn handle_header(header: ProxyHeader) -> Option<(SocketAddr, SocketAddr)> {
742    use ProxyHeader::{Version1, Version2};
743    match header {
744        Version1 { addresses } => handle_header_v1(addresses),
745        Version2 { command, transport_protocol, addresses } => {
746            handle_header_v2(command, transport_protocol, addresses)
747        }
748        _ => {
749            log::info!("[tcp]accept proxy-protocol-v?");
750            None
751        }
752    }
753}
754
755fn handle_header_v1(addr: v1::ProxyAddresses) -> Option<(SocketAddr, SocketAddr)> {
756    use v1::ProxyAddresses::*;
757    match addr {
758        Unknown => {
759            log::debug!("[tcp]accept proxy-protocol-v1: unknown");
760            None
761        }
762        Ipv4 { source, destination } => {
763            log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
764            Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
765        }
766        Ipv6 { source, destination } => {
767            log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
768            Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
769        }
770    }
771}
772
773fn handle_header_v2(
774    cmd: v2::ProxyCommand,
775    proto: v2::ProxyTransportProtocol,
776    addr: v2::ProxyAddresses,
777) -> Option<(SocketAddr, SocketAddr)> {
778    use v2::ProxyAddresses as Address;
779    use v2::ProxyCommand as Command;
780    use v2::ProxyTransportProtocol as Protocol;
781
782    // The connection endpoints are the sender and the receiver.
783    // Such connections exist when the proxy sends health-checks to the server.
784    // The receiver must accept this connection as valid and must use the
785    // real connection endpoints and discard the protocol block including the
786    // family which is ignored
787    if let Command::Local = cmd {
788        log::debug!("[tcp]accept proxy-protocol-v2: command = LOCAL, ignore");
789        return None;
790    }
791
792    // only get tcp address
793    match proto {
794        Protocol::Stream => {}
795        Protocol::Unspec => {
796            log::debug!("[tcp]accept proxy-protocol-v2: protocol = UNSPEC, ignore");
797            return None;
798        }
799        Protocol::Datagram => {
800            log::debug!("[tcp]accept proxy-protocol-v2: protocol = DGRAM, ignore");
801            return None;
802        }
803    }
804
805    match addr {
806        Address::Ipv4 { source, destination } => {
807            log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
808            Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
809        }
810        Address::Ipv6 { source, destination } => {
811            log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
812            Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
813        }
814        Address::Unspec => {
815            log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNSPEC, ignore");
816            None
817        }
818        Address::Unix { .. } => {
819            log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNIX, ignore");
820            None
821        }
822    }
823}