use crate::mqtt::result_code::MqttError;
use alloc::string::String;
use alloc::vec::Vec;
use serde::{Serialize, Serializer};
#[cfg(feature = "std")]
use std::io::IoSlice;
#[cfg(feature = "sso-lv20")]
const SSO_BUFFER_SIZE: usize = 48; #[cfg(all(
not(feature = "sso-lv20"),
any(feature = "sso-lv10", feature = "sso-min-64bit")
))]
const SSO_BUFFER_SIZE: usize = 24; #[cfg(all(
not(any(feature = "sso-lv20", feature = "sso-lv10", feature = "sso-min-64bit")),
feature = "sso-min-32bit"
))]
const SSO_BUFFER_SIZE: usize = 12; #[cfg(not(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
)))]
#[allow(dead_code)]
const SSO_BUFFER_SIZE: usize = 0;
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
const SSO_DATA_THRESHOLD: usize = SSO_BUFFER_SIZE - 2;
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
#[allow(clippy::large_enum_variant)]
pub enum MqttString {
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
Small([u8; SSO_BUFFER_SIZE]),
Large(Vec<u8>),
}
impl MqttString {
pub fn new(s: impl AsRef<str>) -> Result<Self, MqttError> {
let s_ref = s.as_ref();
let len = s_ref.len();
if len > 65535 {
return Err(MqttError::MalformedPacket);
}
let total_encoded_len = 2 + len;
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
if len <= SSO_DATA_THRESHOLD {
let mut buffer = [0u8; SSO_BUFFER_SIZE];
buffer[0] = (len >> 8) as u8;
buffer[1] = len as u8;
buffer[2..2 + len].copy_from_slice(s_ref.as_bytes());
return Ok(Self::Small(buffer));
}
let mut encoded = Vec::with_capacity(total_encoded_len);
encoded.push((len >> 8) as u8);
encoded.push(len as u8);
encoded.extend_from_slice(s_ref.as_bytes());
Ok(Self::Large(encoded))
}
pub fn as_bytes(&self) -> &[u8] {
match self {
MqttString::Large(encoded) => encoded,
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
MqttString::Small(buffer) => {
let len = ((buffer[0] as usize) << 8) | (buffer[1] as usize);
&buffer[..2 + len]
}
}
}
pub fn as_str(&self) -> &str {
let data = match self {
MqttString::Large(encoded) => &encoded[2..],
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
MqttString::Small(buffer) => {
let len = ((buffer[0] as usize) << 8) | (buffer[1] as usize);
&buffer[2..2 + len]
}
};
unsafe { core::str::from_utf8_unchecked(data) }
}
pub fn len(&self) -> usize {
match self {
MqttString::Large(encoded) => encoded.len() - 2,
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
MqttString::Small(buffer) => ((buffer[0] as usize) << 8) | (buffer[1] as usize),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn size(&self) -> usize {
match self {
MqttString::Large(encoded) => encoded.len(),
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
MqttString::Small(buffer) => {
let len = ((buffer[0] as usize) << 8) | (buffer[1] as usize);
2 + len
}
}
}
#[cfg(feature = "std")]
pub fn to_buffers(&self) -> Vec<IoSlice<'_>> {
match self {
MqttString::Large(encoded) => vec![IoSlice::new(encoded)],
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
MqttString::Small(buffer) => {
let len = ((buffer[0] as usize) << 8) | (buffer[1] as usize);
vec![IoSlice::new(&buffer[..2 + len])]
}
}
}
pub fn to_continuous_buffer(&self) -> Vec<u8> {
self.as_bytes().to_vec()
}
pub fn decode(data: &[u8]) -> Result<(Self, usize), MqttError> {
if data.len() < 2 {
return Err(MqttError::MalformedPacket);
}
let string_len = ((data[0] as usize) << 8) | (data[1] as usize);
if data.len() < 2 + string_len {
return Err(MqttError::MalformedPacket);
}
if core::str::from_utf8(&data[2..2 + string_len]).is_err() {
return Err(MqttError::MalformedPacket);
}
let total_encoded_len = 2 + string_len;
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
if string_len <= SSO_DATA_THRESHOLD {
let mut buffer = [0u8; SSO_BUFFER_SIZE];
buffer[0] = data[0];
buffer[1] = data[1];
buffer[2..2 + string_len].copy_from_slice(&data[2..2 + string_len]);
return Ok((Self::Small(buffer), total_encoded_len));
}
let mut encoded = Vec::with_capacity(total_encoded_len);
encoded.extend_from_slice(&data[0..total_encoded_len]);
Ok((Self::Large(encoded), total_encoded_len))
}
pub fn contains(&self, c: char) -> bool {
self.as_str().contains(c)
}
pub fn starts_with(&self, prefix: &str) -> bool {
self.as_str().starts_with(prefix)
}
pub fn ends_with(&self, suffix: &str) -> bool {
self.as_str().ends_with(suffix)
}
}
impl AsRef<str> for MqttString {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl core::fmt::Display for MqttString {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl core::ops::Deref for MqttString {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl Serialize for MqttString {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
self.as_str().serialize(serializer)
}
}
impl core::cmp::PartialEq<str> for MqttString {
fn eq(&self, other: &str) -> bool {
self.as_str() == other
}
}
impl core::cmp::PartialEq<&str> for MqttString {
fn eq(&self, other: &&str) -> bool {
self.as_str() == *other
}
}
impl core::cmp::PartialEq<String> for MqttString {
fn eq(&self, other: &String) -> bool {
self.as_str() == other.as_str()
}
}
impl core::hash::Hash for MqttString {
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
self.as_str().hash(state);
}
}
impl Default for MqttString {
fn default() -> Self {
Self::new("").unwrap()
}
}
impl core::fmt::Debug for MqttString {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("MqttString")
.field("value", &self.as_str())
.finish()
}
}
impl TryFrom<&str> for MqttString {
type Error = MqttError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
Self::new(s)
}
}
impl TryFrom<String> for MqttString {
type Error = MqttError;
fn try_from(s: String) -> Result<Self, Self::Error> {
Self::new(s)
}
}
impl TryFrom<&String> for MqttString {
type Error = MqttError;
fn try_from(s: &String) -> Result<Self, Self::Error> {
Self::new(s.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_string() {
let string = MqttString::new("").unwrap();
assert_eq!(string.len(), 0);
assert!(string.is_empty());
assert_eq!(string.as_str(), "");
assert_eq!(string.as_bytes(), &[0x00, 0x00]);
}
#[test]
fn test_small_string() {
let string = MqttString::new("hello").unwrap();
assert_eq!(string.len(), 5);
assert!(!string.is_empty());
assert_eq!(string.as_str(), "hello");
assert_eq!(
string.as_bytes(),
&[0x00, 0x05, b'h', b'e', b'l', b'l', b'o']
);
}
#[cfg(feature = "std")]
#[test]
fn test_to_buffers() {
let data = "buffer test";
let string = MqttString::new(data).unwrap();
let buffers = string.to_buffers();
assert_eq!(buffers.len(), 1);
let buffer_data: &[u8] = &buffers[0];
assert_eq!(buffer_data, string.as_bytes());
}
#[test]
fn test_string_variants() {
let small_data = "small";
let string = MqttString::new(small_data).unwrap();
#[cfg(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
))]
assert!(matches!(string, MqttString::Small(_)));
#[cfg(not(any(
feature = "sso-min-32bit",
feature = "sso-min-64bit",
feature = "sso-lv10",
feature = "sso-lv20"
)))]
assert!(matches!(string, MqttString::Large(_)));
let medium_data = "This is a medium-size string that is longer than small SSO buffers but fits in the largest one"; let string = MqttString::new(medium_data).unwrap();
assert!(matches!(string, MqttString::Large(_)));
let very_large_data = "This is a very long string that exceeds even the largest SSO buffer size to ensure it's always stored in the Large variant";
let string = MqttString::new(very_large_data).unwrap();
assert!(matches!(string, MqttString::Large(_)));
}
}