use byteorder::{ReadBytesExt, WriteBytesExt};
use crate::{
DataFormatIdentifier, Error, LengthFormatIdentifier, MemoryFormatIdentifier,
NegativeResponseCode, SingleValueWireFormat, WireFormat,
};
const REQUEST_DOWNLOAD_NEGATIVE_RESPONSE_CODES: [NegativeResponseCode; 6] = [
NegativeResponseCode::IncorrectMessageLengthOrInvalidFormat,
NegativeResponseCode::ConditionsNotCorrect,
NegativeResponseCode::RequestOutOfRange,
NegativeResponseCode::SecurityAccessDenied,
NegativeResponseCode::AuthenticationRequired,
NegativeResponseCode::UploadDownloadNotAccepted,
];
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Copy, Debug, PartialEq)]
#[non_exhaustive]
pub struct RequestDownloadRequest {
data_format_identifier: DataFormatIdentifier,
address_and_length_format_identifier: MemoryFormatIdentifier,
pub memory_address: u64,
pub memory_size: u32,
}
impl RequestDownloadRequest {
pub(crate) fn new(
data_format_identifier: DataFormatIdentifier,
memory_address: u64,
memory_size: u32,
) -> Result<Self, Error> {
if memory_address > 0xFF_FFFF_FFFF {
return Err(Error::InvalidMemoryAddress(memory_address));
}
let address_and_length_format_identifier =
MemoryFormatIdentifier::from_values(memory_size, memory_address);
Ok(Self {
data_format_identifier,
address_and_length_format_identifier,
memory_address,
memory_size,
})
}
fn get_shortened_memory_address(&self) -> Vec<u8> {
self.memory_address
.to_be_bytes()
.iter()
.skip(
8 - self
.address_and_length_format_identifier
.memory_address_length as usize,
)
.copied()
.collect()
}
fn get_shortened_memory_size(&self) -> Vec<u8> {
self.memory_size
.to_be_bytes()
.iter()
.skip(4 - self.address_and_length_format_identifier.memory_size_length as usize)
.copied()
.collect()
}
#[must_use]
pub fn allowed_nack_codes() -> &'static [NegativeResponseCode] {
&REQUEST_DOWNLOAD_NEGATIVE_RESPONSE_CODES
}
}
impl WireFormat for RequestDownloadRequest {
fn decode<T: std::io::Read>(reader: &mut T) -> Result<Option<Self>, Error> {
let data_format_identifier = DataFormatIdentifier::from(reader.read_u8()?);
let memory_identifier = MemoryFormatIdentifier::try_from(reader.read_u8()?)?;
let mut memory_address: Vec<u8> = vec![0; memory_identifier.memory_address_length as usize];
let mut memory_size: Vec<u8> = vec![0; memory_identifier.memory_size_length as usize];
reader.read_exact(&mut memory_address)?;
reader.read_exact(&mut memory_size)?;
Ok(Some(Self {
data_format_identifier,
address_and_length_format_identifier: memory_identifier,
memory_address: u64::from_be_bytes({
let mut bytes = [0; 8];
bytes[8 - memory_address.len()..].copy_from_slice(&memory_address);
bytes
}),
memory_size: u32::from_be_bytes({
let mut bytes = [0; 4];
bytes[4 - memory_size.len()..].copy_from_slice(&memory_size);
bytes
}),
}))
}
fn required_size(&self) -> usize {
2 + self.address_and_length_format_identifier.len()
}
fn encode<T: std::io::Write>(&self, writer: &mut T) -> Result<usize, Error> {
writer.write_u8(self.data_format_identifier.into())?;
writer.write_u8(self.address_and_length_format_identifier.into())?;
writer.write_all(self.get_shortened_memory_address().as_mut_slice())?;
writer.write_all(self.get_shortened_memory_size().as_mut_slice())?;
Ok(self.required_size())
}
}
impl SingleValueWireFormat for RequestDownloadRequest {}
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
#[derive(Clone, Debug, PartialEq)]
#[non_exhaustive]
pub struct RequestDownloadResponse {
length_format_identifier: LengthFormatIdentifier,
pub max_number_of_block_length: Vec<u8>,
}
impl RequestDownloadResponse {
pub(crate) fn new(length_format_identifier: u8, max_number_of_block_length: Vec<u8>) -> Self {
Self {
length_format_identifier: LengthFormatIdentifier::from(length_format_identifier),
max_number_of_block_length,
}
}
}
impl WireFormat for RequestDownloadResponse {
fn decode<T: std::io::Read>(reader: &mut T) -> Result<Option<Self>, Error> {
let length_format_identifier = LengthFormatIdentifier::from(reader.read_u8()?);
let mut max_number_of_block_length: Vec<u8> =
vec![0; length_format_identifier.max_number_of_block_length as usize];
reader.read_exact(&mut max_number_of_block_length)?;
Ok(Some(Self {
length_format_identifier,
max_number_of_block_length,
}))
}
fn required_size(&self) -> usize {
1 + self.max_number_of_block_length.len()
}
fn encode<T: std::io::Write>(&self, writer: &mut T) -> Result<usize, Error> {
writer.write_u8(self.length_format_identifier.into())?;
writer.write_all(&self.max_number_of_block_length)?;
Ok(self.required_size())
}
}
impl SingleValueWireFormat for RequestDownloadResponse {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn simple_request() {
let bytes: [u8; 7] = [
0x00, 0x14, 0xF0, 0xFF, 0xFF, 0x67, 0x0A,
];
let req = RequestDownloadRequest::decode(&mut bytes.as_slice())
.unwrap()
.unwrap();
assert_eq!(u8::from(req.data_format_identifier), 0);
assert_eq!(u8::from(req.address_and_length_format_identifier), 0x14);
assert_eq!(
req.address_and_length_format_identifier.memory_size_length,
1
);
assert_eq!(
req.address_and_length_format_identifier
.memory_address_length,
4
);
assert_eq!(req.memory_address, 0xF0FF_FF67);
assert_eq!(req.memory_size, 0x0A);
assert_eq!(
req.get_shortened_memory_address(),
vec![0xF0, 0xFF, 0xFF, 0x67]
);
assert_eq!(req.get_shortened_memory_size(), vec![0x0A]);
}
#[test]
fn bad_request() {
let bytes: [u8; 3] = [
0x00, 0x11, 0x67,
];
let req = RequestDownloadRequest::decode(&mut bytes.as_slice());
assert!(matches!(req, Err(Error::IoError(_))));
}
#[test]
fn read_memory_identifier() {
let memory_format_identifier = MemoryFormatIdentifier::try_from(0x23).unwrap();
assert_eq!(memory_format_identifier.memory_size_length, 2);
assert_eq!(memory_format_identifier.memory_address_length, 3);
assert_eq!(u8::from(memory_format_identifier), 0x23);
}
#[test]
fn read_length_identifier() {
let length_format_identifier = LengthFormatIdentifier::from(0xF0);
assert_eq!(length_format_identifier.max_number_of_block_length, 15);
assert_eq!(u8::from(length_format_identifier), 0xF0);
}
#[test]
fn check_message_size() {
let req = RequestDownloadRequest::new(0x00.into(), 0xF0_FF_FF_67, 0x0A).unwrap();
let mut vec = vec![];
req.encode(&mut vec).unwrap();
assert_eq!(vec.len(), req.required_size());
}
}