1use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
33use std::num::{NonZeroU16, NonZeroU32};
34use std::sync::Arc;
35use std::time::Duration;
36
37use anyhow::anyhow;
38use nonzero_ext::nonzero;
39use proxy_protocol::parse;
40use proxy_protocol::ProxyHeader;
41use proxy_protocol::{version1 as v1, version2 as v2};
42#[cfg(feature = "quic")]
43use quinn::{crypto::rustls::QuicServerConfig, IdleTimeout};
44use rmqtt_codec::types::QoS;
45#[cfg(not(target_os = "windows"))]
46#[cfg(feature = "tls")]
47use rustls::crypto::aws_lc_rs as provider;
48#[cfg(feature = "tls")]
49#[cfg(target_os = "windows")]
50use rustls::crypto::ring as provider;
51#[cfg(feature = "tls")]
52use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
53use socket2::{Domain, SockAddr, Socket, Type};
54use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
55use tokio::net::{TcpListener, TcpStream};
56#[cfg(feature = "tls")]
57use tokio_rustls::{server::TlsStream, TlsAcceptor};
58#[cfg(feature = "ws")]
59use tokio_tungstenite::{
60 accept_hdr_async,
61 tungstenite::handshake::server::{ErrorResponse, Request, Response},
62};
63
64#[cfg(feature = "quic")]
65use crate::quic::QuinnBiStream;
66use crate::stream::Dispatcher;
67#[cfg(feature = "ws")]
68use crate::ws::WsStream;
69#[cfg(feature = "tls")]
70use crate::{CertInfo, TlsCertExtractor};
71use crate::{Error, Result};
72
73#[derive(Clone, Debug)]
75pub struct Builder {
76 pub name: String,
78 pub laddr: SocketAddr,
80 pub backlog: i32,
82 pub nodelay: bool,
84 pub reuseaddr: Option<bool>,
86 pub reuseport: Option<bool>,
88 pub max_connections: usize,
90 pub max_handshaking_limit: usize,
92 pub max_packet_size: u32,
94
95 pub allow_anonymous: bool,
97 pub min_keepalive: u16,
99 pub max_keepalive: u16,
101 pub allow_zero_keepalive: bool,
103 pub keepalive_backoff: f32,
105 pub max_inflight: NonZeroU16,
107 pub handshake_timeout: Duration,
109 pub send_timeout: Duration,
111 pub max_mqueue_len: usize,
113 pub mqueue_rate_limit: (NonZeroU32, Duration),
115 pub max_clientid_len: usize,
117 pub max_qos_allowed: QoS,
119 pub max_topic_levels: usize,
121 pub session_expiry_interval: Duration,
123 pub max_session_expiry_interval: Duration,
126 pub message_retry_interval: Duration,
128 pub message_expiry_interval: Duration,
130 pub max_subscriptions: usize,
132 pub shared_subscription: bool,
134 pub max_topic_aliases: u16,
136 pub limit_subscription: bool,
138 pub delayed_publish: bool,
140
141 pub tls_cross_certificate: bool,
143 pub tls_cert: Option<String>,
145 pub tls_key: Option<String>,
147 pub proxy_protocol: bool,
149 pub proxy_protocol_timeout: Duration,
151
152 pub cert_cn_as_username: bool,
154
155 pub idle_timeout: Duration,
157}
158
159impl Default for Builder {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl Builder {
176 pub fn new() -> Builder {
178 Builder {
179 name: Default::default(),
180 laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
181 max_connections: 1_000_000,
182 max_handshaking_limit: 1_000,
183 max_packet_size: 1024 * 1024,
184 backlog: 512,
185 nodelay: false,
186 reuseaddr: None,
187 reuseport: None,
188
189 allow_anonymous: true,
190 min_keepalive: 0,
191 max_keepalive: 65535,
192 allow_zero_keepalive: true,
193 keepalive_backoff: 0.75,
194 max_inflight: nonzero!(16u16),
195 handshake_timeout: Duration::from_secs(30),
196 send_timeout: Duration::from_secs(10),
197 max_mqueue_len: 1000,
198
199 mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
200 max_clientid_len: 65535,
201 max_qos_allowed: QoS::ExactlyOnce,
202 max_topic_levels: 0,
203 session_expiry_interval: Duration::from_secs(2 * 60 * 60),
204 max_session_expiry_interval: Duration::ZERO,
205 message_retry_interval: Duration::from_secs(20),
206 message_expiry_interval: Duration::from_secs(5 * 60),
207 max_subscriptions: 0,
208 shared_subscription: true,
209 max_topic_aliases: 0,
210
211 limit_subscription: false,
212 delayed_publish: false,
213
214 tls_cross_certificate: false,
215 tls_cert: None,
216 tls_key: None,
217 proxy_protocol: false,
218 proxy_protocol_timeout: Duration::from_secs(5),
219
220 cert_cn_as_username: false,
221
222 idle_timeout: Duration::from_secs(90),
223 }
224 }
225
226 pub fn name<N: Into<String>>(mut self, name: N) -> Self {
228 self.name = name.into();
229 self
230 }
231
232 pub fn laddr(mut self, laddr: SocketAddr) -> Self {
234 self.laddr = laddr;
235 self
236 }
237
238 pub fn backlog(mut self, backlog: i32) -> Self {
240 self.backlog = backlog;
241 self
242 }
243
244 pub fn nodelay(mut self, nodelay: bool) -> Self {
246 self.nodelay = nodelay;
247 self
248 }
249
250 pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
252 self.reuseaddr = reuseaddr;
253 self
254 }
255
256 pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
258 self.reuseport = reuseport;
259 self
260 }
261
262 pub fn max_connections(mut self, max_connections: usize) -> Self {
264 self.max_connections = max_connections;
265 self
266 }
267
268 pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
270 self.max_handshaking_limit = max_handshaking_limit;
271 self
272 }
273
274 pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
276 self.max_packet_size = max_packet_size;
277 self
278 }
279
280 pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
282 self.allow_anonymous = allow_anonymous;
283 self
284 }
285
286 pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
288 self.min_keepalive = min_keepalive;
289 self
290 }
291
292 pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
294 self.max_keepalive = max_keepalive;
295 self
296 }
297
298 pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
300 self.allow_zero_keepalive = allow_zero_keepalive;
301 self
302 }
303
304 pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
306 self.keepalive_backoff = keepalive_backoff;
307 self
308 }
309
310 pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
312 self.max_inflight = max_inflight;
313 self
314 }
315
316 pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
318 self.handshake_timeout = handshake_timeout;
319 self
320 }
321
322 pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
324 self.send_timeout = send_timeout;
325 self
326 }
327
328 pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
330 self.max_mqueue_len = max_mqueue_len;
331 self
332 }
333
334 pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
336 self.mqueue_rate_limit = (rate_limit, duration);
337 self
338 }
339
340 pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
342 self.max_clientid_len = max_clientid_len;
343 self
344 }
345
346 pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
348 self.max_qos_allowed = max_qos_allowed;
349 self
350 }
351
352 pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
354 self.max_topic_levels = max_topic_levels;
355 self
356 }
357
358 pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
360 self.session_expiry_interval = session_expiry_interval;
361 self
362 }
363
364 pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
366 self.max_session_expiry_interval = max_session_expiry_interval;
367 self
368 }
369
370 pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
372 self.message_retry_interval = message_retry_interval;
373 self
374 }
375
376 pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
378 self.message_expiry_interval = message_expiry_interval;
379 self
380 }
381
382 pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
384 self.max_subscriptions = max_subscriptions;
385 self
386 }
387
388 pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
390 self.shared_subscription = shared_subscription;
391 self
392 }
393
394 pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
396 self.max_topic_aliases = max_topic_aliases;
397 self
398 }
399
400 pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
402 self.limit_subscription = limit_subscription;
403 self
404 }
405
406 pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
408 self.delayed_publish = delayed_publish;
409 self
410 }
411
412 pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
414 self.tls_cross_certificate = cross_certificate;
415 self
416 }
417
418 pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
420 self.tls_cert = tls_cert.map(|c| c.into());
421 self
422 }
423
424 pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
426 self.tls_key = tls_key.map(|c| c.into());
427 self
428 }
429
430 pub fn cert_cn_as_username(mut self, cert_cn_as_username: bool) -> Self {
431 self.cert_cn_as_username = cert_cn_as_username;
432 self
433 }
434
435 pub fn proxy_protocol(mut self, enable_protocol_proxy: bool) -> Self {
437 self.proxy_protocol = enable_protocol_proxy;
438 self
439 }
440
441 pub fn proxy_protocol_timeout(mut self, proxy_protocol_timeout: Duration) -> Self {
443 self.proxy_protocol_timeout = proxy_protocol_timeout;
444 self
445 }
446
447 pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
449 self.idle_timeout = idle_timeout;
450 self
451 }
452
453 #[allow(unused_variables)]
455 pub fn bind(self) -> Result<Listener> {
456 let builder = match self.laddr {
457 SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
458 SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
459 };
460
461 builder.set_linger(Some(Duration::from_secs(10)))?;
462
463 builder.set_nonblocking(true)?;
464
465 if let Some(reuseaddr) = self.reuseaddr {
466 builder.set_reuse_address(reuseaddr)?;
467 }
468
469 #[cfg(not(windows))]
470 if let Some(reuseport) = self.reuseport {
471 builder.set_reuse_port(reuseport)?;
472 }
473
474 builder.bind(&SockAddr::from(self.laddr))?;
475 builder.listen(self.backlog)?;
476 let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
477
478 log::info!(
479 "MQTT Broker Listening on {} {}",
480 self.name,
481 tcp_listener.local_addr().unwrap_or(self.laddr)
482 );
483 Ok(Listener {
484 typ: ListenerType::TCP,
485 cfg: Arc::new(self),
486 tcp_listener: Some(tcp_listener),
487 #[cfg(feature = "tls")]
488 tls_acceptor: None,
489 #[cfg(feature = "quic")]
490 quinn_endpoint: None,
491 })
492 }
493
494 #[allow(unused_variables)]
495 #[cfg(feature = "quic")]
496 pub fn bind_quic(self) -> Result<Listener> {
497 let cert_file = self.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
498 let key_file = self.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
499
500 let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
501 .map_err(|e| anyhow!(e))?
502 .collect::<std::result::Result<Vec<_>, _>>()
503 .map_err(|e| anyhow!(e))?;
504 let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
505
506 let provider = Arc::new(provider::default_provider());
507 let client_auth = if self.tls_cross_certificate {
508 let root_chain = cert_chain.clone();
509 let mut client_auth_roots = RootCertStore::empty();
510 for root in root_chain {
511 client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
512 }
513 WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
514 .build()
515 .map_err(|e| anyhow!(e))?
516 } else {
517 WebPkiClientVerifier::no_client_auth()
518 };
519
520 let mut tls_config = ServerConfig::builder_with_provider(provider)
521 .with_safe_default_protocol_versions()
522 .map_err(|e| anyhow!(e))?
523 .with_client_cert_verifier(client_auth)
524 .with_single_cert(cert_chain, key)
525 .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
526
527 tls_config.alpn_protocols = vec![b"mqtt".to_vec(), b"mqttv5".to_vec()];
528 let server_crypto = QuicServerConfig::try_from(tls_config)?;
529 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
530
531 let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
532 transport_config.max_concurrent_uni_streams(0_u8.into());
533 transport_config.max_idle_timeout(Some(IdleTimeout::try_from(self.idle_timeout)?));
534
535 let endpoint = quinn::Endpoint::server(server_config, self.laddr)?;
536
537 log::info!("MQTT Broker Listening on {} {}", self.name, endpoint.local_addr().unwrap_or(self.laddr));
538 Ok(Listener {
539 typ: ListenerType::QUIC,
540 cfg: Arc::new(self),
541 tcp_listener: None,
542 #[cfg(feature = "tls")]
543 tls_acceptor: None,
544 quinn_endpoint: Some(endpoint),
545 })
546 }
547}
548
549#[derive(Debug, Copy, Clone)]
551pub enum ListenerType {
552 TCP,
554 #[cfg(feature = "tls")]
555 TLS,
557 #[cfg(feature = "ws")]
558 WS,
560 #[cfg(feature = "tls")]
561 #[cfg(feature = "ws")]
562 WSS,
564 #[cfg(feature = "quic")]
565 QUIC,
567}
568
569pub struct Listener {
571 pub typ: ListenerType,
573 pub cfg: Arc<Builder>,
575 tcp_listener: Option<TcpListener>,
576 #[cfg(feature = "tls")]
577 tls_acceptor: Option<TlsAcceptor>,
578 #[cfg(feature = "quic")]
579 quinn_endpoint: Option<quinn::Endpoint>,
580}
581
582impl Listener {
592 pub fn tcp(mut self) -> Result<Self> {
594 let _err = anyhow!("Protocol downgrade from TLS/WS/WSS/QUIC to TCP is not permitted");
595 #[cfg(feature = "tls")]
596 if matches!(self.typ, ListenerType::TLS) {
597 return Err(_err);
598 }
599 #[cfg(feature = "tls")]
600 #[cfg(feature = "ws")]
601 if matches!(self.typ, ListenerType::WSS) {
602 return Err(_err);
603 }
604 #[cfg(feature = "ws")]
605 if matches!(self.typ, ListenerType::WS) {
606 return Err(_err);
607 }
608 #[cfg(feature = "quic")]
609 if matches!(self.typ, ListenerType::QUIC) {
610 return Err(_err);
611 }
612
613 self.typ = ListenerType::TCP;
614 Ok(self)
615 }
616
617 #[cfg(feature = "ws")]
618 pub fn ws(mut self) -> Result<Self> {
620 if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
621 self.typ = ListenerType::WS;
622 } else {
623 return Err(anyhow!("Protocol upgrade from TLS/WSS/QUIC to WS is not permitted"));
624 }
625 Ok(self)
626 }
627
628 #[cfg(feature = "tls")]
629 #[cfg(feature = "ws")]
630 pub fn wss(mut self) -> Result<Self> {
632 #[cfg(feature = "quic")]
633 if matches!(self.typ, ListenerType::QUIC) {
634 return Err(anyhow!("Protocol upgrade from QUIC to WS is not permitted"));
635 }
636
637 if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
638 self = self.tls()?;
639 }
640 self.typ = ListenerType::WSS;
641 Ok(self)
642 }
643
644 #[cfg(feature = "tls")]
645 pub fn tls(mut self) -> Result<Listener> {
647 match self.typ {
648 #[cfg(feature = "ws")]
649 ListenerType::WS | ListenerType::WSS => {
650 return Err(anyhow!("Protocol downgrade from WS/WSS/QUIC to TLS is not permitted"));
651 }
652 #[cfg(feature = "quic")]
653 ListenerType::QUIC => {
654 return Err(anyhow!("Protocol downgrade from QUIC to TLS is not permitted"));
655 }
656 ListenerType::TLS => return Ok(self),
657 ListenerType::TCP => {}
658 }
659
660 let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
661 let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
662
663 let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
664 .map_err(|e| anyhow!(e))?
665 .collect::<std::result::Result<Vec<_>, _>>()
666 .map_err(|e| anyhow!(e))?;
667 let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
668
669 let provider = Arc::new(provider::default_provider());
670 let client_auth = if self.cfg.tls_cross_certificate {
671 let root_chain = cert_chain.clone();
672 let mut client_auth_roots = RootCertStore::empty();
673 for root in root_chain {
674 client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
675 }
676 WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
677 .build()
678 .map_err(|e| anyhow!(e))?
679 } else {
680 WebPkiClientVerifier::no_client_auth()
681 };
682
683 let tls_config = ServerConfig::builder_with_provider(provider)
684 .with_safe_default_protocol_versions()
685 .map_err(|e| anyhow!(e))?
686 .with_client_cert_verifier(client_auth)
687 .with_single_cert(cert_chain, key)
688 .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
689
690 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
691 self.tls_acceptor = Some(acceptor);
692 self.typ = ListenerType::TLS;
693 Ok(self)
694 }
695
696 pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
698 if let Some(tcp_listener) = &self.tcp_listener {
699 self.accept_tcp(tcp_listener).await
700 } else {
701 Err(anyhow!(""))
702 }
703 }
704
705 async fn accept_tcp(&self, tcp_listener: &TcpListener) -> Result<Acceptor<TcpStream>> {
706 let (mut socket, mut remote_addr) = tcp_listener.accept().await?;
707 if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
708 return Err(Error::from(e));
709 }
710 log::debug!("remote_addr: {remote_addr}, proxy_protocol: {}", self.cfg.proxy_protocol);
711 if self.cfg.proxy_protocol {
712 let mut buffer = [0u8; u16::MAX as usize];
713 let read_bytes =
714 tokio::time::timeout(self.cfg.proxy_protocol_timeout, socket.peek(&mut buffer)).await??;
715 let len = {
716 let mut slice = &buffer[..read_bytes];
717 let header = parse(&mut slice)?;
718 if let Some((src, _)) = handle_header(header) {
719 remote_addr = src;
720 }
721 read_bytes - slice.len()
722 };
723 let _ = socket.read_exact(&mut buffer[..len]).await;
725 }
726 Ok(Acceptor {
727 socket,
728 remote_addr,
729 #[cfg(feature = "tls")]
730 acceptor: self.tls_acceptor.clone(),
731 cfg: self.cfg.clone(),
732 typ: self.typ,
733 })
734 }
735
736 #[cfg(feature = "quic")]
737 pub async fn accept_quic(&self) -> Result<Acceptor<QuinnBiStream>> {
738 if let Some(endpoint) = &self.quinn_endpoint {
739 let incoming =
740 endpoint.accept().await.ok_or_else(|| anyhow!("No incoming QUIC connection available"))?;
741 let conn = incoming.await?;
742 let remote_addr = conn.remote_address();
743
744 let (send, recv) = conn.accept_bi().await?;
745 let socket = QuinnBiStream::new(send, recv);
746
747 Ok(Acceptor {
748 socket,
749 remote_addr,
750 #[cfg(feature = "tls")]
751 acceptor: self.tls_acceptor.clone(),
752 cfg: self.cfg.clone(),
753 typ: self.typ,
754 })
755 } else {
756 Err(anyhow!(""))
757 }
758 }
759
760 pub fn local_addr(&self) -> Result<SocketAddr> {
761 if let Some(tcp_listener) = &self.tcp_listener {
762 Ok(tcp_listener.local_addr()?)
763 } else {
764 #[cfg(feature = "quic")]
765 if let Some(endpoint) = &self.quinn_endpoint {
766 Ok(endpoint.local_addr()?)
767 } else {
768 Err(anyhow!("No active listener (neither TCP nor QUIC endpoint is available)"))
769 }
770 #[cfg(not(feature = "quic"))]
771 Err(anyhow!("No active listener"))
772 }
773 }
774}
775
776pub struct Acceptor<S> {
778 pub(crate) socket: S,
780 #[cfg(feature = "tls")]
781 acceptor: Option<TlsAcceptor>,
782 pub remote_addr: SocketAddr,
784 pub cfg: Arc<Builder>,
786 pub typ: ListenerType,
788}
789
790impl<S> Acceptor<S>
791where
792 S: AsyncRead + AsyncWrite + Unpin,
793{
794 #[inline]
796 pub fn tcp(self) -> Result<Dispatcher<S>> {
797 if matches!(self.typ, ListenerType::TCP) {
798 Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
799 } else {
800 Err(anyhow!("Protocol mismatch: Expected TCP listener"))
801 }
802 }
803
804 #[cfg(feature = "tls")]
805 #[inline]
807 pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
808 if !matches!(self.typ, ListenerType::TLS) {
809 return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
810 }
811
812 let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
813 let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
814 {
815 Ok(Ok(tls_s)) => tls_s,
816 Ok(Err(e)) => return Err(e.into()),
817 Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
818 };
819
820 let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
821
822 Ok(Dispatcher::new(tls_s, self.remote_addr, cert_info, self.cfg))
823 }
824
825 #[cfg(feature = "ws")]
826 #[inline]
828 pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
829 if !matches!(self.typ, ListenerType::WS) {
830 return Err(anyhow!("Protocol mismatch: Expected WS listener"));
831 }
832
833 match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
834 .await
835 {
836 Ok(Ok(ws_stream)) => {
837 Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, None, self.cfg.clone()))
838 }
839 Ok(Err(e)) => Err(e.into()),
840 Err(_) => Err(crate::MqttError::ReadTimeout.into()),
841 }
842 }
843
844 #[cfg(feature = "tls")]
845 #[cfg(feature = "ws")]
846 #[inline]
848 pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
849 if !matches!(self.typ, ListenerType::WSS) {
850 return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
851 }
852
853 let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
854 let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
855 {
856 Ok(Ok(tls_s)) => tls_s,
857 Ok(Err(e)) => return Err(e.into()),
858 Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
859 };
860
861 let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
862
863 match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
864 Ok(Ok(ws_stream)) => {
865 Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, cert_info, self.cfg.clone()))
866 }
867 Ok(Err(e)) => Err(e.into()),
868 Err(_) => Err(crate::MqttError::ReadTimeout.into()),
869 }
870 }
871
872 #[cfg(feature = "quic")]
873 #[inline]
874 pub async fn quic(self) -> Result<Dispatcher<S>> {
875 if !matches!(self.typ, ListenerType::QUIC) {
876 return Err(anyhow!("Protocol mismatch: Expected QUIC listener"));
877 }
878 Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
879 }
880
881 #[inline]
882 #[cfg(feature = "tls")]
883 fn get_extract_cert_info<C: TlsCertExtractor>(io: &C, cert_cn_as_username: bool) -> Option<CertInfo> {
884 if cert_cn_as_username {
885 let cert_info: Option<CertInfo> = io.extract_cert_info();
887 if let Some(ref cert) = cert_info {
889 log::debug!("Client certificate: {}", cert);
890 log::debug!("CN: {:?}, Org: {:?}", cert.common_name, cert.organization);
891 }
892 cert_info
893 } else {
894 None
895 }
896 }
897}
898
899#[allow(clippy::result_large_err)]
900#[cfg(feature = "ws")]
901fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
903 const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
904 let mqtt_protocol = req
905 .headers()
906 .get("Sec-WebSocket-Protocol")
907 .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
908 if mqtt_protocol != "mqtt" {
909 return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
910 }
911 response.headers_mut().append(
912 "Sec-WebSocket-Protocol",
913 "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
914 );
915 Ok(response)
916}
917
918fn handle_header(header: ProxyHeader) -> Option<(SocketAddr, SocketAddr)> {
920 use ProxyHeader::{Version1, Version2};
921 match header {
922 Version1 { addresses } => handle_header_v1(addresses),
923 Version2 { command, transport_protocol, addresses } => {
924 handle_header_v2(command, transport_protocol, addresses)
925 }
926 _ => {
927 log::info!("[tcp]accept proxy-protocol-v?");
928 None
929 }
930 }
931}
932
933fn handle_header_v1(addr: v1::ProxyAddresses) -> Option<(SocketAddr, SocketAddr)> {
934 use v1::ProxyAddresses::*;
935 match addr {
936 Unknown => {
937 log::debug!("[tcp]accept proxy-protocol-v1: unknown");
938 None
939 }
940 Ipv4 { source, destination } => {
941 log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
942 Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
943 }
944 Ipv6 { source, destination } => {
945 log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
946 Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
947 }
948 }
949}
950
951fn handle_header_v2(
952 cmd: v2::ProxyCommand,
953 proto: v2::ProxyTransportProtocol,
954 addr: v2::ProxyAddresses,
955) -> Option<(SocketAddr, SocketAddr)> {
956 use v2::ProxyAddresses as Address;
957 use v2::ProxyCommand as Command;
958 use v2::ProxyTransportProtocol as Protocol;
959
960 if let Command::Local = cmd {
966 log::debug!("[tcp]accept proxy-protocol-v2: command = LOCAL, ignore");
967 return None;
968 }
969
970 match proto {
972 Protocol::Stream => {}
973 Protocol::Unspec => {
974 log::debug!("[tcp]accept proxy-protocol-v2: protocol = UNSPEC, ignore");
975 return None;
976 }
977 Protocol::Datagram => {
978 log::debug!("[tcp]accept proxy-protocol-v2: protocol = DGRAM, ignore");
979 return None;
980 }
981 }
982
983 match addr {
984 Address::Ipv4 { source, destination } => {
985 log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
986 Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
987 }
988 Address::Ipv6 { source, destination } => {
989 log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
990 Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
991 }
992 Address::Unspec => {
993 log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNSPEC, ignore");
994 None
995 }
996 Address::Unix { .. } => {
997 log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNIX, ignore");
998 None
999 }
1000 }
1001}