Skip to main content

mockforge_mqtt/
session.rs

1//! MQTT Session Management
2//!
3//! This module handles client session tracking, subscription management,
4//! and QoS message delivery state for the MQTT broker.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use tokio::io::AsyncWriteExt;
11use tokio::net::tcp::OwnedWriteHalf;
12use tokio::sync::{mpsc, RwLock};
13use tracing::{debug, info, warn};
14
15use crate::metrics::MqttMetrics;
16use crate::protocol::{
17    ConnackCode, ConnackPacket, Packet, PacketEncoder, PubackPacket, PubcompPacket, PublishPacket,
18    PubrecPacket, PubrelPacket, QoS, SubackPacket, SubackReturnCode, UnsubackPacket,
19};
20use crate::topics::TopicTree;
21
22/// Message to be delivered to a client
23#[derive(Debug, Clone)]
24pub struct PendingMessage {
25    pub packet_id: u16,
26    pub topic: String,
27    pub payload: Vec<u8>,
28    pub qos: QoS,
29    pub retain: bool,
30    pub timestamp: u64,
31    pub retry_count: u8,
32}
33
34/// QoS 2 message state tracking
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum Qos2State {
37    /// PUBLISH received, waiting to send PUBREC
38    PendingPubrec,
39    /// PUBREC sent, waiting for PUBREL
40    WaitingPubrel,
41    /// PUBREL received, waiting to send PUBCOMP
42    PendingPubcomp,
43}
44
45/// State of a QoS 2 outbound message
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum Qos2OutboundState {
48    /// PUBLISH sent, waiting for PUBREC
49    WaitingPubrec,
50    /// PUBREC received, PUBREL sent, waiting for PUBCOMP
51    WaitingPubcomp,
52}
53
54/// Client session state
55#[derive(Debug)]
56pub struct ClientSession {
57    /// Client identifier
58    pub client_id: String,
59    /// Clean session flag from CONNECT
60    pub clean_session: bool,
61    /// Keep-alive interval in seconds
62    pub keep_alive: u16,
63    /// Topic subscriptions with QoS level
64    pub subscriptions: HashMap<String, QoS>,
65    /// Outbound messages pending acknowledgment (QoS 1)
66    pub pending_qos1_out: HashMap<u16, PendingMessage>,
67    /// Outbound QoS 2 message states
68    pub pending_qos2_out: HashMap<u16, Qos2OutboundState>,
69    /// Inbound QoS 2 message states (for duplicate detection)
70    pub pending_qos2_in: HashMap<u16, Qos2State>,
71    /// Last activity timestamp
72    pub last_activity: u64,
73    /// Connection timestamp
74    pub connected_at: u64,
75    /// Next packet ID for this session
76    next_packet_id: u16,
77    /// Username if authenticated
78    pub username: Option<String>,
79}
80
81impl ClientSession {
82    /// Create a new client session
83    pub fn new(client_id: String, clean_session: bool, keep_alive: u16) -> Self {
84        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
85
86        Self {
87            client_id,
88            clean_session,
89            keep_alive,
90            subscriptions: HashMap::new(),
91            pending_qos1_out: HashMap::new(),
92            pending_qos2_out: HashMap::new(),
93            pending_qos2_in: HashMap::new(),
94            last_activity: now,
95            connected_at: now,
96            next_packet_id: 1,
97            username: None,
98        }
99    }
100
101    /// Generate next packet ID for this session
102    pub fn next_packet_id(&mut self) -> u16 {
103        let id = self.next_packet_id;
104        self.next_packet_id = self.next_packet_id.wrapping_add(1);
105        if self.next_packet_id == 0 {
106            self.next_packet_id = 1; // Skip 0
107        }
108        id
109    }
110
111    /// Update last activity timestamp
112    pub fn touch(&mut self) {
113        self.last_activity =
114            SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
115    }
116
117    /// Check if session has timed out
118    pub fn is_expired(&self) -> bool {
119        if self.keep_alive == 0 {
120            return false; // No timeout
121        }
122
123        let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
124
125        // MQTT spec says server should wait 1.5x keep_alive before disconnecting
126        let timeout = (self.keep_alive as u64) * 3 / 2;
127        now - self.last_activity > timeout
128    }
129
130    /// Add a subscription
131    pub fn subscribe(&mut self, topic_filter: String, qos: QoS) {
132        self.subscriptions.insert(topic_filter, qos);
133    }
134
135    /// Remove a subscription
136    pub fn unsubscribe(&mut self, topic_filter: &str) -> bool {
137        self.subscriptions.remove(topic_filter).is_some()
138    }
139
140    /// Queue a message for QoS 1 delivery
141    pub fn queue_qos1_message(&mut self, packet_id: u16, message: PendingMessage) {
142        self.pending_qos1_out.insert(packet_id, message);
143    }
144
145    /// Handle PUBACK for QoS 1
146    pub fn handle_puback(&mut self, packet_id: u16) -> Option<PendingMessage> {
147        self.pending_qos1_out.remove(&packet_id)
148    }
149
150    /// Start QoS 2 outbound flow (PUBLISH sent)
151    pub fn start_qos2_outbound(&mut self, packet_id: u16) {
152        self.pending_qos2_out.insert(packet_id, Qos2OutboundState::WaitingPubrec);
153    }
154
155    /// Handle PUBREC for QoS 2 outbound
156    pub fn handle_pubrec(&mut self, packet_id: u16) -> bool {
157        if let Some(state) = self.pending_qos2_out.get_mut(&packet_id) {
158            if *state == Qos2OutboundState::WaitingPubrec {
159                *state = Qos2OutboundState::WaitingPubcomp;
160                return true;
161            }
162        }
163        false
164    }
165
166    /// Handle PUBCOMP for QoS 2 outbound (completes flow)
167    pub fn handle_pubcomp(&mut self, packet_id: u16) -> bool {
168        if let Some(state) = self.pending_qos2_out.get(&packet_id) {
169            if *state == Qos2OutboundState::WaitingPubcomp {
170                self.pending_qos2_out.remove(&packet_id);
171                return true;
172            }
173        }
174        false
175    }
176
177    /// Start QoS 2 inbound flow (PUBLISH received)
178    pub fn start_qos2_inbound(&mut self, packet_id: u16) {
179        self.pending_qos2_in.insert(packet_id, Qos2State::PendingPubrec);
180    }
181
182    /// Handle PUBREL for QoS 2 inbound
183    pub fn handle_pubrel(&mut self, packet_id: u16) -> bool {
184        if let Some(state) = self.pending_qos2_in.get_mut(&packet_id) {
185            if *state == Qos2State::WaitingPubrel {
186                *state = Qos2State::PendingPubcomp;
187                return true;
188            }
189        }
190        false
191    }
192
193    /// Complete QoS 2 inbound flow (PUBCOMP sent)
194    pub fn complete_qos2_inbound(&mut self, packet_id: u16) {
195        self.pending_qos2_in.remove(&packet_id);
196    }
197
198    /// Mark PUBREC as sent for QoS 2 inbound
199    pub fn mark_pubrec_sent(&mut self, packet_id: u16) {
200        if let Some(state) = self.pending_qos2_in.get_mut(&packet_id) {
201            if *state == Qos2State::PendingPubrec {
202                *state = Qos2State::WaitingPubrel;
203            }
204        }
205    }
206}
207
208/// Channel for sending packets to a connected client
209pub type ClientSender = mpsc::Sender<Packet>;
210
211/// Active client connection state
212pub struct ActiveClient {
213    /// The client session
214    pub session: ClientSession,
215    /// Channel to send packets to the client
216    pub sender: ClientSender,
217}
218
219/// Session manager for tracking all client sessions
220pub struct SessionManager {
221    /// Active connected clients
222    active_clients: RwLock<HashMap<String, ActiveClient>>,
223    /// Persistent sessions for reconnecting clients
224    persistent_sessions: RwLock<HashMap<String, ClientSession>>,
225    /// Topic subscription tree
226    topics: RwLock<TopicTree>,
227    /// Metrics collector
228    metrics: Option<Arc<MqttMetrics>>,
229    /// Maximum number of connections
230    max_connections: usize,
231}
232
233impl SessionManager {
234    /// Create a new session manager
235    pub fn new(max_connections: usize, metrics: Option<Arc<MqttMetrics>>) -> Self {
236        Self {
237            active_clients: RwLock::new(HashMap::new()),
238            persistent_sessions: RwLock::new(HashMap::new()),
239            topics: RwLock::new(TopicTree::new()),
240            metrics,
241            max_connections,
242        }
243    }
244
245    /// Handle a new client connection
246    pub async fn connect(
247        &self,
248        client_id: String,
249        clean_session: bool,
250        keep_alive: u16,
251        sender: ClientSender,
252    ) -> Result<(bool, ConnackCode), ConnackCode> {
253        let active = self.active_clients.read().await;
254        if active.len() >= self.max_connections {
255            return Err(ConnackCode::ServerUnavailable);
256        }
257        drop(active);
258
259        // Check if client is already connected
260        let mut active = self.active_clients.write().await;
261        if let Some(existing) = active.remove(&client_id) {
262            // Disconnect existing client
263            info!("Disconnecting existing client {} for new connection", client_id);
264            let _ = existing.sender.send(Packet::Disconnect).await;
265
266            if let Some(metrics) = &self.metrics {
267                metrics.record_connection_closed();
268            }
269        }
270
271        // Check for persistent session
272        let mut persistent = self.persistent_sessions.write().await;
273        let (session, session_present) = if clean_session {
274            // Remove any existing session
275            persistent.remove(&client_id);
276            // Remove subscriptions from topic tree
277            let mut topics = self.topics.write().await;
278            if let Some(old_session) = persistent.get(&client_id) {
279                for filter in old_session.subscriptions.keys() {
280                    topics.unsubscribe(filter, &client_id);
281                }
282            }
283            (ClientSession::new(client_id.clone(), true, keep_alive), false)
284        } else if let Some(mut session) = persistent.remove(&client_id) {
285            // Restore persistent session
286            session.keep_alive = keep_alive;
287            session.touch();
288            (session, true)
289        } else {
290            (ClientSession::new(client_id.clone(), false, keep_alive), false)
291        };
292
293        active.insert(client_id.clone(), ActiveClient { session, sender });
294
295        if let Some(metrics) = &self.metrics {
296            metrics.record_connection();
297        }
298
299        info!(
300            "Client {} connected (clean_session={}, session_present={})",
301            client_id, clean_session, session_present
302        );
303
304        Ok((session_present, ConnackCode::Accepted))
305    }
306
307    /// Handle client disconnect
308    pub async fn disconnect(&self, client_id: &str) {
309        let mut active = self.active_clients.write().await;
310        if let Some(client) = active.remove(client_id) {
311            if !client.session.clean_session {
312                // Persist session for later reconnection
313                let mut persistent = self.persistent_sessions.write().await;
314                persistent.insert(client_id.to_string(), client.session);
315                info!("Persisted session for client {}", client_id);
316            } else {
317                // Clean session - remove subscriptions
318                let mut topics = self.topics.write().await;
319                for filter in client.session.subscriptions.keys() {
320                    topics.unsubscribe(filter, client_id);
321
322                    if let Some(metrics) = &self.metrics {
323                        metrics.record_unsubscription();
324                    }
325                }
326                info!("Cleaned up session for client {}", client_id);
327            }
328
329            if let Some(metrics) = &self.metrics {
330                metrics.record_connection_closed();
331            }
332        }
333    }
334
335    /// Handle SUBSCRIBE packet
336    pub async fn subscribe(
337        &self,
338        client_id: &str,
339        subscriptions: Vec<(String, QoS)>,
340    ) -> Option<Vec<SubackReturnCode>> {
341        let mut active = self.active_clients.write().await;
342        let client = active.get_mut(client_id)?;
343
344        let mut topics = self.topics.write().await;
345        let mut return_codes = Vec::new();
346
347        for (filter, requested_qos) in subscriptions {
348            // Add to topic tree
349            topics.subscribe(&filter, requested_qos as u8, client_id);
350
351            // Add to session
352            client.session.subscribe(filter.clone(), requested_qos);
353
354            // Return granted QoS (we grant what was requested)
355            return_codes.push(SubackReturnCode::success(requested_qos));
356
357            if let Some(metrics) = &self.metrics {
358                metrics.record_subscription();
359            }
360
361            debug!("Client {} subscribed to {} with QoS {:?}", client_id, filter, requested_qos);
362        }
363
364        Some(return_codes)
365    }
366
367    /// Handle UNSUBSCRIBE packet
368    pub async fn unsubscribe(&self, client_id: &str, topic_filters: Vec<String>) -> bool {
369        let mut active = self.active_clients.write().await;
370        let client = active.get_mut(client_id);
371
372        if client.is_none() {
373            return false;
374        }
375
376        let client = client.unwrap();
377        let mut topics = self.topics.write().await;
378
379        for filter in topic_filters {
380            topics.unsubscribe(&filter, client_id);
381            client.session.unsubscribe(&filter);
382
383            if let Some(metrics) = &self.metrics {
384                metrics.record_unsubscription();
385            }
386
387            debug!("Client {} unsubscribed from {}", client_id, filter);
388        }
389
390        true
391    }
392
393    /// Handle PUBLISH packet - route to subscribers
394    pub async fn publish(&self, publisher_id: &str, publish: &PublishPacket) {
395        // Update publisher's last activity
396        {
397            let mut active = self.active_clients.write().await;
398            if let Some(client) = active.get_mut(publisher_id) {
399                client.session.touch();
400            }
401        }
402
403        if let Some(metrics) = &self.metrics {
404            metrics.record_publish(publish.qos as u8);
405        }
406
407        // Handle retained messages
408        if publish.retain {
409            let mut topics = self.topics.write().await;
410            topics.retain_message(&publish.topic, publish.payload.clone(), publish.qos as u8);
411
412            if let Some(metrics) = &self.metrics {
413                metrics.record_retained_message();
414            }
415        }
416
417        // Find matching subscribers
418        let topics = self.topics.read().await;
419        let subscribers = topics.match_topic(&publish.topic);
420
421        // Deliver to each subscriber
422        let active = self.active_clients.read().await;
423        for sub in subscribers {
424            if sub.client_id == publisher_id {
425                continue; // Don't send to self
426            }
427
428            if let Some(client) = active.get(&sub.client_id) {
429                // Determine delivery QoS (minimum of publish and subscription QoS)
430                let delivery_qos = std::cmp::min(publish.qos as u8, sub.qos);
431                let delivery_qos = QoS::try_from(delivery_qos).unwrap_or(QoS::AtMostOnce);
432
433                let packet = Packet::Publish(PublishPacket {
434                    dup: false,
435                    qos: delivery_qos,
436                    retain: false, // Only first delivery can have retain
437                    topic: publish.topic.clone(),
438                    packet_id: if delivery_qos != QoS::AtMostOnce {
439                        Some(0) // Will be assigned by receiver
440                    } else {
441                        None
442                    },
443                    payload: publish.payload.clone(),
444                });
445
446                if client.sender.send(packet).await.is_ok() {
447                    if let Some(metrics) = &self.metrics {
448                        metrics.record_delivery();
449                    }
450                    debug!("Delivered message to {} on topic {}", sub.client_id, publish.topic);
451                }
452            }
453        }
454    }
455
456    /// Handle PUBACK from client
457    pub async fn handle_puback(&self, client_id: &str, packet_id: u16) {
458        let mut active = self.active_clients.write().await;
459        if let Some(client) = active.get_mut(client_id) {
460            client.session.touch();
461            if client.session.handle_puback(packet_id).is_some() {
462                debug!("QoS 1 delivery confirmed for client {}, packet {}", client_id, packet_id);
463            }
464        }
465    }
466
467    /// Handle PUBREC from client (QoS 2 step 1)
468    pub async fn handle_pubrec(&self, client_id: &str, packet_id: u16) -> bool {
469        let mut active = self.active_clients.write().await;
470        if let Some(client) = active.get_mut(client_id) {
471            client.session.touch();
472            if client.session.handle_pubrec(packet_id) {
473                debug!("QoS 2 PUBREC received for client {}, packet {}", client_id, packet_id);
474                return true;
475            }
476        }
477        false
478    }
479
480    /// Handle PUBREL from client (QoS 2 step 2)
481    pub async fn handle_pubrel(&self, client_id: &str, packet_id: u16) -> bool {
482        let mut active = self.active_clients.write().await;
483        if let Some(client) = active.get_mut(client_id) {
484            client.session.touch();
485            if client.session.handle_pubrel(packet_id) {
486                debug!("QoS 2 PUBREL received for client {}, packet {}", client_id, packet_id);
487                return true;
488            }
489        }
490        false
491    }
492
493    /// Handle PUBCOMP from client (QoS 2 step 3)
494    pub async fn handle_pubcomp(&self, client_id: &str, packet_id: u16) {
495        let mut active = self.active_clients.write().await;
496        if let Some(client) = active.get_mut(client_id) {
497            client.session.touch();
498            if client.session.handle_pubcomp(packet_id) {
499                debug!("QoS 2 delivery completed for client {}, packet {}", client_id, packet_id);
500            }
501        }
502    }
503
504    /// Update client activity timestamp (for PINGREQ)
505    pub async fn touch(&self, client_id: &str) {
506        let mut active = self.active_clients.write().await;
507        if let Some(client) = active.get_mut(client_id) {
508            client.session.touch();
509        }
510    }
511
512    /// Get retained messages matching a topic filter
513    pub async fn get_retained_messages(&self, filter: &str) -> Vec<(String, PublishPacket)> {
514        let topics = self.topics.read().await;
515        topics
516            .get_retained_for_filter(filter)
517            .into_iter()
518            .map(|(topic, msg)| {
519                (
520                    topic.to_string(),
521                    PublishPacket {
522                        dup: false,
523                        qos: QoS::try_from(msg.qos).unwrap_or(QoS::AtMostOnce),
524                        retain: true,
525                        topic: topic.to_string(),
526                        packet_id: None,
527                        payload: msg.payload.clone(),
528                    },
529                )
530            })
531            .collect()
532    }
533
534    /// Get sender for a specific client
535    pub async fn get_sender(&self, client_id: &str) -> Option<ClientSender> {
536        let active = self.active_clients.read().await;
537        active.get(client_id).map(|c| c.sender.clone())
538    }
539
540    /// Get list of connected client IDs
541    pub async fn get_connected_clients(&self) -> Vec<String> {
542        let active = self.active_clients.read().await;
543        active.keys().cloned().collect()
544    }
545
546    /// Get count of active connections
547    pub async fn connection_count(&self) -> usize {
548        let active = self.active_clients.read().await;
549        active.len()
550    }
551
552    /// Check and disconnect expired sessions
553    pub async fn cleanup_expired_sessions(&self) -> Vec<String> {
554        let mut expired = Vec::new();
555        let active = self.active_clients.read().await;
556
557        for (client_id, client) in active.iter() {
558            if client.session.is_expired() {
559                expired.push(client_id.clone());
560            }
561        }
562        drop(active);
563
564        for client_id in &expired {
565            warn!("Disconnecting expired session: {}", client_id);
566            self.disconnect(client_id).await;
567        }
568
569        expired
570    }
571
572    /// Assign a packet ID for outgoing QoS > 0 message
573    pub async fn assign_packet_id(&self, client_id: &str) -> Option<u16> {
574        let mut active = self.active_clients.write().await;
575        active.get_mut(client_id).map(|c| c.session.next_packet_id())
576    }
577
578    /// Get a client's subscriptions
579    pub async fn get_client_subscriptions(&self, client_id: &str) -> Vec<(String, QoS)> {
580        let active = self.active_clients.read().await;
581        if let Some(client) = active.get(client_id) {
582            client
583                .session
584                .subscriptions
585                .iter()
586                .map(|(filter, qos)| (filter.clone(), *qos))
587                .collect()
588        } else {
589            Vec::new()
590        }
591    }
592
593    /// Start QoS 2 inbound tracking
594    pub async fn start_qos2_inbound(&self, client_id: &str, packet_id: u16) {
595        let mut active = self.active_clients.write().await;
596        if let Some(client) = active.get_mut(client_id) {
597            client.session.start_qos2_inbound(packet_id);
598        }
599    }
600
601    /// Mark PUBREC sent for QoS 2 inbound
602    pub async fn mark_pubrec_sent(&self, client_id: &str, packet_id: u16) {
603        let mut active = self.active_clients.write().await;
604        if let Some(client) = active.get_mut(client_id) {
605            client.session.mark_pubrec_sent(packet_id);
606        }
607    }
608
609    /// Complete QoS 2 inbound (PUBCOMP sent)
610    pub async fn complete_qos2_inbound(&self, client_id: &str, packet_id: u16) {
611        let mut active = self.active_clients.write().await;
612        if let Some(client) = active.get_mut(client_id) {
613            client.session.complete_qos2_inbound(packet_id);
614        }
615    }
616}
617
618/// Helper to write a packet to a TCP stream
619pub async fn write_packet(
620    writer: &mut OwnedWriteHalf,
621    packet: &Packet,
622) -> Result<(), std::io::Error> {
623    let bytes = PacketEncoder::encode(packet)
624        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
625    writer.write_all(&bytes).await?;
626    writer.flush().await?;
627    Ok(())
628}
629
630/// Build a CONNACK response packet
631pub fn build_connack(session_present: bool, code: ConnackCode) -> Packet {
632    Packet::Connack(ConnackPacket {
633        session_present,
634        return_code: code,
635    })
636}
637
638/// Build a SUBACK response packet
639pub fn build_suback(packet_id: u16, return_codes: Vec<SubackReturnCode>) -> Packet {
640    Packet::Suback(SubackPacket {
641        packet_id,
642        return_codes,
643    })
644}
645
646/// Build an UNSUBACK response packet
647pub fn build_unsuback(packet_id: u16) -> Packet {
648    Packet::Unsuback(UnsubackPacket { packet_id })
649}
650
651/// Build a PUBACK response packet
652pub fn build_puback(packet_id: u16) -> Packet {
653    Packet::Puback(PubackPacket { packet_id })
654}
655
656/// Build a PUBREC response packet
657pub fn build_pubrec(packet_id: u16) -> Packet {
658    Packet::Pubrec(PubrecPacket { packet_id })
659}
660
661/// Build a PUBREL packet
662pub fn build_pubrel(packet_id: u16) -> Packet {
663    Packet::Pubrel(PubrelPacket { packet_id })
664}
665
666/// Build a PUBCOMP response packet
667pub fn build_pubcomp(packet_id: u16) -> Packet {
668    Packet::Pubcomp(PubcompPacket { packet_id })
669}
670
671#[cfg(test)]
672mod tests {
673    use super::*;
674
675    #[test]
676    fn test_client_session_new() {
677        let session = ClientSession::new("test-client".to_string(), true, 60);
678        assert_eq!(session.client_id, "test-client");
679        assert!(session.clean_session);
680        assert_eq!(session.keep_alive, 60);
681        assert!(session.subscriptions.is_empty());
682    }
683
684    #[test]
685    fn test_client_session_packet_id() {
686        let mut session = ClientSession::new("test".to_string(), true, 60);
687        assert_eq!(session.next_packet_id(), 1);
688        assert_eq!(session.next_packet_id(), 2);
689        assert_eq!(session.next_packet_id(), 3);
690    }
691
692    #[test]
693    fn test_client_session_packet_id_wrap() {
694        let mut session = ClientSession::new("test".to_string(), true, 60);
695        session.next_packet_id = 65535;
696        assert_eq!(session.next_packet_id(), 65535);
697        assert_eq!(session.next_packet_id(), 1); // Wrapped, skipped 0
698    }
699
700    #[test]
701    fn test_client_session_subscribe() {
702        let mut session = ClientSession::new("test".to_string(), true, 60);
703        session.subscribe("topic/a".to_string(), QoS::AtLeastOnce);
704        session.subscribe("topic/b".to_string(), QoS::ExactlyOnce);
705
706        assert_eq!(session.subscriptions.len(), 2);
707        assert_eq!(session.subscriptions.get("topic/a"), Some(&QoS::AtLeastOnce));
708    }
709
710    #[test]
711    fn test_client_session_unsubscribe() {
712        let mut session = ClientSession::new("test".to_string(), true, 60);
713        session.subscribe("topic/a".to_string(), QoS::AtLeastOnce);
714        assert!(session.unsubscribe("topic/a"));
715        assert!(!session.unsubscribe("topic/a")); // Already removed
716    }
717
718    #[test]
719    fn test_client_session_qos1_flow() {
720        let mut session = ClientSession::new("test".to_string(), true, 60);
721
722        let msg = PendingMessage {
723            packet_id: 100,
724            topic: "test".to_string(),
725            payload: vec![1, 2, 3],
726            qos: QoS::AtLeastOnce,
727            retain: false,
728            timestamp: 0,
729            retry_count: 0,
730        };
731
732        session.queue_qos1_message(100, msg);
733        assert!(session.pending_qos1_out.contains_key(&100));
734
735        let removed = session.handle_puback(100);
736        assert!(removed.is_some());
737        assert!(!session.pending_qos1_out.contains_key(&100));
738    }
739
740    #[test]
741    fn test_client_session_qos2_outbound_flow() {
742        let mut session = ClientSession::new("test".to_string(), true, 60);
743
744        // Start QoS 2 flow
745        session.start_qos2_outbound(200);
746        assert!(session.pending_qos2_out.contains_key(&200));
747        assert_eq!(session.pending_qos2_out.get(&200), Some(&Qos2OutboundState::WaitingPubrec));
748
749        // Receive PUBREC
750        assert!(session.handle_pubrec(200));
751        assert_eq!(session.pending_qos2_out.get(&200), Some(&Qos2OutboundState::WaitingPubcomp));
752
753        // Receive PUBCOMP
754        assert!(session.handle_pubcomp(200));
755        assert!(!session.pending_qos2_out.contains_key(&200));
756    }
757
758    #[test]
759    fn test_client_session_qos2_inbound_flow() {
760        let mut session = ClientSession::new("test".to_string(), true, 60);
761
762        // Start QoS 2 inbound
763        session.start_qos2_inbound(300);
764        assert!(session.pending_qos2_in.contains_key(&300));
765
766        // Send PUBREC
767        session.mark_pubrec_sent(300);
768        assert_eq!(session.pending_qos2_in.get(&300), Some(&Qos2State::WaitingPubrel));
769
770        // Receive PUBREL
771        assert!(session.handle_pubrel(300));
772        assert_eq!(session.pending_qos2_in.get(&300), Some(&Qos2State::PendingPubcomp));
773
774        // Send PUBCOMP
775        session.complete_qos2_inbound(300);
776        assert!(!session.pending_qos2_in.contains_key(&300));
777    }
778
779    #[tokio::test]
780    async fn test_session_manager_connect() {
781        let manager = SessionManager::new(100, None);
782        let (tx, _rx) = mpsc::channel(10);
783
784        let result = manager.connect("client-1".to_string(), true, 60, tx).await;
785        assert!(result.is_ok());
786        let (session_present, code) = result.unwrap();
787        assert!(!session_present);
788        assert_eq!(code, ConnackCode::Accepted);
789
790        assert_eq!(manager.connection_count().await, 1);
791    }
792
793    #[tokio::test]
794    async fn test_session_manager_disconnect() {
795        let manager = SessionManager::new(100, None);
796        let (tx, _rx) = mpsc::channel(10);
797
798        manager.connect("client-1".to_string(), true, 60, tx).await.unwrap();
799        manager.disconnect("client-1").await;
800
801        assert_eq!(manager.connection_count().await, 0);
802    }
803
804    #[tokio::test]
805    async fn test_session_manager_persistent_session() {
806        let manager = SessionManager::new(100, None);
807
808        // First connection with clean_session=false
809        let (tx1, _rx1) = mpsc::channel(10);
810        manager.connect("client-1".to_string(), false, 60, tx1).await.unwrap();
811
812        // Subscribe
813        manager
814            .subscribe("client-1", vec![("topic/a".to_string(), QoS::AtLeastOnce)])
815            .await;
816
817        // Disconnect
818        manager.disconnect("client-1").await;
819
820        // Reconnect - should have session
821        let (tx2, _rx2) = mpsc::channel(10);
822        let result = manager.connect("client-1".to_string(), false, 60, tx2).await;
823        let (session_present, _) = result.unwrap();
824        assert!(session_present);
825    }
826
827    #[tokio::test]
828    async fn test_session_manager_subscribe() {
829        let manager = SessionManager::new(100, None);
830        let (tx, _rx) = mpsc::channel(10);
831
832        manager.connect("client-1".to_string(), true, 60, tx).await.unwrap();
833
834        let result = manager
835            .subscribe(
836                "client-1",
837                vec![
838                    ("topic/a".to_string(), QoS::AtMostOnce),
839                    ("topic/b".to_string(), QoS::AtLeastOnce),
840                ],
841            )
842            .await;
843
844        assert!(result.is_some());
845        let codes = result.unwrap();
846        assert_eq!(codes.len(), 2);
847        assert_eq!(codes[0], SubackReturnCode::SuccessQoS0);
848        assert_eq!(codes[1], SubackReturnCode::SuccessQoS1);
849    }
850
851    #[tokio::test]
852    async fn test_session_manager_unsubscribe() {
853        let manager = SessionManager::new(100, None);
854        let (tx, _rx) = mpsc::channel(10);
855
856        manager.connect("client-1".to_string(), true, 60, tx).await.unwrap();
857
858        manager
859            .subscribe("client-1", vec![("topic/a".to_string(), QoS::AtMostOnce)])
860            .await;
861
862        let result = manager.unsubscribe("client-1", vec!["topic/a".to_string()]).await;
863        assert!(result);
864    }
865
866    #[tokio::test]
867    async fn test_session_manager_max_connections() {
868        let manager = SessionManager::new(2, None);
869
870        let (tx1, _rx1) = mpsc::channel(10);
871        let (tx2, _rx2) = mpsc::channel(10);
872        let (tx3, _rx3) = mpsc::channel(10);
873
874        manager.connect("client-1".to_string(), true, 60, tx1).await.unwrap();
875        manager.connect("client-2".to_string(), true, 60, tx2).await.unwrap();
876
877        let result = manager.connect("client-3".to_string(), true, 60, tx3).await;
878        assert!(result.is_err());
879        assert_eq!(result.unwrap_err(), ConnackCode::ServerUnavailable);
880    }
881
882    #[test]
883    fn test_build_connack() {
884        let packet = build_connack(true, ConnackCode::Accepted);
885        if let Packet::Connack(connack) = packet {
886            assert!(connack.session_present);
887            assert_eq!(connack.return_code, ConnackCode::Accepted);
888        } else {
889            panic!("Expected Connack packet");
890        }
891    }
892
893    #[test]
894    fn test_build_suback() {
895        let packet =
896            build_suback(100, vec![SubackReturnCode::SuccessQoS0, SubackReturnCode::SuccessQoS1]);
897        if let Packet::Suback(suback) = packet {
898            assert_eq!(suback.packet_id, 100);
899            assert_eq!(suback.return_codes.len(), 2);
900        } else {
901            panic!("Expected Suback packet");
902        }
903    }
904
905    #[test]
906    fn test_suback_return_code_success() {
907        assert_eq!(SubackReturnCode::success(QoS::AtMostOnce), SubackReturnCode::SuccessQoS0);
908        assert_eq!(SubackReturnCode::success(QoS::AtLeastOnce), SubackReturnCode::SuccessQoS1);
909        assert_eq!(SubackReturnCode::success(QoS::ExactlyOnce), SubackReturnCode::SuccessQoS2);
910    }
911}