Skip to main content

rumqttc/mqttbytes/v5/
auth.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2
3use super::{
4    Error, FixedHeader, PropertyType, len_len, length, property, read_mqtt_bytes, read_mqtt_string,
5    read_u8, write_mqtt_bytes, write_mqtt_string, write_remaining_length,
6};
7
8/// Auth packet reason code
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum AuthReasonCode {
11    Success,
12    Continue,
13    ReAuthenticate,
14}
15
16impl AuthReasonCode {
17    fn read(bytes: &mut Bytes) -> Result<Self, Error> {
18        let reason_code = read_u8(bytes)?;
19        let code = match reason_code {
20            0x00 => Self::Success,
21            0x18 => Self::Continue,
22            0x19 => Self::ReAuthenticate,
23            _ => return Err(Error::MalformedPacket),
24        };
25
26        Ok(code)
27    }
28
29    fn write(&self, buffer: &mut BytesMut) {
30        let reason_code = match self {
31            Self::Success => 0x00,
32            Self::Continue => 0x18,
33            Self::ReAuthenticate => 0x19,
34        };
35
36        buffer.put_u8(reason_code);
37    }
38}
39
40/// Used to perform extended authentication exchange
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub struct Auth {
43    pub code: AuthReasonCode,
44    pub properties: Option<AuthProperties>,
45}
46
47impl Auth {
48    pub const fn new(code: AuthReasonCode, properties: Option<AuthProperties>) -> Self {
49        Self { code, properties }
50    }
51
52    fn len(&self) -> usize {
53        let mut len = 1;
54
55        if let Some(p) = &self.properties {
56            let properties_len = p.len();
57            let properties_len_len = len_len(properties_len);
58            len += properties_len_len + properties_len;
59        } else {
60            // just 1 byte representing 0 len
61            len += 1;
62        }
63
64        len
65    }
66
67    pub fn size(&self) -> usize {
68        let len = self.len();
69        let remaining_len_size = len_len(len);
70
71        1 + remaining_len_size + len
72    }
73
74    pub fn read(fixed_header: FixedHeader, mut bytes: Bytes) -> Result<Self, Error> {
75        let variable_header_index = fixed_header.header_len;
76        bytes.advance(variable_header_index);
77
78        let code = AuthReasonCode::read(&mut bytes)?;
79        let properties = AuthProperties::read(&mut bytes)?;
80        let auth = Self { code, properties };
81
82        Ok(auth)
83    }
84
85    pub fn write(&self, buffer: &mut BytesMut) -> Result<usize, Error> {
86        buffer.put_u8(0xF0);
87
88        let len = self.len();
89        let count = write_remaining_length(buffer, len)?;
90
91        self.code.write(buffer);
92        if let Some(p) = &self.properties {
93            p.write(buffer)?;
94        } else {
95            write_remaining_length(buffer, 0)?;
96        }
97
98        Ok(1 + count + len)
99    }
100}
101
102#[derive(Debug, Default, Clone, PartialEq, Eq)]
103pub struct AuthProperties {
104    pub method: Option<String>,
105    pub data: Option<Bytes>,
106    pub reason: Option<String>,
107    pub user_properties: Vec<(String, String)>,
108}
109
110impl AuthProperties {
111    fn len(&self) -> usize {
112        let mut len = 0;
113
114        if let Some(method) = &self.method {
115            let m_len = method.len();
116            len += 1 + 2 + m_len;
117        }
118
119        if let Some(data) = &self.data {
120            let d_len = data.len();
121            len += 1 + 2 + d_len;
122        }
123
124        if let Some(reason) = &self.reason {
125            let r_len = reason.len();
126            len += 1 + 2 + r_len;
127        }
128
129        for (key, value) in &self.user_properties {
130            let p_len = key.len() + value.len();
131            len += 1 + 4 + p_len;
132        }
133
134        len
135    }
136
137    pub fn read(bytes: &mut Bytes) -> Result<Option<Self>, Error> {
138        let (properties_len_len, properties_len) = length(bytes.iter())?;
139        bytes.advance(properties_len_len);
140        if properties_len == 0 {
141            return Ok(None);
142        }
143
144        let mut props = Self::default();
145
146        let mut cursor = 0;
147        // read until cursor reaches property length. properties_len = 0 will skip this loop
148        while cursor < properties_len {
149            let prop = read_u8(bytes)?;
150            cursor += 1;
151
152            match property(prop)? {
153                PropertyType::AuthenticationMethod => {
154                    let method = read_mqtt_string(bytes)?;
155                    cursor += 2 + method.len();
156                    props.method = Some(method);
157                }
158                PropertyType::AuthenticationData => {
159                    let data = read_mqtt_bytes(bytes)?;
160                    cursor += 2 + data.len();
161                    props.data = Some(data);
162                }
163                PropertyType::ReasonString => {
164                    let reason = read_mqtt_string(bytes)?;
165                    cursor += 2 + reason.len();
166                    props.reason = Some(reason);
167                }
168                PropertyType::UserProperty => {
169                    let key = read_mqtt_string(bytes)?;
170                    let value = read_mqtt_string(bytes)?;
171                    cursor += 2 + key.len() + 2 + value.len();
172                    props.user_properties.push((key, value));
173                }
174                _ => return Err(Error::InvalidPropertyType(prop)),
175            }
176        }
177
178        Ok(Some(props))
179    }
180
181    pub fn write(&self, buffer: &mut BytesMut) -> Result<(), Error> {
182        let len = self.len();
183        write_remaining_length(buffer, len)?;
184
185        if let Some(authentication_method) = &self.method {
186            buffer.put_u8(PropertyType::AuthenticationMethod as u8);
187            write_mqtt_string(buffer, authentication_method);
188        }
189
190        if let Some(authentication_data) = &self.data {
191            buffer.put_u8(PropertyType::AuthenticationData as u8);
192            write_mqtt_bytes(buffer, authentication_data);
193        }
194
195        if let Some(reason) = &self.reason {
196            buffer.put_u8(PropertyType::ReasonString as u8);
197            write_mqtt_string(buffer, reason);
198        }
199
200        for (key, value) in &self.user_properties {
201            buffer.put_u8(PropertyType::UserProperty as u8);
202            write_mqtt_string(buffer, key);
203            write_mqtt_string(buffer, value);
204        }
205
206        Ok(())
207    }
208}
209
210#[cfg(test)]
211mod test {
212    use super::super::test::{USER_PROP_KEY, USER_PROP_VAL};
213    use super::*;
214    use bytes::BytesMut;
215    use pretty_assertions::assert_eq;
216
217    #[test]
218    fn length_calculation() {
219        let mut dummy_bytes = BytesMut::new();
220        // Use user_properties to pad the size to exceed ~128 bytes to make the
221        // remaining_length field in the packet be 2 bytes long.
222        let auth_props = AuthProperties {
223            method: Some("Authentication Method".into()),
224            data: Some("Authentication Data".into()),
225            reason: None,
226            user_properties: vec![(USER_PROP_KEY.into(), USER_PROP_VAL.into())],
227        };
228
229        let auth_pkt = Auth::new(AuthReasonCode::Continue, Some(auth_props));
230
231        let size_from_size = auth_pkt.size();
232        let size_from_write = auth_pkt.write(&mut dummy_bytes).unwrap();
233        let size_from_bytes = dummy_bytes.len();
234
235        assert_eq!(size_from_write, size_from_bytes);
236        assert_eq!(size_from_size, size_from_bytes);
237    }
238}