mqtt5_protocol/encoding/
mqtt_string.rs

1//! MQTT string implementation using `BeBytes` 2.3.0 size expressions
2//!
3//! MQTT strings are prefixed with a 2-byte length field in big-endian format.
4
5use crate::error::{MqttError, Result};
6use bebytes::BeBytes;
7
8/// MQTT string with automatic size handling via `BeBytes` size expressions
9#[derive(Debug, Clone, PartialEq, Eq, BeBytes)]
10pub struct MqttString {
11    /// Length of the string in bytes (big-endian)
12    #[bebytes(big_endian)]
13    length: u16,
14
15    /// UTF-8 string data with size determined by length field
16    #[bebytes(size = "length")]
17    data: String,
18}
19
20impl MqttString {
21    /// Create a new MQTT string
22    ///
23    /// # Errors
24    /// Returns an error if the string is longer than 65535 bytes
25    pub fn create(s: &str) -> Result<Self> {
26        let len = s.len();
27        if len > u16::MAX as usize {
28            return Err(MqttError::StringTooLong(len));
29        }
30
31        Ok(Self {
32            #[allow(clippy::cast_possible_truncation)]
33            length: len as u16, // Safe: we checked len <= u16::MAX above
34            data: s.to_string(),
35        })
36    }
37
38    /// Get the string value
39    #[must_use]
40    pub fn as_str(&self) -> &str {
41        &self.data
42    }
43
44    /// Get the total encoded size (length field + data)
45    #[must_use]
46    pub fn encoded_size(&self) -> usize {
47        2 + self.data.len()
48    }
49}
50
51impl TryFrom<&str> for MqttString {
52    type Error = MqttError;
53
54    fn try_from(s: &str) -> Result<Self> {
55        Self::create(s)
56    }
57}
58
59impl TryFrom<String> for MqttString {
60    type Error = MqttError;
61
62    fn try_from(s: String) -> Result<Self> {
63        Self::create(&s)
64    }
65}
66
67/// Encodes a UTF-8 string with a 2-byte length prefix (compatibility function)
68///
69/// This function provides compatibility with the old string module API.
70/// Prefer using `MqttString::create(string)?.to_be_bytes()` for new code.
71///
72/// # Errors
73///
74/// Returns an error if:
75/// - The string contains null characters
76/// - The string length exceeds maximum string length
77pub fn encode_string<B: bytes::BufMut>(buf: &mut B, string: &str) -> Result<()> {
78    // Check for null characters
79    if string.contains('\0') {
80        return Err(MqttError::MalformedPacket(
81            "String contains null character".to_string(),
82        ));
83    }
84
85    let mqtt_string = MqttString::create(string)?;
86    let encoded = mqtt_string.to_be_bytes();
87    buf.put_slice(&encoded);
88    Ok(())
89}
90
91/// Decodes a UTF-8 string with a 2-byte length prefix (compatibility function)
92///
93/// This function provides compatibility with the old string module API.
94/// Prefer using `MqttString::try_from_be_bytes()` for new code.
95///
96/// # Errors
97///
98/// Returns an error if:
99/// - Insufficient bytes in buffer
100/// - String is not valid UTF-8
101/// - String contains null characters
102pub fn decode_string<B: bytes::Buf>(buf: &mut B) -> Result<String> {
103    if buf.remaining() < 2 {
104        return Err(MqttError::MalformedPacket(
105            "Insufficient bytes for string length".to_string(),
106        ));
107    }
108
109    let len = buf.get_u16() as usize;
110
111    if buf.remaining() < len {
112        return Err(MqttError::MalformedPacket(format!(
113            "Insufficient bytes for string data: expected {}, got {}",
114            len,
115            buf.remaining()
116        )));
117    }
118
119    let mut bytes = vec![0u8; len];
120    buf.copy_to_slice(&mut bytes);
121
122    let string = String::from_utf8(bytes)
123        .map_err(|e| MqttError::MalformedPacket(format!("Invalid UTF-8: {e}")))?;
124
125    // Check for null characters
126    if string.contains('\0') {
127        return Err(MqttError::MalformedPacket(
128            "String contains null character".to_string(),
129        ));
130    }
131
132    Ok(string)
133}
134
135/// Calculates the encoded length of a string (compatibility function)
136#[must_use]
137pub fn string_len(string: &str) -> usize {
138    2 + string.len()
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn test_mqtt_string_encoding() {
147        let mqtt_str = MqttString::create("hello").unwrap();
148        let bytes = mqtt_str.to_be_bytes();
149
150        // Check encoding: 2-byte length (0x00, 0x05) + "hello"
151        assert_eq!(bytes, vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
152    }
153
154    #[test]
155    fn test_mqtt_string_decoding() {
156        let data = vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o'];
157        let (mqtt_str, consumed) = MqttString::try_from_be_bytes(&data).unwrap();
158
159        assert_eq!(mqtt_str.as_str(), "hello");
160        assert_eq!(consumed, 7);
161    }
162
163    #[test]
164    fn test_mqtt_string_round_trip() {
165        let original = MqttString::create("test/topic").unwrap();
166        let bytes = original.to_be_bytes();
167        let (decoded, _) = MqttString::try_from_be_bytes(&bytes).unwrap();
168
169        assert_eq!(original, decoded);
170    }
171
172    #[test]
173    fn test_empty_string() {
174        let mqtt_str = MqttString::create("").unwrap();
175        let bytes = mqtt_str.to_be_bytes();
176
177        assert_eq!(bytes, vec![0x00, 0x00]);
178    }
179
180    #[test]
181    fn test_string_too_long() {
182        let long_string = "x".repeat(65536);
183        let result = MqttString::create(&long_string);
184
185        assert!(result.is_err());
186    }
187}