1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use tracing::{info, warn};
5
6type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum QoS {
11 AtMostOnce = 0, AtLeastOnce = 1, ExactlyOnce = 2, }
15
16impl QoS {
17 pub fn from_u8(value: u8) -> Option<Self> {
18 match value {
19 0 => Some(QoS::AtMostOnce),
20 1 => Some(QoS::AtLeastOnce),
21 2 => Some(QoS::ExactlyOnce),
22 _ => None,
23 }
24 }
25
26 pub fn as_u8(&self) -> u8 {
27 *self as u8
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct MessageState {
34 pub packet_id: u16,
35 pub topic: String,
36 pub payload: Vec<u8>,
37 pub qos: QoS,
38 pub retained: bool,
39 pub timestamp: u64,
40}
41
42#[derive(Debug, Clone)]
44struct PendingQoS1Message {
45 #[allow(dead_code)]
46 message: MessageState,
47 client_id: String,
48 retry_count: u8,
49}
50
51#[derive(Debug, Clone)]
53enum QoS2State {
54 Received, Released, }
57
58pub struct QoSHandler {
60 qos1_pending: Arc<RwLock<HashMap<u16, PendingQoS1Message>>>,
61 qos2_states: Arc<RwLock<HashMap<u16, QoS2State>>>,
62 max_retries: u8,
63}
64
65impl Default for QoSHandler {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl QoSHandler {
72 pub fn new() -> Self {
73 Self {
74 qos1_pending: Arc::new(RwLock::new(HashMap::new())),
75 qos2_states: Arc::new(RwLock::new(HashMap::new())),
76 max_retries: 3,
77 }
78 }
79
80 pub async fn handle_qo_s0(&self, _message: MessageState) -> Result<()> {
82 info!("QoS 0: Fire and forget delivery");
83 Ok(())
85 }
86
87 pub async fn handle_qo_s1(&self, message: MessageState, client_id: &str) -> Result<()> {
89 info!(
90 "QoS 1: Storing message for at-least-once delivery, packet {}",
91 message.packet_id
92 );
93
94 let pending = PendingQoS1Message {
95 message: message.clone(),
96 client_id: client_id.to_string(),
97 retry_count: 0,
98 };
99
100 self.qos1_pending.write().await.insert(message.packet_id, pending);
101
102 self.send_puback(client_id, message.packet_id).await?;
104
105 Ok(())
106 }
107
108 pub async fn handle_qo_s2(&self, message: MessageState, client_id: &str) -> Result<()> {
110 info!("QoS 2: Starting exactly-once delivery handshake, packet {}", message.packet_id);
111
112 self.qos2_states.write().await.insert(message.packet_id, QoS2State::Received);
114
115 self.send_pubrec(client_id, message.packet_id).await?;
117
118 Ok(())
119 }
120
121 pub async fn handle_puback(&self, packet_id: u16) -> Result<()> {
123 if let Some(_pending) = self.qos1_pending.write().await.remove(&packet_id) {
124 info!("QoS 1: Received PUBACK for packet {}, delivery confirmed", packet_id);
125 } else {
126 warn!("QoS 1: Received PUBACK for unknown packet {}", packet_id);
127 }
128 Ok(())
129 }
130
131 pub async fn handle_pubrec(&self, packet_id: u16, client_id: &str) -> Result<()> {
133 let mut states = self.qos2_states.write().await;
134 if let Some(state) = states.get_mut(&packet_id) {
135 match state {
136 QoS2State::Received => {
137 *state = QoS2State::Released;
138 info!("QoS 2: Received PUBREC for packet {}, sending PUBREL", packet_id);
139 self.send_pubrel(client_id, packet_id).await?;
141 }
142 _ => {
143 warn!("QoS 2: Unexpected PUBREC for packet {} in state {:?}", packet_id, state);
144 }
145 }
146 } else {
147 warn!("QoS 2: Received PUBREC for unknown packet {}", packet_id);
148 }
149 Ok(())
150 }
151
152 pub async fn handle_pubrel(&self, packet_id: u16, client_id: &str) -> Result<()> {
154 let mut states = self.qos2_states.write().await;
155 if let Some(state) = states.get_mut(&packet_id) {
156 match state {
157 QoS2State::Released => {
158 states.remove(&packet_id);
159 info!("QoS 2: Received PUBREL for packet {}, sending PUBCOMP", packet_id);
160 self.send_pubcomp(client_id, packet_id).await?;
162 }
163 _ => {
164 warn!("QoS 2: Unexpected PUBREL for packet {} in state {:?}", packet_id, state);
165 }
166 }
167 } else {
168 warn!("QoS 2: Received PUBREL for unknown packet {}", packet_id);
169 }
170 Ok(())
171 }
172
173 pub async fn handle_pubcomp(&self, packet_id: u16) -> Result<()> {
175 if self.qos2_states.write().await.remove(&packet_id).is_some() {
176 info!("QoS 2: Received PUBCOMP for packet {}, delivery completed", packet_id);
177 } else {
178 warn!("QoS 2: Received PUBCOMP for unknown packet {}", packet_id);
179 }
180 Ok(())
181 }
182
183 async fn send_puback(&self, client_id: &str, packet_id: u16) -> Result<()> {
185 info!("QoS 1: Sending PUBACK for packet {} to client {}", packet_id, client_id);
186 Ok(())
189 }
190
191 async fn send_pubrec(&self, client_id: &str, packet_id: u16) -> Result<()> {
193 info!("QoS 2: Sending PUBREC for packet {} to client {}", packet_id, client_id);
194 Ok(())
197 }
198
199 async fn send_pubrel(&self, client_id: &str, packet_id: u16) -> Result<()> {
201 info!("QoS 2: Sending PUBREL for packet {} to client {}", packet_id, client_id);
202 Ok(())
205 }
206
207 async fn send_pubcomp(&self, client_id: &str, packet_id: u16) -> Result<()> {
209 info!("QoS 2: Sending PUBCOMP for packet {} to client {}", packet_id, client_id);
210 Ok(())
213 }
214
215 pub async fn retry_pending_messages(&self) -> Result<()> {
217 let mut pending = self.qos1_pending.write().await;
218 let mut to_retry = Vec::new();
219
220 for (packet_id, message) in pending.iter_mut() {
221 if message.retry_count < self.max_retries {
222 message.retry_count += 1;
223 to_retry.push((*packet_id, message.client_id.clone()));
224 info!(
225 "Retrying QoS 1 message for packet {} (attempt {})",
226 packet_id, message.retry_count
227 );
228 } else {
229 warn!("QoS 1 message for packet {} exceeded max retries", packet_id);
230 }
231 }
232
233 for (packet_id, client_id) in to_retry {
235 info!("Resending QoS 1 message for packet {} to client {}", packet_id, client_id);
236 }
238
239 Ok(())
240 }
241}