use std::io::{self, Read, Write};
pub const SYNC_BYTE: u8 = 0xE0;
pub const MARK_BYTE: u8 = 0xD0;
#[derive(Debug, Clone)]
pub enum Report {
Normal = 1,
IncorrectDataSize = 2,
InvalidData = 3,
Busy = 4,
Unknown,
}
impl From<u8> for Report {
fn from(value: u8) -> Self {
match value {
1 => Report::Normal,
2 => Report::IncorrectDataSize,
3 => Report::InvalidData,
4 => Report::Busy,
_ => Report::Unknown,
}
}
}
pub trait Packet: AsRef<[u8]> + AsMut<[u8]> {
const SIZE_INDEX: usize;
const DATA_BEGIN_INDEX: usize;
const DESTINATION_INDEX: usize;
fn len_of_packet(&self) -> usize {
Self::SIZE_INDEX + self.as_ref()[Self::SIZE_INDEX] as usize + 1
}
fn as_slice(&self) -> &[u8] {
&self.as_ref()[..self.len_of_packet()]
}
fn as_mut_slice(&mut self) -> &mut [u8] {
let len = self.len_of_packet();
&mut self.as_mut()[..len]
}
fn sync(&self) -> u8 {
self.as_ref()[0]
}
fn set_sync(&mut self) -> &mut Self {
self.as_mut()[0] = SYNC_BYTE;
self
}
fn size(&self) -> u8 {
self.as_ref()[Self::SIZE_INDEX]
}
fn set_size(&mut self, size: u8) -> &mut Self {
self.as_mut()[Self::SIZE_INDEX] = size;
self
}
fn dest(&self) -> u8 {
self.as_ref()[Self::DESTINATION_INDEX]
}
fn set_dest(&mut self, dest: u8) -> &mut Self {
self.as_mut()[Self::DESTINATION_INDEX] = dest;
self
}
fn data(&self) -> &[u8] {
&self.as_ref()[Self::DATA_BEGIN_INDEX..self.len_of_packet() - 1]
}
fn set_data(&mut self, data: &[u8]) -> &mut Self {
let size = data.len() + Self::DATA_BEGIN_INDEX;
self.as_mut()[Self::DATA_BEGIN_INDEX..size].copy_from_slice(data);
self.set_size((size - Self::SIZE_INDEX) as u8);
self
}
fn calculate_checksum(&mut self) -> &mut Self {
self.set_checksum(
self.as_slice()
.iter()
.skip(1)
.take(self.len_of_packet() - 2)
.fold(0, |acc: u8, &x| acc.wrapping_add(x)),
);
self
}
fn checksum(&self) -> u8 {
self.as_ref()[self.len_of_packet() - 1]
}
fn set_checksum(&mut self, checksum: u8) -> &mut Self {
let len = self.len_of_packet();
self.as_mut()[len - 1] = checksum;
self
}
}
pub trait ReportField: Packet {
const REPORT_INDEX: usize;
fn report(&self) -> Report {
self.as_ref()[Self::REPORT_INDEX].into()
}
fn report_raw(&self) -> u8 {
self.as_ref()[Self::REPORT_INDEX]
}
fn set_report(&mut self, report: impl Into<u8>) -> &mut Self {
self.as_mut()[Self::REPORT_INDEX] = report.into();
self
}
}
pub trait ReadByteExt: Read {
fn read_u8(&mut self) -> io::Result<u8> {
let mut buf = [0; 1];
self.read_exact(&mut buf)?;
Ok(buf[0])
}
fn read_u8_escaped(&mut self) -> io::Result<u8> {
let mut b = self.read_u8()?;
if b == MARK_BYTE {
b = self.read_u8()?.wrapping_add(1);
}
Ok(b)
}
}
impl<R: Read + ?Sized> ReadByteExt for R {}
pub trait WriteByteExt: Write {
fn write_u8(&mut self, b: u8) -> io::Result<()> {
self.write_all(&[b])
}
fn write_u8_escaped(&mut self, b: u8) -> io::Result<usize> {
if b == SYNC_BYTE || b == MARK_BYTE {
self.write_all(&[MARK_BYTE, b.wrapping_sub(1)])?;
Ok(2)
} else {
self.write_all(&[b])?;
Ok(1)
}
}
}
impl<W: Write + ?Sized> WriteByteExt for W {}
pub trait ReadPacket: Read {
fn read_packet<P: Packet>(&mut self, packet: &mut P) -> io::Result<u8> {
let sync = self.read_u8()?;
if sync != SYNC_BYTE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Expected SYNC byte (0xE0), found: {:#04x}", sync),
));
}
let buf = packet.as_mut();
buf[0] = sync;
for b in &mut buf[1..=P::SIZE_INDEX] {
*b = self.read_u8_escaped()?;
}
let len = buf[P::SIZE_INDEX] as usize + P::SIZE_INDEX;
for b in &mut buf[P::SIZE_INDEX + 1..=len] {
*b = self.read_u8_escaped()?;
}
Ok(packet.len_of_packet() as u8)
}
}
impl<R: Read + ?Sized> ReadPacket for R {}
pub trait WritePacket: Write {
fn write_packet_unchecked<P: Packet>(&mut self, packet: &P) -> io::Result<usize> {
if packet.len_of_packet() < P::DATA_BEGIN_INDEX + 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"The size of packet is can't be less than {}",
P::DATA_BEGIN_INDEX + 1
),
));
}
let mut bytes_written = 1;
self.write_u8(SYNC_BYTE)?;
for &b in &packet.as_slice()[1..] {
bytes_written += self.write_u8_escaped(b)?;
}
Ok(bytes_written)
}
fn write_packet<P: Packet>(&mut self, packet: &P) -> io::Result<usize> {
if packet.len_of_packet() < P::DATA_BEGIN_INDEX + 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"The size of packet is can't be less than {}",
P::DATA_BEGIN_INDEX + 1
),
));
}
self.write_u8(SYNC_BYTE)?;
let mut bytes_written: usize = 2;
let mut checksum: u8 = 0;
for &b in &packet.as_slice()[1..packet.len_of_packet() - 1] {
bytes_written += self.write_u8_escaped(b)?;
checksum = checksum.wrapping_add(b);
}
self.write_u8_escaped(checksum)?;
Ok(bytes_written)
}
#[deprecated(since = "1.1.0")]
fn write_packet_with_checksum<P: Packet>(&mut self, packet: &P) -> io::Result<usize> {
self.write_packet(packet)
}
}
impl<W: Write + ?Sized> WritePacket for W {}