rmqtt_net/
builder.rs

1//! # MQTT Server Implementation
2//!
3//! ## Overall Example
4//!
5//! ```rust,no_run
6//! use std::net::{Ipv4Addr, SocketAddr};
7//! use std::time::Duration;
8//!
9//! #[tokio::main]
10//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
11//!     // Create server configuration
12//!     let builder = rmqtt_net::Builder::new()
13//!         .name("MyMQTTBroker")
14//!         .laddr(SocketAddr::from((Ipv4Addr::LOCALHOST, 1883)))
15//!         .max_connections(5000);
16//!
17//!     // Bind TCP listener
18//!     let listener = builder.bind()?;
19//!
20//!     // Accept and handle connections
21//!     loop {
22//!         let acceptor = listener.accept().await?;
23//!         tokio::spawn(async move {
24//!             let dispatcher = acceptor.tcp().unwrap();
25//!             // Handle MQTT protocol...
26//!         });
27//!     }
28//!     Ok(())
29//! }
30//! ```
31
32use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
33use std::num::{NonZeroU16, NonZeroU32};
34use std::sync::Arc;
35use std::time::Duration;
36
37use anyhow::anyhow;
38use nonzero_ext::nonzero;
39use rmqtt_codec::types::QoS;
40#[cfg(not(target_os = "windows"))]
41#[cfg(feature = "tls")]
42use rustls::crypto::aws_lc_rs as provider;
43#[cfg(feature = "tls")]
44#[cfg(target_os = "windows")]
45use rustls::crypto::ring as provider;
46#[cfg(feature = "tls")]
47use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
48use socket2::{Domain, SockAddr, Socket, Type};
49use tokio::io::{AsyncRead, AsyncWrite};
50use tokio::net::{TcpListener, TcpStream};
51#[cfg(feature = "tls")]
52use tokio_rustls::{server::TlsStream, TlsAcceptor};
53#[cfg(feature = "ws")]
54use tokio_tungstenite::{
55    accept_hdr_async,
56    tungstenite::handshake::server::{ErrorResponse, Request, Response},
57};
58
59use crate::stream::Dispatcher;
60#[cfg(feature = "ws")]
61use crate::ws::WsStream;
62use crate::{Error, Result};
63
64/// Configuration builder for MQTT server instances
65#[derive(Clone, Debug)]
66pub struct Builder {
67    /// Server identifier for logging and monitoring
68    pub name: String,
69    /// Network address to listen on
70    pub laddr: SocketAddr,
71    /// Maximum number of pending connections in the accept queue
72    pub backlog: i32,
73    /// Enable TCP_NODELAY option for lower latency
74    pub nodelay: bool,
75    /// Set SO_REUSEADDR socket option
76    pub reuseaddr: Option<bool>,
77    /// Set SO_REUSEPORT socket option
78    pub reuseport: Option<bool>,
79    /// Maximum concurrent active connections
80    pub max_connections: usize,
81    /// Maximum simultaneous handshakes during connection setup
82    pub max_handshaking_limit: usize,
83    /// Maximum allowed MQTT packet size in bytes (0 = unlimited)
84    pub max_packet_size: u32,
85
86    /// Allow unauthenticated client connections
87    pub allow_anonymous: bool,
88    /// Minimum acceptable keepalive value in seconds
89    pub min_keepalive: u16,
90    /// Maximum acceptable keepalive value in seconds
91    pub max_keepalive: u16,
92    /// Allow clients to disable keepalive mechanism
93    pub allow_zero_keepalive: bool,
94    /// Multiplier for calculating actual keepalive timeout
95    pub keepalive_backoff: f32,
96    /// Window size for unacknowledged QoS 1/2 messages
97    pub max_inflight: NonZeroU16,
98    /// Timeout for completing connection handshake
99    pub handshake_timeout: Duration,
100    /// Network I/O timeout for sending operations
101    pub send_timeout: Duration,
102    /// Maximum messages queued per client
103    pub max_mqueue_len: usize,
104    /// Rate limiting for message delivery (messages per duration)
105    pub mqueue_rate_limit: (NonZeroU32, Duration),
106    /// Maximum length of client identifiers
107    pub max_clientid_len: usize,
108    /// Highest QoS level permitted for publishing
109    pub max_qos_allowed: QoS,
110    /// Maximum depth for topic hierarchy (0 = unlimited)
111    pub max_topic_levels: usize,
112    /// Duration before inactive sessions expire
113    pub session_expiry_interval: Duration,
114    /// The upper limit for how long a session can remain valid before it must expire,
115    /// regardless of the client's requested session expiry interval. (0 = unlimited)
116    pub max_session_expiry_interval: Duration,
117    /// Retry interval for unacknowledged messages
118    pub message_retry_interval: Duration,
119    /// Time-to-live for undelivered messages
120    pub message_expiry_interval: Duration,
121    /// Maximum subscriptions per client (0 = unlimited)
122    pub max_subscriptions: usize,
123    /// Enable shared subscription support
124    pub shared_subscription: bool,
125    /// Maximum topic aliases (MQTTv5 feature)
126    pub max_topic_aliases: u16,
127    /// Enable subscription count limiting
128    pub limit_subscription: bool,
129    /// Enable future-dated message publishing
130    pub delayed_publish: bool,
131
132    /// Enable mutual TLS authentication
133    pub tls_cross_certificate: bool,
134    /// Path to TLS certificate chain
135    pub tls_cert: Option<String>,
136    /// Path to TLS private key
137    pub tls_key: Option<String>,
138}
139
140impl Default for Builder {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// # Examples
147/// ```
148/// use std::net::SocketAddr;
149/// use rmqtt_net::Builder;
150///
151/// let builder = Builder::new()
152///     .name("EdgeBroker")
153///     .laddr("127.0.0.1:1883".parse().unwrap())
154///     .max_connections(10_000);
155/// ```
156impl Builder {
157    /// Creates a new builder with default configuration values
158    pub fn new() -> Builder {
159        Builder {
160            name: Default::default(),
161            laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
162            max_connections: 1_000_000,
163            max_handshaking_limit: 1_000,
164            max_packet_size: 1024 * 1024,
165            backlog: 512,
166            nodelay: false,
167            reuseaddr: None,
168            reuseport: None,
169
170            allow_anonymous: true,
171            min_keepalive: 0,
172            max_keepalive: 65535,
173            allow_zero_keepalive: true,
174            keepalive_backoff: 0.75,
175            max_inflight: nonzero!(16u16),
176            handshake_timeout: Duration::from_secs(30),
177            send_timeout: Duration::from_secs(10),
178            max_mqueue_len: 1000,
179
180            mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
181            max_clientid_len: 65535,
182            max_qos_allowed: QoS::ExactlyOnce,
183            max_topic_levels: 0,
184            session_expiry_interval: Duration::from_secs(2 * 60 * 60),
185            max_session_expiry_interval: Duration::ZERO,
186            message_retry_interval: Duration::from_secs(20),
187            message_expiry_interval: Duration::from_secs(5 * 60),
188            max_subscriptions: 0,
189            shared_subscription: true,
190            max_topic_aliases: 0,
191
192            limit_subscription: false,
193            delayed_publish: false,
194
195            tls_cross_certificate: false,
196            tls_cert: None,
197            tls_key: None,
198        }
199    }
200
201    /// Sets the server name identifier
202    pub fn name<N: Into<String>>(mut self, name: N) -> Self {
203        self.name = name.into();
204        self
205    }
206
207    /// Configures the network listen address
208    pub fn laddr(mut self, laddr: SocketAddr) -> Self {
209        self.laddr = laddr;
210        self
211    }
212
213    /// Sets the TCP backlog size
214    pub fn backlog(mut self, backlog: i32) -> Self {
215        self.backlog = backlog;
216        self
217    }
218
219    /// Enables/disables TCP_NODELAY option
220    pub fn nodelay(mut self, nodelay: bool) -> Self {
221        self.nodelay = nodelay;
222        self
223    }
224
225    /// Configures SO_REUSEADDR socket option
226    pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
227        self.reuseaddr = reuseaddr;
228        self
229    }
230
231    /// Configures SO_REUSEPORT socket option
232    pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
233        self.reuseport = reuseport;
234        self
235    }
236
237    /// Sets maximum concurrent connections
238    pub fn max_connections(mut self, max_connections: usize) -> Self {
239        self.max_connections = max_connections;
240        self
241    }
242
243    /// Sets maximum concurrent handshakes
244    pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
245        self.max_handshaking_limit = max_handshaking_limit;
246        self
247    }
248
249    /// Configures maximum MQTT packet size
250    pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
251        self.max_packet_size = max_packet_size;
252        self
253    }
254
255    /// Enables anonymous client access
256    pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
257        self.allow_anonymous = allow_anonymous;
258        self
259    }
260
261    /// Sets minimum acceptable keepalive value
262    pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
263        self.min_keepalive = min_keepalive;
264        self
265    }
266
267    /// Sets maximum acceptable keepalive value
268    pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
269        self.max_keepalive = max_keepalive;
270        self
271    }
272
273    /// Allows clients to disable keepalive
274    pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
275        self.allow_zero_keepalive = allow_zero_keepalive;
276        self
277    }
278
279    /// Configures keepalive backoff multiplier
280    pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
281        self.keepalive_backoff = keepalive_backoff;
282        self
283    }
284
285    /// Sets inflight message window size
286    pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
287        self.max_inflight = max_inflight;
288        self
289    }
290
291    /// Configures handshake timeout duration
292    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
293        self.handshake_timeout = handshake_timeout;
294        self
295    }
296
297    /// Sets network send timeout duration
298    pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
299        self.send_timeout = send_timeout;
300        self
301    }
302
303    /// Configures maximum message queue length
304    pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
305        self.max_mqueue_len = max_mqueue_len;
306        self
307    }
308
309    /// Sets message rate limiting parameters
310    pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
311        self.mqueue_rate_limit = (rate_limit, duration);
312        self
313    }
314
315    /// Sets maximum client ID length
316    pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
317        self.max_clientid_len = max_clientid_len;
318        self
319    }
320
321    /// Configures maximum allowed QoS level
322    pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
323        self.max_qos_allowed = max_qos_allowed;
324        self
325    }
326
327    /// Sets maximum topic hierarchy depth
328    pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
329        self.max_topic_levels = max_topic_levels;
330        self
331    }
332
333    /// Configures session expiration interval
334    pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
335        self.session_expiry_interval = session_expiry_interval;
336        self
337    }
338
339    /// Configures max session expiration interval
340    pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
341        self.max_session_expiry_interval = max_session_expiry_interval;
342        self
343    }
344
345    /// Sets message retry interval for QoS 1/2
346    pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
347        self.message_retry_interval = message_retry_interval;
348        self
349    }
350
351    /// Configures message expiration time
352    pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
353        self.message_expiry_interval = message_expiry_interval;
354        self
355    }
356
357    /// Sets maximum subscriptions per client
358    pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
359        self.max_subscriptions = max_subscriptions;
360        self
361    }
362
363    /// Enables shared subscription support
364    pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
365        self.shared_subscription = shared_subscription;
366        self
367    }
368
369    /// Configures maximum topic aliases (MQTTv5)
370    pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
371        self.max_topic_aliases = max_topic_aliases;
372        self
373    }
374
375    /// Enables subscription count limiting
376    pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
377        self.limit_subscription = limit_subscription;
378        self
379    }
380
381    /// Enables delayed message publishing
382    pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
383        self.delayed_publish = delayed_publish;
384        self
385    }
386
387    /// Enables mutual TLS authentication
388    pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
389        self.tls_cross_certificate = cross_certificate;
390        self
391    }
392
393    /// Sets path to TLS certificate chain
394    pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
395        self.tls_cert = tls_cert.map(|c| c.into());
396        self
397    }
398
399    /// Sets path to TLS private key
400    pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
401        self.tls_key = tls_key.map(|c| c.into());
402        self
403    }
404
405    /// Binds the server to the configured address
406    #[allow(unused_variables)]
407    pub fn bind(self) -> Result<Listener> {
408        let builder = match self.laddr {
409            SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
410            SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
411        };
412
413        builder.set_linger(Some(Duration::from_secs(10)))?;
414
415        builder.set_nonblocking(true)?;
416
417        if let Some(reuseaddr) = self.reuseaddr {
418            builder.set_reuse_address(reuseaddr)?;
419        }
420
421        #[cfg(not(windows))]
422        if let Some(reuseport) = self.reuseport {
423            builder.set_reuse_port(reuseport)?;
424        }
425
426        builder.bind(&SockAddr::from(self.laddr))?;
427        builder.listen(self.backlog)?;
428        let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
429
430        log::info!(
431            "MQTT Broker Listening on {} {}",
432            self.name,
433            tcp_listener.local_addr().unwrap_or(self.laddr)
434        );
435        Ok(Listener {
436            typ: ListenerType::TCP,
437            cfg: Arc::new(self),
438            tcp_listener,
439            #[cfg(feature = "tls")]
440            tls_acceptor: None,
441        })
442    }
443}
444
445/// Protocol variants for network listeners
446#[derive(Debug, Copy, Clone)]
447pub enum ListenerType {
448    /// Plain TCP listener
449    TCP,
450    #[cfg(feature = "tls")]
451    /// TLS-secured TCP listener
452    TLS,
453    #[cfg(feature = "ws")]
454    /// WebSocket listener
455    WS,
456    #[cfg(feature = "tls")]
457    #[cfg(feature = "ws")]
458    /// TLS-secured WebSocket listener
459    WSS,
460}
461
462/// Network listener for accepting client connections
463pub struct Listener {
464    /// Active listener protocol type
465    pub typ: ListenerType,
466    /// Shared server configuration
467    pub cfg: Arc<Builder>,
468    tcp_listener: TcpListener,
469    #[cfg(feature = "tls")]
470    tls_acceptor: Option<TlsAcceptor>,
471}
472
473/// # Examples
474/// ```
475/// # use rmqtt_net::{Builder, Listener};
476/// # fn setup() -> Result<(), Box<dyn std::error::Error>> {
477/// let builder = Builder::new();
478/// let listener = builder.bind()?;
479/// # Ok(())
480/// # }
481/// ```
482impl Listener {
483    /// Converts listener to plain TCP mode
484    pub fn tcp(mut self) -> Result<Self> {
485        let _err = anyhow!("Protocol downgrade from TLS/WS/WSS to TCP is not permitted");
486        #[cfg(feature = "tls")]
487        if matches!(self.typ, ListenerType::TLS) {
488            return Err(_err);
489        }
490        #[cfg(feature = "tls")]
491        #[cfg(feature = "ws")]
492        if matches!(self.typ, ListenerType::WSS) {
493            return Err(_err);
494        }
495        #[cfg(feature = "ws")]
496        if matches!(self.typ, ListenerType::WS) {
497            return Err(_err);
498        }
499        self.typ = ListenerType::TCP;
500        Ok(self)
501    }
502
503    #[cfg(feature = "ws")]
504    /// Upgrades listener to WebSocket protocol
505    pub fn ws(mut self) -> Result<Self> {
506        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
507            self.typ = ListenerType::WS;
508        } else {
509            return Err(anyhow!("Protocol upgrade from TLS/WSS to WS is not permitted"));
510        }
511        Ok(self)
512    }
513
514    #[cfg(feature = "tls")]
515    #[cfg(feature = "ws")]
516    /// Upgrades listener to secure WebSocket (WSS)
517    pub fn wss(mut self) -> Result<Self> {
518        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
519            self = self.tls()?;
520        }
521        self.typ = ListenerType::WSS;
522        Ok(self)
523    }
524
525    #[cfg(feature = "tls")]
526    /// Upgrades listener to TLS-secured TCP
527    pub fn tls(mut self) -> Result<Listener> {
528        match self.typ {
529            #[cfg(feature = "ws")]
530            ListenerType::WS | ListenerType::WSS => {
531                return Err(anyhow!("Protocol downgrade from WS/WSS to TLS is not permitted"));
532            }
533            ListenerType::TLS => return Ok(self),
534            ListenerType::TCP => {}
535        }
536
537        let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
538        let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
539
540        let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
541            .map_err(|e| anyhow!(e))?
542            .collect::<std::result::Result<Vec<_>, _>>()
543            .map_err(|e| anyhow!(e))?;
544        let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
545
546        let provider = Arc::new(provider::default_provider());
547        let client_auth = if self.cfg.tls_cross_certificate {
548            let root_chain = cert_chain.clone();
549            let mut client_auth_roots = RootCertStore::empty();
550            for root in root_chain {
551                client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
552            }
553            WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
554                .build()
555                .map_err(|e| anyhow!(e))?
556        } else {
557            WebPkiClientVerifier::no_client_auth()
558        };
559
560        let tls_config = ServerConfig::builder_with_provider(provider)
561            .with_safe_default_protocol_versions()
562            .map_err(|e| anyhow!(e))?
563            .with_client_cert_verifier(client_auth)
564            .with_single_cert(cert_chain, key)
565            .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
566
567        let acceptor = TlsAcceptor::from(Arc::new(tls_config));
568        self.tls_acceptor = Some(acceptor);
569        self.typ = ListenerType::TLS;
570        Ok(self)
571    }
572
573    /// Accepts incoming client connections
574    pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
575        let (socket, remote_addr) = self.tcp_listener.accept().await?;
576        if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
577            return Err(Error::from(e));
578        }
579        Ok(Acceptor {
580            socket,
581            remote_addr,
582            #[cfg(feature = "tls")]
583            acceptor: self.tls_acceptor.clone(),
584            cfg: self.cfg.clone(),
585            typ: self.typ,
586        })
587    }
588
589    pub fn local_addr(&self) -> Result<SocketAddr> {
590        Ok(self.tcp_listener.local_addr()?)
591    }
592}
593
594/// Connection handler for processing client streams
595pub struct Acceptor<S> {
596    /// Underlying network transport
597    pub(crate) socket: S,
598    #[cfg(feature = "tls")]
599    acceptor: Option<TlsAcceptor>,
600    /// Remote client address
601    pub remote_addr: SocketAddr,
602    /// Shared server configuration
603    pub cfg: Arc<Builder>,
604    /// Active protocol type
605    pub typ: ListenerType,
606}
607
608impl<S> Acceptor<S>
609where
610    S: AsyncRead + AsyncWrite + Unpin,
611{
612    /// Creates TCP protocol dispatcher
613    #[inline]
614    pub fn tcp(self) -> Result<Dispatcher<S>> {
615        if matches!(self.typ, ListenerType::TCP) {
616            Ok(Dispatcher::new(self.socket, self.remote_addr, self.cfg))
617        } else {
618            Err(anyhow!("Protocol mismatch: Expected TCP listener"))
619        }
620    }
621
622    #[cfg(feature = "tls")]
623    /// Performs TLS handshake and creates secure dispatcher
624    #[inline]
625    pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
626        if !matches!(self.typ, ListenerType::TLS) {
627            return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
628        }
629
630        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
631        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
632        {
633            Ok(Ok(tls_s)) => tls_s,
634            Ok(Err(e)) => return Err(e.into()),
635            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
636        };
637        Ok(Dispatcher::new(tls_s, self.remote_addr, self.cfg))
638    }
639
640    #[cfg(feature = "ws")]
641    /// Performs WebSocket upgrade and creates WS dispatcher
642    #[inline]
643    pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
644        if !matches!(self.typ, ListenerType::WS) {
645            return Err(anyhow!("Protocol mismatch: Expected WS listener"));
646        }
647
648        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
649            .await
650        {
651            Ok(Ok(ws_stream)) => {
652                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, self.cfg.clone()))
653            }
654            Ok(Err(e)) => Err(e.into()),
655            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
656        }
657    }
658
659    #[cfg(feature = "tls")]
660    #[cfg(feature = "ws")]
661    /// Performs TLS handshake and WebSocket upgrade
662    #[inline]
663    pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
664        if !matches!(self.typ, ListenerType::WSS) {
665            return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
666        }
667
668        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
669        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
670        {
671            Ok(Ok(tls_s)) => tls_s,
672            Ok(Err(e)) => return Err(e.into()),
673            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
674        };
675        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
676            Ok(Ok(ws_stream)) => {
677                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, self.cfg.clone()))
678            }
679            Ok(Err(e)) => Err(e.into()),
680            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
681        }
682    }
683}
684
685#[allow(clippy::result_large_err)]
686#[cfg(feature = "ws")]
687/// Validates WebSocket handshake requests for MQTT protocol
688fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
689    const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
690    let mqtt_protocol = req
691        .headers()
692        .get("Sec-WebSocket-Protocol")
693        .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
694    if mqtt_protocol != "mqtt" {
695        return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
696    }
697    response.headers_mut().append(
698        "Sec-WebSocket-Protocol",
699        "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
700    );
701    Ok(response)
702}