rmqtt_net/
builder.rs

1//! # MQTT Server Implementation
2//!
3//! ## Overall Example
4//!
5//! ```rust,no_run
6//! use std::net::{Ipv4Addr, SocketAddr};
7//! use std::time::Duration;
8//!
9//! #[tokio::main]
10//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
11//!     // Create server configuration
12//!     let builder = rmqtt_net::Builder::new()
13//!         .name("MyMQTTBroker")
14//!         .laddr(SocketAddr::from((Ipv4Addr::LOCALHOST, 1883)))
15//!         .max_connections(5000);
16//!
17//!     // Bind TCP listener
18//!     let listener = builder.bind()?;
19//!
20//!     // Accept and handle connections
21//!     loop {
22//!         let acceptor = listener.accept().await?;
23//!         tokio::spawn(async move {
24//!             let dispatcher = acceptor.tcp().unwrap();
25//!             // Handle MQTT protocol...
26//!         });
27//!     }
28//!     Ok(())
29//! }
30//! ```
31
32use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
33use std::num::{NonZeroU16, NonZeroU32};
34use std::sync::Arc;
35use std::time::Duration;
36
37use anyhow::anyhow;
38use nonzero_ext::nonzero;
39use proxy_protocol::parse;
40use proxy_protocol::ProxyHeader;
41use proxy_protocol::{version1 as v1, version2 as v2};
42#[cfg(feature = "quic")]
43use quinn::{crypto::rustls::QuicServerConfig, IdleTimeout};
44use rmqtt_codec::types::QoS;
45#[cfg(not(target_os = "windows"))]
46#[cfg(feature = "tls")]
47use rustls::crypto::aws_lc_rs as provider;
48#[cfg(feature = "tls")]
49#[cfg(target_os = "windows")]
50use rustls::crypto::ring as provider;
51#[cfg(feature = "tls")]
52use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
53use socket2::{Domain, SockAddr, Socket, Type};
54use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
55use tokio::net::{TcpListener, TcpStream};
56#[cfg(feature = "tls")]
57use tokio_rustls::{server::TlsStream, TlsAcceptor};
58#[cfg(feature = "ws")]
59use tokio_tungstenite::{
60    accept_hdr_async,
61    tungstenite::handshake::server::{ErrorResponse, Request, Response},
62};
63
64#[cfg(feature = "quic")]
65use crate::quic::QuinnBiStream;
66use crate::stream::Dispatcher;
67#[cfg(feature = "ws")]
68use crate::ws::WsStream;
69#[cfg(feature = "tls")]
70use crate::{CertInfo, TlsCertExtractor};
71use crate::{Error, Result};
72
73/// Configuration builder for MQTT server instances
74#[derive(Clone, Debug)]
75pub struct Builder {
76    /// Server identifier for logging and monitoring
77    pub name: String,
78    /// Network address to listen on
79    pub laddr: SocketAddr,
80    /// Maximum number of pending connections in the accept queue
81    pub backlog: i32,
82    /// Enable TCP_NODELAY option for lower latency
83    pub nodelay: bool,
84    /// Set SO_REUSEADDR socket option
85    pub reuseaddr: Option<bool>,
86    /// Set SO_REUSEPORT socket option
87    pub reuseport: Option<bool>,
88    /// Maximum concurrent active connections
89    pub max_connections: usize,
90    /// Maximum simultaneous handshakes during connection setup
91    pub max_handshaking_limit: usize,
92    /// Maximum allowed MQTT packet size in bytes (0 = unlimited)
93    pub max_packet_size: u32,
94
95    /// Allow unauthenticated client connections
96    pub allow_anonymous: bool,
97    /// Minimum acceptable keepalive value in seconds
98    pub min_keepalive: u16,
99    /// Maximum acceptable keepalive value in seconds
100    pub max_keepalive: u16,
101    /// Allow clients to disable keepalive mechanism
102    pub allow_zero_keepalive: bool,
103    /// Multiplier for calculating actual keepalive timeout
104    pub keepalive_backoff: f32,
105    /// Window size for unacknowledged QoS 1/2 messages
106    pub max_inflight: NonZeroU16,
107    /// Timeout for completing connection handshake
108    pub handshake_timeout: Duration,
109    /// Network I/O timeout for sending operations
110    pub send_timeout: Duration,
111    /// Maximum messages queued per client
112    pub max_mqueue_len: usize,
113    /// Rate limiting for message delivery (messages per duration)
114    pub mqueue_rate_limit: (NonZeroU32, Duration),
115    /// Maximum length of client identifiers
116    pub max_clientid_len: usize,
117    /// Highest QoS level permitted for publishing
118    pub max_qos_allowed: QoS,
119    /// Maximum depth for topic hierarchy (0 = unlimited)
120    pub max_topic_levels: usize,
121    /// Duration before inactive sessions expire
122    pub session_expiry_interval: Duration,
123    /// The upper limit for how long a session can remain valid before it must expire,
124    /// regardless of the client's requested session expiry interval. (0 = unlimited)
125    pub max_session_expiry_interval: Duration,
126    /// Retry interval for unacknowledged messages
127    pub message_retry_interval: Duration,
128    /// Time-to-live for undelivered messages
129    pub message_expiry_interval: Duration,
130    /// Maximum subscriptions per client (0 = unlimited)
131    pub max_subscriptions: usize,
132    /// Enable shared subscription support
133    pub shared_subscription: bool,
134    /// Maximum topic aliases (MQTTv5 feature)
135    pub max_topic_aliases: u16,
136    /// Enable subscription count limiting
137    pub limit_subscription: bool,
138    /// Enable future-dated message publishing
139    pub delayed_publish: bool,
140
141    /// Enable mutual TLS authentication
142    pub tls_cross_certificate: bool,
143    /// Path to TLS certificate chain
144    pub tls_cert: Option<String>,
145    /// Path to TLS private key
146    pub tls_key: Option<String>,
147    /// Enable Proxy Protocol
148    pub proxy_protocol: bool,
149    /// Proxy Protocol timeout
150    pub proxy_protocol_timeout: Duration,
151
152    /// Use TLS Certificate CN as Username
153    pub cert_cn_as_username: bool,
154
155    /// QUIC(max_idle_timeout)
156    pub idle_timeout: Duration,
157}
158
159impl Default for Builder {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165/// # Examples
166/// ```
167/// use std::net::SocketAddr;
168/// use rmqtt_net::Builder;
169///
170/// let builder = Builder::new()
171///     .name("EdgeBroker")
172///     .laddr("127.0.0.1:1883".parse().unwrap())
173///     .max_connections(10_000);
174/// ```
175impl Builder {
176    /// Creates a new builder with default configuration values
177    pub fn new() -> Builder {
178        Builder {
179            name: Default::default(),
180            laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
181            max_connections: 1_000_000,
182            max_handshaking_limit: 1_000,
183            max_packet_size: 1024 * 1024,
184            backlog: 512,
185            nodelay: false,
186            reuseaddr: None,
187            reuseport: None,
188
189            allow_anonymous: true,
190            min_keepalive: 0,
191            max_keepalive: 65535,
192            allow_zero_keepalive: true,
193            keepalive_backoff: 0.75,
194            max_inflight: nonzero!(16u16),
195            handshake_timeout: Duration::from_secs(30),
196            send_timeout: Duration::from_secs(10),
197            max_mqueue_len: 1000,
198
199            mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
200            max_clientid_len: 65535,
201            max_qos_allowed: QoS::ExactlyOnce,
202            max_topic_levels: 0,
203            session_expiry_interval: Duration::from_secs(2 * 60 * 60),
204            max_session_expiry_interval: Duration::ZERO,
205            message_retry_interval: Duration::from_secs(20),
206            message_expiry_interval: Duration::from_secs(5 * 60),
207            max_subscriptions: 0,
208            shared_subscription: true,
209            max_topic_aliases: 0,
210
211            limit_subscription: false,
212            delayed_publish: false,
213
214            tls_cross_certificate: false,
215            tls_cert: None,
216            tls_key: None,
217            proxy_protocol: false,
218            proxy_protocol_timeout: Duration::from_secs(5),
219
220            cert_cn_as_username: false,
221
222            idle_timeout: Duration::from_secs(90),
223        }
224    }
225
226    /// Sets the server name identifier
227    pub fn name<N: Into<String>>(mut self, name: N) -> Self {
228        self.name = name.into();
229        self
230    }
231
232    /// Configures the network listen address
233    pub fn laddr(mut self, laddr: SocketAddr) -> Self {
234        self.laddr = laddr;
235        self
236    }
237
238    /// Sets the TCP backlog size
239    pub fn backlog(mut self, backlog: i32) -> Self {
240        self.backlog = backlog;
241        self
242    }
243
244    /// Enables/disables TCP_NODELAY option
245    pub fn nodelay(mut self, nodelay: bool) -> Self {
246        self.nodelay = nodelay;
247        self
248    }
249
250    /// Configures SO_REUSEADDR socket option
251    pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
252        self.reuseaddr = reuseaddr;
253        self
254    }
255
256    /// Configures SO_REUSEPORT socket option
257    pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
258        self.reuseport = reuseport;
259        self
260    }
261
262    /// Sets maximum concurrent connections
263    pub fn max_connections(mut self, max_connections: usize) -> Self {
264        self.max_connections = max_connections;
265        self
266    }
267
268    /// Sets maximum concurrent handshakes
269    pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
270        self.max_handshaking_limit = max_handshaking_limit;
271        self
272    }
273
274    /// Configures maximum MQTT packet size
275    pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
276        self.max_packet_size = max_packet_size;
277        self
278    }
279
280    /// Enables anonymous client access
281    pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
282        self.allow_anonymous = allow_anonymous;
283        self
284    }
285
286    /// Sets minimum acceptable keepalive value
287    pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
288        self.min_keepalive = min_keepalive;
289        self
290    }
291
292    /// Sets maximum acceptable keepalive value
293    pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
294        self.max_keepalive = max_keepalive;
295        self
296    }
297
298    /// Allows clients to disable keepalive
299    pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
300        self.allow_zero_keepalive = allow_zero_keepalive;
301        self
302    }
303
304    /// Configures keepalive backoff multiplier
305    pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
306        self.keepalive_backoff = keepalive_backoff;
307        self
308    }
309
310    /// Sets inflight message window size
311    pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
312        self.max_inflight = max_inflight;
313        self
314    }
315
316    /// Configures handshake timeout duration
317    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
318        self.handshake_timeout = handshake_timeout;
319        self
320    }
321
322    /// Sets network send timeout duration
323    pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
324        self.send_timeout = send_timeout;
325        self
326    }
327
328    /// Configures maximum message queue length
329    pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
330        self.max_mqueue_len = max_mqueue_len;
331        self
332    }
333
334    /// Sets message rate limiting parameters
335    pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
336        self.mqueue_rate_limit = (rate_limit, duration);
337        self
338    }
339
340    /// Sets maximum client ID length
341    pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
342        self.max_clientid_len = max_clientid_len;
343        self
344    }
345
346    /// Configures maximum allowed QoS level
347    pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
348        self.max_qos_allowed = max_qos_allowed;
349        self
350    }
351
352    /// Sets maximum topic hierarchy depth
353    pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
354        self.max_topic_levels = max_topic_levels;
355        self
356    }
357
358    /// Configures session expiration interval
359    pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
360        self.session_expiry_interval = session_expiry_interval;
361        self
362    }
363
364    /// Configures max session expiration interval
365    pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
366        self.max_session_expiry_interval = max_session_expiry_interval;
367        self
368    }
369
370    /// Sets message retry interval for QoS 1/2
371    pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
372        self.message_retry_interval = message_retry_interval;
373        self
374    }
375
376    /// Configures message expiration time
377    pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
378        self.message_expiry_interval = message_expiry_interval;
379        self
380    }
381
382    /// Sets maximum subscriptions per client
383    pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
384        self.max_subscriptions = max_subscriptions;
385        self
386    }
387
388    /// Enables shared subscription support
389    pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
390        self.shared_subscription = shared_subscription;
391        self
392    }
393
394    /// Configures maximum topic aliases (MQTTv5)
395    pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
396        self.max_topic_aliases = max_topic_aliases;
397        self
398    }
399
400    /// Enables subscription count limiting
401    pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
402        self.limit_subscription = limit_subscription;
403        self
404    }
405
406    /// Enables delayed message publishing
407    pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
408        self.delayed_publish = delayed_publish;
409        self
410    }
411
412    /// Enables mutual TLS authentication
413    pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
414        self.tls_cross_certificate = cross_certificate;
415        self
416    }
417
418    /// Sets path to TLS certificate chain
419    pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
420        self.tls_cert = tls_cert.map(|c| c.into());
421        self
422    }
423
424    /// Sets path to TLS private key
425    pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
426        self.tls_key = tls_key.map(|c| c.into());
427        self
428    }
429
430    pub fn cert_cn_as_username(mut self, cert_cn_as_username: bool) -> Self {
431        self.cert_cn_as_username = cert_cn_as_username;
432        self
433    }
434
435    /// Enable proxy protocol parse
436    pub fn proxy_protocol(mut self, enable_protocol_proxy: bool) -> Self {
437        self.proxy_protocol = enable_protocol_proxy;
438        self
439    }
440
441    /// Sets proxy protocol timeout
442    pub fn proxy_protocol_timeout(mut self, proxy_protocol_timeout: Duration) -> Self {
443        self.proxy_protocol_timeout = proxy_protocol_timeout;
444        self
445    }
446
447    /// Sets idle timeout (QUIC)
448    pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
449        self.idle_timeout = idle_timeout;
450        self
451    }
452
453    /// Binds the server to the configured address
454    #[allow(unused_variables)]
455    pub fn bind(self) -> Result<Listener> {
456        let builder = match self.laddr {
457            SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
458            SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
459        };
460
461        builder.set_linger(Some(Duration::from_secs(10)))?;
462
463        builder.set_nonblocking(true)?;
464
465        if let Some(reuseaddr) = self.reuseaddr {
466            builder.set_reuse_address(reuseaddr)?;
467        }
468
469        #[cfg(not(windows))]
470        if let Some(reuseport) = self.reuseport {
471            builder.set_reuse_port(reuseport)?;
472        }
473
474        builder.bind(&SockAddr::from(self.laddr))?;
475        builder.listen(self.backlog)?;
476        let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
477
478        log::info!(
479            "MQTT Broker Listening on {} {}",
480            self.name,
481            tcp_listener.local_addr().unwrap_or(self.laddr)
482        );
483        Ok(Listener {
484            typ: ListenerType::TCP,
485            cfg: Arc::new(self),
486            tcp_listener: Some(tcp_listener),
487            #[cfg(feature = "tls")]
488            tls_acceptor: None,
489            #[cfg(feature = "quic")]
490            quinn_endpoint: None,
491        })
492    }
493
494    #[allow(unused_variables)]
495    #[cfg(feature = "quic")]
496    pub fn bind_quic(self) -> Result<Listener> {
497        let cert_file = self.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
498        let key_file = self.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
499
500        let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
501            .map_err(|e| anyhow!(e))?
502            .collect::<std::result::Result<Vec<_>, _>>()
503            .map_err(|e| anyhow!(e))?;
504        let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
505
506        let provider = Arc::new(provider::default_provider());
507        let client_auth = if self.tls_cross_certificate {
508            let root_chain = cert_chain.clone();
509            let mut client_auth_roots = RootCertStore::empty();
510            for root in root_chain {
511                client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
512            }
513            WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
514                .build()
515                .map_err(|e| anyhow!(e))?
516        } else {
517            WebPkiClientVerifier::no_client_auth()
518        };
519
520        let mut tls_config = ServerConfig::builder_with_provider(provider)
521            .with_safe_default_protocol_versions()
522            .map_err(|e| anyhow!(e))?
523            .with_client_cert_verifier(client_auth)
524            .with_single_cert(cert_chain, key)
525            .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
526
527        tls_config.alpn_protocols = vec![b"mqtt".to_vec(), b"mqttv5".to_vec()];
528        let server_crypto = QuicServerConfig::try_from(tls_config)?;
529        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
530
531        let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
532        transport_config.max_concurrent_uni_streams(0_u8.into());
533        transport_config.max_idle_timeout(Some(IdleTimeout::try_from(self.idle_timeout)?));
534
535        let endpoint = quinn::Endpoint::server(server_config, self.laddr)?;
536
537        log::info!("MQTT Broker Listening on {} {}", self.name, endpoint.local_addr().unwrap_or(self.laddr));
538        Ok(Listener {
539            typ: ListenerType::QUIC,
540            cfg: Arc::new(self),
541            tcp_listener: None,
542            #[cfg(feature = "tls")]
543            tls_acceptor: None,
544            quinn_endpoint: Some(endpoint),
545        })
546    }
547}
548
549/// Protocol variants for network listeners
550#[derive(Debug, Copy, Clone)]
551pub enum ListenerType {
552    /// Plain TCP listener
553    TCP,
554    #[cfg(feature = "tls")]
555    /// TLS-secured TCP listener
556    TLS,
557    #[cfg(feature = "ws")]
558    /// WebSocket listener
559    WS,
560    #[cfg(feature = "tls")]
561    #[cfg(feature = "ws")]
562    /// TLS-secured WebSocket listener
563    WSS,
564    #[cfg(feature = "quic")]
565    ///QUIC listener (UDP-based, multiplexed and secured by default)
566    QUIC,
567}
568
569/// Network listener for accepting client connections
570pub struct Listener {
571    /// Active listener protocol type
572    pub typ: ListenerType,
573    /// Shared server configuration
574    pub cfg: Arc<Builder>,
575    tcp_listener: Option<TcpListener>,
576    #[cfg(feature = "tls")]
577    tls_acceptor: Option<TlsAcceptor>,
578    #[cfg(feature = "quic")]
579    quinn_endpoint: Option<quinn::Endpoint>,
580}
581
582/// # Examples
583/// ```
584/// # use rmqtt_net::{Builder, Listener};
585/// # fn setup() -> Result<(), Box<dyn std::error::Error>> {
586/// let builder = Builder::new();
587/// let listener = builder.bind()?;
588/// # Ok(())
589/// # }
590/// ```
591impl Listener {
592    /// Converts listener to plain TCP mode
593    pub fn tcp(mut self) -> Result<Self> {
594        let _err = anyhow!("Protocol downgrade from TLS/WS/WSS/QUIC to TCP is not permitted");
595        #[cfg(feature = "tls")]
596        if matches!(self.typ, ListenerType::TLS) {
597            return Err(_err);
598        }
599        #[cfg(feature = "tls")]
600        #[cfg(feature = "ws")]
601        if matches!(self.typ, ListenerType::WSS) {
602            return Err(_err);
603        }
604        #[cfg(feature = "ws")]
605        if matches!(self.typ, ListenerType::WS) {
606            return Err(_err);
607        }
608        #[cfg(feature = "quic")]
609        if matches!(self.typ, ListenerType::QUIC) {
610            return Err(_err);
611        }
612
613        self.typ = ListenerType::TCP;
614        Ok(self)
615    }
616
617    #[cfg(feature = "ws")]
618    /// Upgrades listener to WebSocket protocol
619    pub fn ws(mut self) -> Result<Self> {
620        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
621            self.typ = ListenerType::WS;
622        } else {
623            return Err(anyhow!("Protocol upgrade from TLS/WSS/QUIC to WS is not permitted"));
624        }
625        Ok(self)
626    }
627
628    #[cfg(feature = "tls")]
629    #[cfg(feature = "ws")]
630    /// Upgrades listener to secure WebSocket (WSS)
631    pub fn wss(mut self) -> Result<Self> {
632        #[cfg(feature = "quic")]
633        if matches!(self.typ, ListenerType::QUIC) {
634            return Err(anyhow!("Protocol upgrade from QUIC to WS is not permitted"));
635        }
636
637        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
638            self = self.tls()?;
639        }
640        self.typ = ListenerType::WSS;
641        Ok(self)
642    }
643
644    #[cfg(feature = "tls")]
645    /// Upgrades listener to TLS-secured TCP
646    pub fn tls(mut self) -> Result<Listener> {
647        match self.typ {
648            #[cfg(feature = "ws")]
649            ListenerType::WS | ListenerType::WSS => {
650                return Err(anyhow!("Protocol downgrade from WS/WSS/QUIC to TLS is not permitted"));
651            }
652            #[cfg(feature = "quic")]
653            ListenerType::QUIC => {
654                return Err(anyhow!("Protocol downgrade from QUIC to TLS is not permitted"));
655            }
656            ListenerType::TLS => return Ok(self),
657            ListenerType::TCP => {}
658        }
659
660        let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
661        let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
662
663        let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
664            .map_err(|e| anyhow!(e))?
665            .collect::<std::result::Result<Vec<_>, _>>()
666            .map_err(|e| anyhow!(e))?;
667        let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
668
669        let provider = Arc::new(provider::default_provider());
670        let client_auth = if self.cfg.tls_cross_certificate {
671            let root_chain = cert_chain.clone();
672            let mut client_auth_roots = RootCertStore::empty();
673            for root in root_chain {
674                client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
675            }
676            WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
677                .build()
678                .map_err(|e| anyhow!(e))?
679        } else {
680            WebPkiClientVerifier::no_client_auth()
681        };
682
683        let tls_config = ServerConfig::builder_with_provider(provider)
684            .with_safe_default_protocol_versions()
685            .map_err(|e| anyhow!(e))?
686            .with_client_cert_verifier(client_auth)
687            .with_single_cert(cert_chain, key)
688            .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
689
690        let acceptor = TlsAcceptor::from(Arc::new(tls_config));
691        self.tls_acceptor = Some(acceptor);
692        self.typ = ListenerType::TLS;
693        Ok(self)
694    }
695
696    /// Accepts incoming client connections
697    pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
698        if let Some(tcp_listener) = &self.tcp_listener {
699            self.accept_tcp(tcp_listener).await
700        } else {
701            Err(anyhow!(""))
702        }
703    }
704
705    async fn accept_tcp(&self, tcp_listener: &TcpListener) -> Result<Acceptor<TcpStream>> {
706        let (mut socket, mut remote_addr) = tcp_listener.accept().await?;
707        if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
708            return Err(Error::from(e));
709        }
710        log::debug!("remote_addr: {remote_addr}, proxy_protocol: {}", self.cfg.proxy_protocol);
711        if self.cfg.proxy_protocol {
712            let mut buffer = [0u8; u16::MAX as usize];
713            let read_bytes =
714                tokio::time::timeout(self.cfg.proxy_protocol_timeout, socket.peek(&mut buffer)).await??;
715            let len = {
716                let mut slice = &buffer[..read_bytes];
717                let header = parse(&mut slice)?;
718                if let Some((src, _)) = handle_header(header) {
719                    remote_addr = src;
720                }
721                read_bytes - slice.len()
722            };
723            // skip proxy protocol data
724            let _ = socket.read_exact(&mut buffer[..len]).await;
725        }
726        Ok(Acceptor {
727            socket,
728            remote_addr,
729            #[cfg(feature = "tls")]
730            acceptor: self.tls_acceptor.clone(),
731            cfg: self.cfg.clone(),
732            typ: self.typ,
733        })
734    }
735
736    #[cfg(feature = "quic")]
737    pub async fn accept_quic(&self) -> Result<Acceptor<QuinnBiStream>> {
738        if let Some(endpoint) = &self.quinn_endpoint {
739            let incoming =
740                endpoint.accept().await.ok_or_else(|| anyhow!("No incoming QUIC connection available"))?;
741            let conn = incoming.await?;
742            let remote_addr = conn.remote_address();
743
744            let (send, recv) = conn.accept_bi().await?;
745            let socket = QuinnBiStream::new(send, recv);
746
747            Ok(Acceptor {
748                socket,
749                remote_addr,
750                #[cfg(feature = "tls")]
751                acceptor: self.tls_acceptor.clone(),
752                cfg: self.cfg.clone(),
753                typ: self.typ,
754            })
755        } else {
756            Err(anyhow!(""))
757        }
758    }
759
760    pub fn local_addr(&self) -> Result<SocketAddr> {
761        if let Some(tcp_listener) = &self.tcp_listener {
762            Ok(tcp_listener.local_addr()?)
763        } else {
764            #[cfg(feature = "quic")]
765            if let Some(endpoint) = &self.quinn_endpoint {
766                Ok(endpoint.local_addr()?)
767            } else {
768                Err(anyhow!("No active listener (neither TCP nor QUIC endpoint is available)"))
769            }
770            #[cfg(not(feature = "quic"))]
771            Err(anyhow!("No active listener"))
772        }
773    }
774}
775
776/// Connection handler for processing client streams
777pub struct Acceptor<S> {
778    /// Underlying network transport
779    pub(crate) socket: S,
780    #[cfg(feature = "tls")]
781    acceptor: Option<TlsAcceptor>,
782    /// Remote client address
783    pub remote_addr: SocketAddr,
784    /// Shared server configuration
785    pub cfg: Arc<Builder>,
786    /// Active protocol type
787    pub typ: ListenerType,
788}
789
790impl<S> Acceptor<S>
791where
792    S: AsyncRead + AsyncWrite + Unpin,
793{
794    /// Creates TCP protocol dispatcher
795    #[inline]
796    pub fn tcp(self) -> Result<Dispatcher<S>> {
797        if matches!(self.typ, ListenerType::TCP) {
798            Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
799        } else {
800            Err(anyhow!("Protocol mismatch: Expected TCP listener"))
801        }
802    }
803
804    #[cfg(feature = "tls")]
805    /// Performs TLS handshake and creates secure dispatcher
806    #[inline]
807    pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
808        if !matches!(self.typ, ListenerType::TLS) {
809            return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
810        }
811
812        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
813        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
814        {
815            Ok(Ok(tls_s)) => tls_s,
816            Ok(Err(e)) => return Err(e.into()),
817            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
818        };
819
820        let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
821
822        Ok(Dispatcher::new(tls_s, self.remote_addr, cert_info, self.cfg))
823    }
824
825    #[cfg(feature = "ws")]
826    /// Performs WebSocket upgrade and creates WS dispatcher
827    #[inline]
828    pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
829        if !matches!(self.typ, ListenerType::WS) {
830            return Err(anyhow!("Protocol mismatch: Expected WS listener"));
831        }
832
833        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
834            .await
835        {
836            Ok(Ok(ws_stream)) => {
837                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, None, self.cfg.clone()))
838            }
839            Ok(Err(e)) => Err(e.into()),
840            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
841        }
842    }
843
844    #[cfg(feature = "tls")]
845    #[cfg(feature = "ws")]
846    /// Performs TLS handshake and WebSocket upgrade
847    #[inline]
848    pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
849        if !matches!(self.typ, ListenerType::WSS) {
850            return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
851        }
852
853        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
854        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
855        {
856            Ok(Ok(tls_s)) => tls_s,
857            Ok(Err(e)) => return Err(e.into()),
858            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
859        };
860
861        let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
862
863        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
864            Ok(Ok(ws_stream)) => {
865                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, cert_info, self.cfg.clone()))
866            }
867            Ok(Err(e)) => Err(e.into()),
868            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
869        }
870    }
871
872    #[cfg(feature = "quic")]
873    #[inline]
874    pub async fn quic(self) -> Result<Dispatcher<S>> {
875        if !matches!(self.typ, ListenerType::QUIC) {
876            return Err(anyhow!("Protocol mismatch: Expected QUIC listener"));
877        }
878        Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
879    }
880
881    #[inline]
882    #[cfg(feature = "tls")]
883    fn get_extract_cert_info<C: TlsCertExtractor>(io: &C, cert_cn_as_username: bool) -> Option<CertInfo> {
884        if cert_cn_as_username {
885            // Extract cert info BEFORE consuming self
886            let cert_info: Option<CertInfo> = io.extract_cert_info();
887            // Certificate info is now available in s.cert_info
888            if let Some(ref cert) = cert_info {
889                log::debug!("Client certificate: {}", cert);
890                log::debug!("CN: {:?}, Org: {:?}", cert.common_name, cert.organization);
891            }
892            cert_info
893        } else {
894            None
895        }
896    }
897}
898
899#[allow(clippy::result_large_err)]
900#[cfg(feature = "ws")]
901/// Validates WebSocket handshake requests for MQTT protocol
902fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
903    const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
904    let mqtt_protocol = req
905        .headers()
906        .get("Sec-WebSocket-Protocol")
907        .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
908    if mqtt_protocol != "mqtt" {
909        return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
910    }
911    response.headers_mut().append(
912        "Sec-WebSocket-Protocol",
913        "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
914    );
915    Ok(response)
916}
917
918// from https://github.com/zhboner/realm/blob/master/realm_core/src/tcp/proxy.rs
919fn handle_header(header: ProxyHeader) -> Option<(SocketAddr, SocketAddr)> {
920    use ProxyHeader::{Version1, Version2};
921    match header {
922        Version1 { addresses } => handle_header_v1(addresses),
923        Version2 { command, transport_protocol, addresses } => {
924            handle_header_v2(command, transport_protocol, addresses)
925        }
926        _ => {
927            log::info!("[tcp]accept proxy-protocol-v?");
928            None
929        }
930    }
931}
932
933fn handle_header_v1(addr: v1::ProxyAddresses) -> Option<(SocketAddr, SocketAddr)> {
934    use v1::ProxyAddresses::*;
935    match addr {
936        Unknown => {
937            log::debug!("[tcp]accept proxy-protocol-v1: unknown");
938            None
939        }
940        Ipv4 { source, destination } => {
941            log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
942            Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
943        }
944        Ipv6 { source, destination } => {
945            log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
946            Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
947        }
948    }
949}
950
951fn handle_header_v2(
952    cmd: v2::ProxyCommand,
953    proto: v2::ProxyTransportProtocol,
954    addr: v2::ProxyAddresses,
955) -> Option<(SocketAddr, SocketAddr)> {
956    use v2::ProxyAddresses as Address;
957    use v2::ProxyCommand as Command;
958    use v2::ProxyTransportProtocol as Protocol;
959
960    // The connection endpoints are the sender and the receiver.
961    // Such connections exist when the proxy sends health-checks to the server.
962    // The receiver must accept this connection as valid and must use the
963    // real connection endpoints and discard the protocol block including the
964    // family which is ignored
965    if let Command::Local = cmd {
966        log::debug!("[tcp]accept proxy-protocol-v2: command = LOCAL, ignore");
967        return None;
968    }
969
970    // only get tcp address
971    match proto {
972        Protocol::Stream => {}
973        Protocol::Unspec => {
974            log::debug!("[tcp]accept proxy-protocol-v2: protocol = UNSPEC, ignore");
975            return None;
976        }
977        Protocol::Datagram => {
978            log::debug!("[tcp]accept proxy-protocol-v2: protocol = DGRAM, ignore");
979            return None;
980        }
981    }
982
983    match addr {
984        Address::Ipv4 { source, destination } => {
985            log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
986            Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
987        }
988        Address::Ipv6 { source, destination } => {
989            log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
990            Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
991        }
992        Address::Unspec => {
993            log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNSPEC, ignore");
994            None
995        }
996        Address::Unix { .. } => {
997            log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNIX, ignore");
998            None
999        }
1000    }
1001}