use std::{
convert::TryInto,
error, fmt,
io::{self, Read, Write},
};
use serde::{Deserialize, Serialize};
use crate::{
config::Config,
hash::sha512,
io::{ReadFrom, WriteTo},
priv_util::ToHexString,
};
pub use crate::command::{Command, CommandKind, ParseCommandError};
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Magic([u8; 4]);
impl Magic {
pub fn new(value: [u8; 4]) -> Self {
Self(value)
}
}
impl fmt::Display for Magic {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.to_hex_string().fmt(f)
}
}
impl AsRef<[u8]> for Magic {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl WriteTo for Magic {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for Magic {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(<[u8; 4]>::read_from(r)?))
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize, Deserialize)]
pub struct PayloadLength(u32);
impl PayloadLength {
pub fn new(value: u32) -> Self {
Self(value)
}
pub fn as_u32(self) -> u32 {
self.0
}
}
impl fmt::Display for PayloadLength {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl WriteTo for PayloadLength {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for PayloadLength {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(u32::read_from(r)?))
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)]
pub struct Checksum([u8; 4]);
impl Checksum {
pub fn new(value: [u8; 4]) -> Self {
Self(value)
}
}
impl fmt::Display for Checksum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.to_hex_string().fmt(f)
}
}
impl AsRef<[u8]> for Checksum {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl WriteTo for Checksum {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.0.write_to(w)
}
}
impl ReadFrom for Checksum {
fn read_from(r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
Ok(Self(<[u8; 4]>::read_from(r)?))
}
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Header {
magic: Magic,
command: Command,
length: PayloadLength,
checksum: Checksum,
}
impl Header {
const MAGIC: Magic = Magic([0xe9, 0xbe, 0xb4, 0xd9]);
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum PacketError {
InvalidMagic {
expected: Magic,
actual: Magic,
},
TooLong {
max: usize,
len: usize,
},
InvalidChecksum {
expected: Checksum,
actual: Checksum,
},
InvalidLength {
expected: usize,
actual: usize,
},
}
impl fmt::Display for PacketError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidMagic { expected, actual } => {
write!(f, "magic must be {}, but {}", expected, actual)
}
Self::TooLong { max, len } => write!(f, "length must be <={}, but {}", max, len),
Self::InvalidChecksum { expected, actual } => {
write!(f, "checksum should be {}, but {}", expected, actual)
}
Self::InvalidLength { expected, actual } => {
write!(f, "length should be {}, but {}", expected, actual)
}
}
}
}
impl error::Error for PacketError {}
impl From<PacketError> for io::Error {
fn from(err: PacketError) -> Self {
io::Error::new(io::ErrorKind::Other, err)
}
}
pub type Result<T> = std::result::Result<T, PacketError>;
fn checksum(payload: impl AsRef<[u8]>) -> Checksum {
Checksum(sha512(payload)[0..4].try_into().unwrap())
}
impl Header {
pub const LEN_BM: usize = 4 + 12 + 4 + 4;
pub fn new(config: &Config, command: Command, payload: impl AsRef<[u8]>) -> Result<Self> {
let payload = payload.as_ref();
if payload.len() > config.max_payload_length().as_u32() as usize {
Err(PacketError::TooLong {
max: config.max_payload_length().as_u32() as usize,
len: payload.len(),
})
} else {
Ok(Self {
magic: Self::MAGIC,
command,
length: PayloadLength(payload.len() as u32),
checksum: checksum(payload),
})
}
}
pub fn command(&self) -> Command {
self.command
}
pub fn length(&self) -> PayloadLength {
self.length
}
fn validate(&self, config: &Config) -> Result<()> {
if self.magic != Self::MAGIC {
Err(PacketError::InvalidMagic {
expected: Header::MAGIC,
actual: self.magic,
})
} else if self.length > config.max_payload_length() {
Err(PacketError::TooLong {
max: config.max_payload_length().as_u32() as usize,
len: self.length.as_u32() as usize,
})
} else {
Ok(())
}
}
fn validate_with_payload(&self, config: &Config, payload: impl AsRef<[u8]>) -> Result<()> {
self.validate(config)?;
let payload = payload.as_ref();
if payload.len() != self.length.as_u32() as usize {
Err(PacketError::InvalidLength {
expected: self.length.as_u32() as usize,
actual: payload.len(),
})
} else {
let checksum = checksum(payload);
if checksum != self.checksum {
Err(PacketError::InvalidChecksum {
expected: checksum,
actual: self.checksum,
})
} else {
Ok(())
}
}
}
}
impl WriteTo for Header {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.magic.write_to(w)?;
self.command.write_to(w)?;
self.length.write_to(w)?;
self.checksum.write_to(w)?;
Ok(())
}
}
impl Header {
pub fn read_from_with_config(config: &Config, r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
let result = Self {
magic: Magic::read_from(r)?,
command: Command::read_from(r)?,
length: PayloadLength::read_from(r)?,
checksum: Checksum::read_from(r)?,
};
result.validate(config)?;
Ok(result)
}
}
#[test]
fn test_header_write_to() {
use crate::command::CommandKind;
let config = Config::new();
let payload = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef];
let test = Header::new(&config, CommandKind::Version.into(), &payload).unwrap();
let mut bytes = Vec::new();
test.write_to(&mut bytes).unwrap();
let expected = [
0xe9, 0xbe, 0xb4, 0xd9, b'v', b'e', b'r', b's', b'i', b'o', b'n', 0, 0, 0, 0, 0, 0, 0, 0,
8, 0x65, 0x01, 0x61, 0x85,
];
assert_eq!(bytes, expected);
}
#[test]
fn test_header_read_from() {
use crate::command::CommandKind;
use std::io::Cursor;
let mut bytes = Cursor::new([
0xe9, 0xbe, 0xb4, 0xd9, b'v', b'e', b'r', b's', b'i', b'o', b'n', 0, 0, 0, 0, 0, 0, 0, 0,
8, 0x65, 0x01, 0x61, 0x85,
]);
let config = Config::new();
let test = Header::read_from_with_config(&config, &mut bytes).unwrap();
let payload = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef];
let expected = Header::new(&config, CommandKind::Version.into(), &payload).unwrap();
assert_eq!(test, expected);
}
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct Packet {
header: Header,
payload: Vec<u8>,
}
impl Packet {
pub fn new(config: &Config, command: Command, payload: Vec<u8>) -> Result<Self> {
if payload.len() > config.max_payload_length().as_u32() as usize {
Err(PacketError::TooLong {
max: config.max_payload_length().as_u32() as usize,
len: payload.len(),
})
} else {
Ok(Self {
header: Header::new(config, command, &payload)?,
payload,
})
}
}
pub fn compose(config: &Config, header: Header, payload: Vec<u8>) -> Result<Self> {
let result = Self { header, payload };
result.validate(config)?;
Ok(result)
}
pub fn header(&self) -> &Header {
&self.header
}
pub fn payload(&self) -> &[u8] {
&self.payload
}
fn validate(&self, config: &Config) -> Result<()> {
self.header.validate_with_payload(config, &self.payload)
}
}
impl WriteTo for Packet {
fn write_to(&self, w: &mut dyn Write) -> io::Result<()> {
self.header.write_to(w)?;
w.write_all(self.payload.as_ref())?;
Ok(())
}
}
impl Packet {
pub fn read_from_with_config(config: &Config, r: &mut dyn Read) -> io::Result<Self>
where
Self: Sized,
{
let header = Header::read_from_with_config(config, r)?;
header.validate(config)?;
let mut r = r.take(header.length.as_u32() as u64);
let mut payload = Vec::with_capacity(header.length.as_u32() as usize);
r.read_to_end(&mut payload)?;
let packet = Self { header, payload };
packet.validate(config)?;
Ok(packet)
}
}
#[test]
fn test_packet_write_to() {
use crate::command::CommandKind;
let config = Config::new();
let payload = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef].to_vec();
let test = Packet::new(&config, CommandKind::Version.into(), payload).unwrap();
let mut bytes = Vec::new();
test.write_to(&mut bytes).unwrap();
let expected = [
0xe9, 0xbe, 0xb4, 0xd9, b'v', b'e', b'r', b's', b'i', b'o', b'n', 0, 0, 0, 0, 0, 0, 0, 0,
8, 0x65, 0x01, 0x61, 0x85, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
];
assert_eq!(bytes, expected);
}
#[test]
fn test_packet_read_from() {
use crate::command::CommandKind;
use std::io::Cursor;
let mut bytes = Cursor::new([
0xe9, 0xbe, 0xb4, 0xd9, b'v', b'e', b'r', b's', b'i', b'o', b'n', 0, 0, 0, 0, 0, 0, 0, 0,
8, 0x65, 0x01, 0x61, 0x85, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
]);
let config = Config::new();
let test = Packet::read_from_with_config(&config, &mut bytes).unwrap();
let payload = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef].to_vec();
let expected = Packet::new(&config, CommandKind::Version.into(), payload).unwrap();
assert_eq!(test, expected);
}