use crate::{
error::{Error, Result},
utils,
};
use bytes::Bytes;
use log::trace;
use serde::{Deserialize, Serialize};
use std::{fmt, net::SocketAddr};
use unwrap::unwrap;
const MSG_HEADER_LEN: usize = 9;
const MSG_PROTOCOL_VERSION: u16 = 0x0001;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub enum WireMsg {
EndpointEchoReq,
EndpointEchoResp(SocketAddr),
UserMsg(Bytes),
}
const USER_MSG_FLAG: u8 = 0x00;
const ECHO_SRVC_MSG_FLAG: u8 = 0x01;
impl WireMsg {
pub async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result<Self> {
let mut header_bytes = [0; MSG_HEADER_LEN];
recv.read_exact(&mut header_bytes).await?;
let msg_header = MsgHeader::from_bytes(header_bytes);
let mut data: Vec<u8> = vec![0; msg_header.data_len()];
let msg_flag = msg_header.usr_msg_flag();
recv.read_exact(&mut data).await?;
trace!("Got new message with {} bytes.", data.len());
if data.is_empty() {
Err(Error::EmptyResponse)
} else if msg_flag == USER_MSG_FLAG {
Ok(WireMsg::UserMsg(From::from(data)))
} else if msg_flag == ECHO_SRVC_MSG_FLAG {
Ok(bincode::deserialize(&data)?)
} else {
Err(Error::InvalidWireMsgFlag)
}
}
pub async fn write_to_stream(&self, send_stream: &mut quinn::SendStream) -> Result<()> {
let (msg_bytes, msg_flag) = match self {
WireMsg::UserMsg(ref m) => (m.clone(), USER_MSG_FLAG),
_ => (
From::from(unwrap!(bincode::serialize(&self))),
!USER_MSG_FLAG,
),
};
trace!("Sending message to remote peer ({} bytes)", msg_bytes.len());
let msg_header = MsgHeader::new(&msg_bytes, msg_flag)?;
let header_bytes = msg_header.to_bytes();
send_stream.write_all(&header_bytes).await?;
send_stream.write_all(&msg_bytes[..]).await?;
Ok(())
}
}
impl fmt::Display for WireMsg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
WireMsg::UserMsg(ref m) => {
write!(f, "WireMsg::UserMsg({})", utils::bin_data_format(&*m))
}
WireMsg::EndpointEchoReq => write!(f, "WireMsg::EndpointEchoReq"),
WireMsg::EndpointEchoResp(ref sa) => write!(f, "WireMsg::EndpointEchoResp({})", sa),
}
}
}
struct MsgHeader {
version: u16,
data_len: usize,
usr_msg_flag: u8,
#[allow(unused)]
reserved: [u8; 2],
}
impl MsgHeader {
pub fn new(msg: &Bytes, usr_msg_flag: u8) -> Result<Self> {
let data_len = msg.len();
if data_len > u32::MAX as usize {
return Err(Error::MaxLengthExceeded);
}
Ok(Self {
version: MSG_PROTOCOL_VERSION,
data_len,
usr_msg_flag,
reserved: [0, 0],
})
}
pub fn data_len(&self) -> usize {
self.data_len
}
pub fn usr_msg_flag(&self) -> u8 {
self.usr_msg_flag
}
pub fn to_bytes(&self) -> [u8; MSG_HEADER_LEN] {
let version = self.version.to_be_bytes();
let data_len = self.data_len.to_be_bytes();
[
version[0],
version[1],
data_len[4],
data_len[5],
data_len[6],
data_len[7],
self.usr_msg_flag,
0,
0,
]
}
pub fn from_bytes(bytes: [u8; MSG_HEADER_LEN]) -> Self {
let version = u16::from_be_bytes([bytes[0], bytes[1]]);
let data_len = usize::from_be_bytes([0, 0, 0, 0, bytes[2], bytes[3], bytes[4], bytes[5]]);
let usr_msg_flag = bytes[6];
Self {
version,
data_len,
usr_msg_flag,
reserved: [0, 0],
}
}
}