use alloc::string::String;
use alloc::vec::Vec;
use core::fmt;
pub mod codes {
pub const NULL: u8 = 0x40;
pub const BOOLEAN_TRUE: u8 = 0x41;
pub const BOOLEAN_FALSE: u8 = 0x42;
pub const BOOLEAN: u8 = 0x56;
pub const UBYTE: u8 = 0x50;
pub const USHORT: u8 = 0x60;
pub const UINT: u8 = 0x70;
pub const SMALLUINT: u8 = 0x52;
pub const UINT0: u8 = 0x43;
pub const ULONG: u8 = 0x80;
pub const SMALLULONG: u8 = 0x53;
pub const ULONG0: u8 = 0x44;
pub const BYTE: u8 = 0x51;
pub const SHORT: u8 = 0x61;
pub const INT: u8 = 0x71;
pub const SMALLINT: u8 = 0x54;
pub const LONG: u8 = 0x81;
pub const SMALLLONG: u8 = 0x55;
pub const VBIN8: u8 = 0xA0;
pub const VBIN32: u8 = 0xB0;
pub const STR8: u8 = 0xA1;
pub const STR32: u8 = 0xB1;
pub const SYM8: u8 = 0xA3;
pub const SYM32: u8 = 0xB3;
pub const FLOAT: u8 = 0x72;
pub const DOUBLE: u8 = 0x82;
pub const CHAR: u8 = 0x73;
pub const DECIMAL32: u8 = 0x74;
pub const DECIMAL64: u8 = 0x84;
pub const DECIMAL128: u8 = 0x94;
pub const TIMESTAMP: u8 = 0x83;
pub const UUID: u8 = 0x98;
pub const LIST0: u8 = 0x45;
pub const LIST8: u8 = 0xC0;
pub const LIST32: u8 = 0xD0;
pub const MAP8: u8 = 0xC1;
pub const MAP32: u8 = 0xD1;
pub const ARRAY8: u8 = 0xE0;
pub const ARRAY32: u8 = 0xF0;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FormatCode {
Fixed(u8),
Variable(u8),
Compound(u8),
Array(u8),
}
impl FormatCode {
#[must_use]
pub const fn from_byte(b: u8) -> Self {
match b >> 4 {
0x4..=0x9 => Self::Fixed(b),
0xA | 0xB => Self::Variable(b),
0xC | 0xD => Self::Compound(b),
0xE | 0xF => Self::Array(b),
_ => Self::Fixed(b),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AmqpValue {
Null,
Boolean(bool),
Ulong(u64),
Long(i64),
Binary(Vec<u8>),
String(String),
Symbol(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TypeError {
Truncated,
UnsupportedFormatCode(u8),
InvalidUtf8,
NonAsciiSymbol,
LengthTooLarge,
}
impl fmt::Display for TypeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Truncated => f.write_str("input truncated"),
Self::UnsupportedFormatCode(c) => write!(f, "unsupported format code 0x{c:02X}"),
Self::InvalidUtf8 => f.write_str("invalid UTF-8 in str8/str32"),
Self::NonAsciiSymbol => f.write_str("non-ASCII byte in symbol"),
Self::LengthTooLarge => f.write_str("length exceeds u32::MAX"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for TypeError {}
#[must_use]
pub fn encode_null() -> Vec<u8> {
alloc::vec![codes::NULL]
}
#[must_use]
pub fn encode_boolean(v: bool) -> Vec<u8> {
alloc::vec![if v {
codes::BOOLEAN_TRUE
} else {
codes::BOOLEAN_FALSE
}]
}
#[must_use]
pub fn encode_ulong(v: u64) -> Vec<u8> {
if v == 0 {
alloc::vec![codes::ULONG0]
} else if v <= u64::from(u8::MAX) {
let b = (v & 0xFF) as u8;
alloc::vec![codes::SMALLULONG, b]
} else {
let mut out = Vec::with_capacity(9);
out.push(codes::ULONG);
out.extend_from_slice(&v.to_be_bytes());
out
}
}
#[must_use]
pub fn encode_long(v: i64) -> Vec<u8> {
if (i64::from(i8::MIN)..=i64::from(i8::MAX)).contains(&v) {
let b = (v as i8) as u8;
alloc::vec![codes::SMALLLONG, b]
} else {
let mut out = Vec::with_capacity(9);
out.push(codes::LONG);
out.extend_from_slice(&v.to_be_bytes());
out
}
}
pub fn encode_binary(data: &[u8]) -> Result<Vec<u8>, TypeError> {
let len = data.len();
if len > u32::MAX as usize {
return Err(TypeError::LengthTooLarge);
}
if len <= u8::MAX as usize {
let mut out = Vec::with_capacity(2 + len);
out.push(codes::VBIN8);
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
out.extend_from_slice(data);
Ok(out)
} else {
let mut out = Vec::with_capacity(5 + len);
out.push(codes::VBIN32);
#[allow(clippy::cast_possible_truncation)]
out.extend_from_slice(&(len as u32).to_be_bytes());
out.extend_from_slice(data);
Ok(out)
}
}
pub fn encode_string(s: &str) -> Result<Vec<u8>, TypeError> {
let bytes = s.as_bytes();
let len = bytes.len();
if len > u32::MAX as usize {
return Err(TypeError::LengthTooLarge);
}
if len <= u8::MAX as usize {
let mut out = Vec::with_capacity(2 + len);
out.push(codes::STR8);
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
out.extend_from_slice(bytes);
Ok(out)
} else {
let mut out = Vec::with_capacity(5 + len);
out.push(codes::STR32);
#[allow(clippy::cast_possible_truncation)]
out.extend_from_slice(&(len as u32).to_be_bytes());
out.extend_from_slice(bytes);
Ok(out)
}
}
pub fn encode_symbol(s: &str) -> Result<Vec<u8>, TypeError> {
if !s.is_ascii() {
return Err(TypeError::NonAsciiSymbol);
}
let bytes = s.as_bytes();
let len = bytes.len();
if len > u32::MAX as usize {
return Err(TypeError::LengthTooLarge);
}
if len <= u8::MAX as usize {
let mut out = Vec::with_capacity(2 + len);
out.push(codes::SYM8);
#[allow(clippy::cast_possible_truncation)]
out.push(len as u8);
out.extend_from_slice(bytes);
Ok(out)
} else {
let mut out = Vec::with_capacity(5 + len);
out.push(codes::SYM32);
#[allow(clippy::cast_possible_truncation)]
out.extend_from_slice(&(len as u32).to_be_bytes());
out.extend_from_slice(bytes);
Ok(out)
}
}
pub fn decode_value(bytes: &[u8]) -> Result<(AmqpValue, usize), TypeError> {
if bytes.is_empty() {
return Err(TypeError::Truncated);
}
let code = bytes[0];
match code {
codes::NULL => Ok((AmqpValue::Null, 1)),
codes::BOOLEAN_TRUE => Ok((AmqpValue::Boolean(true), 1)),
codes::BOOLEAN_FALSE => Ok((AmqpValue::Boolean(false), 1)),
codes::BOOLEAN => {
if bytes.len() < 2 {
return Err(TypeError::Truncated);
}
Ok((AmqpValue::Boolean(bytes[1] != 0), 2))
}
codes::ULONG0 => Ok((AmqpValue::Ulong(0), 1)),
codes::SMALLULONG => {
if bytes.len() < 2 {
return Err(TypeError::Truncated);
}
Ok((AmqpValue::Ulong(u64::from(bytes[1])), 2))
}
codes::ULONG => {
if bytes.len() < 9 {
return Err(TypeError::Truncated);
}
let mut buf = [0u8; 8];
buf.copy_from_slice(&bytes[1..9]);
Ok((AmqpValue::Ulong(u64::from_be_bytes(buf)), 9))
}
codes::SMALLLONG => {
if bytes.len() < 2 {
return Err(TypeError::Truncated);
}
#[allow(clippy::cast_possible_wrap)]
Ok((AmqpValue::Long(i64::from(bytes[1] as i8)), 2))
}
codes::LONG => {
if bytes.len() < 9 {
return Err(TypeError::Truncated);
}
let mut buf = [0u8; 8];
buf.copy_from_slice(&bytes[1..9]);
Ok((AmqpValue::Long(i64::from_be_bytes(buf)), 9))
}
codes::VBIN8 => {
if bytes.len() < 2 {
return Err(TypeError::Truncated);
}
let len = usize::from(bytes[1]);
if bytes.len() < 2 + len {
return Err(TypeError::Truncated);
}
Ok((AmqpValue::Binary(bytes[2..2 + len].to_vec()), 2 + len))
}
codes::VBIN32 => {
if bytes.len() < 5 {
return Err(TypeError::Truncated);
}
let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
if bytes.len() < 5 + len {
return Err(TypeError::Truncated);
}
Ok((AmqpValue::Binary(bytes[5..5 + len].to_vec()), 5 + len))
}
codes::STR8 => decode_str8(bytes, AmqpValue::String),
codes::STR32 => decode_str32(bytes, AmqpValue::String),
codes::SYM8 => decode_str8(bytes, AmqpValue::Symbol),
codes::SYM32 => decode_str32(bytes, AmqpValue::Symbol),
other => Err(TypeError::UnsupportedFormatCode(other)),
}
}
fn decode_str8(
bytes: &[u8],
wrap: fn(String) -> AmqpValue,
) -> Result<(AmqpValue, usize), TypeError> {
if bytes.len() < 2 {
return Err(TypeError::Truncated);
}
let len = usize::from(bytes[1]);
if bytes.len() < 2 + len {
return Err(TypeError::Truncated);
}
let s = core::str::from_utf8(&bytes[2..2 + len])
.map_err(|_| TypeError::InvalidUtf8)?
.to_owned();
Ok((wrap(s), 2 + len))
}
fn decode_str32(
bytes: &[u8],
wrap: fn(String) -> AmqpValue,
) -> Result<(AmqpValue, usize), TypeError> {
if bytes.len() < 5 {
return Err(TypeError::Truncated);
}
let len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
if bytes.len() < 5 + len {
return Err(TypeError::Truncated);
}
let s = core::str::from_utf8(&bytes[5..5 + len])
.map_err(|_| TypeError::InvalidUtf8)?
.to_owned();
Ok((wrap(s), 5 + len))
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn null_encodes_to_single_byte_0x40() {
assert_eq!(encode_null(), alloc::vec![0x40]);
}
#[test]
fn boolean_uses_compact_format_codes() {
assert_eq!(encode_boolean(true), alloc::vec![0x41]);
assert_eq!(encode_boolean(false), alloc::vec![0x42]);
}
#[test]
fn ulong_zero_uses_ulong0_format() {
assert_eq!(encode_ulong(0), alloc::vec![0x44]);
}
#[test]
fn ulong_small_uses_smallulong_format() {
assert_eq!(encode_ulong(255), alloc::vec![0x53, 0xFF]);
}
#[test]
fn ulong_large_uses_full_8_byte_format() {
let bytes = encode_ulong(0x1122_3344_5566_7788);
assert_eq!(bytes[0], 0x80);
assert_eq!(&bytes[1..], &0x1122_3344_5566_7788_u64.to_be_bytes());
}
#[test]
fn long_small_uses_smalllong_format() {
assert_eq!(encode_long(-1), alloc::vec![0x55, 0xFF]);
assert_eq!(encode_long(127), alloc::vec![0x55, 0x7F]);
}
#[test]
fn long_large_uses_full_8_byte_format() {
let bytes = encode_long(i64::MIN);
assert_eq!(bytes[0], 0x81);
}
#[test]
fn binary_short_uses_vbin8_format() {
let bytes = encode_binary(&[1, 2, 3]).expect("encode");
assert_eq!(bytes, alloc::vec![0xA0, 0x03, 1, 2, 3]);
}
#[test]
fn binary_long_uses_vbin32_format() {
let data = alloc::vec![0xAA; 300];
let bytes = encode_binary(&data).expect("encode");
assert_eq!(bytes[0], 0xB0);
assert_eq!(&bytes[1..5], &300u32.to_be_bytes());
}
#[test]
fn string_short_uses_str8_format() {
let bytes = encode_string("hi").expect("encode");
assert_eq!(bytes, alloc::vec![0xA1, 0x02, b'h', b'i']);
}
#[test]
fn string_unicode_round_trip() {
let bytes = encode_string("Käfer").expect("encode");
let (parsed, consumed) = decode_value(&bytes).expect("decode");
assert_eq!(consumed, bytes.len());
match parsed {
AmqpValue::String(s) => assert_eq!(s, "Käfer"),
_ => panic!("expected string"),
}
}
#[test]
fn symbol_rejects_non_ascii() {
assert_eq!(encode_symbol("Käfer"), Err(TypeError::NonAsciiSymbol));
}
#[test]
fn symbol_short_uses_sym8_format() {
let bytes = encode_symbol("hello").expect("encode");
assert_eq!(bytes[0], 0xA3);
assert_eq!(bytes[1], 0x05);
assert_eq!(&bytes[2..], b"hello");
}
#[test]
fn round_trip_all_primitive_values() {
let values = alloc::vec![
(encode_null(), AmqpValue::Null),
(encode_boolean(true), AmqpValue::Boolean(true)),
(encode_boolean(false), AmqpValue::Boolean(false)),
(encode_ulong(0), AmqpValue::Ulong(0)),
(encode_ulong(42), AmqpValue::Ulong(42)),
(
encode_ulong(0x1234_5678_9ABC_DEF0),
AmqpValue::Ulong(0x1234_5678_9ABC_DEF0)
),
(encode_long(-100), AmqpValue::Long(-100)),
(encode_long(i64::MIN), AmqpValue::Long(i64::MIN)),
(
encode_binary(&[1, 2, 3]).expect("ok"),
AmqpValue::Binary(alloc::vec![1, 2, 3])
),
(
encode_binary(&alloc::vec![0u8; 500]).expect("ok"),
AmqpValue::Binary(alloc::vec![0u8; 500])
),
(
encode_string("foo").expect("ok"),
AmqpValue::String("foo".into())
),
(
encode_symbol("bar").expect("ok"),
AmqpValue::Symbol("bar".into())
),
];
for (bytes, expected) in values {
let (parsed, consumed) = decode_value(&bytes).expect("decode");
assert_eq!(parsed, expected);
assert_eq!(consumed, bytes.len());
}
}
#[test]
fn unsupported_format_code_yields_error() {
assert_eq!(
decode_value(&[0xFF]),
Err(TypeError::UnsupportedFormatCode(0xFF))
);
}
#[test]
fn truncated_inputs_yield_error() {
assert_eq!(decode_value(&[]), Err(TypeError::Truncated));
assert_eq!(decode_value(&[0xA0]), Err(TypeError::Truncated)); assert_eq!(decode_value(&[0xA0, 5, 1]), Err(TypeError::Truncated)); assert_eq!(decode_value(&[0x80, 0, 0, 0]), Err(TypeError::Truncated)); }
#[test]
fn invalid_utf8_in_str_yields_error() {
assert_eq!(
decode_value(&[0xA1, 0x01, 0xFF]),
Err(TypeError::InvalidUtf8)
);
}
#[test]
fn format_code_categorizes_correctly() {
assert!(matches!(FormatCode::from_byte(0x40), FormatCode::Fixed(_)));
assert!(matches!(FormatCode::from_byte(0x80), FormatCode::Fixed(_)));
assert!(matches!(
FormatCode::from_byte(0xA0),
FormatCode::Variable(_)
));
assert!(matches!(
FormatCode::from_byte(0xB0),
FormatCode::Variable(_)
));
assert!(matches!(
FormatCode::from_byte(0xC0),
FormatCode::Compound(_)
));
assert!(matches!(
FormatCode::from_byte(0xD0),
FormatCode::Compound(_)
));
assert!(matches!(FormatCode::from_byte(0xE0), FormatCode::Array(_)));
assert!(matches!(FormatCode::from_byte(0xF0), FormatCode::Array(_)));
}
}