1use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
33use std::num::{NonZeroU16, NonZeroU32};
34use std::sync::Arc;
35use std::time::Duration;
36
37use anyhow::anyhow;
38use nonzero_ext::nonzero;
39use rmqtt_codec::types::QoS;
40#[cfg(not(target_os = "windows"))]
41#[cfg(feature = "tls")]
42use rustls::crypto::aws_lc_rs as provider;
43#[cfg(feature = "tls")]
44#[cfg(target_os = "windows")]
45use rustls::crypto::ring as provider;
46#[cfg(feature = "tls")]
47use rustls::{pki_types::pem::PemObject, server::WebPkiClientVerifier, RootCertStore, ServerConfig};
48use socket2::{Domain, SockAddr, Socket, Type};
49use tokio::io::{AsyncRead, AsyncWrite};
50use tokio::net::{TcpListener, TcpStream};
51#[cfg(feature = "tls")]
52use tokio_rustls::{server::TlsStream, TlsAcceptor};
53#[cfg(feature = "ws")]
54use tokio_tungstenite::{
55 accept_hdr_async,
56 tungstenite::handshake::server::{ErrorResponse, Request, Response},
57};
58
59use crate::stream::Dispatcher;
60#[cfg(feature = "ws")]
61use crate::ws::WsStream;
62use crate::{Error, Result};
63
64#[derive(Clone, Debug)]
66pub struct Builder {
67 pub name: String,
69 pub laddr: SocketAddr,
71 pub backlog: i32,
73 pub nodelay: bool,
75 pub reuseaddr: Option<bool>,
77 pub reuseport: Option<bool>,
79 pub max_connections: usize,
81 pub max_handshaking_limit: usize,
83 pub max_packet_size: u32,
85
86 pub allow_anonymous: bool,
88 pub min_keepalive: u16,
90 pub max_keepalive: u16,
92 pub allow_zero_keepalive: bool,
94 pub keepalive_backoff: f32,
96 pub max_inflight: NonZeroU16,
98 pub handshake_timeout: Duration,
100 pub send_timeout: Duration,
102 pub max_mqueue_len: usize,
104 pub mqueue_rate_limit: (NonZeroU32, Duration),
106 pub max_clientid_len: usize,
108 pub max_qos_allowed: QoS,
110 pub max_topic_levels: usize,
112 pub session_expiry_interval: Duration,
114 pub max_session_expiry_interval: Duration,
117 pub message_retry_interval: Duration,
119 pub message_expiry_interval: Duration,
121 pub max_subscriptions: usize,
123 pub shared_subscription: bool,
125 pub max_topic_aliases: u16,
127 pub limit_subscription: bool,
129 pub delayed_publish: bool,
131
132 pub tls_cross_certificate: bool,
134 pub tls_cert: Option<String>,
136 pub tls_key: Option<String>,
138}
139
140impl Default for Builder {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146impl Builder {
157 pub fn new() -> Builder {
159 Builder {
160 name: Default::default(),
161 laddr: SocketAddr::from(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 1883)),
162 max_connections: 1_000_000,
163 max_handshaking_limit: 1_000,
164 max_packet_size: 1024 * 1024,
165 backlog: 512,
166 nodelay: false,
167 reuseaddr: None,
168 reuseport: None,
169
170 allow_anonymous: true,
171 min_keepalive: 0,
172 max_keepalive: 65535,
173 allow_zero_keepalive: true,
174 keepalive_backoff: 0.75,
175 max_inflight: nonzero!(16u16),
176 handshake_timeout: Duration::from_secs(30),
177 send_timeout: Duration::from_secs(10),
178 max_mqueue_len: 1000,
179
180 mqueue_rate_limit: (nonzero!(u32::MAX), Duration::from_secs(1)),
181 max_clientid_len: 65535,
182 max_qos_allowed: QoS::ExactlyOnce,
183 max_topic_levels: 0,
184 session_expiry_interval: Duration::from_secs(2 * 60 * 60),
185 max_session_expiry_interval: Duration::ZERO,
186 message_retry_interval: Duration::from_secs(20),
187 message_expiry_interval: Duration::from_secs(5 * 60),
188 max_subscriptions: 0,
189 shared_subscription: true,
190 max_topic_aliases: 0,
191
192 limit_subscription: false,
193 delayed_publish: false,
194
195 tls_cross_certificate: false,
196 tls_cert: None,
197 tls_key: None,
198 }
199 }
200
201 pub fn name<N: Into<String>>(mut self, name: N) -> Self {
203 self.name = name.into();
204 self
205 }
206
207 pub fn laddr(mut self, laddr: SocketAddr) -> Self {
209 self.laddr = laddr;
210 self
211 }
212
213 pub fn backlog(mut self, backlog: i32) -> Self {
215 self.backlog = backlog;
216 self
217 }
218
219 pub fn nodelay(mut self, nodelay: bool) -> Self {
221 self.nodelay = nodelay;
222 self
223 }
224
225 pub fn reuseaddr(mut self, reuseaddr: Option<bool>) -> Self {
227 self.reuseaddr = reuseaddr;
228 self
229 }
230
231 pub fn reuseport(mut self, reuseport: Option<bool>) -> Self {
233 self.reuseport = reuseport;
234 self
235 }
236
237 pub fn max_connections(mut self, max_connections: usize) -> Self {
239 self.max_connections = max_connections;
240 self
241 }
242
243 pub fn max_handshaking_limit(mut self, max_handshaking_limit: usize) -> Self {
245 self.max_handshaking_limit = max_handshaking_limit;
246 self
247 }
248
249 pub fn max_packet_size(mut self, max_packet_size: u32) -> Self {
251 self.max_packet_size = max_packet_size;
252 self
253 }
254
255 pub fn allow_anonymous(mut self, allow_anonymous: bool) -> Self {
257 self.allow_anonymous = allow_anonymous;
258 self
259 }
260
261 pub fn min_keepalive(mut self, min_keepalive: u16) -> Self {
263 self.min_keepalive = min_keepalive;
264 self
265 }
266
267 pub fn max_keepalive(mut self, max_keepalive: u16) -> Self {
269 self.max_keepalive = max_keepalive;
270 self
271 }
272
273 pub fn allow_zero_keepalive(mut self, allow_zero_keepalive: bool) -> Self {
275 self.allow_zero_keepalive = allow_zero_keepalive;
276 self
277 }
278
279 pub fn keepalive_backoff(mut self, keepalive_backoff: f32) -> Self {
281 self.keepalive_backoff = keepalive_backoff;
282 self
283 }
284
285 pub fn max_inflight(mut self, max_inflight: NonZeroU16) -> Self {
287 self.max_inflight = max_inflight;
288 self
289 }
290
291 pub fn handshake_timeout(mut self, handshake_timeout: Duration) -> Self {
293 self.handshake_timeout = handshake_timeout;
294 self
295 }
296
297 pub fn send_timeout(mut self, send_timeout: Duration) -> Self {
299 self.send_timeout = send_timeout;
300 self
301 }
302
303 pub fn max_mqueue_len(mut self, max_mqueue_len: usize) -> Self {
305 self.max_mqueue_len = max_mqueue_len;
306 self
307 }
308
309 pub fn mqueue_rate_limit(mut self, rate_limit: NonZeroU32, duration: Duration) -> Self {
311 self.mqueue_rate_limit = (rate_limit, duration);
312 self
313 }
314
315 pub fn max_clientid_len(mut self, max_clientid_len: usize) -> Self {
317 self.max_clientid_len = max_clientid_len;
318 self
319 }
320
321 pub fn max_qos_allowed(mut self, max_qos_allowed: QoS) -> Self {
323 self.max_qos_allowed = max_qos_allowed;
324 self
325 }
326
327 pub fn max_topic_levels(mut self, max_topic_levels: usize) -> Self {
329 self.max_topic_levels = max_topic_levels;
330 self
331 }
332
333 pub fn session_expiry_interval(mut self, session_expiry_interval: Duration) -> Self {
335 self.session_expiry_interval = session_expiry_interval;
336 self
337 }
338
339 pub fn max_session_expiry_interval(mut self, max_session_expiry_interval: Duration) -> Self {
341 self.max_session_expiry_interval = max_session_expiry_interval;
342 self
343 }
344
345 pub fn message_retry_interval(mut self, message_retry_interval: Duration) -> Self {
347 self.message_retry_interval = message_retry_interval;
348 self
349 }
350
351 pub fn message_expiry_interval(mut self, message_expiry_interval: Duration) -> Self {
353 self.message_expiry_interval = message_expiry_interval;
354 self
355 }
356
357 pub fn max_subscriptions(mut self, max_subscriptions: usize) -> Self {
359 self.max_subscriptions = max_subscriptions;
360 self
361 }
362
363 pub fn shared_subscription(mut self, shared_subscription: bool) -> Self {
365 self.shared_subscription = shared_subscription;
366 self
367 }
368
369 pub fn max_topic_aliases(mut self, max_topic_aliases: u16) -> Self {
371 self.max_topic_aliases = max_topic_aliases;
372 self
373 }
374
375 pub fn limit_subscription(mut self, limit_subscription: bool) -> Self {
377 self.limit_subscription = limit_subscription;
378 self
379 }
380
381 pub fn delayed_publish(mut self, delayed_publish: bool) -> Self {
383 self.delayed_publish = delayed_publish;
384 self
385 }
386
387 pub fn tls_cross_certificate(mut self, cross_certificate: bool) -> Self {
389 self.tls_cross_certificate = cross_certificate;
390 self
391 }
392
393 pub fn tls_cert<N: Into<String>>(mut self, tls_cert: Option<N>) -> Self {
395 self.tls_cert = tls_cert.map(|c| c.into());
396 self
397 }
398
399 pub fn tls_key<N: Into<String>>(mut self, tls_key: Option<N>) -> Self {
401 self.tls_key = tls_key.map(|c| c.into());
402 self
403 }
404
405 #[allow(unused_variables)]
407 pub fn bind(self) -> Result<Listener> {
408 let builder = match self.laddr {
409 SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
410 SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
411 };
412
413 builder.set_linger(Some(Duration::from_secs(10)))?;
414
415 builder.set_nonblocking(true)?;
416
417 if let Some(reuseaddr) = self.reuseaddr {
418 builder.set_reuse_address(reuseaddr)?;
419 }
420
421 #[cfg(not(windows))]
422 if let Some(reuseport) = self.reuseport {
423 builder.set_reuse_port(reuseport)?;
424 }
425
426 builder.bind(&SockAddr::from(self.laddr))?;
427 builder.listen(self.backlog)?;
428 let tcp_listener = TcpListener::from_std(std::net::TcpListener::from(builder))?;
429
430 log::info!(
431 "MQTT Broker Listening on {} {}",
432 self.name,
433 tcp_listener.local_addr().unwrap_or(self.laddr)
434 );
435 Ok(Listener {
436 typ: ListenerType::TCP,
437 cfg: Arc::new(self),
438 tcp_listener,
439 #[cfg(feature = "tls")]
440 tls_acceptor: None,
441 })
442 }
443}
444
445#[derive(Debug, Copy, Clone)]
447pub enum ListenerType {
448 TCP,
450 #[cfg(feature = "tls")]
451 TLS,
453 #[cfg(feature = "ws")]
454 WS,
456 #[cfg(feature = "tls")]
457 #[cfg(feature = "ws")]
458 WSS,
460}
461
462pub struct Listener {
464 pub typ: ListenerType,
466 pub cfg: Arc<Builder>,
468 tcp_listener: TcpListener,
469 #[cfg(feature = "tls")]
470 tls_acceptor: Option<TlsAcceptor>,
471}
472
473impl Listener {
483 pub fn tcp(mut self) -> Result<Self> {
485 let _err = anyhow!("Protocol downgrade from TLS/WS/WSS to TCP is not permitted");
486 #[cfg(feature = "tls")]
487 if matches!(self.typ, ListenerType::TLS) {
488 return Err(_err);
489 }
490 #[cfg(feature = "tls")]
491 #[cfg(feature = "ws")]
492 if matches!(self.typ, ListenerType::WSS) {
493 return Err(_err);
494 }
495 #[cfg(feature = "ws")]
496 if matches!(self.typ, ListenerType::WS) {
497 return Err(_err);
498 }
499 self.typ = ListenerType::TCP;
500 Ok(self)
501 }
502
503 #[cfg(feature = "ws")]
504 pub fn ws(mut self) -> Result<Self> {
506 if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
507 self.typ = ListenerType::WS;
508 } else {
509 return Err(anyhow!("Protocol upgrade from TLS/WSS to WS is not permitted"));
510 }
511 Ok(self)
512 }
513
514 #[cfg(feature = "tls")]
515 #[cfg(feature = "ws")]
516 pub fn wss(mut self) -> Result<Self> {
518 if matches!(self.typ, ListenerType::TCP | ListenerType::WS) {
519 self = self.tls()?;
520 }
521 self.typ = ListenerType::WSS;
522 Ok(self)
523 }
524
525 #[cfg(feature = "tls")]
526 pub fn tls(mut self) -> Result<Listener> {
528 match self.typ {
529 #[cfg(feature = "ws")]
530 ListenerType::WS | ListenerType::WSS => {
531 return Err(anyhow!("Protocol downgrade from WS/WSS to TLS is not permitted"));
532 }
533 ListenerType::TLS => return Ok(self),
534 ListenerType::TCP => {}
535 }
536
537 let cert_file = self.cfg.tls_cert.as_ref().ok_or(anyhow!("TLS certificate path not set"))?;
538 let key_file = self.cfg.tls_key.as_ref().ok_or(anyhow!("TLS key path not set"))?;
539
540 let cert_chain = rustls::pki_types::CertificateDer::pem_file_iter(cert_file)
541 .map_err(|e| anyhow!(e))?
542 .collect::<std::result::Result<Vec<_>, _>>()
543 .map_err(|e| anyhow!(e))?;
544 let key = rustls::pki_types::PrivateKeyDer::from_pem_file(key_file).map_err(|e| anyhow!(e))?;
545
546 let provider = Arc::new(provider::default_provider());
547 let client_auth = if self.cfg.tls_cross_certificate {
548 let root_chain = cert_chain.clone();
549 let mut client_auth_roots = RootCertStore::empty();
550 for root in root_chain {
551 client_auth_roots.add(root).map_err(|e| anyhow!(e))?;
552 }
553 WebPkiClientVerifier::builder_with_provider(client_auth_roots.into(), provider.clone())
554 .build()
555 .map_err(|e| anyhow!(e))?
556 } else {
557 WebPkiClientVerifier::no_client_auth()
558 };
559
560 let tls_config = ServerConfig::builder_with_provider(provider)
561 .with_safe_default_protocol_versions()
562 .map_err(|e| anyhow!(e))?
563 .with_client_cert_verifier(client_auth)
564 .with_single_cert(cert_chain, key)
565 .map_err(|e| anyhow!(format!("Certificate error: {}", e)))?;
566
567 let acceptor = TlsAcceptor::from(Arc::new(tls_config));
568 self.tls_acceptor = Some(acceptor);
569 self.typ = ListenerType::TLS;
570 Ok(self)
571 }
572
573 pub async fn accept(&self) -> Result<Acceptor<TcpStream>> {
575 let (socket, remote_addr) = self.tcp_listener.accept().await?;
576 if let Err(e) = socket.set_nodelay(self.cfg.nodelay) {
577 return Err(Error::from(e));
578 }
579 Ok(Acceptor {
580 socket,
581 remote_addr,
582 #[cfg(feature = "tls")]
583 acceptor: self.tls_acceptor.clone(),
584 cfg: self.cfg.clone(),
585 typ: self.typ,
586 })
587 }
588
589 pub fn local_addr(&self) -> Result<SocketAddr> {
590 Ok(self.tcp_listener.local_addr()?)
591 }
592}
593
594pub struct Acceptor<S> {
596 pub(crate) socket: S,
598 #[cfg(feature = "tls")]
599 acceptor: Option<TlsAcceptor>,
600 pub remote_addr: SocketAddr,
602 pub cfg: Arc<Builder>,
604 pub typ: ListenerType,
606}
607
608impl<S> Acceptor<S>
609where
610 S: AsyncRead + AsyncWrite + Unpin,
611{
612 #[inline]
614 pub fn tcp(self) -> Result<Dispatcher<S>> {
615 if matches!(self.typ, ListenerType::TCP) {
616 Ok(Dispatcher::new(self.socket, self.remote_addr, self.cfg))
617 } else {
618 Err(anyhow!("Protocol mismatch: Expected TCP listener"))
619 }
620 }
621
622 #[cfg(feature = "tls")]
623 #[inline]
625 pub async fn tls(self) -> Result<Dispatcher<TlsStream<S>>> {
626 if !matches!(self.typ, ListenerType::TLS) {
627 return Err(anyhow!("Protocol mismatch: Expected TLS listener"));
628 }
629
630 let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
631 let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
632 {
633 Ok(Ok(tls_s)) => tls_s,
634 Ok(Err(e)) => return Err(e.into()),
635 Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
636 };
637 Ok(Dispatcher::new(tls_s, self.remote_addr, self.cfg))
638 }
639
640 #[cfg(feature = "ws")]
641 #[inline]
643 pub async fn ws(self) -> Result<Dispatcher<WsStream<S>>> {
644 if !matches!(self.typ, ListenerType::WS) {
645 return Err(anyhow!("Protocol mismatch: Expected WS listener"));
646 }
647
648 match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(self.socket, on_handshake))
649 .await
650 {
651 Ok(Ok(ws_stream)) => {
652 Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, self.cfg.clone()))
653 }
654 Ok(Err(e)) => Err(e.into()),
655 Err(_) => Err(crate::MqttError::ReadTimeout.into()),
656 }
657 }
658
659 #[cfg(feature = "tls")]
660 #[cfg(feature = "ws")]
661 #[inline]
663 pub async fn wss(self) -> Result<Dispatcher<WsStream<TlsStream<S>>>> {
664 if !matches!(self.typ, ListenerType::WSS) {
665 return Err(anyhow!("Protocol mismatch: Expected WSS listener"));
666 }
667
668 let acceptor = self.acceptor.ok_or_else(|| crate::MqttError::ServiceUnavailable)?;
669 let tls_s = match tokio::time::timeout(self.cfg.handshake_timeout, acceptor.accept(self.socket)).await
670 {
671 Ok(Ok(tls_s)) => tls_s,
672 Ok(Err(e)) => return Err(e.into()),
673 Err(_) => return Err(crate::MqttError::ReadTimeout.into()),
674 };
675 match tokio::time::timeout(self.cfg.handshake_timeout, accept_hdr_async(tls_s, on_handshake)).await {
676 Ok(Ok(ws_stream)) => {
677 Ok(Dispatcher::new(WsStream::new(ws_stream), self.remote_addr, self.cfg.clone()))
678 }
679 Ok(Err(e)) => Err(e.into()),
680 Err(_) => Err(crate::MqttError::ReadTimeout.into()),
681 }
682 }
683}
684
685#[allow(clippy::result_large_err)]
686#[cfg(feature = "ws")]
687fn on_handshake(req: &Request, mut response: Response) -> std::result::Result<Response, ErrorResponse> {
689 const PROTOCOL_ERROR: &str = "Missing required 'Sec-WebSocket-Protocol: mqtt' header";
690 let mqtt_protocol = req
691 .headers()
692 .get("Sec-WebSocket-Protocol")
693 .ok_or_else(|| ErrorResponse::new(Some(PROTOCOL_ERROR.into())))?;
694 if mqtt_protocol != "mqtt" {
695 return Err(ErrorResponse::new(Some(PROTOCOL_ERROR.into())));
696 }
697 response.headers_mut().append(
698 "Sec-WebSocket-Protocol",
699 "mqtt".parse().map_err(|_| ErrorResponse::new(Some("InvalidHeaderValue".into())))?,
700 );
701 Ok(response)
702}