use crate::error::{MqttError, Result};
use crate::prelude::{format, vec, String, ToString};
use bebytes::BeBytes;
#[derive(Debug, Clone, PartialEq, Eq, BeBytes)]
pub struct MqttString {
#[bebytes(big_endian)]
length: u16,
#[bebytes(size = "length")]
data: String,
}
impl MqttString {
pub fn create(s: &str) -> Result<Self> {
if s.contains('\0') {
return Err(MqttError::MalformedPacket(
"String contains null character".to_string(),
));
}
let len = s.len();
if len > u16::MAX as usize {
return Err(MqttError::StringTooLong(len));
}
Ok(Self {
#[allow(clippy::cast_possible_truncation)]
length: len as u16,
data: s.to_string(),
})
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.data
}
#[must_use]
pub fn encoded_size(&self) -> usize {
2 + self.data.len()
}
}
impl TryFrom<&str> for MqttString {
type Error = MqttError;
fn try_from(s: &str) -> Result<Self> {
Self::create(s)
}
}
impl TryFrom<String> for MqttString {
type Error = MqttError;
fn try_from(s: String) -> Result<Self> {
Self::create(&s)
}
}
pub fn encode_string<B: bytes::BufMut>(buf: &mut B, string: &str) -> Result<()> {
if string.contains('\0') {
return Err(MqttError::MalformedPacket(
"String contains null character".to_string(),
));
}
let mqtt_string = MqttString::create(string)?;
let encoded = mqtt_string.to_be_bytes();
buf.put_slice(&encoded);
Ok(())
}
pub fn decode_string<B: bytes::Buf>(buf: &mut B) -> Result<String> {
if buf.remaining() < 2 {
return Err(MqttError::MalformedPacket(
"Insufficient bytes for string length".to_string(),
));
}
let len = buf.get_u16() as usize;
if buf.remaining() < len {
return Err(MqttError::MalformedPacket(format!(
"Insufficient bytes for string data: expected {}, got {}",
len,
buf.remaining()
)));
}
let mut bytes = vec![0u8; len];
buf.copy_to_slice(&mut bytes);
let string = String::from_utf8(bytes)
.map_err(|e| MqttError::MalformedPacket(format!("Invalid UTF-8: {e}")))?;
if string.contains('\0') {
return Err(MqttError::MalformedPacket(
"String contains null character".to_string(),
));
}
Ok(string)
}
#[must_use]
pub fn string_len(string: &str) -> usize {
2 + string.len()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mqtt_string_encoding() {
let mqtt_str = MqttString::create("hello").unwrap();
let bytes = mqtt_str.to_be_bytes();
assert_eq!(bytes, vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o']);
}
#[test]
fn test_mqtt_string_decoding() {
let data = vec![0x00, 0x05, b'h', b'e', b'l', b'l', b'o'];
let (mqtt_str, consumed) = MqttString::try_from_be_bytes(&data).unwrap();
assert_eq!(mqtt_str.as_str(), "hello");
assert_eq!(consumed, 7);
}
#[test]
fn test_mqtt_string_round_trip() {
let original = MqttString::create("test/topic").unwrap();
let bytes = original.to_be_bytes();
let (decoded, _) = MqttString::try_from_be_bytes(&bytes).unwrap();
assert_eq!(original, decoded);
}
#[test]
fn test_empty_string() {
let mqtt_str = MqttString::create("").unwrap();
let bytes = mqtt_str.to_be_bytes();
assert_eq!(bytes, vec![0x00, 0x00]);
}
#[test]
fn test_string_too_long() {
let long_string = "x".repeat(65536);
let result = MqttString::create(&long_string);
assert!(result.is_err());
}
}