mqtt5_protocol/encoding/
mqtt_string.rs

1//! MQTT string implementation using `BeBytes` 2.3.0 size expressions
2//!
3//! MQTT strings are prefixed with a 2-byte length field in big-endian format.
4
5use crate::error::{MqttError, Result};
6use crate::prelude::{format, vec, String, ToString, Vec};
7use bebytes::BeBytes;
8
9/// MQTT string with automatic size handling via `BeBytes` size expressions
10#[derive(Debug, Clone, PartialEq, Eq, BeBytes)]
11pub struct MqttString {
12    /// Length of the string in bytes (big-endian)
13    #[bebytes(big_endian)]
14    length: u16,
15
16    /// UTF-8 string data with size determined by length field
17    #[bebytes(size = "length")]
18    data: String,
19}
20
21impl MqttString {
22    /// Create a new MQTT string
23    ///
24    /// # Errors
25    /// Returns an error if:
26    /// - The string is longer than 65535 bytes
27    /// - The string contains null characters (forbidden by MQTT spec)
28    pub fn create(s: &str) -> Result<Self> {
29        if s.contains('\0') {
30            return Err(MqttError::MalformedPacket(
31                "String contains null character".to_string(),
32            ));
33        }
34
35        let len = s.len();
36        if len > u16::MAX as usize {
37            return Err(MqttError::StringTooLong(len));
38        }
39
40        Ok(Self {
41            #[allow(clippy::cast_possible_truncation)]
42            length: len as u16,
43            data: s.to_string(),
44        })
45    }
46
47    /// Get the string value
48    #[must_use]
49    pub fn as_str(&self) -> &str {
50        &self.data
51    }
52
53    /// Get the total encoded size (length field + data)
54    #[must_use]
55    pub fn encoded_size(&self) -> usize {
56        2 + self.data.len()
57    }
58}
59
60impl TryFrom<&str> for MqttString {
61    type Error = MqttError;
62
63    fn try_from(s: &str) -> Result<Self> {
64        Self::create(s)
65    }
66}
67
68impl TryFrom<String> for MqttString {
69    type Error = MqttError;
70
71    fn try_from(s: String) -> Result<Self> {
72        Self::create(&s)
73    }
74}
75
76/// Encodes a UTF-8 string with a 2-byte length prefix (compatibility function)
77///
78/// This function provides compatibility with the old string module API.
79/// Prefer using `MqttString::create(string)?.to_be_bytes()` for new code.
80///
81/// # Errors
82///
83/// Returns an error if:
84/// - The string contains null characters
85/// - The string length exceeds maximum string length
86pub fn encode_string<B: bytes::BufMut>(buf: &mut B, string: &str) -> Result<()> {
87    // Check for null characters
88    if string.contains('\0') {
89        return Err(MqttError::MalformedPacket(
90            "String contains null character".to_string(),
91        ));
92    }
93
94    let mqtt_string = MqttString::create(string)?;
95    let encoded = mqtt_string.to_be_bytes();
96    buf.put_slice(&encoded);
97    Ok(())
98}
99
100/// Decodes a UTF-8 string with a 2-byte length prefix (compatibility function)
101///
102/// This function provides compatibility with the old string module API.
103/// Prefer using `MqttString::try_from_be_bytes()` for new code.
104///
105/// # Errors
106///
107/// Returns an error if:
108/// - Insufficient bytes in buffer
109/// - String is not valid UTF-8
110/// - String contains null characters
111pub fn decode_string<B: bytes::Buf>(buf: &mut B) -> Result<String> {
112    if buf.remaining() < 2 {
113        return Err(MqttError::MalformedPacket(
114            "Insufficient bytes for string length".to_string(),
115        ));
116    }
117
118    let len = buf.get_u16() as usize;
119
120    if buf.remaining() < len {
121        return Err(MqttError::MalformedPacket(format!(
122            "Insufficient bytes for string data: expected {}, got {}",
123            len,
124            buf.remaining()
125        )));
126    }
127
128    let mut bytes = vec![0u8; len];
129    buf.copy_to_slice(&mut bytes);
130
131    let string = String::from_utf8(bytes)
132        .map_err(|e| MqttError::MalformedPacket(format!("Invalid UTF-8: {e}")))?;
133
134    // Check for null characters
135    if string.contains('\0') {
136        return Err(MqttError::MalformedPacket(
137            "String contains null character".to_string(),
138        ));
139    }
140
141    Ok(string)
142}
143
144/// Calculates the encoded length of a string (compatibility function)
145#[must_use]
146pub fn string_len(string: &str) -> usize {
147    2 + string.len()
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_mqtt_string_encoding() {
156        let mqtt_str = MqttString::create("hello").unwrap();
157        let bytes = mqtt_str.to_be_bytes();
158
159        // Check encoding: 2-byte length (0x00, 0x05) + "hello"
160        assert_eq!(bytes, vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
161    }
162
163    #[test]
164    fn test_mqtt_string_decoding() {
165        let data = vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o'];
166        let (mqtt_str, consumed) = MqttString::try_from_be_bytes(&data).unwrap();
167
168        assert_eq!(mqtt_str.as_str(), "hello");
169        assert_eq!(consumed, 7);
170    }
171
172    #[test]
173    fn test_mqtt_string_round_trip() {
174        let original = MqttString::create("test/topic").unwrap();
175        let bytes = original.to_be_bytes();
176        let (decoded, _) = MqttString::try_from_be_bytes(&bytes).unwrap();
177
178        assert_eq!(original, decoded);
179    }
180
181    #[test]
182    fn test_empty_string() {
183        let mqtt_str = MqttString::create("").unwrap();
184        let bytes = mqtt_str.to_be_bytes();
185
186        assert_eq!(bytes, vec![0x00, 0x00]);
187    }
188
189    #[test]
190    fn test_string_too_long() {
191        let long_string = "x".repeat(65536);
192        let result = MqttString::create(&long_string);
193
194        assert!(result.is_err());
195    }
196}