rmqtt_net/
builder.rs

1//! # MQTT Server Implementation
2//!
3//! ## Overall Example
4//!
5//! ```rust,no_run
6//! use std::net::{Ipv4Addr, SocketAddr};
7//! use std::time::Duration;
8//!
9//! #[tokio::main]
10//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
11//!     // Create server configuration
12//!     let builder = rmqtt_net::Builder::new()
13//!         .name("MyMQTTBroker")
14//!         .laddr(SocketAddr::from((Ipv4Addr::LOCALHOST, 1883)))
15//!         .max_connections(5000);
16//!
17//!     // Bind TCP listener
18//!     let listener = builder.bind()?;
19//!
20//!     // Accept and handle connections
21//!     loop {
22//!         let acceptor = listener.accept().await?;
23//!         tokio::spawn(async move {
24//!             let dispatcher = acceptor.tcp().unwrap();
25//!             // Handle MQTT protocol...
26//!         });
27//!     }
28//!     Ok(())
29//! }
30//! ```
31
32use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
33use std::num::{NonZeroU16, NonZeroU32};
34use std::sync::Arc;
35use std::time::Duration;
36
37use anyhow::anyhow;
38use nonzero_ext::nonzero;
39use proxy_protocol::parse;
40use proxy_protocol::ProxyHeader;
41use proxy_protocol::{version1 as v1, version2 as v2};
42#[cfg(feature = "quic")]
43use quinn::{crypto::rustls::QuicServerConfig, IdleTimeout};
44use rmqtt_codec::types::QoS;
45#[cfg(not(target_os = "windows"))]
46#[cfg(feature = "tls")]
47use rustls::crypto::aws_lc_rs as provider;
48#[cfg(feature = "tls")]
49#[cfg(target_os = "windows")]
50use rustls::crypto::ring as provider;
51#[cfg(feature = "tls")]
52use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
53use socket2::{Domain, SockAddr, Socket, Type};
54use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
55use tokio::net::{TcpListener, TcpStream};
56#[cfg(feature = "tls")]
57use tokio_rustls::{server::TlsStream, TlsAcceptor};
58#[cfg(feature = "ws")]
59use tokio_tungstenite::{
60    accept_hdr_async,
61    tungstenite::handshake::server::{ErrorResponse, Request, Response},
62};
63
64#[cfg(feature = "quic")]
65use crate::quic::QuinnBiStream;
66use crate::stream::Dispatcher;
67#[cfg(feature = "ws")]
68use crate::ws::WsStream;
69use crate::{CertInfo, Error, Result, TlsCertExtractor};
70
71/// Configuration builder for MQTT server instances
72#[derive(Clone, Debug)]
73pub struct Builder {
74    /// Server identifier for logging and monitoring
75    pub name: String,
76    /// Network address to listen on
77    pub laddr: SocketAddr,
78    /// Maximum number of pending connections in the accept queue
79    pub backlog: i32,
80    /// Enable TCP_NODELAY option for lower latency
81    pub nodelay: bool,
82    /// Set SO_REUSEADDR socket option
83    pub reuseaddr: Option<bool>,
84    /// Set SO_REUSEPORT socket option
85    pub reuseport: Option<bool>,
86    /// Maximum concurrent active connections
87    pub max_connections: usize,
88    /// Maximum simultaneous handshakes during connection setup
89    pub max_handshaking_limit: usize,
90    /// Maximum allowed MQTT packet size in bytes (0 = unlimited)
91    pub max_packet_size: u32,
92
93    /// Allow unauthenticated client connections
94    pub allow_anonymous: bool,
95    /// Minimum acceptable keepalive value in seconds
96    pub min_keepalive: u16,
97    /// Maximum acceptable keepalive value in seconds
98    pub max_keepalive: u16,
99    /// Allow clients to disable keepalive mechanism
100    pub allow_zero_keepalive: bool,
101    /// Multiplier for calculating actual keepalive timeout
102    pub keepalive_backoff: f32,
103    /// Window size for unacknowledged QoS 1/2 messages
104    pub max_inflight: NonZeroU16,
105    /// Timeout for completing connection handshake
106    pub handshake_timeout: Duration,
107    /// Network I/O timeout for sending operations
108    pub send_timeout: Duration,
109    /// Maximum messages queued per client
110    pub max_mqueue_len: usize,
111    /// Rate limiting for message delivery (messages per duration)
112    pub mqueue_rate_limit: (NonZeroU32, Duration),
113    /// Maximum length of client identifiers
114    pub max_clientid_len: usize,
115    /// Highest QoS level permitted for publishing
116    pub max_qos_allowed: QoS,
117    /// Maximum depth for topic hierarchy (0 = unlimited)
118    pub max_topic_levels: usize,
119    /// Duration before inactive sessions expire
120    pub session_expiry_interval: Duration,
121    /// The upper limit for how long a session can remain valid before it must expire,
122    /// regardless of the client's requested session expiry interval. (0 = unlimited)
123    pub max_session_expiry_interval: Duration,
124    /// Retry interval for unacknowledged messages
125    pub message_retry_interval: Duration,
126    /// Time-to-live for undelivered messages
127    pub message_expiry_interval: Duration,
128    /// Maximum subscriptions per client (0 = unlimited)
129    pub max_subscriptions: usize,
130    /// Enable shared subscription support
131    pub shared_subscription: bool,
132    /// Maximum topic aliases (MQTTv5 feature)
133    pub max_topic_aliases: u16,
134    /// Enable subscription count limiting
135    pub limit_subscription: bool,
136    /// Enable future-dated message publishing
137    pub delayed_publish: bool,
138
139    /// Enable mutual TLS authentication
140    pub tls_cross_certificate: bool,
141    /// Path to TLS certificate chain
142    pub tls_cert: Option<String>,
143    /// Path to TLS private key
144    pub tls_key: Option<String>,
145    /// Enable Proxy Protocol
146    pub proxy_protocol: bool,
147    /// Proxy Protocol timeout
148    pub proxy_protocol_timeout: Duration,
149
150    /// Use TLS Certificate CN as Username
151    pub cert_cn_as_username: bool,
152
153    /// QUIC(max_idle_timeout)
154    pub idle_timeout: Duration,
155}
156
157impl Default for Builder {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163/// # Examples
164/// ```
165/// use std::net::SocketAddr;
166/// use rmqtt_net::Builder;
167///
168/// let builder = Builder::new()
169///     .name("EdgeBroker")
170///     .laddr("127.0.0.1:1883".parse().unwrap())
171///     .max_connections(10_000);
172/// ```
173impl Builder {
174    /// Creates a new builder with default configuration values
175    pub fn new() -> Builder {
176        Builder {
177            name: Default::default(),
178            laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
179            max_connections: 1_000_000,
180            max_handshaking_limit: 1_000,
181            max_packet_size: 1024 * 1024,
182            backlog: 512,
183            nodelay: false,
184            reuseaddr: None,
185            reuseport: None,
186
187            allow_anonymous: true,
188            min_keepalive: 0,
189            max_keepalive: 65535,
190            allow_zero_keepalive: true,
191            keepalive_backoff: 0.75,
192            max_inflight: nonzero!(16u16),
193            handshake_timeout: Duration::from_secs(30),
194            send_timeout: Duration::from_secs(10),
195            max_mqueue_len: 1000,
196
197            mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
198            max_clientid_len: 65535,
199            max_qos_allowed: QoS::ExactlyOnce,
200            max_topic_levels: 0,
201            session_expiry_interval: Duration::from_secs(2 * 60 * 60),
202            max_session_expiry_interval: Duration::ZERO,
203            message_retry_interval: Duration::from_secs(20),
204            message_expiry_interval: Duration::from_secs(5 * 60),
205            max_subscriptions: 0,
206            shared_subscription: true,
207            max_topic_aliases: 0,
208
209            limit_subscription: false,
210            delayed_publish: false,
211
212            tls_cross_certificate: false,
213            tls_cert: None,
214            tls_key: None,
215            proxy_protocol: false,
216            proxy_protocol_timeout: Duration::from_secs(5),
217
218            cert_cn_as_username: false,
219
220            idle_timeout: Duration::from_secs(90),
221        }
222    }
223
224    /// Sets the server name identifier
225    pub fn name<N: Into<String>>(mut self, name: N) -> Self {
226        self.name = name.into();
227        self
228    }
229
230    /// Configures the network listen address
231    pub fn laddr(mut self, laddr: SocketAddr) -> Self {
232        self.laddr = laddr;
233        self
234    }
235
236    /// Sets the TCP backlog size
237    pub fn backlog(mut self, backlog: i32) -> Self {
238        self.backlog = backlog;
239        self
240    }
241
242    /// Enables/disables TCP_NODELAY option
243    pub fn nodelay(mut self, nodelay: bool) -> Self {
244        self.nodelay = nodelay;
245        self
246    }
247
248    /// Configures SO_REUSEADDR socket option
249    pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
250        self.reuseaddr = reuseaddr;
251        self
252    }
253
254    /// Configures SO_REUSEPORT socket option
255    pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
256        self.reuseport = reuseport;
257        self
258    }
259
260    /// Sets maximum concurrent connections
261    pub fn max_connections(mut self, max_connections: usize) -> Self {
262        self.max_connections = max_connections;
263        self
264    }
265
266    /// Sets maximum concurrent handshakes
267    pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
268        self.max_handshaking_limit = max_handshaking_limit;
269        self
270    }
271
272    /// Configures maximum MQTT packet size
273    pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
274        self.max_packet_size = max_packet_size;
275        self
276    }
277
278    /// Enables anonymous client access
279    pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
280        self.allow_anonymous = allow_anonymous;
281        self
282    }
283
284    /// Sets minimum acceptable keepalive value
285    pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
286        self.min_keepalive = min_keepalive;
287        self
288    }
289
290    /// Sets maximum acceptable keepalive value
291    pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
292        self.max_keepalive = max_keepalive;
293        self
294    }
295
296    /// Allows clients to disable keepalive
297    pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
298        self.allow_zero_keepalive = allow_zero_keepalive;
299        self
300    }
301
302    /// Configures keepalive backoff multiplier
303    pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
304        self.keepalive_backoff = keepalive_backoff;
305        self
306    }
307
308    /// Sets inflight message window size
309    pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
310        self.max_inflight = max_inflight;
311        self
312    }
313
314    /// Configures handshake timeout duration
315    pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
316        self.handshake_timeout = handshake_timeout;
317        self
318    }
319
320    /// Sets network send timeout duration
321    pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
322        self.send_timeout = send_timeout;
323        self
324    }
325
326    /// Configures maximum message queue length
327    pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
328        self.max_mqueue_len = max_mqueue_len;
329        self
330    }
331
332    /// Sets message rate limiting parameters
333    pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
334        self.mqueue_rate_limit = (rate_limit, duration);
335        self
336    }
337
338    /// Sets maximum client ID length
339    pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
340        self.max_clientid_len = max_clientid_len;
341        self
342    }
343
344    /// Configures maximum allowed QoS level
345    pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
346        self.max_qos_allowed = max_qos_allowed;
347        self
348    }
349
350    /// Sets maximum topic hierarchy depth
351    pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
352        self.max_topic_levels = max_topic_levels;
353        self
354    }
355
356    /// Configures session expiration interval
357    pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
358        self.session_expiry_interval = session_expiry_interval;
359        self
360    }
361
362    /// Configures max session expiration interval
363    pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
364        self.max_session_expiry_interval = max_session_expiry_interval;
365        self
366    }
367
368    /// Sets message retry interval for QoS 1/2
369    pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
370        self.message_retry_interval = message_retry_interval;
371        self
372    }
373
374    /// Configures message expiration time
375    pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
376        self.message_expiry_interval = message_expiry_interval;
377        self
378    }
379
380    /// Sets maximum subscriptions per client
381    pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
382        self.max_subscriptions = max_subscriptions;
383        self
384    }
385
386    /// Enables shared subscription support
387    pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
388        self.shared_subscription = shared_subscription;
389        self
390    }
391
392    /// Configures maximum topic aliases (MQTTv5)
393    pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
394        self.max_topic_aliases = max_topic_aliases;
395        self
396    }
397
398    /// Enables subscription count limiting
399    pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
400        self.limit_subscription = limit_subscription;
401        self
402    }
403
404    /// Enables delayed message publishing
405    pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
406        self.delayed_publish = delayed_publish;
407        self
408    }
409
410    /// Enables mutual TLS authentication
411    pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
412        self.tls_cross_certificate = cross_certificate;
413        self
414    }
415
416    /// Sets path to TLS certificate chain
417    pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
418        self.tls_cert = tls_cert.map(|c| c.into());
419        self
420    }
421
422    /// Sets path to TLS private key
423    pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
424        self.tls_key = tls_key.map(|c| c.into());
425        self
426    }
427
428    pub fn cert_cn_as_username(mut self, cert_cn_as_username: bool) -> Self {
429        self.cert_cn_as_username = cert_cn_as_username;
430        self
431    }
432
433    /// Enable proxy protocol parse
434    pub fn proxy_protocol(mut self, enable_protocol_proxy: bool) -> Self {
435        self.proxy_protocol = enable_protocol_proxy;
436        self
437    }
438
439    /// Sets proxy protocol timeout
440    pub fn proxy_protocol_timeout(mut self, proxy_protocol_timeout: Duration) -> Self {
441        self.proxy_protocol_timeout = proxy_protocol_timeout;
442        self
443    }
444
445    /// Sets idle timeout (QUIC)
446    pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
447        self.idle_timeout = idle_timeout;
448        self
449    }
450
451    /// Binds the server to the configured address
452    #[allow(unused_variables)]
453    pub fn bind(self) -> Result<Listener> {
454        let builder = match self.laddr {
455            SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
456            SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
457        };
458
459        builder.set_linger(Some(Duration::from_secs(10)))?;
460
461        builder.set_nonblocking(true)?;
462
463        if let Some(reuseaddr) = self.reuseaddr {
464            builder.set_reuse_address(reuseaddr)?;
465        }
466
467        #[cfg(not(windows))]
468        if let Some(reuseport) = self.reuseport {
469            builder.set_reuse_port(reuseport)?;
470        }
471
472        builder.bind(&SockAddr::from(self.laddr))?;
473        builder.listen(self.backlog)?;
474        let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
475
476        log::info!(
477            "MQTT Broker Listening on {} {}",
478            self.name,
479            tcp_listener.local_addr().unwrap_or(self.laddr)
480        );
481        Ok(Listener {
482            typ: ListenerType::TCP,
483            cfg: Arc::new(self),
484            tcp_listener: Some(tcp_listener),
485            #[cfg(feature = "tls")]
486            tls_acceptor: None,
487            #[cfg(feature = "quic")]
488            quinn_endpoint: None,
489        })
490    }
491
492    #[allow(unused_variables)]
493    #[cfg(feature = "quic")]
494    pub fn bind_quic(self) -> Result<Listener> {
495        let cert_file = self.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
496        let key_file = self.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
497
498        let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
499            .map_err(|e| anyhow!(e))?
500            .collect::<std::result::Result<Vec<_>, _>>()
501            .map_err(|e| anyhow!(e))?;
502        let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
503
504        let provider = Arc::new(provider::default_provider());
505        let client_auth = if self.tls_cross_certificate {
506            let root_chain = cert_chain.clone();
507            let mut client_auth_roots = RootCertStore::empty();
508            for root in root_chain {
509                client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
510            }
511            WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
512                .build()
513                .map_err(|e| anyhow!(e))?
514        } else {
515            WebPkiClientVerifier::no_client_auth()
516        };
517
518        let mut tls_config = ServerConfig::builder_with_provider(provider)
519            .with_safe_default_protocol_versions()
520            .map_err(|e| anyhow!(e))?
521            .with_client_cert_verifier(client_auth)
522            .with_single_cert(cert_chain, key)
523            .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
524
525        tls_config.alpn_protocols = vec![b"mqtt".to_vec(), b"mqttv5".to_vec()];
526        let server_crypto = QuicServerConfig::try_from(tls_config)?;
527        let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
528
529        let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
530        transport_config.max_concurrent_uni_streams(0_u8.into());
531        transport_config.max_idle_timeout(Some(IdleTimeout::try_from(self.idle_timeout)?));
532
533        let endpoint = quinn::Endpoint::server(server_config, self.laddr)?;
534
535        log::info!("MQTT Broker Listening on {} {}", self.name, endpoint.local_addr().unwrap_or(self.laddr));
536        Ok(Listener {
537            typ: ListenerType::QUIC,
538            cfg: Arc::new(self),
539            tcp_listener: None,
540            #[cfg(feature = "tls")]
541            tls_acceptor: None,
542            quinn_endpoint: Some(endpoint),
543        })
544    }
545}
546
547/// Protocol variants for network listeners
548#[derive(Debug, Copy, Clone)]
549pub enum ListenerType {
550    /// Plain TCP listener
551    TCP,
552    #[cfg(feature = "tls")]
553    /// TLS-secured TCP listener
554    TLS,
555    #[cfg(feature = "ws")]
556    /// WebSocket listener
557    WS,
558    #[cfg(feature = "tls")]
559    #[cfg(feature = "ws")]
560    /// TLS-secured WebSocket listener
561    WSS,
562    #[cfg(feature = "quic")]
563    ///QUIC listener (UDP-based, multiplexed and secured by default)
564    QUIC,
565}
566
567/// Network listener for accepting client connections
568pub struct Listener {
569    /// Active listener protocol type
570    pub typ: ListenerType,
571    /// Shared server configuration
572    pub cfg: Arc<Builder>,
573    tcp_listener: Option<TcpListener>,
574    #[cfg(feature = "tls")]
575    tls_acceptor: Option<TlsAcceptor>,
576    #[cfg(feature = "quic")]
577    quinn_endpoint: Option<quinn::Endpoint>,
578}
579
580/// # Examples
581/// ```
582/// # use rmqtt_net::{Builder, Listener};
583/// # fn setup() -> Result<(), Box<dyn std::error::Error>> {
584/// let builder = Builder::new();
585/// let listener = builder.bind()?;
586/// # Ok(())
587/// # }
588/// ```
589impl Listener {
590    /// Converts listener to plain TCP mode
591    pub fn tcp(mut self) -> Result<Self> {
592        let _err = anyhow!("Protocol downgrade from TLS/WS/WSS/QUIC to TCP is not permitted");
593        #[cfg(feature = "tls")]
594        if matches!(self.typ, ListenerType::TLS) {
595            return Err(_err);
596        }
597        #[cfg(feature = "tls")]
598        #[cfg(feature = "ws")]
599        if matches!(self.typ, ListenerType::WSS) {
600            return Err(_err);
601        }
602        #[cfg(feature = "ws")]
603        if matches!(self.typ, ListenerType::WS) {
604            return Err(_err);
605        }
606        #[cfg(feature = "quic")]
607        if matches!(self.typ, ListenerType::QUIC) {
608            return Err(_err);
609        }
610
611        self.typ = ListenerType::TCP;
612        Ok(self)
613    }
614
615    #[cfg(feature = "ws")]
616    /// Upgrades listener to WebSocket protocol
617    pub fn ws(mut self) -> Result<Self> {
618        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
619            self.typ = ListenerType::WS;
620        } else {
621            return Err(anyhow!("Protocol upgrade from TLS/WSS/QUIC to WS is not permitted"));
622        }
623        Ok(self)
624    }
625
626    #[cfg(feature = "tls")]
627    #[cfg(feature = "ws")]
628    /// Upgrades listener to secure WebSocket (WSS)
629    pub fn wss(mut self) -> Result<Self> {
630        #[cfg(feature = "quic")]
631        if matches!(self.typ, ListenerType::QUIC) {
632            return Err(anyhow!("Protocol upgrade from QUIC to WS is not permitted"));
633        }
634
635        if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
636            self = self.tls()?;
637        }
638        self.typ = ListenerType::WSS;
639        Ok(self)
640    }
641
642    #[cfg(feature = "tls")]
643    /// Upgrades listener to TLS-secured TCP
644    pub fn tls(mut self) -> Result<Listener> {
645        match self.typ {
646            #[cfg(feature = "ws")]
647            ListenerType::WS | ListenerType::WSS => {
648                return Err(anyhow!("Protocol downgrade from WS/WSS/QUIC to TLS is not permitted"));
649            }
650            #[cfg(feature = "quic")]
651            ListenerType::QUIC => {
652                return Err(anyhow!("Protocol downgrade from QUIC to TLS is not permitted"));
653            }
654            ListenerType::TLS => return Ok(self),
655            ListenerType::TCP => {}
656        }
657
658        let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
659        let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
660
661        let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
662            .map_err(|e| anyhow!(e))?
663            .collect::<std::result::Result<Vec<_>, _>>()
664            .map_err(|e| anyhow!(e))?;
665        let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
666
667        let provider = Arc::new(provider::default_provider());
668        let client_auth = if self.cfg.tls_cross_certificate {
669            let root_chain = cert_chain.clone();
670            let mut client_auth_roots = RootCertStore::empty();
671            for root in root_chain {
672                client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
673            }
674            WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
675                .build()
676                .map_err(|e| anyhow!(e))?
677        } else {
678            WebPkiClientVerifier::no_client_auth()
679        };
680
681        let tls_config = ServerConfig::builder_with_provider(provider)
682            .with_safe_default_protocol_versions()
683            .map_err(|e| anyhow!(e))?
684            .with_client_cert_verifier(client_auth)
685            .with_single_cert(cert_chain, key)
686            .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
687
688        let acceptor = TlsAcceptor::from(Arc::new(tls_config));
689        self.tls_acceptor = Some(acceptor);
690        self.typ = ListenerType::TLS;
691        Ok(self)
692    }
693
694    /// Accepts incoming client connections
695    pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
696        if let Some(tcp_listener) = &self.tcp_listener {
697            self.accept_tcp(tcp_listener).await
698        } else {
699            Err(anyhow!(""))
700        }
701    }
702
703    async fn accept_tcp(&self, tcp_listener: &TcpListener) -> Result<Acceptor<TcpStream>> {
704        let (mut socket, mut remote_addr) = tcp_listener.accept().await?;
705        if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
706            return Err(Error::from(e));
707        }
708        log::debug!("remote_addr: {remote_addr}, proxy_protocol: {}", self.cfg.proxy_protocol);
709        if self.cfg.proxy_protocol {
710            let mut buffer = [0u8; u16::MAX as usize];
711            let read_bytes =
712                tokio::time::timeout(self.cfg.proxy_protocol_timeout, socket.peek(&mut buffer)).await??;
713            let len = {
714                let mut slice = &buffer[..read_bytes];
715                let header = parse(&mut slice)?;
716                if let Some((src, _)) = handle_header(header) {
717                    remote_addr = src;
718                }
719                read_bytes - slice.len()
720            };
721            // skip proxy protocol data
722            let _ = socket.read_exact(&mut buffer[..len]).await;
723        }
724        Ok(Acceptor {
725            socket,
726            remote_addr,
727            #[cfg(feature = "tls")]
728            acceptor: self.tls_acceptor.clone(),
729            cfg: self.cfg.clone(),
730            typ: self.typ,
731        })
732    }
733
734    #[cfg(feature = "quic")]
735    pub async fn accept_quic(&self) -> Result<Acceptor<QuinnBiStream>> {
736        if let Some(endpoint) = &self.quinn_endpoint {
737            let incoming =
738                endpoint.accept().await.ok_or_else(|| anyhow!("No incoming QUIC connection available"))?;
739            let conn = incoming.await?;
740            let remote_addr = conn.remote_address();
741
742            let (send, recv) = conn.accept_bi().await?;
743            let socket = QuinnBiStream::new(send, recv);
744
745            Ok(Acceptor {
746                socket,
747                remote_addr,
748                #[cfg(feature = "tls")]
749                acceptor: self.tls_acceptor.clone(),
750                cfg: self.cfg.clone(),
751                typ: self.typ,
752            })
753        } else {
754            Err(anyhow!(""))
755        }
756    }
757
758    pub fn local_addr(&self) -> Result<SocketAddr> {
759        if let Some(tcp_listener) = &self.tcp_listener {
760            Ok(tcp_listener.local_addr()?)
761        } else {
762            #[cfg(feature = "quic")]
763            if let Some(endpoint) = &self.quinn_endpoint {
764                Ok(endpoint.local_addr()?)
765            } else {
766                Err(anyhow!("No active listener (neither TCP nor QUIC endpoint is available)"))
767            }
768            #[cfg(not(feature = "quic"))]
769            Err(anyhow!("No active listener"))
770        }
771    }
772}
773
774/// Connection handler for processing client streams
775pub struct Acceptor<S> {
776    /// Underlying network transport
777    pub(crate) socket: S,
778    #[cfg(feature = "tls")]
779    acceptor: Option<TlsAcceptor>,
780    /// Remote client address
781    pub remote_addr: SocketAddr,
782    /// Shared server configuration
783    pub cfg: Arc<Builder>,
784    /// Active protocol type
785    pub typ: ListenerType,
786}
787
788impl<S> Acceptor<S>
789where
790    S: AsyncRead + AsyncWrite + Unpin,
791{
792    /// Creates TCP protocol dispatcher
793    #[inline]
794    pub fn tcp(self) -> Result<Dispatcher<S>> {
795        if matches!(self.typ, ListenerType::TCP) {
796            Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
797        } else {
798            Err(anyhow!("Protocol mismatch: Expected TCP listener"))
799        }
800    }
801
802    #[cfg(feature = "tls")]
803    /// Performs TLS handshake and creates secure dispatcher
804    #[inline]
805    pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
806        if !matches!(self.typ, ListenerType::TLS) {
807            return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
808        }
809
810        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
811        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
812        {
813            Ok(Ok(tls_s)) => tls_s,
814            Ok(Err(e)) => return Err(e.into()),
815            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
816        };
817
818        let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
819
820        Ok(Dispatcher::new(tls_s, self.remote_addr, cert_info, self.cfg))
821    }
822
823    #[cfg(feature = "ws")]
824    /// Performs WebSocket upgrade and creates WS dispatcher
825    #[inline]
826    pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
827        if !matches!(self.typ, ListenerType::WS) {
828            return Err(anyhow!("Protocol mismatch: Expected WS listener"));
829        }
830
831        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
832            .await
833        {
834            Ok(Ok(ws_stream)) => {
835                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, None, self.cfg.clone()))
836            }
837            Ok(Err(e)) => Err(e.into()),
838            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
839        }
840    }
841
842    #[cfg(feature = "tls")]
843    #[cfg(feature = "ws")]
844    /// Performs TLS handshake and WebSocket upgrade
845    #[inline]
846    pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
847        if !matches!(self.typ, ListenerType::WSS) {
848            return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
849        }
850
851        let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
852        let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
853        {
854            Ok(Ok(tls_s)) => tls_s,
855            Ok(Err(e)) => return Err(e.into()),
856            Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
857        };
858
859        let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
860
861        match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
862            Ok(Ok(ws_stream)) => {
863                Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, cert_info, self.cfg.clone()))
864            }
865            Ok(Err(e)) => Err(e.into()),
866            Err(_) => Err(crate::MqttError::ReadTimeout.into()),
867        }
868    }
869
870    #[cfg(feature = "quic")]
871    #[inline]
872    pub async fn quic(self) -> Result<Dispatcher<S>> {
873        if !matches!(self.typ, ListenerType::QUIC) {
874            return Err(anyhow!("Protocol mismatch: Expected QUIC listener"));
875        }
876        Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
877    }
878
879    #[inline]
880    #[cfg(feature = "tls")]
881    fn get_extract_cert_info<C: TlsCertExtractor>(io: &C, cert_cn_as_username: bool) -> Option<CertInfo> {
882        if cert_cn_as_username {
883            // Extract cert info BEFORE consuming self
884            let cert_info: Option<CertInfo> = io.extract_cert_info();
885            // Certificate info is now available in s.cert_info
886            if let Some(ref cert) = cert_info {
887                log::debug!("Client certificate: {}", cert);
888                log::debug!("CN: {:?}, Org: {:?}", cert.common_name, cert.organization);
889            }
890            cert_info
891        } else {
892            None
893        }
894    }
895}
896
897#[allow(clippy::result_large_err)]
898#[cfg(feature = "ws")]
899/// Validates WebSocket handshake requests for MQTT protocol
900fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
901    const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
902    let mqtt_protocol = req
903        .headers()
904        .get("Sec-WebSocket-Protocol")
905        .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
906    if mqtt_protocol != "mqtt" {
907        return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
908    }
909    response.headers_mut().append(
910        "Sec-WebSocket-Protocol",
911        "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
912    );
913    Ok(response)
914}
915
916// from https://github.com/zhboner/realm/blob/master/realm_core/src/tcp/proxy.rs
917fn handle_header(header: ProxyHeader) -> Option<(SocketAddr, SocketAddr)> {
918    use ProxyHeader::{Version1, Version2};
919    match header {
920        Version1 { addresses } => handle_header_v1(addresses),
921        Version2 { command, transport_protocol, addresses } => {
922            handle_header_v2(command, transport_protocol, addresses)
923        }
924        _ => {
925            log::info!("[tcp]accept proxy-protocol-v?");
926            None
927        }
928    }
929}
930
931fn handle_header_v1(addr: v1::ProxyAddresses) -> Option<(SocketAddr, SocketAddr)> {
932    use v1::ProxyAddresses::*;
933    match addr {
934        Unknown => {
935            log::debug!("[tcp]accept proxy-protocol-v1: unknown");
936            None
937        }
938        Ipv4 { source, destination } => {
939            log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
940            Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
941        }
942        Ipv6 { source, destination } => {
943            log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
944            Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
945        }
946    }
947}
948
949fn handle_header_v2(
950    cmd: v2::ProxyCommand,
951    proto: v2::ProxyTransportProtocol,
952    addr: v2::ProxyAddresses,
953) -> Option<(SocketAddr, SocketAddr)> {
954    use v2::ProxyAddresses as Address;
955    use v2::ProxyCommand as Command;
956    use v2::ProxyTransportProtocol as Protocol;
957
958    // The connection endpoints are the sender and the receiver.
959    // Such connections exist when the proxy sends health-checks to the server.
960    // The receiver must accept this connection as valid and must use the
961    // real connection endpoints and discard the protocol block including the
962    // family which is ignored
963    if let Command::Local = cmd {
964        log::debug!("[tcp]accept proxy-protocol-v2: command = LOCAL, ignore");
965        return None;
966    }
967
968    // only get tcp address
969    match proto {
970        Protocol::Stream => {}
971        Protocol::Unspec => {
972            log::debug!("[tcp]accept proxy-protocol-v2: protocol = UNSPEC, ignore");
973            return None;
974        }
975        Protocol::Datagram => {
976            log::debug!("[tcp]accept proxy-protocol-v2: protocol = DGRAM, ignore");
977            return None;
978        }
979    }
980
981    match addr {
982        Address::Ipv4 { source, destination } => {
983            log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
984            Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
985        }
986        Address::Ipv6 { source, destination } => {
987            log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
988            Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
989        }
990        Address::Unspec => {
991            log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNSPEC, ignore");
992            None
993        }
994        Address::Unix { .. } => {
995            log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNIX, ignore");
996            None
997        }
998    }
999}