mockforge_mqtt/
qos.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use tracing::{info, warn};
5
6type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
7
8/// MQTT Quality of Service levels
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum QoS {
11    AtMostOnce = 0,  // QoS 0
12    AtLeastOnce = 1, // QoS 1
13    ExactlyOnce = 2, // QoS 2
14}
15
16impl QoS {
17    pub fn from_u8(value: u8) -> Option<Self> {
18        match value {
19            0 => Some(QoS::AtMostOnce),
20            1 => Some(QoS::AtLeastOnce),
21            2 => Some(QoS::ExactlyOnce),
22            _ => None,
23        }
24    }
25
26    pub fn as_u8(&self) -> u8 {
27        *self as u8
28    }
29}
30
31/// Message state for QoS handling
32#[derive(Debug, Clone)]
33pub struct MessageState {
34    pub packet_id: u16,
35    pub topic: String,
36    pub payload: Vec<u8>,
37    pub qos: QoS,
38    pub retained: bool,
39    pub timestamp: u64,
40}
41
42/// QoS 1 message awaiting acknowledgment
43#[derive(Debug, Clone)]
44struct PendingQoS1Message {
45    #[allow(dead_code)]
46    message: MessageState,
47    client_id: String,
48    retry_count: u8,
49}
50
51/// QoS 2 message state
52#[derive(Debug, Clone)]
53enum QoS2State {
54    Received, // PUBREC sent, waiting for PUBREL
55    Released, // PUBREL received, PUBCOMP sent
56}
57
58/// QoS handler for managing message delivery guarantees
59pub struct QoSHandler {
60    qos1_pending: Arc<RwLock<HashMap<u16, PendingQoS1Message>>>,
61    qos2_states: Arc<RwLock<HashMap<u16, QoS2State>>>,
62    max_retries: u8,
63}
64
65impl Default for QoSHandler {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl QoSHandler {
72    pub fn new() -> Self {
73        Self {
74            qos1_pending: Arc::new(RwLock::new(HashMap::new())),
75            qos2_states: Arc::new(RwLock::new(HashMap::new())),
76            max_retries: 3,
77        }
78    }
79
80    /// Handle QoS 0: At most once delivery
81    pub async fn handle_qo_s0(&self, _message: MessageState) -> Result<()> {
82        info!("QoS 0: Fire and forget delivery");
83        // QoS 0 - no acknowledgment needed
84        Ok(())
85    }
86
87    /// Handle QoS 1: At least once delivery
88    pub async fn handle_qo_s1(&self, message: MessageState, client_id: &str) -> Result<()> {
89        info!(
90            "QoS 1: Storing message for at-least-once delivery, packet {}",
91            message.packet_id
92        );
93
94        let pending = PendingQoS1Message {
95            message: message.clone(),
96            client_id: client_id.to_string(),
97            retry_count: 0,
98        };
99
100        self.qos1_pending.write().await.insert(message.packet_id, pending);
101
102        // Send PUBACK to client
103        self.send_puback(client_id, message.packet_id).await?;
104
105        Ok(())
106    }
107
108    /// Handle QoS 2: Exactly once delivery
109    pub async fn handle_qo_s2(&self, message: MessageState, client_id: &str) -> Result<()> {
110        info!("QoS 2: Starting exactly-once delivery handshake, packet {}", message.packet_id);
111
112        // Store the message state
113        self.qos2_states.write().await.insert(message.packet_id, QoS2State::Received);
114
115        // Send PUBREC to client
116        self.send_pubrec(client_id, message.packet_id).await?;
117
118        Ok(())
119    }
120
121    /// Handle PUBACK (QoS 1 acknowledgment)
122    pub async fn handle_puback(&self, packet_id: u16) -> Result<()> {
123        if let Some(_pending) = self.qos1_pending.write().await.remove(&packet_id) {
124            info!("QoS 1: Received PUBACK for packet {}, delivery confirmed", packet_id);
125        } else {
126            warn!("QoS 1: Received PUBACK for unknown packet {}", packet_id);
127        }
128        Ok(())
129    }
130
131    /// Handle PUBREC (QoS 2 first acknowledgment)
132    pub async fn handle_pubrec(&self, packet_id: u16, client_id: &str) -> Result<()> {
133        let mut states = self.qos2_states.write().await;
134        if let Some(state) = states.get_mut(&packet_id) {
135            match state {
136                QoS2State::Received => {
137                    *state = QoS2State::Released;
138                    info!("QoS 2: Received PUBREC for packet {}, sending PUBREL", packet_id);
139                    // Send PUBREL to client
140                    self.send_pubrel(client_id, packet_id).await?;
141                }
142                _ => {
143                    warn!("QoS 2: Unexpected PUBREC for packet {} in state {:?}", packet_id, state);
144                }
145            }
146        } else {
147            warn!("QoS 2: Received PUBREC for unknown packet {}", packet_id);
148        }
149        Ok(())
150    }
151
152    /// Handle PUBREL (QoS 2 release)
153    pub async fn handle_pubrel(&self, packet_id: u16, client_id: &str) -> Result<()> {
154        let mut states = self.qos2_states.write().await;
155        if let Some(state) = states.get_mut(&packet_id) {
156            match state {
157                QoS2State::Released => {
158                    states.remove(&packet_id);
159                    info!("QoS 2: Received PUBREL for packet {}, sending PUBCOMP", packet_id);
160                    // Send PUBCOMP to client
161                    self.send_pubcomp(client_id, packet_id).await?;
162                }
163                _ => {
164                    warn!("QoS 2: Unexpected PUBREL for packet {} in state {:?}", packet_id, state);
165                }
166            }
167        } else {
168            warn!("QoS 2: Received PUBREL for unknown packet {}", packet_id);
169        }
170        Ok(())
171    }
172
173    /// Handle PUBCOMP (QoS 2 completion)
174    pub async fn handle_pubcomp(&self, packet_id: u16) -> Result<()> {
175        if self.qos2_states.write().await.remove(&packet_id).is_some() {
176            info!("QoS 2: Received PUBCOMP for packet {}, delivery completed", packet_id);
177        } else {
178            warn!("QoS 2: Received PUBCOMP for unknown packet {}", packet_id);
179        }
180        Ok(())
181    }
182
183    /// Send PUBACK packet to client (QoS 1 acknowledgment)
184    async fn send_puback(&self, client_id: &str, packet_id: u16) -> Result<()> {
185        info!("QoS 1: Sending PUBACK for packet {} to client {}", packet_id, client_id);
186        // In a real implementation, this would send the actual MQTT PUBACK packet
187        // For the management layer, we simulate the send
188        Ok(())
189    }
190
191    /// Send PUBREC packet to client (QoS 2 first acknowledgment)
192    async fn send_pubrec(&self, client_id: &str, packet_id: u16) -> Result<()> {
193        info!("QoS 2: Sending PUBREC for packet {} to client {}", packet_id, client_id);
194        // In a real implementation, this would send the actual MQTT PUBREC packet
195        // For the management layer, we simulate the send
196        Ok(())
197    }
198
199    /// Send PUBREL packet to client (QoS 2 release)
200    async fn send_pubrel(&self, client_id: &str, packet_id: u16) -> Result<()> {
201        info!("QoS 2: Sending PUBREL for packet {} to client {}", packet_id, client_id);
202        // In a real implementation, this would send the actual MQTT PUBREL packet
203        // For the management layer, we simulate the send
204        Ok(())
205    }
206
207    /// Send PUBCOMP packet to client (QoS 2 completion)
208    async fn send_pubcomp(&self, client_id: &str, packet_id: u16) -> Result<()> {
209        info!("QoS 2: Sending PUBCOMP for packet {} to client {}", packet_id, client_id);
210        // In a real implementation, this would send the actual MQTT PUBCOMP packet
211        // For the management layer, we simulate the send
212        Ok(())
213    }
214
215    /// Retry pending QoS 1 messages
216    pub async fn retry_pending_messages(&self) -> Result<()> {
217        let mut pending = self.qos1_pending.write().await;
218        let mut to_retry = Vec::new();
219
220        for (packet_id, message) in pending.iter_mut() {
221            if message.retry_count < self.max_retries {
222                message.retry_count += 1;
223                to_retry.push((*packet_id, message.client_id.clone()));
224                info!(
225                    "Retrying QoS 1 message for packet {} (attempt {})",
226                    packet_id, message.retry_count
227                );
228            } else {
229                warn!("QoS 1 message for packet {} exceeded max retries", packet_id);
230            }
231        }
232
233        // Resend the messages
234        for (packet_id, client_id) in to_retry {
235            info!("Resending QoS 1 message for packet {} to client {}", packet_id, client_id);
236            // In a real implementation, this would resend the PUBLISH packet
237        }
238
239        Ok(())
240    }
241}