1use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
33use std::num::{NonZeroU16, NonZeroU32};
34use std::sync::Arc;
35use std::time::Duration;
36
37use anyhow::anyhow;
38use nonzero_ext::nonzero;
39use proxy_protocol::parse;
40use proxy_protocol::ProxyHeader;
41use proxy_protocol::{version1 as v1, version2 as v2};
42#[cfg(feature = "quic")]
43use quinn::{crypto::rustls::QuicServerConfig, IdleTimeout};
44use rmqtt_codec::types::QoS;
45#[cfg(not(target_os = "windows"))]
46#[cfg(feature = "tls")]
47use rustls::crypto::aws_lc_rs as provider;
48#[cfg(feature = "tls")]
49#[cfg(target_os = "windows")]
50use rustls::crypto::ring as provider;
51#[cfg(feature = "tls")]
52use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
53use socket2::{Domain, SockAddr, Socket, Type};
54use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
55use tokio::net::{TcpListener, TcpStream};
56#[cfg(feature = "tls")]
57use tokio_rustls::{server::TlsStream, TlsAcceptor};
58#[cfg(feature = "ws")]
59use tokio_tungstenite::{
60 accept_hdr_async,
61 tungstenite::handshake::server::{ErrorResponse, Request, Response},
62};
63
64#[cfg(feature = "quic")]
65use crate::quic::QuinnBiStream;
66use crate::stream::Dispatcher;
67#[cfg(feature = "ws")]
68use crate::ws::WsStream;
69use crate::{CertInfo, Error, Result, TlsCertExtractor};
70
71#[derive(Clone, Debug)]
73pub struct Builder {
74 pub name: String,
76 pub laddr: SocketAddr,
78 pub backlog: i32,
80 pub nodelay: bool,
82 pub reuseaddr: Option<bool>,
84 pub reuseport: Option<bool>,
86 pub max_connections: usize,
88 pub max_handshaking_limit: usize,
90 pub max_packet_size: u32,
92
93 pub allow_anonymous: bool,
95 pub min_keepalive: u16,
97 pub max_keepalive: u16,
99 pub allow_zero_keepalive: bool,
101 pub keepalive_backoff: f32,
103 pub max_inflight: NonZeroU16,
105 pub handshake_timeout: Duration,
107 pub send_timeout: Duration,
109 pub max_mqueue_len: usize,
111 pub mqueue_rate_limit: (NonZeroU32, Duration),
113 pub max_clientid_len: usize,
115 pub max_qos_allowed: QoS,
117 pub max_topic_levels: usize,
119 pub session_expiry_interval: Duration,
121 pub max_session_expiry_interval: Duration,
124 pub message_retry_interval: Duration,
126 pub message_expiry_interval: Duration,
128 pub max_subscriptions: usize,
130 pub shared_subscription: bool,
132 pub max_topic_aliases: u16,
134 pub limit_subscription: bool,
136 pub delayed_publish: bool,
138
139 pub tls_cross_certificate: bool,
141 pub tls_cert: Option<String>,
143 pub tls_key: Option<String>,
145 pub proxy_protocol: bool,
147 pub proxy_protocol_timeout: Duration,
149
150 pub cert_cn_as_username: bool,
152
153 pub idle_timeout: Duration,
155}
156
157impl Default for Builder {
158 fn default() -> Self {
159 Self::new()
160 }
161}
162
163impl Builder {
174 pub fn new() -> Builder {
176 Builder {
177 name: Default::default(),
178 laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
179 max_connections: 1_000_000,
180 max_handshaking_limit: 1_000,
181 max_packet_size: 1024 * 1024,
182 backlog: 512,
183 nodelay: false,
184 reuseaddr: None,
185 reuseport: None,
186
187 allow_anonymous: true,
188 min_keepalive: 0,
189 max_keepalive: 65535,
190 allow_zero_keepalive: true,
191 keepalive_backoff: 0.75,
192 max_inflight: nonzero!(16u16),
193 handshake_timeout: Duration::from_secs(30),
194 send_timeout: Duration::from_secs(10),
195 max_mqueue_len: 1000,
196
197 mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
198 max_clientid_len: 65535,
199 max_qos_allowed: QoS::ExactlyOnce,
200 max_topic_levels: 0,
201 session_expiry_interval: Duration::from_secs(2 * 60 * 60),
202 max_session_expiry_interval: Duration::ZERO,
203 message_retry_interval: Duration::from_secs(20),
204 message_expiry_interval: Duration::from_secs(5 * 60),
205 max_subscriptions: 0,
206 shared_subscription: true,
207 max_topic_aliases: 0,
208
209 limit_subscription: false,
210 delayed_publish: false,
211
212 tls_cross_certificate: false,
213 tls_cert: None,
214 tls_key: None,
215 proxy_protocol: false,
216 proxy_protocol_timeout: Duration::from_secs(5),
217
218 cert_cn_as_username: false,
219
220 idle_timeout: Duration::from_secs(90),
221 }
222 }
223
224 pub fn name<N: Into<String>>(mut self, name: N) -> Self {
226 self.name = name.into();
227 self
228 }
229
230 pub fn laddr(mut self, laddr: SocketAddr) -> Self {
232 self.laddr = laddr;
233 self
234 }
235
236 pub fn backlog(mut self, backlog: i32) -> Self {
238 self.backlog = backlog;
239 self
240 }
241
242 pub fn nodelay(mut self, nodelay: bool) -> Self {
244 self.nodelay = nodelay;
245 self
246 }
247
248 pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
250 self.reuseaddr = reuseaddr;
251 self
252 }
253
254 pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
256 self.reuseport = reuseport;
257 self
258 }
259
260 pub fn max_connections(mut self, max_connections: usize) -> Self {
262 self.max_connections = max_connections;
263 self
264 }
265
266 pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
268 self.max_handshaking_limit = max_handshaking_limit;
269 self
270 }
271
272 pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
274 self.max_packet_size = max_packet_size;
275 self
276 }
277
278 pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
280 self.allow_anonymous = allow_anonymous;
281 self
282 }
283
284 pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
286 self.min_keepalive = min_keepalive;
287 self
288 }
289
290 pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
292 self.max_keepalive = max_keepalive;
293 self
294 }
295
296 pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
298 self.allow_zero_keepalive = allow_zero_keepalive;
299 self
300 }
301
302 pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
304 self.keepalive_backoff = keepalive_backoff;
305 self
306 }
307
308 pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
310 self.max_inflight = max_inflight;
311 self
312 }
313
314 pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
316 self.handshake_timeout = handshake_timeout;
317 self
318 }
319
320 pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
322 self.send_timeout = send_timeout;
323 self
324 }
325
326 pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
328 self.max_mqueue_len = max_mqueue_len;
329 self
330 }
331
332 pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
334 self.mqueue_rate_limit = (rate_limit, duration);
335 self
336 }
337
338 pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
340 self.max_clientid_len = max_clientid_len;
341 self
342 }
343
344 pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
346 self.max_qos_allowed = max_qos_allowed;
347 self
348 }
349
350 pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
352 self.max_topic_levels = max_topic_levels;
353 self
354 }
355
356 pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
358 self.session_expiry_interval = session_expiry_interval;
359 self
360 }
361
362 pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
364 self.max_session_expiry_interval = max_session_expiry_interval;
365 self
366 }
367
368 pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
370 self.message_retry_interval = message_retry_interval;
371 self
372 }
373
374 pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
376 self.message_expiry_interval = message_expiry_interval;
377 self
378 }
379
380 pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
382 self.max_subscriptions = max_subscriptions;
383 self
384 }
385
386 pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
388 self.shared_subscription = shared_subscription;
389 self
390 }
391
392 pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
394 self.max_topic_aliases = max_topic_aliases;
395 self
396 }
397
398 pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
400 self.limit_subscription = limit_subscription;
401 self
402 }
403
404 pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
406 self.delayed_publish = delayed_publish;
407 self
408 }
409
410 pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
412 self.tls_cross_certificate = cross_certificate;
413 self
414 }
415
416 pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
418 self.tls_cert = tls_cert.map(|c| c.into());
419 self
420 }
421
422 pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
424 self.tls_key = tls_key.map(|c| c.into());
425 self
426 }
427
428 pub fn cert_cn_as_username(mut self, cert_cn_as_username: bool) -> Self {
429 self.cert_cn_as_username = cert_cn_as_username;
430 self
431 }
432
433 pub fn proxy_protocol(mut self, enable_protocol_proxy: bool) -> Self {
435 self.proxy_protocol = enable_protocol_proxy;
436 self
437 }
438
439 pub fn proxy_protocol_timeout(mut self, proxy_protocol_timeout: Duration) -> Self {
441 self.proxy_protocol_timeout = proxy_protocol_timeout;
442 self
443 }
444
445 pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
447 self.idle_timeout = idle_timeout;
448 self
449 }
450
451 #[allow(unused_variables)]
453 pub fn bind(self) -> Result<Listener> {
454 let builder = match self.laddr {
455 SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
456 SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
457 };
458
459 builder.set_linger(Some(Duration::from_secs(10)))?;
460
461 builder.set_nonblocking(true)?;
462
463 if let Some(reuseaddr) = self.reuseaddr {
464 builder.set_reuse_address(reuseaddr)?;
465 }
466
467 #[cfg(not(windows))]
468 if let Some(reuseport) = self.reuseport {
469 builder.set_reuse_port(reuseport)?;
470 }
471
472 builder.bind(&SockAddr::from(self.laddr))?;
473 builder.listen(self.backlog)?;
474 let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
475
476 log::info!(
477 "MQTT Broker Listening on {} {}",
478 self.name,
479 tcp_listener.local_addr().unwrap_or(self.laddr)
480 );
481 Ok(Listener {
482 typ: ListenerType::TCP,
483 cfg: Arc::new(self),
484 tcp_listener: Some(tcp_listener),
485 #[cfg(feature = "tls")]
486 tls_acceptor: None,
487 #[cfg(feature = "quic")]
488 quinn_endpoint: None,
489 })
490 }
491
492 #[allow(unused_variables)]
493 #[cfg(feature = "quic")]
494 pub fn bind_quic(self) -> Result<Listener> {
495 let cert_file = self.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
496 let key_file = self.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
497
498 let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
499 .map_err(|e| anyhow!(e))?
500 .collect::<std::result::Result<Vec<_>, _>>()
501 .map_err(|e| anyhow!(e))?;
502 let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
503
504 let provider = Arc::new(provider::default_provider());
505 let client_auth = if self.tls_cross_certificate {
506 let root_chain = cert_chain.clone();
507 let mut client_auth_roots = RootCertStore::empty();
508 for root in root_chain {
509 client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
510 }
511 WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
512 .build()
513 .map_err(|e| anyhow!(e))?
514 } else {
515 WebPkiClientVerifier::no_client_auth()
516 };
517
518 let mut tls_config = ServerConfig::builder_with_provider(provider)
519 .with_safe_default_protocol_versions()
520 .map_err(|e| anyhow!(e))?
521 .with_client_cert_verifier(client_auth)
522 .with_single_cert(cert_chain, key)
523 .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
524
525 tls_config.alpn_protocols = vec![b"mqtt".to_vec(), b"mqttv5".to_vec()];
526 let server_crypto = QuicServerConfig::try_from(tls_config)?;
527 let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
528
529 let transport_config = Arc::get_mut(&mut server_config.transport).unwrap();
530 transport_config.max_concurrent_uni_streams(0_u8.into());
531 transport_config.max_idle_timeout(Some(IdleTimeout::try_from(self.idle_timeout)?));
532
533 let endpoint = quinn::Endpoint::server(server_config, self.laddr)?;
534
535 log::info!("MQTT Broker Listening on {} {}", self.name, endpoint.local_addr().unwrap_or(self.laddr));
536 Ok(Listener {
537 typ: ListenerType::QUIC,
538 cfg: Arc::new(self),
539 tcp_listener: None,
540 #[cfg(feature = "tls")]
541 tls_acceptor: None,
542 quinn_endpoint: Some(endpoint),
543 })
544 }
545}
546
547#[derive(Debug, Copy, Clone)]
549pub enum ListenerType {
550 TCP,
552 #[cfg(feature = "tls")]
553 TLS,
555 #[cfg(feature = "ws")]
556 WS,
558 #[cfg(feature = "tls")]
559 #[cfg(feature = "ws")]
560 WSS,
562 #[cfg(feature = "quic")]
563 QUIC,
565}
566
567pub struct Listener {
569 pub typ: ListenerType,
571 pub cfg: Arc<Builder>,
573 tcp_listener: Option<TcpListener>,
574 #[cfg(feature = "tls")]
575 tls_acceptor: Option<TlsAcceptor>,
576 #[cfg(feature = "quic")]
577 quinn_endpoint: Option<quinn::Endpoint>,
578}
579
580impl Listener {
590 pub fn tcp(mut self) -> Result<Self> {
592 let _err = anyhow!("Protocol downgrade from TLS/WS/WSS/QUIC to TCP is not permitted");
593 #[cfg(feature = "tls")]
594 if matches!(self.typ, ListenerType::TLS) {
595 return Err(_err);
596 }
597 #[cfg(feature = "tls")]
598 #[cfg(feature = "ws")]
599 if matches!(self.typ, ListenerType::WSS) {
600 return Err(_err);
601 }
602 #[cfg(feature = "ws")]
603 if matches!(self.typ, ListenerType::WS) {
604 return Err(_err);
605 }
606 #[cfg(feature = "quic")]
607 if matches!(self.typ, ListenerType::QUIC) {
608 return Err(_err);
609 }
610
611 self.typ = ListenerType::TCP;
612 Ok(self)
613 }
614
615 #[cfg(feature = "ws")]
616 pub fn ws(mut self) -> Result<Self> {
618 if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
619 self.typ = ListenerType::WS;
620 } else {
621 return Err(anyhow!("Protocol upgrade from TLS/WSS/QUIC to WS is not permitted"));
622 }
623 Ok(self)
624 }
625
626 #[cfg(feature = "tls")]
627 #[cfg(feature = "ws")]
628 pub fn wss(mut self) -> Result<Self> {
630 #[cfg(feature = "quic")]
631 if matches!(self.typ, ListenerType::QUIC) {
632 return Err(anyhow!("Protocol upgrade from QUIC to WS is not permitted"));
633 }
634
635 if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
636 self = self.tls()?;
637 }
638 self.typ = ListenerType::WSS;
639 Ok(self)
640 }
641
642 #[cfg(feature = "tls")]
643 pub fn tls(mut self) -> Result<Listener> {
645 match self.typ {
646 #[cfg(feature = "ws")]
647 ListenerType::WS | ListenerType::WSS => {
648 return Err(anyhow!("Protocol downgrade from WS/WSS/QUIC to TLS is not permitted"));
649 }
650 #[cfg(feature = "quic")]
651 ListenerType::QUIC => {
652 return Err(anyhow!("Protocol downgrade from QUIC to TLS is not permitted"));
653 }
654 ListenerType::TLS => return Ok(self),
655 ListenerType::TCP => {}
656 }
657
658 let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
659 let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
660
661 let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
662 .map_err(|e| anyhow!(e))?
663 .collect::<std::result::Result<Vec<_>, _>>()
664 .map_err(|e| anyhow!(e))?;
665 let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
666
667 let provider = Arc::new(provider::default_provider());
668 let client_auth = if self.cfg.tls_cross_certificate {
669 let root_chain = cert_chain.clone();
670 let mut client_auth_roots = RootCertStore::empty();
671 for root in root_chain {
672 client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
673 }
674 WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
675 .build()
676 .map_err(|e| anyhow!(e))?
677 } else {
678 WebPkiClientVerifier::no_client_auth()
679 };
680
681 let tls_config = ServerConfig::builder_with_provider(provider)
682 .with_safe_default_protocol_versions()
683 .map_err(|e| anyhow!(e))?
684 .with_client_cert_verifier(client_auth)
685 .with_single_cert(cert_chain, key)
686 .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
687
688 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
689 self.tls_acceptor = Some(acceptor);
690 self.typ = ListenerType::TLS;
691 Ok(self)
692 }
693
694 pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
696 if let Some(tcp_listener) = &self.tcp_listener {
697 self.accept_tcp(tcp_listener).await
698 } else {
699 Err(anyhow!(""))
700 }
701 }
702
703 async fn accept_tcp(&self, tcp_listener: &TcpListener) -> Result<Acceptor<TcpStream>> {
704 let (mut socket, mut remote_addr) = tcp_listener.accept().await?;
705 if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
706 return Err(Error::from(e));
707 }
708 log::debug!("remote_addr: {remote_addr}, proxy_protocol: {}", self.cfg.proxy_protocol);
709 if self.cfg.proxy_protocol {
710 let mut buffer = [0u8; u16::MAX as usize];
711 let read_bytes =
712 tokio::time::timeout(self.cfg.proxy_protocol_timeout, socket.peek(&mut buffer)).await??;
713 let len = {
714 let mut slice = &buffer[..read_bytes];
715 let header = parse(&mut slice)?;
716 if let Some((src, _)) = handle_header(header) {
717 remote_addr = src;
718 }
719 read_bytes - slice.len()
720 };
721 let _ = socket.read_exact(&mut buffer[..len]).await;
723 }
724 Ok(Acceptor {
725 socket,
726 remote_addr,
727 #[cfg(feature = "tls")]
728 acceptor: self.tls_acceptor.clone(),
729 cfg: self.cfg.clone(),
730 typ: self.typ,
731 })
732 }
733
734 #[cfg(feature = "quic")]
735 pub async fn accept_quic(&self) -> Result<Acceptor<QuinnBiStream>> {
736 if let Some(endpoint) = &self.quinn_endpoint {
737 let incoming =
738 endpoint.accept().await.ok_or_else(|| anyhow!("No incoming QUIC connection available"))?;
739 let conn = incoming.await?;
740 let remote_addr = conn.remote_address();
741
742 let (send, recv) = conn.accept_bi().await?;
743 let socket = QuinnBiStream::new(send, recv);
744
745 Ok(Acceptor {
746 socket,
747 remote_addr,
748 #[cfg(feature = "tls")]
749 acceptor: self.tls_acceptor.clone(),
750 cfg: self.cfg.clone(),
751 typ: self.typ,
752 })
753 } else {
754 Err(anyhow!(""))
755 }
756 }
757
758 pub fn local_addr(&self) -> Result<SocketAddr> {
759 if let Some(tcp_listener) = &self.tcp_listener {
760 Ok(tcp_listener.local_addr()?)
761 } else {
762 #[cfg(feature = "quic")]
763 if let Some(endpoint) = &self.quinn_endpoint {
764 Ok(endpoint.local_addr()?)
765 } else {
766 Err(anyhow!("No active listener (neither TCP nor QUIC endpoint is available)"))
767 }
768 #[cfg(not(feature = "quic"))]
769 Err(anyhow!("No active listener"))
770 }
771 }
772}
773
774pub struct Acceptor<S> {
776 pub(crate) socket: S,
778 #[cfg(feature = "tls")]
779 acceptor: Option<TlsAcceptor>,
780 pub remote_addr: SocketAddr,
782 pub cfg: Arc<Builder>,
784 pub typ: ListenerType,
786}
787
788impl<S> Acceptor<S>
789where
790 S: AsyncRead + AsyncWrite + Unpin,
791{
792 #[inline]
794 pub fn tcp(self) -> Result<Dispatcher<S>> {
795 if matches!(self.typ, ListenerType::TCP) {
796 Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
797 } else {
798 Err(anyhow!("Protocol mismatch: Expected TCP listener"))
799 }
800 }
801
802 #[cfg(feature = "tls")]
803 #[inline]
805 pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
806 if !matches!(self.typ, ListenerType::TLS) {
807 return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
808 }
809
810 let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
811 let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
812 {
813 Ok(Ok(tls_s)) => tls_s,
814 Ok(Err(e)) => return Err(e.into()),
815 Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
816 };
817
818 let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
819
820 Ok(Dispatcher::new(tls_s, self.remote_addr, cert_info, self.cfg))
821 }
822
823 #[cfg(feature = "ws")]
824 #[inline]
826 pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
827 if !matches!(self.typ, ListenerType::WS) {
828 return Err(anyhow!("Protocol mismatch: Expected WS listener"));
829 }
830
831 match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
832 .await
833 {
834 Ok(Ok(ws_stream)) => {
835 Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, None, self.cfg.clone()))
836 }
837 Ok(Err(e)) => Err(e.into()),
838 Err(_) => Err(crate::MqttError::ReadTimeout.into()),
839 }
840 }
841
842 #[cfg(feature = "tls")]
843 #[cfg(feature = "ws")]
844 #[inline]
846 pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
847 if !matches!(self.typ, ListenerType::WSS) {
848 return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
849 }
850
851 let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
852 let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
853 {
854 Ok(Ok(tls_s)) => tls_s,
855 Ok(Err(e)) => return Err(e.into()),
856 Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
857 };
858
859 let cert_info = Self::get_extract_cert_info(&tls_s, self.cfg.cert_cn_as_username);
860
861 match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
862 Ok(Ok(ws_stream)) => {
863 Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, cert_info, self.cfg.clone()))
864 }
865 Ok(Err(e)) => Err(e.into()),
866 Err(_) => Err(crate::MqttError::ReadTimeout.into()),
867 }
868 }
869
870 #[cfg(feature = "quic")]
871 #[inline]
872 pub async fn quic(self) -> Result<Dispatcher<S>> {
873 if !matches!(self.typ, ListenerType::QUIC) {
874 return Err(anyhow!("Protocol mismatch: Expected QUIC listener"));
875 }
876 Ok(Dispatcher::new(self.socket, self.remote_addr, None, self.cfg))
877 }
878
879 #[inline]
880 #[cfg(feature = "tls")]
881 fn get_extract_cert_info<C: TlsCertExtractor>(io: &C, cert_cn_as_username: bool) -> Option<CertInfo> {
882 if cert_cn_as_username {
883 let cert_info: Option<CertInfo> = io.extract_cert_info();
885 if let Some(ref cert) = cert_info {
887 log::debug!("Client certificate: {}", cert);
888 log::debug!("CN: {:?}, Org: {:?}", cert.common_name, cert.organization);
889 }
890 cert_info
891 } else {
892 None
893 }
894 }
895}
896
897#[allow(clippy::result_large_err)]
898#[cfg(feature = "ws")]
899fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
901 const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
902 let mqtt_protocol = req
903 .headers()
904 .get("Sec-WebSocket-Protocol")
905 .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
906 if mqtt_protocol != "mqtt" {
907 return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
908 }
909 response.headers_mut().append(
910 "Sec-WebSocket-Protocol",
911 "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
912 );
913 Ok(response)
914}
915
916fn handle_header(header: ProxyHeader) -> Option<(SocketAddr, SocketAddr)> {
918 use ProxyHeader::{Version1, Version2};
919 match header {
920 Version1 { addresses } => handle_header_v1(addresses),
921 Version2 { command, transport_protocol, addresses } => {
922 handle_header_v2(command, transport_protocol, addresses)
923 }
924 _ => {
925 log::info!("[tcp]accept proxy-protocol-v?");
926 None
927 }
928 }
929}
930
931fn handle_header_v1(addr: v1::ProxyAddresses) -> Option<(SocketAddr, SocketAddr)> {
932 use v1::ProxyAddresses::*;
933 match addr {
934 Unknown => {
935 log::debug!("[tcp]accept proxy-protocol-v1: unknown");
936 None
937 }
938 Ipv4 { source, destination } => {
939 log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
940 Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
941 }
942 Ipv6 { source, destination } => {
943 log::debug!("[tcp]accept proxy-protocol-v1: {} => {}", &source, &destination);
944 Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
945 }
946 }
947}
948
949fn handle_header_v2(
950 cmd: v2::ProxyCommand,
951 proto: v2::ProxyTransportProtocol,
952 addr: v2::ProxyAddresses,
953) -> Option<(SocketAddr, SocketAddr)> {
954 use v2::ProxyAddresses as Address;
955 use v2::ProxyCommand as Command;
956 use v2::ProxyTransportProtocol as Protocol;
957
958 if let Command::Local = cmd {
964 log::debug!("[tcp]accept proxy-protocol-v2: command = LOCAL, ignore");
965 return None;
966 }
967
968 match proto {
970 Protocol::Stream => {}
971 Protocol::Unspec => {
972 log::debug!("[tcp]accept proxy-protocol-v2: protocol = UNSPEC, ignore");
973 return None;
974 }
975 Protocol::Datagram => {
976 log::debug!("[tcp]accept proxy-protocol-v2: protocol = DGRAM, ignore");
977 return None;
978 }
979 }
980
981 match addr {
982 Address::Ipv4 { source, destination } => {
983 log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
984 Some((SocketAddr::V4(source), SocketAddr::V4(destination)))
985 }
986 Address::Ipv6 { source, destination } => {
987 log::debug!("[tcp]accept proxy-protocol-v2: {} => {}", &source, &destination);
988 Some((SocketAddr::V6(source), SocketAddr::V6(destination)))
989 }
990 Address::Unspec => {
991 log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNSPEC, ignore");
992 None
993 }
994 Address::Unix { .. } => {
995 log::debug!("[tcp]accept proxy-protocol-v2: af_family = AF_UNIX, ignore");
996 None
997 }
998 }
999}