1use std::io;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
17use tokio::net::{TcpListener, TcpStream};
18use tokio::sync::mpsc;
19use tokio_rustls::server::TlsStream;
20use tracing::{debug, error, info, warn};
21
22use crate::broker::MqttConfig;
23use crate::metrics::MqttMetrics;
24use crate::protocol::{
25 ConnackCode, Packet, PacketDecoder, PacketEncoder, ProtocolError, PublishPacket, QoS,
26};
27use crate::session::{
28 build_connack, build_puback, build_pubcomp, build_pubrec, build_pubrel, build_suback,
29 build_unsuback, SessionManager,
30};
31use crate::tls::{create_tls_acceptor_with_client_auth, TlsError};
32
33const READ_BUFFER_SIZE: usize = 64 * 1024; const CLIENT_CHANNEL_CAPACITY: usize = 256;
37const CLEANUP_INTERVAL_SECS: u64 = 30;
39
40pub enum MqttStream {
42 Plain(TcpStream),
44 Tls(TlsStream<TcpStream>),
46}
47
48impl AsyncRead for MqttStream {
49 fn poll_read(
50 self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 buf: &mut ReadBuf<'_>,
53 ) -> Poll<io::Result<()>> {
54 match self.get_mut() {
55 MqttStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
56 MqttStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
57 }
58 }
59}
60
61impl AsyncWrite for MqttStream {
62 fn poll_write(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &[u8],
66 ) -> Poll<io::Result<usize>> {
67 match self.get_mut() {
68 MqttStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
69 MqttStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
70 }
71 }
72
73 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74 match self.get_mut() {
75 MqttStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
76 MqttStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
77 }
78 }
79
80 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81 match self.get_mut() {
82 MqttStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
83 MqttStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
84 }
85 }
86}
87
88pub struct MqttServer {
90 session_manager: Arc<SessionManager>,
91 metrics: Arc<MqttMetrics>,
92}
93
94impl MqttServer {
95 pub fn new(config: &MqttConfig, metrics: Arc<MqttMetrics>) -> Self {
97 Self {
98 session_manager: Arc::new(SessionManager::new(
99 config.max_connections,
100 Some(metrics.clone()),
101 )),
102 metrics,
103 }
104 }
105
106 pub fn session_manager(&self) -> Arc<SessionManager> {
108 self.session_manager.clone()
109 }
110
111 pub fn metrics(&self) -> Arc<MqttMetrics> {
113 self.metrics.clone()
114 }
115}
116
117pub async fn start_mqtt_server(
123 config: MqttConfig,
124) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
125 let metrics = Arc::new(MqttMetrics::new());
126 start_mqtt_server_with_metrics(config, metrics).await
127}
128
129pub async fn start_mqtt_server_with_metrics(
131 config: MqttConfig,
132 metrics: Arc<MqttMetrics>,
133) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
134 let addr = format!("{}:{}", config.host, config.port);
135
136 info!(
137 "Starting MQTT broker on {}:{} (MQTT {:?})",
138 config.host, config.port, config.version
139 );
140
141 let listener = TcpListener::bind(&addr).await?;
142 let session_manager =
143 Arc::new(SessionManager::new(config.max_connections, Some(metrics.clone())));
144
145 info!(
146 "MQTT broker listening on {}:{} (MQTT {:?})",
147 config.host, config.port, config.version
148 );
149
150 let cleanup_manager = session_manager.clone();
152 tokio::spawn(async move {
153 let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_INTERVAL_SECS));
154 loop {
155 interval.tick().await;
156 let expired = cleanup_manager.cleanup_expired_sessions().await;
157 if !expired.is_empty() {
158 debug!("Cleaned up {} expired sessions", expired.len());
159 }
160 }
161 });
162
163 loop {
165 match listener.accept().await {
166 Ok((socket, addr)) => {
167 info!("New MQTT connection from {}", addr);
168
169 let session_manager = session_manager.clone();
170 let metrics = metrics.clone();
171 let max_packet_size = config.max_packet_size;
172
173 tokio::spawn(async move {
174 if let Err(e) =
175 handle_connection(socket, addr, session_manager, metrics, max_packet_size)
176 .await
177 {
178 warn!("Connection error from {}: {}", addr, e);
179 }
180 });
181 }
182 Err(e) => {
183 error!("Error accepting MQTT connection: {}", e);
184 }
185 }
186 }
187}
188
189pub async fn start_mqtt_tls_server(config: MqttConfig) -> Result<(), TlsError> {
194 let metrics = Arc::new(MqttMetrics::new());
195 start_mqtt_tls_server_with_metrics(config, metrics).await
196}
197
198pub async fn start_mqtt_tls_server_with_metrics(
200 config: MqttConfig,
201 metrics: Arc<MqttMetrics>,
202) -> Result<(), TlsError> {
203 if !config.tls_enabled {
204 return Err(TlsError::ConfigError("TLS is not enabled in configuration".to_string()));
205 }
206
207 let tls_acceptor = create_tls_acceptor_with_client_auth(&config)?;
208 let addr = format!("{}:{}", config.host, config.tls_port);
209
210 let listener = TcpListener::bind(&addr)
211 .await
212 .map_err(|e| TlsError::ConfigError(format!("Failed to bind to {}: {}", addr, e)))?;
213
214 info!(
215 "Starting MQTTS broker with TLS on {}:{} (MQTT {:?})",
216 config.host, config.tls_port, config.version
217 );
218
219 let session_manager =
220 Arc::new(SessionManager::new(config.max_connections, Some(metrics.clone())));
221
222 let cleanup_manager = session_manager.clone();
224 tokio::spawn(async move {
225 let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_INTERVAL_SECS));
226 loop {
227 interval.tick().await;
228 let expired = cleanup_manager.cleanup_expired_sessions().await;
229 if !expired.is_empty() {
230 debug!("Cleaned up {} expired sessions", expired.len());
231 }
232 }
233 });
234
235 loop {
237 match listener.accept().await {
238 Ok((socket, addr)) => {
239 info!("New MQTTS connection from {}", addr);
240
241 let tls_acceptor = tls_acceptor.clone();
242 let session_manager = session_manager.clone();
243 let metrics = metrics.clone();
244 let max_packet_size = config.max_packet_size;
245
246 tokio::spawn(async move {
247 match tls_acceptor.accept(socket).await {
248 Ok(tls_stream) => {
249 if let Err(e) = handle_tls_connection(
250 tls_stream,
251 addr,
252 session_manager,
253 metrics,
254 max_packet_size,
255 )
256 .await
257 {
258 warn!("TLS connection error from {}: {}", addr, e);
259 }
260 }
261 Err(e) => {
262 warn!("TLS handshake failed from {}: {}", addr, e);
263 }
264 }
265 });
266 }
267 Err(e) => {
268 error!("Error accepting MQTTS connection: {}", e);
269 }
270 }
271 }
272}
273
274pub async fn start_mqtt_dual_server(
276 config: MqttConfig,
277) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
278 let metrics = Arc::new(MqttMetrics::new());
279 start_mqtt_dual_server_with_metrics(config, metrics).await
280}
281
282pub async fn start_mqtt_dual_server_with_metrics(
284 config: MqttConfig,
285 metrics: Arc<MqttMetrics>,
286) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
287 let session_manager =
288 Arc::new(SessionManager::new(config.max_connections, Some(metrics.clone())));
289
290 let cleanup_manager = session_manager.clone();
292 tokio::spawn(async move {
293 let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_INTERVAL_SECS));
294 loop {
295 interval.tick().await;
296 let expired = cleanup_manager.cleanup_expired_sessions().await;
297 if !expired.is_empty() {
298 debug!("Cleaned up {} expired sessions", expired.len());
299 }
300 }
301 });
302
303 let plain_addr = format!("{}:{}", config.host, config.port);
305 let plain_listener = TcpListener::bind(&plain_addr).await?;
306 info!("Starting MQTT broker on {} (MQTT {:?})", plain_addr, config.version);
307
308 let plain_session_manager = session_manager.clone();
309 let plain_metrics = metrics.clone();
310 let plain_max_packet_size = config.max_packet_size;
311
312 tokio::spawn(async move {
313 loop {
314 match plain_listener.accept().await {
315 Ok((socket, addr)) => {
316 info!("New MQTT connection from {}", addr);
317
318 let session_manager = plain_session_manager.clone();
319 let metrics = plain_metrics.clone();
320
321 tokio::spawn(async move {
322 if let Err(e) = handle_connection(
323 socket,
324 addr,
325 session_manager,
326 metrics,
327 plain_max_packet_size,
328 )
329 .await
330 {
331 warn!("Connection error from {}: {}", addr, e);
332 }
333 });
334 }
335 Err(e) => {
336 error!("Error accepting MQTT connection: {}", e);
337 }
338 }
339 }
340 });
341
342 if config.tls_enabled {
344 let tls_acceptor = create_tls_acceptor_with_client_auth(&config)?;
345 let tls_addr = format!("{}:{}", config.host, config.tls_port);
346 let tls_listener = TcpListener::bind(&tls_addr).await?;
347 info!("Starting MQTTS broker with TLS on {}", tls_addr);
348
349 let tls_session_manager = session_manager.clone();
350 let tls_metrics = metrics.clone();
351 let tls_max_packet_size = config.max_packet_size;
352
353 tokio::spawn(async move {
354 loop {
355 match tls_listener.accept().await {
356 Ok((socket, addr)) => {
357 info!("New MQTTS connection from {}", addr);
358
359 let tls_acceptor = tls_acceptor.clone();
360 let session_manager = tls_session_manager.clone();
361 let metrics = tls_metrics.clone();
362
363 tokio::spawn(async move {
364 match tls_acceptor.accept(socket).await {
365 Ok(tls_stream) => {
366 if let Err(e) = handle_tls_connection(
367 tls_stream,
368 addr,
369 session_manager,
370 metrics,
371 tls_max_packet_size,
372 )
373 .await
374 {
375 warn!("TLS connection error from {}: {}", addr, e);
376 }
377 }
378 Err(e) => {
379 warn!("TLS handshake failed from {}: {}", addr, e);
380 }
381 }
382 });
383 }
384 Err(e) => {
385 error!("Error accepting MQTTS connection: {}", e);
386 }
387 }
388 }
389 });
390 }
391
392 loop {
394 tokio::time::sleep(Duration::from_secs(3600)).await;
395 }
396}
397
398async fn handle_tls_connection(
400 stream: TlsStream<TcpStream>,
401 addr: std::net::SocketAddr,
402 session_manager: Arc<SessionManager>,
403 metrics: Arc<MqttMetrics>,
404 max_packet_size: usize,
405) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
406 let (read_half, write_half) = tokio::io::split(stream);
407 let mut reader = tokio::io::BufReader::new(read_half);
408 let mut writer = write_half;
409
410 let mut buffer = vec![0u8; READ_BUFFER_SIZE.min(max_packet_size)];
412 let mut buf_len = 0usize;
413
414 let mut client_id: Option<String> = None;
416 let mut packet_rx: Option<mpsc::Receiver<Packet>> = None;
417
418 let connect_timeout = Duration::from_secs(10);
420 let first_read = tokio::time::timeout(connect_timeout, reader.read(&mut buffer[buf_len..]))
421 .await
422 .map_err(|_| "Connection timeout waiting for CONNECT")?;
423
424 match first_read {
425 Ok(0) => {
426 debug!("TLS client {} closed connection before CONNECT", addr);
427 return Ok(());
428 }
429 Ok(n) => buf_len += n,
430 Err(e) => return Err(e.into()),
431 }
432
433 let (connect_packet, consumed) = match PacketDecoder::decode(&buffer[..buf_len])? {
435 Some((Packet::Connect(connect), consumed)) => (connect, consumed),
436 Some((_, _)) => {
437 warn!("First packet from TLS client {} was not CONNECT", addr);
438 let connack = build_connack(false, ConnackCode::NotAuthorized);
439 let bytes = PacketEncoder::encode(&connack)?;
440 writer.write_all(&bytes).await?;
441 return Err("Expected CONNECT packet".into());
442 }
443 None => {
444 return Err("Incomplete CONNECT packet".into());
445 }
446 };
447
448 buffer.copy_within(consumed..buf_len, 0);
450 buf_len -= consumed;
451
452 let cid = if connect_packet.client_id.is_empty() {
454 if connect_packet.clean_session {
455 format!("auto-tls-{}", uuid::Uuid::new_v4())
456 } else {
457 let connack = build_connack(false, ConnackCode::IdentifierRejected);
458 let bytes = PacketEncoder::encode(&connack)?;
459 writer.write_all(&bytes).await?;
460 return Err("Empty client ID with clean_session=false".into());
461 }
462 } else {
463 connect_packet.client_id.clone()
464 };
465
466 info!(
467 "TLS CONNECT from {} (client_id={}, clean_session={})",
468 addr, cid, connect_packet.clean_session
469 );
470
471 let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY);
473 packet_rx = Some(rx);
474
475 let connect_result = session_manager
477 .connect(cid.clone(), connect_packet.clean_session, connect_packet.keep_alive, tx)
478 .await;
479
480 let session_present = match connect_result {
481 Ok((session_present, code)) => {
482 let connack = build_connack(session_present, code);
483 let bytes = PacketEncoder::encode(&connack)?;
484 writer.write_all(&bytes).await?;
485 session_present
486 }
487 Err(code) => {
488 let connack = build_connack(false, code);
489 let bytes = PacketEncoder::encode(&connack)?;
490 writer.write_all(&bytes).await?;
491 return Err(format!("Connection rejected: {:?}", code).into());
492 }
493 };
494
495 client_id = Some(cid.clone());
496
497 info!("TLS client {} connected (session_present={})", cid, session_present);
498
499 if session_present {
501 let subscriptions = session_manager.get_client_subscriptions(&cid).await;
502
503 for (filter, _sub_qos) in subscriptions {
504 let retained = session_manager.get_retained_messages(&filter).await;
505 for (topic, mut publish) in retained {
506 if publish.qos != QoS::AtMostOnce {
507 if let Some(id) = session_manager.assign_packet_id(&cid).await {
508 publish.packet_id = Some(id);
509 }
510 }
511
512 let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
513 writer.write_all(&bytes).await?;
514
515 debug!(
516 "Delivered retained message for topic {} to restored TLS session {}",
517 topic, cid
518 );
519 }
520 }
521 }
522
523 let mut rx = packet_rx.take().unwrap();
525 let result = handle_tls_client_loop(
526 &mut reader,
527 &mut writer,
528 &mut rx,
529 &cid,
530 &session_manager,
531 &metrics,
532 &mut buffer,
533 &mut buf_len,
534 max_packet_size,
535 )
536 .await;
537
538 session_manager.disconnect(&cid).await;
540 info!("TLS client {} disconnected", cid);
541
542 result
543}
544
545async fn handle_tls_client_loop<R, W>(
547 reader: &mut tokio::io::BufReader<R>,
548 writer: &mut W,
549 packet_rx: &mut mpsc::Receiver<Packet>,
550 client_id: &str,
551 session_manager: &Arc<SessionManager>,
552 metrics: &Arc<MqttMetrics>,
553 buffer: &mut Vec<u8>,
554 buf_len: &mut usize,
555 max_packet_size: usize,
556) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
557where
558 R: AsyncRead + Unpin,
559 W: AsyncWrite + Unpin,
560{
561 loop {
562 tokio::select! {
563 read_result = reader.read(&mut buffer[*buf_len..]) => {
565 match read_result {
566 Ok(0) => {
567 debug!("TLS client {} closed connection", client_id);
568 return Ok(());
569 }
570 Ok(n) => {
571 *buf_len += n;
572
573 if *buf_len > max_packet_size {
575 warn!("TLS client {} sent oversized packet", client_id);
576 metrics.record_error("oversized_packet");
577 return Err("Packet too large".into());
578 }
579
580 while let Some((packet, consumed)) = PacketDecoder::decode(&buffer[..*buf_len])? {
582 buffer.copy_within(consumed..*buf_len, 0);
584 *buf_len -= consumed;
585
586 match handle_tls_packet(
588 client_id,
589 packet,
590 writer,
591 session_manager,
592 metrics,
593 ).await {
594 Ok(true) => continue,
595 Ok(false) => return Ok(()), Err(e) => {
597 warn!("Error handling packet from TLS client {}: {}", client_id, e);
598 metrics.record_error(&e.to_string());
599 }
600 }
601 }
602 }
603 Err(e) => {
604 return Err(e.into());
605 }
606 }
607 }
608
609 packet = packet_rx.recv() => {
611 match packet {
612 Some(Packet::Disconnect) => {
613 debug!("Sending disconnect to TLS client {}", client_id);
614 return Ok(());
615 }
616 Some(mut packet) => {
617 if let Packet::Publish(ref mut publish) = packet {
619 if publish.qos != QoS::AtMostOnce && publish.packet_id.is_none() {
620 if let Some(id) = session_manager.assign_packet_id(client_id).await {
621 publish.packet_id = Some(id);
622 }
623 }
624 }
625
626 let bytes = PacketEncoder::encode(&packet)?;
627 writer.write_all(&bytes).await?;
628 }
629 None => {
630 debug!("Channel closed for TLS client {}", client_id);
631 return Ok(());
632 }
633 }
634 }
635 }
636 }
637}
638
639async fn handle_tls_packet<W>(
641 client_id: &str,
642 packet: Packet,
643 writer: &mut W,
644 session_manager: &Arc<SessionManager>,
645 metrics: &Arc<MqttMetrics>,
646) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>
647where
648 W: AsyncWrite + Unpin,
649{
650 match packet {
651 Packet::Connect(_) => {
652 warn!("TLS client {} sent second CONNECT packet", client_id);
653 return Ok(false);
654 }
655
656 Packet::Publish(publish) => {
657 debug!(
658 "PUBLISH from TLS client {} to topic {} (QoS {:?})",
659 client_id, publish.topic, publish.qos
660 );
661
662 match publish.qos {
663 QoS::AtMostOnce => {}
664 QoS::AtLeastOnce => {
665 if let Some(packet_id) = publish.packet_id {
666 let puback = build_puback(packet_id);
667 let bytes = PacketEncoder::encode(&puback)?;
668 writer.write_all(&bytes).await?;
669 }
670 }
671 QoS::ExactlyOnce => {
672 if let Some(packet_id) = publish.packet_id {
673 session_manager.start_qos2_inbound(client_id, packet_id).await;
674 let pubrec = build_pubrec(packet_id);
675 let bytes = PacketEncoder::encode(&pubrec)?;
676 writer.write_all(&bytes).await?;
677 session_manager.mark_pubrec_sent(client_id, packet_id).await;
678 }
679 }
680 }
681
682 session_manager.publish(client_id, &publish).await;
683 }
684
685 Packet::Puback(puback) => {
686 debug!("PUBACK from TLS client {} for packet {}", client_id, puback.packet_id);
687 session_manager.handle_puback(client_id, puback.packet_id).await;
688 }
689
690 Packet::Pubrec(pubrec) => {
691 debug!("PUBREC from TLS client {} for packet {}", client_id, pubrec.packet_id);
692 if session_manager.handle_pubrec(client_id, pubrec.packet_id).await {
693 let pubrel = build_pubrel(pubrec.packet_id);
694 let bytes = PacketEncoder::encode(&pubrel)?;
695 writer.write_all(&bytes).await?;
696 }
697 }
698
699 Packet::Pubrel(pubrel) => {
700 debug!("PUBREL from TLS client {} for packet {}", client_id, pubrel.packet_id);
701 if session_manager.handle_pubrel(client_id, pubrel.packet_id).await {
702 let pubcomp = build_pubcomp(pubrel.packet_id);
703 let bytes = PacketEncoder::encode(&pubcomp)?;
704 writer.write_all(&bytes).await?;
705 session_manager.complete_qos2_inbound(client_id, pubrel.packet_id).await;
706 }
707 }
708
709 Packet::Pubcomp(pubcomp) => {
710 debug!("PUBCOMP from TLS client {} for packet {}", client_id, pubcomp.packet_id);
711 session_manager.handle_pubcomp(client_id, pubcomp.packet_id).await;
712 }
713
714 Packet::Subscribe(subscribe) => {
715 debug!(
716 "SUBSCRIBE from TLS client {} for {} topics",
717 client_id,
718 subscribe.subscriptions.len()
719 );
720
721 if let Some(return_codes) =
722 session_manager.subscribe(client_id, subscribe.subscriptions.clone()).await
723 {
724 let suback = build_suback(subscribe.packet_id, return_codes);
725 let bytes = PacketEncoder::encode(&suback)?;
726 writer.write_all(&bytes).await?;
727
728 for (filter, _) in &subscribe.subscriptions {
729 let retained = session_manager.get_retained_messages(filter).await;
730 for (topic, mut publish) in retained {
731 if publish.qos != QoS::AtMostOnce {
732 if let Some(id) = session_manager.assign_packet_id(client_id).await {
733 publish.packet_id = Some(id);
734 }
735 }
736 let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
737 writer.write_all(&bytes).await?;
738 debug!(
739 "Sent retained message for topic {} to TLS client {}",
740 topic, client_id
741 );
742 }
743 }
744 }
745 }
746
747 Packet::Unsubscribe(unsubscribe) => {
748 debug!(
749 "UNSUBSCRIBE from TLS client {} for {} topics",
750 client_id,
751 unsubscribe.topics.len()
752 );
753
754 session_manager.unsubscribe(client_id, unsubscribe.topics).await;
755
756 let unsuback = build_unsuback(unsubscribe.packet_id);
757 let bytes = PacketEncoder::encode(&unsuback)?;
758 writer.write_all(&bytes).await?;
759 }
760
761 Packet::Pingreq => {
762 debug!("PINGREQ from TLS client {}", client_id);
763 session_manager.touch(client_id).await;
764
765 let pingresp = Packet::Pingresp;
766 let bytes = PacketEncoder::encode(&pingresp)?;
767 writer.write_all(&bytes).await?;
768 }
769
770 Packet::Disconnect => {
771 info!("DISCONNECT from TLS client {}", client_id);
772 return Ok(false);
773 }
774
775 Packet::Connack(_) | Packet::Suback(_) | Packet::Unsuback(_) | Packet::Pingresp => {
776 warn!("TLS client {} sent unexpected packet type: {:?}", client_id, packet);
777 metrics.record_error("unexpected_packet_type");
778 }
779 }
780
781 Ok(true)
782}
783
784async fn handle_connection(
786 socket: tokio::net::TcpStream,
787 addr: std::net::SocketAddr,
788 session_manager: Arc<SessionManager>,
789 metrics: Arc<MqttMetrics>,
790 max_packet_size: usize,
791) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
792 let (read_half, mut write_half) = socket.into_split();
793 let mut reader = tokio::io::BufReader::new(read_half);
794
795 let mut buffer = vec![0u8; READ_BUFFER_SIZE.min(max_packet_size)];
797 let mut buf_len = 0usize;
798
799 let mut client_id: Option<String> = None;
801 let mut packet_rx: Option<mpsc::Receiver<Packet>> = None;
802
803 let connect_timeout = Duration::from_secs(10);
805 let first_read = tokio::time::timeout(connect_timeout, reader.read(&mut buffer[buf_len..]))
806 .await
807 .map_err(|_| "Connection timeout waiting for CONNECT")?;
808
809 match first_read {
810 Ok(0) => {
811 debug!("Client {} closed connection before CONNECT", addr);
812 return Ok(());
813 }
814 Ok(n) => buf_len += n,
815 Err(e) => return Err(e.into()),
816 }
817
818 let (connect_packet, consumed) = match PacketDecoder::decode(&buffer[..buf_len])? {
820 Some((Packet::Connect(connect), consumed)) => (connect, consumed),
821 Some((_, _)) => {
822 warn!("First packet from {} was not CONNECT", addr);
823 let connack = build_connack(false, ConnackCode::NotAuthorized);
824 let bytes = PacketEncoder::encode(&connack)?;
825 write_half.write_all(&bytes).await?;
826 return Err("Expected CONNECT packet".into());
827 }
828 None => {
829 return Err("Incomplete CONNECT packet".into());
830 }
831 };
832
833 buffer.copy_within(consumed..buf_len, 0);
835 buf_len -= consumed;
836
837 let cid = if connect_packet.client_id.is_empty() {
839 if connect_packet.clean_session {
841 format!("auto-{}", uuid::Uuid::new_v4())
842 } else {
843 let connack = build_connack(false, ConnackCode::IdentifierRejected);
844 let bytes = PacketEncoder::encode(&connack)?;
845 write_half.write_all(&bytes).await?;
846 return Err("Empty client ID with clean_session=false".into());
847 }
848 } else {
849 connect_packet.client_id.clone()
850 };
851
852 info!(
853 "CONNECT from {} (client_id={}, clean_session={})",
854 addr, cid, connect_packet.clean_session
855 );
856
857 let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY);
859 packet_rx = Some(rx);
860
861 let connect_result = session_manager
863 .connect(cid.clone(), connect_packet.clean_session, connect_packet.keep_alive, tx)
864 .await;
865
866 let session_present = match connect_result {
867 Ok((session_present, code)) => {
868 let connack = build_connack(session_present, code);
869 let bytes = PacketEncoder::encode(&connack)?;
870 write_half.write_all(&bytes).await?;
871 session_present
872 }
873 Err(code) => {
874 let connack = build_connack(false, code);
875 let bytes = PacketEncoder::encode(&connack)?;
876 write_half.write_all(&bytes).await?;
877 return Err(format!("Connection rejected: {:?}", code).into());
878 }
879 };
880
881 client_id = Some(cid.clone());
882
883 info!("Client {} connected (session_present={})", cid, session_present);
884
885 if session_present {
887 let subscriptions = session_manager.get_client_subscriptions(&cid).await;
889
890 for (filter, _sub_qos) in subscriptions {
891 let retained = session_manager.get_retained_messages(&filter).await;
892 for (topic, mut publish) in retained {
893 if publish.qos != QoS::AtMostOnce {
895 if let Some(id) = session_manager.assign_packet_id(&cid).await {
896 publish.packet_id = Some(id);
897 }
898 }
899
900 let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
901 write_half.write_all(&bytes).await?;
902
903 debug!(
904 "Delivered retained message for topic {} to restored session {}",
905 topic, cid
906 );
907 }
908 }
909 }
910
911 let mut rx = packet_rx.take().unwrap();
913 let result = handle_client_loop(
914 &mut reader,
915 &mut write_half,
916 &mut rx,
917 &cid,
918 &session_manager,
919 &metrics,
920 &mut buffer,
921 &mut buf_len,
922 max_packet_size,
923 )
924 .await;
925
926 session_manager.disconnect(&cid).await;
928 info!("Client {} disconnected", cid);
929
930 result
931}
932
933async fn handle_client_loop(
935 reader: &mut tokio::io::BufReader<tokio::net::tcp::OwnedReadHalf>,
936 writer: &mut tokio::net::tcp::OwnedWriteHalf,
937 packet_rx: &mut mpsc::Receiver<Packet>,
938 client_id: &str,
939 session_manager: &Arc<SessionManager>,
940 metrics: &Arc<MqttMetrics>,
941 buffer: &mut Vec<u8>,
942 buf_len: &mut usize,
943 max_packet_size: usize,
944) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
945 loop {
946 tokio::select! {
947 read_result = reader.read(&mut buffer[*buf_len..]) => {
949 match read_result {
950 Ok(0) => {
951 debug!("Client {} closed connection", client_id);
952 return Ok(());
953 }
954 Ok(n) => {
955 *buf_len += n;
956
957 if *buf_len > max_packet_size {
959 warn!("Client {} sent oversized packet", client_id);
960 metrics.record_error("oversized_packet");
961 return Err("Packet too large".into());
962 }
963
964 while let Some((packet, consumed)) = PacketDecoder::decode(&buffer[..*buf_len])? {
966 buffer.copy_within(consumed..*buf_len, 0);
968 *buf_len -= consumed;
969
970 match handle_packet(
972 client_id,
973 packet,
974 writer,
975 session_manager,
976 metrics,
977 ).await {
978 Ok(true) => continue,
979 Ok(false) => return Ok(()), Err(e) => {
981 warn!("Error handling packet from {}: {}", client_id, e);
982 metrics.record_error(&e.to_string());
983 }
984 }
985 }
986 }
987 Err(e) => {
988 return Err(e.into());
989 }
990 }
991 }
992
993 packet = packet_rx.recv() => {
995 match packet {
996 Some(Packet::Disconnect) => {
997 debug!("Sending disconnect to {}", client_id);
998 return Ok(());
999 }
1000 Some(mut packet) => {
1001 if let Packet::Publish(ref mut publish) = packet {
1003 if publish.qos != QoS::AtMostOnce && publish.packet_id.is_none() {
1004 if let Some(id) = session_manager.assign_packet_id(client_id).await {
1005 publish.packet_id = Some(id);
1006 }
1007 }
1008 }
1009
1010 let bytes = PacketEncoder::encode(&packet)?;
1011 writer.write_all(&bytes).await?;
1012 }
1013 None => {
1014 debug!("Channel closed for {}", client_id);
1015 return Ok(());
1016 }
1017 }
1018 }
1019 }
1020 }
1021}
1022
1023async fn handle_packet(
1025 client_id: &str,
1026 packet: Packet,
1027 writer: &mut tokio::net::tcp::OwnedWriteHalf,
1028 session_manager: &Arc<SessionManager>,
1029 metrics: &Arc<MqttMetrics>,
1030) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
1031 match packet {
1032 Packet::Connect(_) => {
1033 warn!("Client {} sent second CONNECT packet", client_id);
1035 return Ok(false);
1036 }
1037
1038 Packet::Publish(publish) => {
1039 debug!("PUBLISH from {} to topic {} (QoS {:?})", client_id, publish.topic, publish.qos);
1040
1041 match publish.qos {
1043 QoS::AtMostOnce => {
1044 }
1046 QoS::AtLeastOnce => {
1047 if let Some(packet_id) = publish.packet_id {
1049 let puback = build_puback(packet_id);
1050 let bytes = PacketEncoder::encode(&puback)?;
1051 writer.write_all(&bytes).await?;
1052 }
1053 }
1054 QoS::ExactlyOnce => {
1055 if let Some(packet_id) = publish.packet_id {
1057 session_manager.start_qos2_inbound(client_id, packet_id).await;
1058 let pubrec = build_pubrec(packet_id);
1059 let bytes = PacketEncoder::encode(&pubrec)?;
1060 writer.write_all(&bytes).await?;
1061 session_manager.mark_pubrec_sent(client_id, packet_id).await;
1062 }
1063 }
1064 }
1065
1066 session_manager.publish(client_id, &publish).await;
1068 }
1069
1070 Packet::Puback(puback) => {
1071 debug!("PUBACK from {} for packet {}", client_id, puback.packet_id);
1072 session_manager.handle_puback(client_id, puback.packet_id).await;
1073 }
1074
1075 Packet::Pubrec(pubrec) => {
1076 debug!("PUBREC from {} for packet {}", client_id, pubrec.packet_id);
1077 if session_manager.handle_pubrec(client_id, pubrec.packet_id).await {
1078 let pubrel = build_pubrel(pubrec.packet_id);
1079 let bytes = PacketEncoder::encode(&pubrel)?;
1080 writer.write_all(&bytes).await?;
1081 }
1082 }
1083
1084 Packet::Pubrel(pubrel) => {
1085 debug!("PUBREL from {} for packet {}", client_id, pubrel.packet_id);
1086 if session_manager.handle_pubrel(client_id, pubrel.packet_id).await {
1087 let pubcomp = build_pubcomp(pubrel.packet_id);
1088 let bytes = PacketEncoder::encode(&pubcomp)?;
1089 writer.write_all(&bytes).await?;
1090 session_manager.complete_qos2_inbound(client_id, pubrel.packet_id).await;
1091 }
1092 }
1093
1094 Packet::Pubcomp(pubcomp) => {
1095 debug!("PUBCOMP from {} for packet {}", client_id, pubcomp.packet_id);
1096 session_manager.handle_pubcomp(client_id, pubcomp.packet_id).await;
1097 }
1098
1099 Packet::Subscribe(subscribe) => {
1100 debug!("SUBSCRIBE from {} for {} topics", client_id, subscribe.subscriptions.len());
1101
1102 if let Some(return_codes) =
1103 session_manager.subscribe(client_id, subscribe.subscriptions.clone()).await
1104 {
1105 let suback = build_suback(subscribe.packet_id, return_codes);
1107 let bytes = PacketEncoder::encode(&suback)?;
1108 writer.write_all(&bytes).await?;
1109
1110 for (filter, _) in &subscribe.subscriptions {
1112 let retained = session_manager.get_retained_messages(filter).await;
1113 for (topic, mut publish) in retained {
1114 if publish.qos != QoS::AtMostOnce {
1116 if let Some(id) = session_manager.assign_packet_id(client_id).await {
1117 publish.packet_id = Some(id);
1118 }
1119 }
1120 let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
1121 writer.write_all(&bytes).await?;
1122 debug!("Sent retained message for topic {} to {}", topic, client_id);
1123 }
1124 }
1125 }
1126 }
1127
1128 Packet::Unsubscribe(unsubscribe) => {
1129 debug!("UNSUBSCRIBE from {} for {} topics", client_id, unsubscribe.topics.len());
1130
1131 session_manager.unsubscribe(client_id, unsubscribe.topics).await;
1132
1133 let unsuback = build_unsuback(unsubscribe.packet_id);
1135 let bytes = PacketEncoder::encode(&unsuback)?;
1136 writer.write_all(&bytes).await?;
1137 }
1138
1139 Packet::Pingreq => {
1140 debug!("PINGREQ from {}", client_id);
1141 session_manager.touch(client_id).await;
1142
1143 let pingresp = Packet::Pingresp;
1145 let bytes = PacketEncoder::encode(&pingresp)?;
1146 writer.write_all(&bytes).await?;
1147 }
1148
1149 Packet::Disconnect => {
1150 info!("DISCONNECT from {}", client_id);
1151 return Ok(false);
1152 }
1153
1154 Packet::Connack(_) | Packet::Suback(_) | Packet::Unsuback(_) | Packet::Pingresp => {
1156 warn!("Client {} sent unexpected packet type: {:?}", client_id, packet);
1157 metrics.record_error("unexpected_packet_type");
1158 }
1159 }
1160
1161 Ok(true)
1162}
1163
1164#[cfg(test)]
1165mod tests {
1166 use super::*;
1167 use crate::broker::MqttVersion;
1168
1169 #[test]
1170 fn test_mqtt_config_address_formatting() {
1171 let config = MqttConfig {
1172 host: "127.0.0.1".to_string(),
1173 port: 1883,
1174 ..Default::default()
1175 };
1176 let addr = format!("{}:{}", config.host, config.port);
1177 assert_eq!(addr, "127.0.0.1:1883");
1178 }
1179
1180 #[test]
1181 fn test_mqtt_config_default_host_port() {
1182 let config = MqttConfig::default();
1183 let addr = format!("{}:{}", config.host, config.port);
1184 assert_eq!(addr, "0.0.0.0:1883");
1185 }
1186
1187 #[test]
1188 fn test_mqtt_config_custom_port() {
1189 let config = MqttConfig {
1190 port: 8883,
1191 ..Default::default()
1192 };
1193 assert_eq!(config.port, 8883);
1194 }
1195
1196 #[test]
1197 fn test_mqtt_config_version_v3() {
1198 let config = MqttConfig {
1199 version: MqttVersion::V3_1_1,
1200 ..Default::default()
1201 };
1202 assert!(matches!(config.version, MqttVersion::V3_1_1));
1203 }
1204
1205 #[test]
1206 fn test_mqtt_config_version_v5() {
1207 let config = MqttConfig {
1208 version: MqttVersion::V5_0,
1209 ..Default::default()
1210 };
1211 assert!(matches!(config.version, MqttVersion::V5_0));
1212 }
1213
1214 #[tokio::test]
1215 async fn test_tcp_listener_bind_localhost() {
1216 let config = MqttConfig {
1217 host: "127.0.0.1".to_string(),
1218 port: 0, ..Default::default()
1220 };
1221 let addr = format!("{}:{}", config.host, config.port);
1222
1223 let listener = TcpListener::bind(&addr).await;
1225 assert!(listener.is_ok());
1226 }
1227
1228 #[tokio::test]
1229 async fn test_tcp_listener_local_addr() {
1230 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1231 let addr = listener.local_addr().unwrap();
1232 assert_eq!(addr.ip().to_string(), "127.0.0.1");
1233 assert!(addr.port() > 0);
1234 }
1235
1236 #[test]
1237 fn test_mqtt_version_debug_format() {
1238 let v3 = MqttVersion::V3_1_1;
1239 let v5 = MqttVersion::V5_0;
1240 assert!(format!("{:?}", v3).contains("V3_1_1"));
1241 assert!(format!("{:?}", v5).contains("V5_0"));
1242 }
1243
1244 #[test]
1245 fn test_config_max_connections() {
1246 let config = MqttConfig {
1247 max_connections: 500,
1248 ..Default::default()
1249 };
1250 assert_eq!(config.max_connections, 500);
1251 }
1252
1253 #[test]
1254 fn test_config_max_packet_size() {
1255 let config = MqttConfig {
1256 max_packet_size: 2048,
1257 ..Default::default()
1258 };
1259 assert_eq!(config.max_packet_size, 2048);
1260 }
1261
1262 #[test]
1263 fn test_config_keep_alive_secs() {
1264 let config = MqttConfig {
1265 keep_alive_secs: 120,
1266 ..Default::default()
1267 };
1268 assert_eq!(config.keep_alive_secs, 120);
1269 }
1270
1271 #[test]
1272 fn test_config_clone() {
1273 let config1 = MqttConfig {
1274 port: 9999,
1275 host: "localhost".to_string(),
1276 max_connections: 200,
1277 max_packet_size: 4096,
1278 keep_alive_secs: 90,
1279 version: MqttVersion::V3_1_1,
1280 ..Default::default()
1281 };
1282 let config2 = config1.clone();
1283 assert_eq!(config1.port, config2.port);
1284 assert_eq!(config1.host, config2.host);
1285 assert_eq!(config1.max_connections, config2.max_connections);
1286 }
1287
1288 #[test]
1289 fn test_config_debug_format() {
1290 let config = MqttConfig::default();
1291 let debug = format!("{:?}", config);
1292 assert!(debug.contains("MqttConfig"));
1293 assert!(debug.contains("1883"));
1294 }
1295
1296 #[tokio::test]
1297 async fn test_mqtt_server_creation() {
1298 let config = MqttConfig::default();
1299 let metrics = Arc::new(MqttMetrics::new());
1300 let server = MqttServer::new(&config, metrics.clone());
1301
1302 assert_eq!(server.session_manager().connection_count().await, 0);
1303 }
1304
1305 #[tokio::test]
1306 async fn test_session_manager_integration() {
1307 let config = MqttConfig {
1308 max_connections: 10,
1309 ..Default::default()
1310 };
1311 let metrics = Arc::new(MqttMetrics::new());
1312 let server = MqttServer::new(&config, metrics);
1313
1314 let (tx, _rx) = mpsc::channel(10);
1315 let result =
1316 server.session_manager().connect("test-client".to_string(), true, 60, tx).await;
1317
1318 assert!(result.is_ok());
1319 assert_eq!(server.session_manager().connection_count().await, 1);
1320
1321 let clients = server.session_manager().get_connected_clients().await;
1322 assert!(clients.contains(&"test-client".to_string()));
1323 }
1324}