1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use tokio::io::AsyncWriteExt;
11use tokio::net::tcp::OwnedWriteHalf;
12use tokio::sync::{mpsc, RwLock};
13use tracing::{debug, info, warn};
14
15use crate::metrics::MqttMetrics;
16use crate::protocol::{
17 ConnackCode, ConnackPacket, Packet, PacketEncoder, PubackPacket, PubcompPacket, PublishPacket,
18 PubrecPacket, PubrelPacket, QoS, SubackPacket, SubackReturnCode, UnsubackPacket,
19};
20use crate::topics::TopicTree;
21
22#[derive(Debug, Clone)]
24pub struct PendingMessage {
25 pub packet_id: u16,
26 pub topic: String,
27 pub payload: Vec<u8>,
28 pub qos: QoS,
29 pub retain: bool,
30 pub timestamp: u64,
31 pub retry_count: u8,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
36pub enum Qos2State {
37 PendingPubrec,
39 WaitingPubrel,
41 PendingPubcomp,
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum Qos2OutboundState {
48 WaitingPubrec,
50 WaitingPubcomp,
52}
53
54#[derive(Debug)]
56pub struct ClientSession {
57 pub client_id: String,
59 pub clean_session: bool,
61 pub keep_alive: u16,
63 pub subscriptions: HashMap<String, QoS>,
65 pub pending_qos1_out: HashMap<u16, PendingMessage>,
67 pub pending_qos2_out: HashMap<u16, Qos2OutboundState>,
69 pub pending_qos2_in: HashMap<u16, Qos2State>,
71 pub last_activity: u64,
73 pub connected_at: u64,
75 next_packet_id: u16,
77 pub username: Option<String>,
79}
80
81impl ClientSession {
82 pub fn new(client_id: String, clean_session: bool, keep_alive: u16) -> Self {
84 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
85
86 Self {
87 client_id,
88 clean_session,
89 keep_alive,
90 subscriptions: HashMap::new(),
91 pending_qos1_out: HashMap::new(),
92 pending_qos2_out: HashMap::new(),
93 pending_qos2_in: HashMap::new(),
94 last_activity: now,
95 connected_at: now,
96 next_packet_id: 1,
97 username: None,
98 }
99 }
100
101 pub fn next_packet_id(&mut self) -> u16 {
103 let id = self.next_packet_id;
104 self.next_packet_id = self.next_packet_id.wrapping_add(1);
105 if self.next_packet_id == 0 {
106 self.next_packet_id = 1; }
108 id
109 }
110
111 pub fn touch(&mut self) {
113 self.last_activity =
114 SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
115 }
116
117 pub fn is_expired(&self) -> bool {
119 if self.keep_alive == 0 {
120 return false; }
122
123 let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_secs();
124
125 let timeout = (self.keep_alive as u64) * 3 / 2;
127 now - self.last_activity > timeout
128 }
129
130 pub fn subscribe(&mut self, topic_filter: String, qos: QoS) {
132 self.subscriptions.insert(topic_filter, qos);
133 }
134
135 pub fn unsubscribe(&mut self, topic_filter: &str) -> bool {
137 self.subscriptions.remove(topic_filter).is_some()
138 }
139
140 pub fn queue_qos1_message(&mut self, packet_id: u16, message: PendingMessage) {
142 self.pending_qos1_out.insert(packet_id, message);
143 }
144
145 pub fn handle_puback(&mut self, packet_id: u16) -> Option<PendingMessage> {
147 self.pending_qos1_out.remove(&packet_id)
148 }
149
150 pub fn start_qos2_outbound(&mut self, packet_id: u16) {
152 self.pending_qos2_out.insert(packet_id, Qos2OutboundState::WaitingPubrec);
153 }
154
155 pub fn handle_pubrec(&mut self, packet_id: u16) -> bool {
157 if let Some(state) = self.pending_qos2_out.get_mut(&packet_id) {
158 if *state == Qos2OutboundState::WaitingPubrec {
159 *state = Qos2OutboundState::WaitingPubcomp;
160 return true;
161 }
162 }
163 false
164 }
165
166 pub fn handle_pubcomp(&mut self, packet_id: u16) -> bool {
168 if let Some(state) = self.pending_qos2_out.get(&packet_id) {
169 if *state == Qos2OutboundState::WaitingPubcomp {
170 self.pending_qos2_out.remove(&packet_id);
171 return true;
172 }
173 }
174 false
175 }
176
177 pub fn start_qos2_inbound(&mut self, packet_id: u16) {
179 self.pending_qos2_in.insert(packet_id, Qos2State::PendingPubrec);
180 }
181
182 pub fn handle_pubrel(&mut self, packet_id: u16) -> bool {
184 if let Some(state) = self.pending_qos2_in.get_mut(&packet_id) {
185 if *state == Qos2State::WaitingPubrel {
186 *state = Qos2State::PendingPubcomp;
187 return true;
188 }
189 }
190 false
191 }
192
193 pub fn complete_qos2_inbound(&mut self, packet_id: u16) {
195 self.pending_qos2_in.remove(&packet_id);
196 }
197
198 pub fn mark_pubrec_sent(&mut self, packet_id: u16) {
200 if let Some(state) = self.pending_qos2_in.get_mut(&packet_id) {
201 if *state == Qos2State::PendingPubrec {
202 *state = Qos2State::WaitingPubrel;
203 }
204 }
205 }
206}
207
208pub type ClientSender = mpsc::Sender<Packet>;
210
211pub struct ActiveClient {
213 pub session: ClientSession,
215 pub sender: ClientSender,
217}
218
219pub struct SessionManager {
221 active_clients: RwLock<HashMap<String, ActiveClient>>,
223 persistent_sessions: RwLock<HashMap<String, ClientSession>>,
225 topics: RwLock<TopicTree>,
227 metrics: Option<Arc<MqttMetrics>>,
229 max_connections: usize,
231}
232
233impl SessionManager {
234 pub fn new(max_connections: usize, metrics: Option<Arc<MqttMetrics>>) -> Self {
236 Self {
237 active_clients: RwLock::new(HashMap::new()),
238 persistent_sessions: RwLock::new(HashMap::new()),
239 topics: RwLock::new(TopicTree::new()),
240 metrics,
241 max_connections,
242 }
243 }
244
245 pub async fn connect(
247 &self,
248 client_id: String,
249 clean_session: bool,
250 keep_alive: u16,
251 sender: ClientSender,
252 ) -> Result<(bool, ConnackCode), ConnackCode> {
253 let active = self.active_clients.read().await;
254 if active.len() >= self.max_connections {
255 return Err(ConnackCode::ServerUnavailable);
256 }
257 drop(active);
258
259 let mut active = self.active_clients.write().await;
261 if let Some(existing) = active.remove(&client_id) {
262 info!("Disconnecting existing client {} for new connection", client_id);
264 let _ = existing.sender.send(Packet::Disconnect).await;
265
266 if let Some(metrics) = &self.metrics {
267 metrics.record_connection_closed();
268 }
269 }
270
271 let mut persistent = self.persistent_sessions.write().await;
273 let (session, session_present) = if clean_session {
274 persistent.remove(&client_id);
276 let mut topics = self.topics.write().await;
278 if let Some(old_session) = persistent.get(&client_id) {
279 for filter in old_session.subscriptions.keys() {
280 topics.unsubscribe(filter, &client_id);
281 }
282 }
283 (ClientSession::new(client_id.clone(), true, keep_alive), false)
284 } else if let Some(mut session) = persistent.remove(&client_id) {
285 session.keep_alive = keep_alive;
287 session.touch();
288 (session, true)
289 } else {
290 (ClientSession::new(client_id.clone(), false, keep_alive), false)
291 };
292
293 active.insert(client_id.clone(), ActiveClient { session, sender });
294
295 if let Some(metrics) = &self.metrics {
296 metrics.record_connection();
297 }
298
299 info!(
300 "Client {} connected (clean_session={}, session_present={})",
301 client_id, clean_session, session_present
302 );
303
304 Ok((session_present, ConnackCode::Accepted))
305 }
306
307 pub async fn disconnect(&self, client_id: &str) {
309 let mut active = self.active_clients.write().await;
310 if let Some(client) = active.remove(client_id) {
311 if !client.session.clean_session {
312 let mut persistent = self.persistent_sessions.write().await;
314 persistent.insert(client_id.to_string(), client.session);
315 info!("Persisted session for client {}", client_id);
316 } else {
317 let mut topics = self.topics.write().await;
319 for filter in client.session.subscriptions.keys() {
320 topics.unsubscribe(filter, client_id);
321
322 if let Some(metrics) = &self.metrics {
323 metrics.record_unsubscription();
324 }
325 }
326 info!("Cleaned up session for client {}", client_id);
327 }
328
329 if let Some(metrics) = &self.metrics {
330 metrics.record_connection_closed();
331 }
332 }
333 }
334
335 pub async fn subscribe(
337 &self,
338 client_id: &str,
339 subscriptions: Vec<(String, QoS)>,
340 ) -> Option<Vec<SubackReturnCode>> {
341 let mut active = self.active_clients.write().await;
342 let client = active.get_mut(client_id)?;
343
344 let mut topics = self.topics.write().await;
345 let mut return_codes = Vec::new();
346
347 for (filter, requested_qos) in subscriptions {
348 topics.subscribe(&filter, requested_qos as u8, client_id);
350
351 client.session.subscribe(filter.clone(), requested_qos);
353
354 return_codes.push(SubackReturnCode::success(requested_qos));
356
357 if let Some(metrics) = &self.metrics {
358 metrics.record_subscription();
359 }
360
361 debug!("Client {} subscribed to {} with QoS {:?}", client_id, filter, requested_qos);
362 }
363
364 Some(return_codes)
365 }
366
367 pub async fn unsubscribe(&self, client_id: &str, topic_filters: Vec<String>) -> bool {
369 let mut active = self.active_clients.write().await;
370 let client = active.get_mut(client_id);
371
372 if client.is_none() {
373 return false;
374 }
375
376 let client = client.unwrap();
377 let mut topics = self.topics.write().await;
378
379 for filter in topic_filters {
380 topics.unsubscribe(&filter, client_id);
381 client.session.unsubscribe(&filter);
382
383 if let Some(metrics) = &self.metrics {
384 metrics.record_unsubscription();
385 }
386
387 debug!("Client {} unsubscribed from {}", client_id, filter);
388 }
389
390 true
391 }
392
393 pub async fn publish(&self, publisher_id: &str, publish: &PublishPacket) {
395 {
397 let mut active = self.active_clients.write().await;
398 if let Some(client) = active.get_mut(publisher_id) {
399 client.session.touch();
400 }
401 }
402
403 if let Some(metrics) = &self.metrics {
404 metrics.record_publish(publish.qos as u8);
405 }
406
407 if publish.retain {
409 let mut topics = self.topics.write().await;
410 topics.retain_message(&publish.topic, publish.payload.clone(), publish.qos as u8);
411
412 if let Some(metrics) = &self.metrics {
413 metrics.record_retained_message();
414 }
415 }
416
417 let topics = self.topics.read().await;
419 let subscribers = topics.match_topic(&publish.topic);
420
421 let active = self.active_clients.read().await;
423 for sub in subscribers {
424 if sub.client_id == publisher_id {
425 continue; }
427
428 if let Some(client) = active.get(&sub.client_id) {
429 let delivery_qos = std::cmp::min(publish.qos as u8, sub.qos);
431 let delivery_qos = QoS::try_from(delivery_qos).unwrap_or(QoS::AtMostOnce);
432
433 let packet = Packet::Publish(PublishPacket {
434 dup: false,
435 qos: delivery_qos,
436 retain: false, topic: publish.topic.clone(),
438 packet_id: if delivery_qos != QoS::AtMostOnce {
439 Some(0) } else {
441 None
442 },
443 payload: publish.payload.clone(),
444 });
445
446 if client.sender.send(packet).await.is_ok() {
447 if let Some(metrics) = &self.metrics {
448 metrics.record_delivery();
449 }
450 debug!("Delivered message to {} on topic {}", sub.client_id, publish.topic);
451 }
452 }
453 }
454 }
455
456 pub async fn handle_puback(&self, client_id: &str, packet_id: u16) {
458 let mut active = self.active_clients.write().await;
459 if let Some(client) = active.get_mut(client_id) {
460 client.session.touch();
461 if client.session.handle_puback(packet_id).is_some() {
462 debug!("QoS 1 delivery confirmed for client {}, packet {}", client_id, packet_id);
463 }
464 }
465 }
466
467 pub async fn handle_pubrec(&self, client_id: &str, packet_id: u16) -> bool {
469 let mut active = self.active_clients.write().await;
470 if let Some(client) = active.get_mut(client_id) {
471 client.session.touch();
472 if client.session.handle_pubrec(packet_id) {
473 debug!("QoS 2 PUBREC received for client {}, packet {}", client_id, packet_id);
474 return true;
475 }
476 }
477 false
478 }
479
480 pub async fn handle_pubrel(&self, client_id: &str, packet_id: u16) -> bool {
482 let mut active = self.active_clients.write().await;
483 if let Some(client) = active.get_mut(client_id) {
484 client.session.touch();
485 if client.session.handle_pubrel(packet_id) {
486 debug!("QoS 2 PUBREL received for client {}, packet {}", client_id, packet_id);
487 return true;
488 }
489 }
490 false
491 }
492
493 pub async fn handle_pubcomp(&self, client_id: &str, packet_id: u16) {
495 let mut active = self.active_clients.write().await;
496 if let Some(client) = active.get_mut(client_id) {
497 client.session.touch();
498 if client.session.handle_pubcomp(packet_id) {
499 debug!("QoS 2 delivery completed for client {}, packet {}", client_id, packet_id);
500 }
501 }
502 }
503
504 pub async fn touch(&self, client_id: &str) {
506 let mut active = self.active_clients.write().await;
507 if let Some(client) = active.get_mut(client_id) {
508 client.session.touch();
509 }
510 }
511
512 pub async fn get_retained_messages(&self, filter: &str) -> Vec<(String, PublishPacket)> {
514 let topics = self.topics.read().await;
515 topics
516 .get_retained_for_filter(filter)
517 .into_iter()
518 .map(|(topic, msg)| {
519 (
520 topic.to_string(),
521 PublishPacket {
522 dup: false,
523 qos: QoS::try_from(msg.qos).unwrap_or(QoS::AtMostOnce),
524 retain: true,
525 topic: topic.to_string(),
526 packet_id: None,
527 payload: msg.payload.clone(),
528 },
529 )
530 })
531 .collect()
532 }
533
534 pub async fn get_sender(&self, client_id: &str) -> Option<ClientSender> {
536 let active = self.active_clients.read().await;
537 active.get(client_id).map(|c| c.sender.clone())
538 }
539
540 pub async fn get_connected_clients(&self) -> Vec<String> {
542 let active = self.active_clients.read().await;
543 active.keys().cloned().collect()
544 }
545
546 pub async fn connection_count(&self) -> usize {
548 let active = self.active_clients.read().await;
549 active.len()
550 }
551
552 pub async fn cleanup_expired_sessions(&self) -> Vec<String> {
554 let mut expired = Vec::new();
555 let active = self.active_clients.read().await;
556
557 for (client_id, client) in active.iter() {
558 if client.session.is_expired() {
559 expired.push(client_id.clone());
560 }
561 }
562 drop(active);
563
564 for client_id in &expired {
565 warn!("Disconnecting expired session: {}", client_id);
566 self.disconnect(client_id).await;
567 }
568
569 expired
570 }
571
572 pub async fn assign_packet_id(&self, client_id: &str) -> Option<u16> {
574 let mut active = self.active_clients.write().await;
575 active.get_mut(client_id).map(|c| c.session.next_packet_id())
576 }
577
578 pub async fn get_client_subscriptions(&self, client_id: &str) -> Vec<(String, QoS)> {
580 let active = self.active_clients.read().await;
581 if let Some(client) = active.get(client_id) {
582 client
583 .session
584 .subscriptions
585 .iter()
586 .map(|(filter, qos)| (filter.clone(), *qos))
587 .collect()
588 } else {
589 Vec::new()
590 }
591 }
592
593 pub async fn start_qos2_inbound(&self, client_id: &str, packet_id: u16) {
595 let mut active = self.active_clients.write().await;
596 if let Some(client) = active.get_mut(client_id) {
597 client.session.start_qos2_inbound(packet_id);
598 }
599 }
600
601 pub async fn mark_pubrec_sent(&self, client_id: &str, packet_id: u16) {
603 let mut active = self.active_clients.write().await;
604 if let Some(client) = active.get_mut(client_id) {
605 client.session.mark_pubrec_sent(packet_id);
606 }
607 }
608
609 pub async fn complete_qos2_inbound(&self, client_id: &str, packet_id: u16) {
611 let mut active = self.active_clients.write().await;
612 if let Some(client) = active.get_mut(client_id) {
613 client.session.complete_qos2_inbound(packet_id);
614 }
615 }
616}
617
618pub async fn write_packet(
620 writer: &mut OwnedWriteHalf,
621 packet: &Packet,
622) -> Result<(), std::io::Error> {
623 let bytes = PacketEncoder::encode(packet)
624 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
625 writer.write_all(&bytes).await?;
626 writer.flush().await?;
627 Ok(())
628}
629
630pub fn build_connack(session_present: bool, code: ConnackCode) -> Packet {
632 Packet::Connack(ConnackPacket {
633 session_present,
634 return_code: code,
635 })
636}
637
638pub fn build_suback(packet_id: u16, return_codes: Vec<SubackReturnCode>) -> Packet {
640 Packet::Suback(SubackPacket {
641 packet_id,
642 return_codes,
643 })
644}
645
646pub fn build_unsuback(packet_id: u16) -> Packet {
648 Packet::Unsuback(UnsubackPacket { packet_id })
649}
650
651pub fn build_puback(packet_id: u16) -> Packet {
653 Packet::Puback(PubackPacket { packet_id })
654}
655
656pub fn build_pubrec(packet_id: u16) -> Packet {
658 Packet::Pubrec(PubrecPacket { packet_id })
659}
660
661pub fn build_pubrel(packet_id: u16) -> Packet {
663 Packet::Pubrel(PubrelPacket { packet_id })
664}
665
666pub fn build_pubcomp(packet_id: u16) -> Packet {
668 Packet::Pubcomp(PubcompPacket { packet_id })
669}
670
671#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn test_client_session_new() {
677 let session = ClientSession::new("test-client".to_string(), true, 60);
678 assert_eq!(session.client_id, "test-client");
679 assert!(session.clean_session);
680 assert_eq!(session.keep_alive, 60);
681 assert!(session.subscriptions.is_empty());
682 }
683
684 #[test]
685 fn test_client_session_packet_id() {
686 let mut session = ClientSession::new("test".to_string(), true, 60);
687 assert_eq!(session.next_packet_id(), 1);
688 assert_eq!(session.next_packet_id(), 2);
689 assert_eq!(session.next_packet_id(), 3);
690 }
691
692 #[test]
693 fn test_client_session_packet_id_wrap() {
694 let mut session = ClientSession::new("test".to_string(), true, 60);
695 session.next_packet_id = 65535;
696 assert_eq!(session.next_packet_id(), 65535);
697 assert_eq!(session.next_packet_id(), 1); }
699
700 #[test]
701 fn test_client_session_subscribe() {
702 let mut session = ClientSession::new("test".to_string(), true, 60);
703 session.subscribe("topic/a".to_string(), QoS::AtLeastOnce);
704 session.subscribe("topic/b".to_string(), QoS::ExactlyOnce);
705
706 assert_eq!(session.subscriptions.len(), 2);
707 assert_eq!(session.subscriptions.get("topic/a"), Some(&QoS::AtLeastOnce));
708 }
709
710 #[test]
711 fn test_client_session_unsubscribe() {
712 let mut session = ClientSession::new("test".to_string(), true, 60);
713 session.subscribe("topic/a".to_string(), QoS::AtLeastOnce);
714 assert!(session.unsubscribe("topic/a"));
715 assert!(!session.unsubscribe("topic/a")); }
717
718 #[test]
719 fn test_client_session_qos1_flow() {
720 let mut session = ClientSession::new("test".to_string(), true, 60);
721
722 let msg = PendingMessage {
723 packet_id: 100,
724 topic: "test".to_string(),
725 payload: vec![1, 2, 3],
726 qos: QoS::AtLeastOnce,
727 retain: false,
728 timestamp: 0,
729 retry_count: 0,
730 };
731
732 session.queue_qos1_message(100, msg);
733 assert!(session.pending_qos1_out.contains_key(&100));
734
735 let removed = session.handle_puback(100);
736 assert!(removed.is_some());
737 assert!(!session.pending_qos1_out.contains_key(&100));
738 }
739
740 #[test]
741 fn test_client_session_qos2_outbound_flow() {
742 let mut session = ClientSession::new("test".to_string(), true, 60);
743
744 session.start_qos2_outbound(200);
746 assert!(session.pending_qos2_out.contains_key(&200));
747 assert_eq!(session.pending_qos2_out.get(&200), Some(&Qos2OutboundState::WaitingPubrec));
748
749 assert!(session.handle_pubrec(200));
751 assert_eq!(session.pending_qos2_out.get(&200), Some(&Qos2OutboundState::WaitingPubcomp));
752
753 assert!(session.handle_pubcomp(200));
755 assert!(!session.pending_qos2_out.contains_key(&200));
756 }
757
758 #[test]
759 fn test_client_session_qos2_inbound_flow() {
760 let mut session = ClientSession::new("test".to_string(), true, 60);
761
762 session.start_qos2_inbound(300);
764 assert!(session.pending_qos2_in.contains_key(&300));
765
766 session.mark_pubrec_sent(300);
768 assert_eq!(session.pending_qos2_in.get(&300), Some(&Qos2State::WaitingPubrel));
769
770 assert!(session.handle_pubrel(300));
772 assert_eq!(session.pending_qos2_in.get(&300), Some(&Qos2State::PendingPubcomp));
773
774 session.complete_qos2_inbound(300);
776 assert!(!session.pending_qos2_in.contains_key(&300));
777 }
778
779 #[tokio::test]
780 async fn test_session_manager_connect() {
781 let manager = SessionManager::new(100, None);
782 let (tx, _rx) = mpsc::channel(10);
783
784 let result = manager.connect("client-1".to_string(), true, 60, tx).await;
785 assert!(result.is_ok());
786 let (session_present, code) = result.unwrap();
787 assert!(!session_present);
788 assert_eq!(code, ConnackCode::Accepted);
789
790 assert_eq!(manager.connection_count().await, 1);
791 }
792
793 #[tokio::test]
794 async fn test_session_manager_disconnect() {
795 let manager = SessionManager::new(100, None);
796 let (tx, _rx) = mpsc::channel(10);
797
798 manager.connect("client-1".to_string(), true, 60, tx).await.unwrap();
799 manager.disconnect("client-1").await;
800
801 assert_eq!(manager.connection_count().await, 0);
802 }
803
804 #[tokio::test]
805 async fn test_session_manager_persistent_session() {
806 let manager = SessionManager::new(100, None);
807
808 let (tx1, _rx1) = mpsc::channel(10);
810 manager.connect("client-1".to_string(), false, 60, tx1).await.unwrap();
811
812 manager
814 .subscribe("client-1", vec![("topic/a".to_string(), QoS::AtLeastOnce)])
815 .await;
816
817 manager.disconnect("client-1").await;
819
820 let (tx2, _rx2) = mpsc::channel(10);
822 let result = manager.connect("client-1".to_string(), false, 60, tx2).await;
823 let (session_present, _) = result.unwrap();
824 assert!(session_present);
825 }
826
827 #[tokio::test]
828 async fn test_session_manager_subscribe() {
829 let manager = SessionManager::new(100, None);
830 let (tx, _rx) = mpsc::channel(10);
831
832 manager.connect("client-1".to_string(), true, 60, tx).await.unwrap();
833
834 let result = manager
835 .subscribe(
836 "client-1",
837 vec![
838 ("topic/a".to_string(), QoS::AtMostOnce),
839 ("topic/b".to_string(), QoS::AtLeastOnce),
840 ],
841 )
842 .await;
843
844 assert!(result.is_some());
845 let codes = result.unwrap();
846 assert_eq!(codes.len(), 2);
847 assert_eq!(codes[0], SubackReturnCode::SuccessQoS0);
848 assert_eq!(codes[1], SubackReturnCode::SuccessQoS1);
849 }
850
851 #[tokio::test]
852 async fn test_session_manager_unsubscribe() {
853 let manager = SessionManager::new(100, None);
854 let (tx, _rx) = mpsc::channel(10);
855
856 manager.connect("client-1".to_string(), true, 60, tx).await.unwrap();
857
858 manager
859 .subscribe("client-1", vec![("topic/a".to_string(), QoS::AtMostOnce)])
860 .await;
861
862 let result = manager.unsubscribe("client-1", vec!["topic/a".to_string()]).await;
863 assert!(result);
864 }
865
866 #[tokio::test]
867 async fn test_session_manager_max_connections() {
868 let manager = SessionManager::new(2, None);
869
870 let (tx1, _rx1) = mpsc::channel(10);
871 let (tx2, _rx2) = mpsc::channel(10);
872 let (tx3, _rx3) = mpsc::channel(10);
873
874 manager.connect("client-1".to_string(), true, 60, tx1).await.unwrap();
875 manager.connect("client-2".to_string(), true, 60, tx2).await.unwrap();
876
877 let result = manager.connect("client-3".to_string(), true, 60, tx3).await;
878 assert!(result.is_err());
879 assert_eq!(result.unwrap_err(), ConnackCode::ServerUnavailable);
880 }
881
882 #[test]
883 fn test_build_connack() {
884 let packet = build_connack(true, ConnackCode::Accepted);
885 if let Packet::Connack(connack) = packet {
886 assert!(connack.session_present);
887 assert_eq!(connack.return_code, ConnackCode::Accepted);
888 } else {
889 panic!("Expected Connack packet");
890 }
891 }
892
893 #[test]
894 fn test_build_suback() {
895 let packet =
896 build_suback(100, vec![SubackReturnCode::SuccessQoS0, SubackReturnCode::SuccessQoS1]);
897 if let Packet::Suback(suback) = packet {
898 assert_eq!(suback.packet_id, 100);
899 assert_eq!(suback.return_codes.len(), 2);
900 } else {
901 panic!("Expected Suback packet");
902 }
903 }
904
905 #[test]
906 fn test_suback_return_code_success() {
907 assert_eq!(SubackReturnCode::success(QoS::AtMostOnce), SubackReturnCode::SuccessQoS0);
908 assert_eq!(SubackReturnCode::success(QoS::AtLeastOnce), SubackReturnCode::SuccessQoS1);
909 assert_eq!(SubackReturnCode::success(QoS::ExactlyOnce), SubackReturnCode::SuccessQoS2);
910 }
911}