mqtt_proto/v3/
subscribe.rs

1use alloc::vec::Vec;
2
3use crate::{
4    read_string, read_u16, read_u8, write_string, write_u16, write_u8, AsyncRead, Encodable, Error,
5    Pid, QoS, SyncWrite, TopicFilter,
6};
7
8/// Subscribe packet body type.
9#[derive(Debug, Clone, PartialEq, Eq, Hash)]
10#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
11pub struct Subscribe {
12    pub pid: Pid,
13    pub topics: Vec<(TopicFilter, QoS)>,
14}
15
16/// Suback packet body type.
17#[derive(Debug, Clone, PartialEq, Eq)]
18#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
19pub struct Suback {
20    pub pid: Pid,
21    pub topics: Vec<SubscribeReturnCode>,
22}
23
24/// Unsubscribe packet body type.
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27pub struct Unsubscribe {
28    pub pid: Pid,
29    pub topics: Vec<TopicFilter>,
30}
31
32impl Subscribe {
33    pub fn new(pid: Pid, topics: Vec<(TopicFilter, QoS)>) -> Self {
34        Self { pid, topics }
35    }
36
37    pub async fn decode_async<T: AsyncRead + Unpin>(
38        reader: &mut T,
39        mut remaining_len: usize,
40    ) -> Result<Self, Error> {
41        let pid = Pid::try_from(read_u16(reader).await?)?;
42        remaining_len = remaining_len
43            .checked_sub(2)
44            .ok_or(Error::InvalidRemainingLength)?;
45        if remaining_len == 0 {
46            return Err(Error::EmptySubscription);
47        }
48        let mut topics = Vec::new();
49        while remaining_len > 0 {
50            let topic_filter = TopicFilter::try_from(read_string(reader).await?)?;
51            let max_qos = QoS::from_u8(read_u8(reader).await?)?;
52            remaining_len = remaining_len
53                .checked_sub(3 + topic_filter.len())
54                .ok_or(Error::InvalidRemainingLength)?;
55            topics.push((topic_filter, max_qos));
56        }
57        Ok(Subscribe { pid, topics })
58    }
59}
60
61impl Encodable for Subscribe {
62    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
63        write_u16(writer, self.pid.value())?;
64        for (topic_filter, max_qos) in &self.topics {
65            write_string(writer, topic_filter)?;
66            write_u8(writer, *max_qos as u8)?;
67        }
68        Ok(())
69    }
70
71    fn encode_len(&self) -> usize {
72        2 + self
73            .topics
74            .iter()
75            .map(|(filter, _)| 3 + filter.len())
76            .sum::<usize>()
77    }
78}
79
80impl Suback {
81    pub fn new(pid: Pid, topics: Vec<SubscribeReturnCode>) -> Self {
82        Self { pid, topics }
83    }
84
85    pub async fn decode_async<T: AsyncRead + Unpin>(
86        reader: &mut T,
87        mut remaining_len: usize,
88    ) -> Result<Self, Error> {
89        let pid = Pid::try_from(read_u16(reader).await?)?;
90        remaining_len = remaining_len
91            .checked_sub(2)
92            .ok_or(Error::InvalidRemainingLength)?;
93        let mut topics = Vec::new();
94        while remaining_len > 0 {
95            let value = read_u8(reader).await?;
96            let code = SubscribeReturnCode::from_u8(value)?;
97            topics.push(code);
98            remaining_len -= 1;
99        }
100        Ok(Suback { pid, topics })
101    }
102}
103
104impl Encodable for Suback {
105    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
106        write_u16(writer, self.pid.value())?;
107        for code in &self.topics {
108            write_u8(writer, *code as u8)?;
109        }
110        Ok(())
111    }
112    fn encode_len(&self) -> usize {
113        2 + self.topics.len()
114    }
115}
116
117impl Unsubscribe {
118    pub fn new(pid: Pid, topics: Vec<TopicFilter>) -> Self {
119        Self { pid, topics }
120    }
121
122    pub async fn decode_async<T: AsyncRead + Unpin>(
123        reader: &mut T,
124        mut remaining_len: usize,
125    ) -> Result<Self, Error> {
126        let pid = Pid::try_from(read_u16(reader).await?)?;
127        remaining_len = remaining_len
128            .checked_sub(2)
129            .ok_or(Error::InvalidRemainingLength)?;
130        if remaining_len == 0 {
131            return Err(Error::EmptySubscription);
132        }
133        let mut topics = Vec::new();
134        while remaining_len > 0 {
135            let topic_filter = TopicFilter::try_from(read_string(reader).await?)?;
136            remaining_len = remaining_len
137                .checked_sub(2 + topic_filter.len())
138                .ok_or(Error::InvalidRemainingLength)?;
139            topics.push(topic_filter);
140        }
141        Ok(Unsubscribe { pid, topics })
142    }
143}
144
145impl Encodable for Unsubscribe {
146    fn encode<W: SyncWrite>(&self, writer: &mut W) -> Result<(), Error> {
147        write_u16(writer, self.pid.value())?;
148        for topic_filter in &self.topics {
149            write_string(writer, topic_filter)?;
150        }
151        Ok(())
152    }
153
154    fn encode_len(&self) -> usize {
155        2 + self
156            .topics
157            .iter()
158            .map(|filter| 2 + filter.len())
159            .sum::<usize>()
160    }
161}
162
163/// Subscribe return code type.
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
165#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
166pub enum SubscribeReturnCode {
167    MaxLevel0,
168    MaxLevel1,
169    MaxLevel2,
170    Failure,
171}
172
173impl SubscribeReturnCode {
174    pub fn from_u8(value: u8) -> Result<SubscribeReturnCode, Error> {
175        match value {
176            0x80 => Ok(SubscribeReturnCode::Failure),
177            0 => Ok(SubscribeReturnCode::MaxLevel0),
178            1 => Ok(SubscribeReturnCode::MaxLevel1),
179            2 => Ok(SubscribeReturnCode::MaxLevel2),
180            _ => Err(Error::InvalidQos(value)),
181        }
182    }
183}
184
185impl From<QoS> for SubscribeReturnCode {
186    fn from(qos: QoS) -> SubscribeReturnCode {
187        match qos {
188            QoS::Level0 => SubscribeReturnCode::MaxLevel0,
189            QoS::Level1 => SubscribeReturnCode::MaxLevel1,
190            QoS::Level2 => SubscribeReturnCode::MaxLevel2,
191        }
192    }
193}