use crate::requests::common::MAGIC_NUMBER;
use crate::requests::{ResponseStatus, Result};
#[cfg(feature = "fuzz")]
use arbitrary::Arbitrary;
use bincode::Options;
use log::error;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::io::{Read, Write};
const WIRE_PROTOCOL_VERSION_MAJ: u8 = 1;
const WIRE_PROTOCOL_VERSION_MIN: u8 = 0;
const REQUEST_HDR_SIZE: u16 = 30;
#[cfg_attr(feature = "fuzz", derive(Arbitrary))]
#[derive(Copy, Clone, Debug, Serialize, Deserialize)]
pub struct WireHeader {
pub flags: u16,
pub provider: u8,
pub session: u64,
pub content_type: u8,
pub accept_type: u8,
pub auth_type: u8,
pub body_len: u32,
pub auth_len: u16,
pub opcode: u32,
pub status: u16,
pub reserved1: u8,
pub reserved2: u8,
}
impl WireHeader {
#[cfg(feature = "testing")]
#[allow(clippy::new_without_default)]
pub fn new() -> WireHeader {
WireHeader {
flags: 0,
provider: 0,
session: 0,
content_type: 0,
accept_type: 0,
auth_type: 0,
body_len: 0,
auth_len: 0,
opcode: 0,
status: 0,
reserved1: 0,
reserved2: 0,
}
}
pub fn write_to_stream<W: Write>(&self, stream: &mut W) -> Result<()> {
let serializer = bincode::DefaultOptions::new()
.with_little_endian()
.with_fixint_encoding();
stream.write_all(&serializer.serialize(&MAGIC_NUMBER)?)?;
stream.write_all(&serializer.serialize(&REQUEST_HDR_SIZE)?)?;
stream.write_all(&serializer.serialize(&WIRE_PROTOCOL_VERSION_MAJ)?)?;
stream.write_all(&serializer.serialize(&WIRE_PROTOCOL_VERSION_MIN)?)?;
stream.write_all(&serializer.serialize(&self)?)?;
Ok(())
}
pub fn read_from_stream<R: Read>(mut stream: &mut R) -> Result<WireHeader> {
let magic_number = get_from_stream!(stream, u32);
if magic_number != MAGIC_NUMBER {
error!(
"Expected magic number {}, got {}",
MAGIC_NUMBER, magic_number
);
return Err(ResponseStatus::InvalidHeader);
}
let hdr_size = get_from_stream!(stream, u16);
let mut bytes = vec![0_u8; usize::try_from(hdr_size)?];
stream.read_exact(&mut bytes)?;
if hdr_size != REQUEST_HDR_SIZE {
error!(
"Expected request header size {}, got {}",
REQUEST_HDR_SIZE, hdr_size
);
return Err(ResponseStatus::InvalidHeader);
}
let version_maj = bytes.remove(0); let version_min = bytes.remove(0); if version_maj != WIRE_PROTOCOL_VERSION_MAJ || version_min != WIRE_PROTOCOL_VERSION_MIN {
error!(
"Expected wire protocol version {}.{}, got {}.{} instead",
WIRE_PROTOCOL_VERSION_MAJ, WIRE_PROTOCOL_VERSION_MIN, version_maj, version_min
);
return Err(ResponseStatus::WireProtocolVersionNotSupported);
}
let deserializer = bincode::DefaultOptions::new()
.with_little_endian()
.with_fixint_encoding();
let wire_header: WireHeader = deserializer.deserialize(&bytes)?;
if wire_header.reserved1 != 0x00 || wire_header.reserved2 != 0x00 {
Err(ResponseStatus::InvalidHeader)
} else {
Ok(wire_header)
}
}
}