use crate::{Error, SingleValueWireFormat, WireFormat};
use byteorder::{ReadBytesExt, WriteBytesExt};
const LOW_NIBBLE_MASK: u8 = 0b0000_1111;
const HIGH_NIBBLE_MASK: u8 = 0b1111_0000;
const MEMORY_SIZE_NIBBLE_MASK: u8 = HIGH_NIBBLE_MASK;
const MEMORY_ADDRESS_NIBBLE_MASK: u8 = LOW_NIBBLE_MASK;
const BLOCK_LENGTH_NIBBLE_MASK: u8 = HIGH_NIBBLE_MASK;
const COMPRESSION_NIBBLE_MASK: u8 = HIGH_NIBBLE_MASK;
const ENCRYPTION_NIBBLE_MASK: u8 = LOW_NIBBLE_MASK;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct MemoryFormatIdentifier {
pub memory_size_length: u8,
pub memory_address_length: u8,
}
impl MemoryFormatIdentifier {
#[allow(clippy::cast_possible_truncation)]
pub fn from_values(memory_size: u32, memory_address: u64) -> Self {
let memory_address_length = (u64::BITS - memory_address.leading_zeros()).div_ceil(8) as u8;
let memory_size_length = (u32::BITS - memory_size.leading_zeros()).div_ceil(8) as u8;
Self {
memory_size_length,
memory_address_length,
}
}
pub fn len(self) -> usize {
self.memory_size_length as usize + self.memory_address_length as usize
}
}
impl TryFrom<u8> for MemoryFormatIdentifier {
type Error = Error;
fn try_from(value: u8) -> Result<Self, Error> {
let memory_size_length = (value & MEMORY_SIZE_NIBBLE_MASK) >> 4;
let memory_address_length = value & MEMORY_ADDRESS_NIBBLE_MASK;
match memory_size_length {
1..4 => (),
_ => return Err(Error::IncorrectMessageLengthOrInvalidFormat),
}
match memory_address_length {
1..5 => (),
_ => return Err(Error::IncorrectMessageLengthOrInvalidFormat),
}
Ok(Self {
memory_size_length,
memory_address_length: value & MEMORY_ADDRESS_NIBBLE_MASK,
})
}
}
impl From<MemoryFormatIdentifier> for u8 {
fn from(memory_format_identifier: MemoryFormatIdentifier) -> u8 {
(memory_format_identifier.memory_size_length << 4)
| memory_format_identifier.memory_address_length
}
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub(crate) struct LengthFormatIdentifier {
pub max_number_of_block_length: u8,
}
impl From<u8> for LengthFormatIdentifier {
fn from(value: u8) -> Self {
Self {
max_number_of_block_length: (value & BLOCK_LENGTH_NIBBLE_MASK) >> 4,
}
}
}
impl From<LengthFormatIdentifier> for u8 {
fn from(length_format_identifier: LengthFormatIdentifier) -> u8 {
length_format_identifier.max_number_of_block_length << 4
}
}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct DataFormatIdentifier {
encryption_method: u8,
compression_method: u8,
}
impl DataFormatIdentifier {
pub fn new(encryption_method: u8, compression_method: u8) -> Result<Self, Error> {
Ok(Self {
encryption_method: Self::check_value(encryption_method)?,
compression_method: Self::check_value(compression_method)?,
})
}
fn check_value(value: u8) -> Result<u8, Error> {
match value {
0..=15 => Ok(value),
_ => Err(Error::InvalidEncryptionCompressionMethod(value)),
}
}
}
impl From<u8> for DataFormatIdentifier {
fn from(value: u8) -> Self {
let encryption_method = value & ENCRYPTION_NIBBLE_MASK;
let compression_method = (value & COMPRESSION_NIBBLE_MASK) >> 4;
Self {
encryption_method,
compression_method,
}
}
}
impl From<DataFormatIdentifier> for u8 {
fn from(data_format_identifier: DataFormatIdentifier) -> u8 {
data_format_identifier.encryption_method | (data_format_identifier.compression_method << 4)
}
}
impl PartialEq<u8> for DataFormatIdentifier {
fn eq(&self, other: &u8) -> bool {
let other_data_format_identifier = DataFormatIdentifier::from(*other);
self == &other_data_format_identifier
}
}
impl WireFormat for DataFormatIdentifier {
fn decode<T: std::io::Read>(reader: &mut T) -> Result<Option<Self>, Error> {
let value = reader.read_u8()?;
Ok(Some(DataFormatIdentifier::from(value)))
}
fn required_size(&self) -> usize {
1
}
fn encode<T: std::io::Write>(&self, writer: &mut T) -> Result<usize, Error> {
writer.write_u8(u8::from(*self))?;
Ok(1)
}
}
impl SingleValueWireFormat for DataFormatIdentifier {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn memory_format_identifier() {
let memory_format_identifier = MemoryFormatIdentifier::try_from(0x23).unwrap();
assert_eq!(memory_format_identifier.memory_size_length, 2);
assert_eq!(memory_format_identifier.memory_address_length, 3);
assert_eq!(u8::from(memory_format_identifier), 0x23);
}
#[test]
fn failed_memory_format_identifier() {
let memory_format_identifier = MemoryFormatIdentifier::try_from(0x00);
assert!(matches!(
memory_format_identifier,
Err(Error::IncorrectMessageLengthOrInvalidFormat)
));
}
#[test]
fn length_format_identifier() {
let length_format_identifier = LengthFormatIdentifier::from(0xF0);
assert_eq!(length_format_identifier.max_number_of_block_length, 15);
assert_eq!(u8::from(length_format_identifier), 0xF0);
}
#[test]
fn data_format_identifier() {
let data_format_identifier = DataFormatIdentifier::from(0x23);
assert_eq!(data_format_identifier.encryption_method, 3);
assert_eq!(data_format_identifier.compression_method, 2);
assert_eq!(u8::from(data_format_identifier), 0x23);
let data_format_identifier = DataFormatIdentifier::new(0x0F, 0x0F);
assert!(data_format_identifier.is_ok());
let data_format_identifier = DataFormatIdentifier::new(0x1F, 0x0F);
assert!(matches!(
data_format_identifier,
Err(Error::InvalidEncryptionCompressionMethod(0x1F))
));
}
}