mqtt5_protocol/encoding/
mqtt_string.rs1use crate::error::{MqttError, Result};
6use bebytes::BeBytes;
7
8#[derive(Debug, Clone, PartialEq, Eq, BeBytes)]
10pub struct MqttString {
11 #[bebytes(big_endian)]
13 length: u16,
14
15 #[bebytes(size = "length")]
17 data: String,
18}
19
20impl MqttString {
21 pub fn create(s: &str) -> Result<Self> {
26 let len = s.len();
27 if len > u16::MAX as usize {
28 return Err(MqttError::StringTooLong(len));
29 }
30
31 Ok(Self {
32 #[allow(clippy::cast_possible_truncation)]
33 length: len as u16, data: s.to_string(),
35 })
36 }
37
38 #[must_use]
40 pub fn as_str(&self) -> &str {
41 &self.data
42 }
43
44 #[must_use]
46 pub fn encoded_size(&self) -> usize {
47 2 + self.data.len()
48 }
49}
50
51impl TryFrom<&str> for MqttString {
52 type Error = MqttError;
53
54 fn try_from(s: &str) -> Result<Self> {
55 Self::create(s)
56 }
57}
58
59impl TryFrom<String> for MqttString {
60 type Error = MqttError;
61
62 fn try_from(s: String) -> Result<Self> {
63 Self::create(&s)
64 }
65}
66
67pub fn encode_string<B: bytes::BufMut>(buf: &mut B, string: &str) -> Result<()> {
78 if string.contains('\0') {
80 return Err(MqttError::MalformedPacket(
81 "String contains null character".to_string(),
82 ));
83 }
84
85 let mqtt_string = MqttString::create(string)?;
86 let encoded = mqtt_string.to_be_bytes();
87 buf.put_slice(&encoded);
88 Ok(())
89}
90
91pub fn decode_string<B: bytes::Buf>(buf: &mut B) -> Result<String> {
103 if buf.remaining() < 2 {
104 return Err(MqttError::MalformedPacket(
105 "Insufficient bytes for string length".to_string(),
106 ));
107 }
108
109 let len = buf.get_u16() as usize;
110
111 if buf.remaining() < len {
112 return Err(MqttError::MalformedPacket(format!(
113 "Insufficient bytes for string data: expected {}, got {}",
114 len,
115 buf.remaining()
116 )));
117 }
118
119 let mut bytes = vec![0u8; len];
120 buf.copy_to_slice(&mut bytes);
121
122 let string = String::from_utf8(bytes)
123 .map_err(|e| MqttError::MalformedPacket(format!("Invalid UTF-8: {e}")))?;
124
125 if string.contains('\0') {
127 return Err(MqttError::MalformedPacket(
128 "String contains null character".to_string(),
129 ));
130 }
131
132 Ok(string)
133}
134
135#[must_use]
137pub fn string_len(string: &str) -> usize {
138 2 + string.len()
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 #[test]
146 fn test_mqtt_string_encoding() {
147 let mqtt_str = MqttString::create("hello").unwrap();
148 let bytes = mqtt_str.to_be_bytes();
149
150 assert_eq!(bytes, vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
152 }
153
154 #[test]
155 fn test_mqtt_string_decoding() {
156 let data = vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o'];
157 let (mqtt_str, consumed) = MqttString::try_from_be_bytes(&data).unwrap();
158
159 assert_eq!(mqtt_str.as_str(), "hello");
160 assert_eq!(consumed, 7);
161 }
162
163 #[test]
164 fn test_mqtt_string_round_trip() {
165 let original = MqttString::create("test/topic").unwrap();
166 let bytes = original.to_be_bytes();
167 let (decoded, _) = MqttString::try_from_be_bytes(&bytes).unwrap();
168
169 assert_eq!(original, decoded);
170 }
171
172 #[test]
173 fn test_empty_string() {
174 let mqtt_str = MqttString::create("").unwrap();
175 let bytes = mqtt_str.to_be_bytes();
176
177 assert_eq!(bytes, vec![0x00, 0x00]);
178 }
179
180 #[test]
181 fn test_string_too_long() {
182 let long_string = "x".repeat(65536);
183 let result = MqttString::create(&long_string);
184
185 assert!(result.is_err());
186 }
187}