mqtt_proto/v3/
subscribe.rs

1use std::io;
2
3use tokio::io::AsyncRead;
4
5use crate::{
6    read_string, read_u16, read_u8, write_bytes, write_u16, write_u8, Encodable, Error, Pid, QoS,
7    TopicFilter,
8};
9
10/// Subscribe packet body type.
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
12#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
13pub struct Subscribe {
14    pub pid: Pid,
15    pub topics: Vec<(TopicFilter, QoS)>,
16}
17
18/// Suback packet body type.
19#[derive(Debug, Clone, PartialEq, Eq)]
20#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
21pub struct Suback {
22    pub pid: Pid,
23    pub topics: Vec<SubscribeReturnCode>,
24}
25
26/// Unsubscribe packet body type.
27#[derive(Debug, Clone, PartialEq, Eq, Hash)]
28#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
29pub struct Unsubscribe {
30    pub pid: Pid,
31    pub topics: Vec<TopicFilter>,
32}
33
34impl Subscribe {
35    pub fn new(pid: Pid, topics: Vec<(TopicFilter, QoS)>) -> Self {
36        Self { pid, topics }
37    }
38
39    pub async fn decode_async<T: AsyncRead + Unpin>(
40        reader: &mut T,
41        mut remaining_len: usize,
42    ) -> Result<Self, Error> {
43        let pid = Pid::try_from(read_u16(reader).await?)?;
44        remaining_len = remaining_len
45            .checked_sub(2)
46            .ok_or(Error::InvalidRemainingLength)?;
47        if remaining_len == 0 {
48            return Err(Error::EmptySubscription);
49        }
50        let mut topics = Vec::new();
51        while remaining_len > 0 {
52            let topic_filter = TopicFilter::try_from(read_string(reader).await?)?;
53            let max_qos = QoS::from_u8(read_u8(reader).await?)?;
54            remaining_len = remaining_len
55                .checked_sub(3 + topic_filter.len())
56                .ok_or(Error::InvalidRemainingLength)?;
57            topics.push((topic_filter, max_qos));
58        }
59        Ok(Subscribe { pid, topics })
60    }
61}
62
63impl Encodable for Subscribe {
64    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
65        write_u16(writer, self.pid.value())?;
66        for (topic_filter, max_qos) in &self.topics {
67            write_bytes(writer, topic_filter.as_bytes())?;
68            write_u8(writer, *max_qos as u8)?;
69        }
70        Ok(())
71    }
72
73    fn encode_len(&self) -> usize {
74        2 + self
75            .topics
76            .iter()
77            .map(|(filter, _)| 3 + filter.len())
78            .sum::<usize>()
79    }
80}
81
82impl Suback {
83    pub fn new(pid: Pid, topics: Vec<SubscribeReturnCode>) -> Self {
84        Self { pid, topics }
85    }
86
87    pub async fn decode_async<T: AsyncRead + Unpin>(
88        reader: &mut T,
89        mut remaining_len: usize,
90    ) -> Result<Self, Error> {
91        let pid = Pid::try_from(read_u16(reader).await?)?;
92        remaining_len = remaining_len
93            .checked_sub(2)
94            .ok_or(Error::InvalidRemainingLength)?;
95        let mut topics = Vec::new();
96        while remaining_len > 0 {
97            let value = read_u8(reader).await?;
98            let code = SubscribeReturnCode::from_u8(value)?;
99            topics.push(code);
100            remaining_len -= 1;
101        }
102        Ok(Suback { pid, topics })
103    }
104}
105
106impl Encodable for Suback {
107    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
108        write_u16(writer, self.pid.value())?;
109        for code in &self.topics {
110            write_u8(writer, *code as u8)?;
111        }
112        Ok(())
113    }
114    fn encode_len(&self) -> usize {
115        2 + self.topics.len()
116    }
117}
118
119impl Unsubscribe {
120    pub fn new(pid: Pid, topics: Vec<TopicFilter>) -> Self {
121        Self { pid, topics }
122    }
123
124    pub async fn decode_async<T: AsyncRead + Unpin>(
125        reader: &mut T,
126        mut remaining_len: usize,
127    ) -> Result<Self, Error> {
128        let pid = Pid::try_from(read_u16(reader).await?)?;
129        remaining_len = remaining_len
130            .checked_sub(2)
131            .ok_or(Error::InvalidRemainingLength)?;
132        if remaining_len == 0 {
133            return Err(Error::EmptySubscription);
134        }
135        let mut topics = Vec::new();
136        while remaining_len > 0 {
137            let topic_filter = TopicFilter::try_from(read_string(reader).await?)?;
138            remaining_len = remaining_len
139                .checked_sub(2 + topic_filter.len())
140                .ok_or(Error::InvalidRemainingLength)?;
141            topics.push(topic_filter);
142        }
143        Ok(Unsubscribe { pid, topics })
144    }
145}
146
147impl Encodable for Unsubscribe {
148    fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
149        write_u16(writer, self.pid.value())?;
150        for topic_filter in &self.topics {
151            write_bytes(writer, topic_filter.as_bytes())?;
152        }
153        Ok(())
154    }
155
156    fn encode_len(&self) -> usize {
157        2 + self
158            .topics
159            .iter()
160            .map(|filter| 2 + filter.len())
161            .sum::<usize>()
162    }
163}
164
165/// Subscribe return code type.
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
167#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
168pub enum SubscribeReturnCode {
169    MaxLevel0,
170    MaxLevel1,
171    MaxLevel2,
172    Failure,
173}
174
175impl SubscribeReturnCode {
176    pub fn from_u8(value: u8) -> Result<SubscribeReturnCode, Error> {
177        match value {
178            0x80 => Ok(SubscribeReturnCode::Failure),
179            0 => Ok(SubscribeReturnCode::MaxLevel0),
180            1 => Ok(SubscribeReturnCode::MaxLevel1),
181            2 => Ok(SubscribeReturnCode::MaxLevel2),
182            _ => Err(Error::InvalidQos(value)),
183        }
184    }
185}
186
187impl From<QoS> for SubscribeReturnCode {
188    fn from(qos: QoS) -> SubscribeReturnCode {
189        match qos {
190            QoS::Level0 => SubscribeReturnCode::MaxLevel0,
191            QoS::Level1 => SubscribeReturnCode::MaxLevel1,
192            QoS::Level2 => SubscribeReturnCode::MaxLevel2,
193        }
194    }
195}