mockforge_mqtt/
server.rs

1//! MQTT server implementation using tokio TCP listener
2//!
3//! This module provides a complete MQTT 3.1.1 broker implementation that:
4//! - Accepts TCP connections from MQTT clients
5//! - Parses and handles all MQTT control packets
6//! - Manages client sessions and subscriptions
7//! - Routes messages between publishers and subscribers
8//! - Supports QoS 0, 1, and 2 message delivery
9
10use std::io;
11use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14use std::time::Duration;
15
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
17use tokio::net::{TcpListener, TcpStream};
18use tokio::sync::mpsc;
19use tokio_rustls::server::TlsStream;
20use tracing::{debug, error, info, warn};
21
22use crate::broker::MqttConfig;
23use crate::metrics::MqttMetrics;
24use crate::protocol::{
25    ConnackCode, Packet, PacketDecoder, PacketEncoder, ProtocolError, PublishPacket, QoS,
26};
27use crate::session::{
28    build_connack, build_puback, build_pubcomp, build_pubrec, build_pubrel, build_suback,
29    build_unsuback, SessionManager,
30};
31use crate::tls::{create_tls_acceptor_with_client_auth, TlsError};
32
33/// Maximum buffer size for reading packets
34const READ_BUFFER_SIZE: usize = 64 * 1024; // 64KB
35/// Channel capacity for outgoing packets per client
36const CLIENT_CHANNEL_CAPACITY: usize = 256;
37/// Session cleanup interval in seconds
38const CLEANUP_INTERVAL_SECS: u64 = 30;
39
40/// Stream type that can be either plain TCP or TLS-encrypted
41pub enum MqttStream {
42    /// Plain TCP stream
43    Plain(TcpStream),
44    /// TLS-encrypted stream
45    Tls(TlsStream<TcpStream>),
46}
47
48impl AsyncRead for MqttStream {
49    fn poll_read(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52        buf: &mut ReadBuf<'_>,
53    ) -> Poll<io::Result<()>> {
54        match self.get_mut() {
55            MqttStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
56            MqttStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
57        }
58    }
59}
60
61impl AsyncWrite for MqttStream {
62    fn poll_write(
63        self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65        buf: &[u8],
66    ) -> Poll<io::Result<usize>> {
67        match self.get_mut() {
68            MqttStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
69            MqttStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
70        }
71    }
72
73    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
74        match self.get_mut() {
75            MqttStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
76            MqttStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
77        }
78    }
79
80    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
81        match self.get_mut() {
82            MqttStream::Plain(stream) => Pin::new(stream).poll_shutdown(cx),
83            MqttStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx),
84        }
85    }
86}
87
88/// MQTT Server handle
89pub struct MqttServer {
90    session_manager: Arc<SessionManager>,
91    metrics: Arc<MqttMetrics>,
92}
93
94impl MqttServer {
95    /// Create a new MQTT server
96    pub fn new(config: &MqttConfig, metrics: Arc<MqttMetrics>) -> Self {
97        Self {
98            session_manager: Arc::new(SessionManager::new(
99                config.max_connections,
100                Some(metrics.clone()),
101            )),
102            metrics,
103        }
104    }
105
106    /// Get the session manager
107    pub fn session_manager(&self) -> Arc<SessionManager> {
108        self.session_manager.clone()
109    }
110
111    /// Get the metrics
112    pub fn metrics(&self) -> Arc<MqttMetrics> {
113        self.metrics.clone()
114    }
115}
116
117/// Start an MQTT server using tokio TCP listener
118///
119/// This is the main entry point for the MQTT broker. It binds to the
120/// configured address and handles client connections with full MQTT
121/// protocol support.
122pub async fn start_mqtt_server(
123    config: MqttConfig,
124) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
125    let metrics = Arc::new(MqttMetrics::new());
126    start_mqtt_server_with_metrics(config, metrics).await
127}
128
129/// Start an MQTT server with custom metrics
130pub async fn start_mqtt_server_with_metrics(
131    config: MqttConfig,
132    metrics: Arc<MqttMetrics>,
133) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
134    let addr = format!("{}:{}", config.host, config.port);
135
136    info!(
137        "Starting MQTT broker on {}:{} (MQTT {:?})",
138        config.host, config.port, config.version
139    );
140
141    let listener = TcpListener::bind(&addr).await?;
142    let session_manager =
143        Arc::new(SessionManager::new(config.max_connections, Some(metrics.clone())));
144
145    info!(
146        "MQTT broker listening on {}:{} (MQTT {:?})",
147        config.host, config.port, config.version
148    );
149
150    // Spawn session cleanup task
151    let cleanup_manager = session_manager.clone();
152    tokio::spawn(async move {
153        let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_INTERVAL_SECS));
154        loop {
155            interval.tick().await;
156            let expired = cleanup_manager.cleanup_expired_sessions().await;
157            if !expired.is_empty() {
158                debug!("Cleaned up {} expired sessions", expired.len());
159            }
160        }
161    });
162
163    // Accept connections in a loop
164    loop {
165        match listener.accept().await {
166            Ok((socket, addr)) => {
167                info!("New MQTT connection from {}", addr);
168
169                let session_manager = session_manager.clone();
170                let metrics = metrics.clone();
171                let max_packet_size = config.max_packet_size;
172
173                tokio::spawn(async move {
174                    if let Err(e) =
175                        handle_connection(socket, addr, session_manager, metrics, max_packet_size)
176                            .await
177                    {
178                        warn!("Connection error from {}: {}", addr, e);
179                    }
180                });
181            }
182            Err(e) => {
183                error!("Error accepting MQTT connection: {}", e);
184            }
185        }
186    }
187}
188
189/// Start an MQTT server with TLS support
190///
191/// This starts an MQTTS listener on the configured TLS port (default 8883).
192/// Requires TLS certificate and key to be configured.
193pub async fn start_mqtt_tls_server(config: MqttConfig) -> Result<(), TlsError> {
194    let metrics = Arc::new(MqttMetrics::new());
195    start_mqtt_tls_server_with_metrics(config, metrics).await
196}
197
198/// Start an MQTT server with TLS and custom metrics
199pub async fn start_mqtt_tls_server_with_metrics(
200    config: MqttConfig,
201    metrics: Arc<MqttMetrics>,
202) -> Result<(), TlsError> {
203    if !config.tls_enabled {
204        return Err(TlsError::ConfigError("TLS is not enabled in configuration".to_string()));
205    }
206
207    let tls_acceptor = create_tls_acceptor_with_client_auth(&config)?;
208    let addr = format!("{}:{}", config.host, config.tls_port);
209
210    let listener = TcpListener::bind(&addr)
211        .await
212        .map_err(|e| TlsError::ConfigError(format!("Failed to bind to {}: {}", addr, e)))?;
213
214    info!(
215        "Starting MQTTS broker with TLS on {}:{} (MQTT {:?})",
216        config.host, config.tls_port, config.version
217    );
218
219    let session_manager =
220        Arc::new(SessionManager::new(config.max_connections, Some(metrics.clone())));
221
222    // Spawn session cleanup task
223    let cleanup_manager = session_manager.clone();
224    tokio::spawn(async move {
225        let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_INTERVAL_SECS));
226        loop {
227            interval.tick().await;
228            let expired = cleanup_manager.cleanup_expired_sessions().await;
229            if !expired.is_empty() {
230                debug!("Cleaned up {} expired sessions", expired.len());
231            }
232        }
233    });
234
235    // Accept TLS connections
236    loop {
237        match listener.accept().await {
238            Ok((socket, addr)) => {
239                info!("New MQTTS connection from {}", addr);
240
241                let tls_acceptor = tls_acceptor.clone();
242                let session_manager = session_manager.clone();
243                let metrics = metrics.clone();
244                let max_packet_size = config.max_packet_size;
245
246                tokio::spawn(async move {
247                    match tls_acceptor.accept(socket).await {
248                        Ok(tls_stream) => {
249                            if let Err(e) = handle_tls_connection(
250                                tls_stream,
251                                addr,
252                                session_manager,
253                                metrics,
254                                max_packet_size,
255                            )
256                            .await
257                            {
258                                warn!("TLS connection error from {}: {}", addr, e);
259                            }
260                        }
261                        Err(e) => {
262                            warn!("TLS handshake failed from {}: {}", addr, e);
263                        }
264                    }
265                });
266            }
267            Err(e) => {
268                error!("Error accepting MQTTS connection: {}", e);
269            }
270        }
271    }
272}
273
274/// Start both plain and TLS listeners concurrently
275pub async fn start_mqtt_dual_server(
276    config: MqttConfig,
277) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
278    let metrics = Arc::new(MqttMetrics::new());
279    start_mqtt_dual_server_with_metrics(config, metrics).await
280}
281
282/// Start both plain and TLS listeners with custom metrics
283pub async fn start_mqtt_dual_server_with_metrics(
284    config: MqttConfig,
285    metrics: Arc<MqttMetrics>,
286) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
287    let session_manager =
288        Arc::new(SessionManager::new(config.max_connections, Some(metrics.clone())));
289
290    // Spawn session cleanup task
291    let cleanup_manager = session_manager.clone();
292    tokio::spawn(async move {
293        let mut interval = tokio::time::interval(Duration::from_secs(CLEANUP_INTERVAL_SECS));
294        loop {
295            interval.tick().await;
296            let expired = cleanup_manager.cleanup_expired_sessions().await;
297            if !expired.is_empty() {
298                debug!("Cleaned up {} expired sessions", expired.len());
299            }
300        }
301    });
302
303    // Start plain TCP listener
304    let plain_addr = format!("{}:{}", config.host, config.port);
305    let plain_listener = TcpListener::bind(&plain_addr).await?;
306    info!("Starting MQTT broker on {} (MQTT {:?})", plain_addr, config.version);
307
308    let plain_session_manager = session_manager.clone();
309    let plain_metrics = metrics.clone();
310    let plain_max_packet_size = config.max_packet_size;
311
312    tokio::spawn(async move {
313        loop {
314            match plain_listener.accept().await {
315                Ok((socket, addr)) => {
316                    info!("New MQTT connection from {}", addr);
317
318                    let session_manager = plain_session_manager.clone();
319                    let metrics = plain_metrics.clone();
320
321                    tokio::spawn(async move {
322                        if let Err(e) = handle_connection(
323                            socket,
324                            addr,
325                            session_manager,
326                            metrics,
327                            plain_max_packet_size,
328                        )
329                        .await
330                        {
331                            warn!("Connection error from {}: {}", addr, e);
332                        }
333                    });
334                }
335                Err(e) => {
336                    error!("Error accepting MQTT connection: {}", e);
337                }
338            }
339        }
340    });
341
342    // Start TLS listener if enabled
343    if config.tls_enabled {
344        let tls_acceptor = create_tls_acceptor_with_client_auth(&config)?;
345        let tls_addr = format!("{}:{}", config.host, config.tls_port);
346        let tls_listener = TcpListener::bind(&tls_addr).await?;
347        info!("Starting MQTTS broker with TLS on {}", tls_addr);
348
349        let tls_session_manager = session_manager.clone();
350        let tls_metrics = metrics.clone();
351        let tls_max_packet_size = config.max_packet_size;
352
353        tokio::spawn(async move {
354            loop {
355                match tls_listener.accept().await {
356                    Ok((socket, addr)) => {
357                        info!("New MQTTS connection from {}", addr);
358
359                        let tls_acceptor = tls_acceptor.clone();
360                        let session_manager = tls_session_manager.clone();
361                        let metrics = tls_metrics.clone();
362
363                        tokio::spawn(async move {
364                            match tls_acceptor.accept(socket).await {
365                                Ok(tls_stream) => {
366                                    if let Err(e) = handle_tls_connection(
367                                        tls_stream,
368                                        addr,
369                                        session_manager,
370                                        metrics,
371                                        tls_max_packet_size,
372                                    )
373                                    .await
374                                    {
375                                        warn!("TLS connection error from {}: {}", addr, e);
376                                    }
377                                }
378                                Err(e) => {
379                                    warn!("TLS handshake failed from {}: {}", addr, e);
380                                }
381                            }
382                        });
383                    }
384                    Err(e) => {
385                        error!("Error accepting MQTTS connection: {}", e);
386                    }
387                }
388            }
389        });
390    }
391
392    // Keep the main task running
393    loop {
394        tokio::time::sleep(Duration::from_secs(3600)).await;
395    }
396}
397
398/// Handle a single TLS client connection
399async fn handle_tls_connection(
400    stream: TlsStream<TcpStream>,
401    addr: std::net::SocketAddr,
402    session_manager: Arc<SessionManager>,
403    metrics: Arc<MqttMetrics>,
404    max_packet_size: usize,
405) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
406    let (read_half, write_half) = tokio::io::split(stream);
407    let mut reader = tokio::io::BufReader::new(read_half);
408    let mut writer = write_half;
409
410    // Buffer for reading packets
411    let mut buffer = vec![0u8; READ_BUFFER_SIZE.min(max_packet_size)];
412    let mut buf_len = 0usize;
413
414    // Client state
415    let mut client_id: Option<String> = None;
416    let mut packet_rx: Option<mpsc::Receiver<Packet>> = None;
417
418    // Read first packet - must be CONNECT
419    let connect_timeout = Duration::from_secs(10);
420    let first_read = tokio::time::timeout(connect_timeout, reader.read(&mut buffer[buf_len..]))
421        .await
422        .map_err(|_| "Connection timeout waiting for CONNECT")?;
423
424    match first_read {
425        Ok(0) => {
426            debug!("TLS client {} closed connection before CONNECT", addr);
427            return Ok(());
428        }
429        Ok(n) => buf_len += n,
430        Err(e) => return Err(e.into()),
431    }
432
433    // Parse CONNECT packet
434    let (connect_packet, consumed) = match PacketDecoder::decode(&buffer[..buf_len])? {
435        Some((Packet::Connect(connect), consumed)) => (connect, consumed),
436        Some((_, _)) => {
437            warn!("First packet from TLS client {} was not CONNECT", addr);
438            let connack = build_connack(false, ConnackCode::NotAuthorized);
439            let bytes = PacketEncoder::encode(&connack)?;
440            writer.write_all(&bytes).await?;
441            return Err("Expected CONNECT packet".into());
442        }
443        None => {
444            return Err("Incomplete CONNECT packet".into());
445        }
446    };
447
448    // Shift buffer
449    buffer.copy_within(consumed..buf_len, 0);
450    buf_len -= consumed;
451
452    // Validate CONNECT packet
453    let cid = if connect_packet.client_id.is_empty() {
454        if connect_packet.clean_session {
455            format!("auto-tls-{}", uuid::Uuid::new_v4())
456        } else {
457            let connack = build_connack(false, ConnackCode::IdentifierRejected);
458            let bytes = PacketEncoder::encode(&connack)?;
459            writer.write_all(&bytes).await?;
460            return Err("Empty client ID with clean_session=false".into());
461        }
462    } else {
463        connect_packet.client_id.clone()
464    };
465
466    info!(
467        "TLS CONNECT from {} (client_id={}, clean_session={})",
468        addr, cid, connect_packet.clean_session
469    );
470
471    // Create channel for sending packets to this client
472    let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY);
473    packet_rx = Some(rx);
474
475    // Register with session manager
476    let connect_result = session_manager
477        .connect(cid.clone(), connect_packet.clean_session, connect_packet.keep_alive, tx)
478        .await;
479
480    let session_present = match connect_result {
481        Ok((session_present, code)) => {
482            let connack = build_connack(session_present, code);
483            let bytes = PacketEncoder::encode(&connack)?;
484            writer.write_all(&bytes).await?;
485            session_present
486        }
487        Err(code) => {
488            let connack = build_connack(false, code);
489            let bytes = PacketEncoder::encode(&connack)?;
490            writer.write_all(&bytes).await?;
491            return Err(format!("Connection rejected: {:?}", code).into());
492        }
493    };
494
495    client_id = Some(cid.clone());
496
497    info!("TLS client {} connected (session_present={})", cid, session_present);
498
499    // Send retained messages for existing subscriptions (if session was restored)
500    if session_present {
501        let subscriptions = session_manager.get_client_subscriptions(&cid).await;
502
503        for (filter, _sub_qos) in subscriptions {
504            let retained = session_manager.get_retained_messages(&filter).await;
505            for (topic, mut publish) in retained {
506                if publish.qos != QoS::AtMostOnce {
507                    if let Some(id) = session_manager.assign_packet_id(&cid).await {
508                        publish.packet_id = Some(id);
509                    }
510                }
511
512                let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
513                writer.write_all(&bytes).await?;
514
515                debug!(
516                    "Delivered retained message for topic {} to restored TLS session {}",
517                    topic, cid
518                );
519            }
520        }
521    }
522
523    // Main connection loop
524    let mut rx = packet_rx.take().unwrap();
525    let result = handle_tls_client_loop(
526        &mut reader,
527        &mut writer,
528        &mut rx,
529        &cid,
530        &session_manager,
531        &metrics,
532        &mut buffer,
533        &mut buf_len,
534        max_packet_size,
535    )
536    .await;
537
538    // Clean up on disconnect
539    session_manager.disconnect(&cid).await;
540    info!("TLS client {} disconnected", cid);
541
542    result
543}
544
545/// Handle the main TLS client message loop
546async fn handle_tls_client_loop<R, W>(
547    reader: &mut tokio::io::BufReader<R>,
548    writer: &mut W,
549    packet_rx: &mut mpsc::Receiver<Packet>,
550    client_id: &str,
551    session_manager: &Arc<SessionManager>,
552    metrics: &Arc<MqttMetrics>,
553    buffer: &mut Vec<u8>,
554    buf_len: &mut usize,
555    max_packet_size: usize,
556) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
557where
558    R: AsyncRead + Unpin,
559    W: AsyncWrite + Unpin,
560{
561    loop {
562        tokio::select! {
563            // Handle incoming packets from client
564            read_result = reader.read(&mut buffer[*buf_len..]) => {
565                match read_result {
566                    Ok(0) => {
567                        debug!("TLS client {} closed connection", client_id);
568                        return Ok(());
569                    }
570                    Ok(n) => {
571                        *buf_len += n;
572
573                        // Check for oversized packet
574                        if *buf_len > max_packet_size {
575                            warn!("TLS client {} sent oversized packet", client_id);
576                            metrics.record_error("oversized_packet");
577                            return Err("Packet too large".into());
578                        }
579
580                        // Parse and handle packets
581                        while let Some((packet, consumed)) = PacketDecoder::decode(&buffer[..*buf_len])? {
582                            // Shift buffer
583                            buffer.copy_within(consumed..*buf_len, 0);
584                            *buf_len -= consumed;
585
586                            // Handle the packet
587                            match handle_tls_packet(
588                                client_id,
589                                packet,
590                                writer,
591                                session_manager,
592                                metrics,
593                            ).await {
594                                Ok(true) => continue,
595                                Ok(false) => return Ok(()), // Disconnect requested
596                                Err(e) => {
597                                    warn!("Error handling packet from TLS client {}: {}", client_id, e);
598                                    metrics.record_error(&e.to_string());
599                                }
600                            }
601                        }
602                    }
603                    Err(e) => {
604                        return Err(e.into());
605                    }
606                }
607            }
608
609            // Handle outgoing packets to client
610            packet = packet_rx.recv() => {
611                match packet {
612                    Some(Packet::Disconnect) => {
613                        debug!("Sending disconnect to TLS client {}", client_id);
614                        return Ok(());
615                    }
616                    Some(mut packet) => {
617                        // Assign packet ID for QoS > 0 publish packets
618                        if let Packet::Publish(ref mut publish) = packet {
619                            if publish.qos != QoS::AtMostOnce && publish.packet_id.is_none() {
620                                if let Some(id) = session_manager.assign_packet_id(client_id).await {
621                                    publish.packet_id = Some(id);
622                                }
623                            }
624                        }
625
626                        let bytes = PacketEncoder::encode(&packet)?;
627                        writer.write_all(&bytes).await?;
628                    }
629                    None => {
630                        debug!("Channel closed for TLS client {}", client_id);
631                        return Ok(());
632                    }
633                }
634            }
635        }
636    }
637}
638
639/// Handle a single MQTT packet from TLS client
640async fn handle_tls_packet<W>(
641    client_id: &str,
642    packet: Packet,
643    writer: &mut W,
644    session_manager: &Arc<SessionManager>,
645    metrics: &Arc<MqttMetrics>,
646) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>
647where
648    W: AsyncWrite + Unpin,
649{
650    match packet {
651        Packet::Connect(_) => {
652            warn!("TLS client {} sent second CONNECT packet", client_id);
653            return Ok(false);
654        }
655
656        Packet::Publish(publish) => {
657            debug!(
658                "PUBLISH from TLS client {} to topic {} (QoS {:?})",
659                client_id, publish.topic, publish.qos
660            );
661
662            match publish.qos {
663                QoS::AtMostOnce => {}
664                QoS::AtLeastOnce => {
665                    if let Some(packet_id) = publish.packet_id {
666                        let puback = build_puback(packet_id);
667                        let bytes = PacketEncoder::encode(&puback)?;
668                        writer.write_all(&bytes).await?;
669                    }
670                }
671                QoS::ExactlyOnce => {
672                    if let Some(packet_id) = publish.packet_id {
673                        session_manager.start_qos2_inbound(client_id, packet_id).await;
674                        let pubrec = build_pubrec(packet_id);
675                        let bytes = PacketEncoder::encode(&pubrec)?;
676                        writer.write_all(&bytes).await?;
677                        session_manager.mark_pubrec_sent(client_id, packet_id).await;
678                    }
679                }
680            }
681
682            session_manager.publish(client_id, &publish).await;
683        }
684
685        Packet::Puback(puback) => {
686            debug!("PUBACK from TLS client {} for packet {}", client_id, puback.packet_id);
687            session_manager.handle_puback(client_id, puback.packet_id).await;
688        }
689
690        Packet::Pubrec(pubrec) => {
691            debug!("PUBREC from TLS client {} for packet {}", client_id, pubrec.packet_id);
692            if session_manager.handle_pubrec(client_id, pubrec.packet_id).await {
693                let pubrel = build_pubrel(pubrec.packet_id);
694                let bytes = PacketEncoder::encode(&pubrel)?;
695                writer.write_all(&bytes).await?;
696            }
697        }
698
699        Packet::Pubrel(pubrel) => {
700            debug!("PUBREL from TLS client {} for packet {}", client_id, pubrel.packet_id);
701            if session_manager.handle_pubrel(client_id, pubrel.packet_id).await {
702                let pubcomp = build_pubcomp(pubrel.packet_id);
703                let bytes = PacketEncoder::encode(&pubcomp)?;
704                writer.write_all(&bytes).await?;
705                session_manager.complete_qos2_inbound(client_id, pubrel.packet_id).await;
706            }
707        }
708
709        Packet::Pubcomp(pubcomp) => {
710            debug!("PUBCOMP from TLS client {} for packet {}", client_id, pubcomp.packet_id);
711            session_manager.handle_pubcomp(client_id, pubcomp.packet_id).await;
712        }
713
714        Packet::Subscribe(subscribe) => {
715            debug!(
716                "SUBSCRIBE from TLS client {} for {} topics",
717                client_id,
718                subscribe.subscriptions.len()
719            );
720
721            if let Some(return_codes) =
722                session_manager.subscribe(client_id, subscribe.subscriptions.clone()).await
723            {
724                let suback = build_suback(subscribe.packet_id, return_codes);
725                let bytes = PacketEncoder::encode(&suback)?;
726                writer.write_all(&bytes).await?;
727
728                for (filter, _) in &subscribe.subscriptions {
729                    let retained = session_manager.get_retained_messages(filter).await;
730                    for (topic, mut publish) in retained {
731                        if publish.qos != QoS::AtMostOnce {
732                            if let Some(id) = session_manager.assign_packet_id(client_id).await {
733                                publish.packet_id = Some(id);
734                            }
735                        }
736                        let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
737                        writer.write_all(&bytes).await?;
738                        debug!(
739                            "Sent retained message for topic {} to TLS client {}",
740                            topic, client_id
741                        );
742                    }
743                }
744            }
745        }
746
747        Packet::Unsubscribe(unsubscribe) => {
748            debug!(
749                "UNSUBSCRIBE from TLS client {} for {} topics",
750                client_id,
751                unsubscribe.topics.len()
752            );
753
754            session_manager.unsubscribe(client_id, unsubscribe.topics).await;
755
756            let unsuback = build_unsuback(unsubscribe.packet_id);
757            let bytes = PacketEncoder::encode(&unsuback)?;
758            writer.write_all(&bytes).await?;
759        }
760
761        Packet::Pingreq => {
762            debug!("PINGREQ from TLS client {}", client_id);
763            session_manager.touch(client_id).await;
764
765            let pingresp = Packet::Pingresp;
766            let bytes = PacketEncoder::encode(&pingresp)?;
767            writer.write_all(&bytes).await?;
768        }
769
770        Packet::Disconnect => {
771            info!("DISCONNECT from TLS client {}", client_id);
772            return Ok(false);
773        }
774
775        Packet::Connack(_) | Packet::Suback(_) | Packet::Unsuback(_) | Packet::Pingresp => {
776            warn!("TLS client {} sent unexpected packet type: {:?}", client_id, packet);
777            metrics.record_error("unexpected_packet_type");
778        }
779    }
780
781    Ok(true)
782}
783
784/// Handle a single client connection
785async fn handle_connection(
786    socket: tokio::net::TcpStream,
787    addr: std::net::SocketAddr,
788    session_manager: Arc<SessionManager>,
789    metrics: Arc<MqttMetrics>,
790    max_packet_size: usize,
791) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
792    let (read_half, mut write_half) = socket.into_split();
793    let mut reader = tokio::io::BufReader::new(read_half);
794
795    // Buffer for reading packets
796    let mut buffer = vec![0u8; READ_BUFFER_SIZE.min(max_packet_size)];
797    let mut buf_len = 0usize;
798
799    // Client state
800    let mut client_id: Option<String> = None;
801    let mut packet_rx: Option<mpsc::Receiver<Packet>> = None;
802
803    // Read first packet - must be CONNECT
804    let connect_timeout = Duration::from_secs(10);
805    let first_read = tokio::time::timeout(connect_timeout, reader.read(&mut buffer[buf_len..]))
806        .await
807        .map_err(|_| "Connection timeout waiting for CONNECT")?;
808
809    match first_read {
810        Ok(0) => {
811            debug!("Client {} closed connection before CONNECT", addr);
812            return Ok(());
813        }
814        Ok(n) => buf_len += n,
815        Err(e) => return Err(e.into()),
816    }
817
818    // Parse CONNECT packet
819    let (connect_packet, consumed) = match PacketDecoder::decode(&buffer[..buf_len])? {
820        Some((Packet::Connect(connect), consumed)) => (connect, consumed),
821        Some((_, _)) => {
822            warn!("First packet from {} was not CONNECT", addr);
823            let connack = build_connack(false, ConnackCode::NotAuthorized);
824            let bytes = PacketEncoder::encode(&connack)?;
825            write_half.write_all(&bytes).await?;
826            return Err("Expected CONNECT packet".into());
827        }
828        None => {
829            return Err("Incomplete CONNECT packet".into());
830        }
831    };
832
833    // Shift buffer
834    buffer.copy_within(consumed..buf_len, 0);
835    buf_len -= consumed;
836
837    // Validate CONNECT packet
838    let cid = if connect_packet.client_id.is_empty() {
839        // Generate client ID for empty client ID with clean session
840        if connect_packet.clean_session {
841            format!("auto-{}", uuid::Uuid::new_v4())
842        } else {
843            let connack = build_connack(false, ConnackCode::IdentifierRejected);
844            let bytes = PacketEncoder::encode(&connack)?;
845            write_half.write_all(&bytes).await?;
846            return Err("Empty client ID with clean_session=false".into());
847        }
848    } else {
849        connect_packet.client_id.clone()
850    };
851
852    info!(
853        "CONNECT from {} (client_id={}, clean_session={})",
854        addr, cid, connect_packet.clean_session
855    );
856
857    // Create channel for sending packets to this client
858    let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY);
859    packet_rx = Some(rx);
860
861    // Register with session manager
862    let connect_result = session_manager
863        .connect(cid.clone(), connect_packet.clean_session, connect_packet.keep_alive, tx)
864        .await;
865
866    let session_present = match connect_result {
867        Ok((session_present, code)) => {
868            let connack = build_connack(session_present, code);
869            let bytes = PacketEncoder::encode(&connack)?;
870            write_half.write_all(&bytes).await?;
871            session_present
872        }
873        Err(code) => {
874            let connack = build_connack(false, code);
875            let bytes = PacketEncoder::encode(&connack)?;
876            write_half.write_all(&bytes).await?;
877            return Err(format!("Connection rejected: {:?}", code).into());
878        }
879    };
880
881    client_id = Some(cid.clone());
882
883    info!("Client {} connected (session_present={})", cid, session_present);
884
885    // Send retained messages for existing subscriptions (if session was restored)
886    if session_present {
887        // Get the client's restored subscriptions
888        let subscriptions = session_manager.get_client_subscriptions(&cid).await;
889
890        for (filter, _sub_qos) in subscriptions {
891            let retained = session_manager.get_retained_messages(&filter).await;
892            for (topic, mut publish) in retained {
893                // Assign packet ID if needed for QoS > 0
894                if publish.qos != QoS::AtMostOnce {
895                    if let Some(id) = session_manager.assign_packet_id(&cid).await {
896                        publish.packet_id = Some(id);
897                    }
898                }
899
900                let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
901                write_half.write_all(&bytes).await?;
902
903                debug!(
904                    "Delivered retained message for topic {} to restored session {}",
905                    topic, cid
906                );
907            }
908        }
909    }
910
911    // Main connection loop
912    let mut rx = packet_rx.take().unwrap();
913    let result = handle_client_loop(
914        &mut reader,
915        &mut write_half,
916        &mut rx,
917        &cid,
918        &session_manager,
919        &metrics,
920        &mut buffer,
921        &mut buf_len,
922        max_packet_size,
923    )
924    .await;
925
926    // Clean up on disconnect
927    session_manager.disconnect(&cid).await;
928    info!("Client {} disconnected", cid);
929
930    result
931}
932
933/// Handle the main client message loop
934async fn handle_client_loop(
935    reader: &mut tokio::io::BufReader<tokio::net::tcp::OwnedReadHalf>,
936    writer: &mut tokio::net::tcp::OwnedWriteHalf,
937    packet_rx: &mut mpsc::Receiver<Packet>,
938    client_id: &str,
939    session_manager: &Arc<SessionManager>,
940    metrics: &Arc<MqttMetrics>,
941    buffer: &mut Vec<u8>,
942    buf_len: &mut usize,
943    max_packet_size: usize,
944) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
945    loop {
946        tokio::select! {
947            // Handle incoming packets from client
948            read_result = reader.read(&mut buffer[*buf_len..]) => {
949                match read_result {
950                    Ok(0) => {
951                        debug!("Client {} closed connection", client_id);
952                        return Ok(());
953                    }
954                    Ok(n) => {
955                        *buf_len += n;
956
957                        // Check for oversized packet
958                        if *buf_len > max_packet_size {
959                            warn!("Client {} sent oversized packet", client_id);
960                            metrics.record_error("oversized_packet");
961                            return Err("Packet too large".into());
962                        }
963
964                        // Parse and handle packets
965                        while let Some((packet, consumed)) = PacketDecoder::decode(&buffer[..*buf_len])? {
966                            // Shift buffer
967                            buffer.copy_within(consumed..*buf_len, 0);
968                            *buf_len -= consumed;
969
970                            // Handle the packet
971                            match handle_packet(
972                                client_id,
973                                packet,
974                                writer,
975                                session_manager,
976                                metrics,
977                            ).await {
978                                Ok(true) => continue,
979                                Ok(false) => return Ok(()), // Disconnect requested
980                                Err(e) => {
981                                    warn!("Error handling packet from {}: {}", client_id, e);
982                                    metrics.record_error(&e.to_string());
983                                }
984                            }
985                        }
986                    }
987                    Err(e) => {
988                        return Err(e.into());
989                    }
990                }
991            }
992
993            // Handle outgoing packets to client
994            packet = packet_rx.recv() => {
995                match packet {
996                    Some(Packet::Disconnect) => {
997                        debug!("Sending disconnect to {}", client_id);
998                        return Ok(());
999                    }
1000                    Some(mut packet) => {
1001                        // Assign packet ID for QoS > 0 publish packets
1002                        if let Packet::Publish(ref mut publish) = packet {
1003                            if publish.qos != QoS::AtMostOnce && publish.packet_id.is_none() {
1004                                if let Some(id) = session_manager.assign_packet_id(client_id).await {
1005                                    publish.packet_id = Some(id);
1006                                }
1007                            }
1008                        }
1009
1010                        let bytes = PacketEncoder::encode(&packet)?;
1011                        writer.write_all(&bytes).await?;
1012                    }
1013                    None => {
1014                        debug!("Channel closed for {}", client_id);
1015                        return Ok(());
1016                    }
1017                }
1018            }
1019        }
1020    }
1021}
1022
1023/// Handle a single MQTT packet
1024async fn handle_packet(
1025    client_id: &str,
1026    packet: Packet,
1027    writer: &mut tokio::net::tcp::OwnedWriteHalf,
1028    session_manager: &Arc<SessionManager>,
1029    metrics: &Arc<MqttMetrics>,
1030) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
1031    match packet {
1032        Packet::Connect(_) => {
1033            // Second CONNECT is a protocol error
1034            warn!("Client {} sent second CONNECT packet", client_id);
1035            return Ok(false);
1036        }
1037
1038        Packet::Publish(publish) => {
1039            debug!("PUBLISH from {} to topic {} (QoS {:?})", client_id, publish.topic, publish.qos);
1040
1041            // Handle QoS acknowledgments
1042            match publish.qos {
1043                QoS::AtMostOnce => {
1044                    // No acknowledgment needed
1045                }
1046                QoS::AtLeastOnce => {
1047                    // Send PUBACK
1048                    if let Some(packet_id) = publish.packet_id {
1049                        let puback = build_puback(packet_id);
1050                        let bytes = PacketEncoder::encode(&puback)?;
1051                        writer.write_all(&bytes).await?;
1052                    }
1053                }
1054                QoS::ExactlyOnce => {
1055                    // Start QoS 2 flow
1056                    if let Some(packet_id) = publish.packet_id {
1057                        session_manager.start_qos2_inbound(client_id, packet_id).await;
1058                        let pubrec = build_pubrec(packet_id);
1059                        let bytes = PacketEncoder::encode(&pubrec)?;
1060                        writer.write_all(&bytes).await?;
1061                        session_manager.mark_pubrec_sent(client_id, packet_id).await;
1062                    }
1063                }
1064            }
1065
1066            // Route message to subscribers
1067            session_manager.publish(client_id, &publish).await;
1068        }
1069
1070        Packet::Puback(puback) => {
1071            debug!("PUBACK from {} for packet {}", client_id, puback.packet_id);
1072            session_manager.handle_puback(client_id, puback.packet_id).await;
1073        }
1074
1075        Packet::Pubrec(pubrec) => {
1076            debug!("PUBREC from {} for packet {}", client_id, pubrec.packet_id);
1077            if session_manager.handle_pubrec(client_id, pubrec.packet_id).await {
1078                let pubrel = build_pubrel(pubrec.packet_id);
1079                let bytes = PacketEncoder::encode(&pubrel)?;
1080                writer.write_all(&bytes).await?;
1081            }
1082        }
1083
1084        Packet::Pubrel(pubrel) => {
1085            debug!("PUBREL from {} for packet {}", client_id, pubrel.packet_id);
1086            if session_manager.handle_pubrel(client_id, pubrel.packet_id).await {
1087                let pubcomp = build_pubcomp(pubrel.packet_id);
1088                let bytes = PacketEncoder::encode(&pubcomp)?;
1089                writer.write_all(&bytes).await?;
1090                session_manager.complete_qos2_inbound(client_id, pubrel.packet_id).await;
1091            }
1092        }
1093
1094        Packet::Pubcomp(pubcomp) => {
1095            debug!("PUBCOMP from {} for packet {}", client_id, pubcomp.packet_id);
1096            session_manager.handle_pubcomp(client_id, pubcomp.packet_id).await;
1097        }
1098
1099        Packet::Subscribe(subscribe) => {
1100            debug!("SUBSCRIBE from {} for {} topics", client_id, subscribe.subscriptions.len());
1101
1102            if let Some(return_codes) =
1103                session_manager.subscribe(client_id, subscribe.subscriptions.clone()).await
1104            {
1105                // Send SUBACK
1106                let suback = build_suback(subscribe.packet_id, return_codes);
1107                let bytes = PacketEncoder::encode(&suback)?;
1108                writer.write_all(&bytes).await?;
1109
1110                // Send retained messages for new subscriptions
1111                for (filter, _) in &subscribe.subscriptions {
1112                    let retained = session_manager.get_retained_messages(filter).await;
1113                    for (topic, mut publish) in retained {
1114                        // Assign packet ID if needed
1115                        if publish.qos != QoS::AtMostOnce {
1116                            if let Some(id) = session_manager.assign_packet_id(client_id).await {
1117                                publish.packet_id = Some(id);
1118                            }
1119                        }
1120                        let bytes = PacketEncoder::encode(&Packet::Publish(publish))?;
1121                        writer.write_all(&bytes).await?;
1122                        debug!("Sent retained message for topic {} to {}", topic, client_id);
1123                    }
1124                }
1125            }
1126        }
1127
1128        Packet::Unsubscribe(unsubscribe) => {
1129            debug!("UNSUBSCRIBE from {} for {} topics", client_id, unsubscribe.topics.len());
1130
1131            session_manager.unsubscribe(client_id, unsubscribe.topics).await;
1132
1133            // Send UNSUBACK
1134            let unsuback = build_unsuback(unsubscribe.packet_id);
1135            let bytes = PacketEncoder::encode(&unsuback)?;
1136            writer.write_all(&bytes).await?;
1137        }
1138
1139        Packet::Pingreq => {
1140            debug!("PINGREQ from {}", client_id);
1141            session_manager.touch(client_id).await;
1142
1143            // Send PINGRESP
1144            let pingresp = Packet::Pingresp;
1145            let bytes = PacketEncoder::encode(&pingresp)?;
1146            writer.write_all(&bytes).await?;
1147        }
1148
1149        Packet::Disconnect => {
1150            info!("DISCONNECT from {}", client_id);
1151            return Ok(false);
1152        }
1153
1154        // These are server-to-client packets, shouldn't receive from client
1155        Packet::Connack(_) | Packet::Suback(_) | Packet::Unsuback(_) | Packet::Pingresp => {
1156            warn!("Client {} sent unexpected packet type: {:?}", client_id, packet);
1157            metrics.record_error("unexpected_packet_type");
1158        }
1159    }
1160
1161    Ok(true)
1162}
1163
1164#[cfg(test)]
1165mod tests {
1166    use super::*;
1167    use crate::broker::MqttVersion;
1168
1169    #[test]
1170    fn test_mqtt_config_address_formatting() {
1171        let config = MqttConfig {
1172            host: "127.0.0.1".to_string(),
1173            port: 1883,
1174            ..Default::default()
1175        };
1176        let addr = format!("{}:{}", config.host, config.port);
1177        assert_eq!(addr, "127.0.0.1:1883");
1178    }
1179
1180    #[test]
1181    fn test_mqtt_config_default_host_port() {
1182        let config = MqttConfig::default();
1183        let addr = format!("{}:{}", config.host, config.port);
1184        assert_eq!(addr, "0.0.0.0:1883");
1185    }
1186
1187    #[test]
1188    fn test_mqtt_config_custom_port() {
1189        let config = MqttConfig {
1190            port: 8883,
1191            ..Default::default()
1192        };
1193        assert_eq!(config.port, 8883);
1194    }
1195
1196    #[test]
1197    fn test_mqtt_config_version_v3() {
1198        let config = MqttConfig {
1199            version: MqttVersion::V3_1_1,
1200            ..Default::default()
1201        };
1202        assert!(matches!(config.version, MqttVersion::V3_1_1));
1203    }
1204
1205    #[test]
1206    fn test_mqtt_config_version_v5() {
1207        let config = MqttConfig {
1208            version: MqttVersion::V5_0,
1209            ..Default::default()
1210        };
1211        assert!(matches!(config.version, MqttVersion::V5_0));
1212    }
1213
1214    #[tokio::test]
1215    async fn test_tcp_listener_bind_localhost() {
1216        let config = MqttConfig {
1217            host: "127.0.0.1".to_string(),
1218            port: 0, // Use port 0 to get a random available port
1219            ..Default::default()
1220        };
1221        let addr = format!("{}:{}", config.host, config.port);
1222
1223        // Test that we can bind to the address
1224        let listener = TcpListener::bind(&addr).await;
1225        assert!(listener.is_ok());
1226    }
1227
1228    #[tokio::test]
1229    async fn test_tcp_listener_local_addr() {
1230        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1231        let addr = listener.local_addr().unwrap();
1232        assert_eq!(addr.ip().to_string(), "127.0.0.1");
1233        assert!(addr.port() > 0);
1234    }
1235
1236    #[test]
1237    fn test_mqtt_version_debug_format() {
1238        let v3 = MqttVersion::V3_1_1;
1239        let v5 = MqttVersion::V5_0;
1240        assert!(format!("{:?}", v3).contains("V3_1_1"));
1241        assert!(format!("{:?}", v5).contains("V5_0"));
1242    }
1243
1244    #[test]
1245    fn test_config_max_connections() {
1246        let config = MqttConfig {
1247            max_connections: 500,
1248            ..Default::default()
1249        };
1250        assert_eq!(config.max_connections, 500);
1251    }
1252
1253    #[test]
1254    fn test_config_max_packet_size() {
1255        let config = MqttConfig {
1256            max_packet_size: 2048,
1257            ..Default::default()
1258        };
1259        assert_eq!(config.max_packet_size, 2048);
1260    }
1261
1262    #[test]
1263    fn test_config_keep_alive_secs() {
1264        let config = MqttConfig {
1265            keep_alive_secs: 120,
1266            ..Default::default()
1267        };
1268        assert_eq!(config.keep_alive_secs, 120);
1269    }
1270
1271    #[test]
1272    fn test_config_clone() {
1273        let config1 = MqttConfig {
1274            port: 9999,
1275            host: "localhost".to_string(),
1276            max_connections: 200,
1277            max_packet_size: 4096,
1278            keep_alive_secs: 90,
1279            version: MqttVersion::V3_1_1,
1280            ..Default::default()
1281        };
1282        let config2 = config1.clone();
1283        assert_eq!(config1.port, config2.port);
1284        assert_eq!(config1.host, config2.host);
1285        assert_eq!(config1.max_connections, config2.max_connections);
1286    }
1287
1288    #[test]
1289    fn test_config_debug_format() {
1290        let config = MqttConfig::default();
1291        let debug = format!("{:?}", config);
1292        assert!(debug.contains("MqttConfig"));
1293        assert!(debug.contains("1883"));
1294    }
1295
1296    #[tokio::test]
1297    async fn test_mqtt_server_creation() {
1298        let config = MqttConfig::default();
1299        let metrics = Arc::new(MqttMetrics::new());
1300        let server = MqttServer::new(&config, metrics.clone());
1301
1302        assert_eq!(server.session_manager().connection_count().await, 0);
1303    }
1304
1305    #[tokio::test]
1306    async fn test_session_manager_integration() {
1307        let config = MqttConfig {
1308            max_connections: 10,
1309            ..Default::default()
1310        };
1311        let metrics = Arc::new(MqttMetrics::new());
1312        let server = MqttServer::new(&config, metrics);
1313
1314        let (tx, _rx) = mpsc::channel(10);
1315        let result =
1316            server.session_manager().connect("test-client".to_string(), true, 60, tx).await;
1317
1318        assert!(result.is_ok());
1319        assert_eq!(server.session_manager().connection_count().await, 1);
1320
1321        let clients = server.session_manager().get_connected_clients().await;
1322        assert!(clients.contains(&"test-client".to_string()));
1323    }
1324}