use crate::schema::*;
use pretty_hash::fmt as pretty_fmt;
use prost::Message as _;
use std::fmt;
use std::io;
use crate::constants::MAX_MESSAGE_SIZE;
#[derive(Debug)]
pub struct EncodeError {
required: usize,
}
impl fmt::Display for EncodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Cannot encode message: Write buffer is full")
}
}
impl EncodeError {
fn new(required: usize) -> Self {
Self { required }
}
}
impl From<prost::EncodeError> for EncodeError {
fn from(e: prost::EncodeError) -> Self {
Self::new(e.required_capacity())
}
}
impl From<EncodeError> for io::Error {
fn from(e: EncodeError) -> Self {
io::Error::new(io::ErrorKind::Other, format!("{}", e))
}
}
pub trait Encoder: Sized + fmt::Debug {
fn encoded_len(&self) -> usize;
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError>;
}
impl Encoder for &[u8] {
fn encoded_len(&self) -> usize {
self.len()
}
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
let len = self.encoded_len();
if len > buf.len() {
return Err(EncodeError::new(len));
}
buf[..len].copy_from_slice(&self[..]);
Ok(len)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FrameType {
Raw,
Message,
}
#[derive(Clone, PartialEq)]
pub enum Frame {
Raw(Vec<u8>),
Message(ChannelMessage),
}
impl fmt::Debug for Frame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Frame::Raw(buf) => write!(f, "Frame(Raw <{}>)", buf.len()),
Frame::Message(message) => write!(f, "Frame({:?})", message),
}
}
}
impl From<ChannelMessage> for Frame {
fn from(m: ChannelMessage) -> Self {
Self::Message(m)
}
}
impl From<Vec<u8>> for Frame {
fn from(m: Vec<u8>) -> Self {
Self::Raw(m)
}
}
impl Frame {
pub fn decode(buf: &[u8], frame_type: &FrameType) -> Result<Self, io::Error> {
match frame_type {
FrameType::Raw => Ok(Frame::Raw(buf.to_vec())),
FrameType::Message => Ok(Frame::Message(ChannelMessage::decode(buf)?)),
}
}
fn body_len(&self) -> usize {
match self {
Self::Raw(message) => message.as_slice().encoded_len(),
Self::Message(message) => message.encoded_len(),
}
}
}
impl Encoder for Frame {
fn encoded_len(&self) -> usize {
let body_len = self.body_len();
body_len + varinteger::length(body_len as u64)
}
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
let len = self.encoded_len();
if buf.len() < len {
return Err(EncodeError::new(len));
}
let body_len = self.body_len();
let header_len = len - body_len;
varinteger::encode(body_len as u64, &mut buf[..header_len]);
match self {
Self::Raw(ref message) => message.as_slice().encode(&mut buf[header_len..]),
Self::Message(ref message) => message.encode(&mut buf[header_len..]),
}?;
Ok(len)
}
}
#[derive(Debug, Clone, PartialEq)]
#[allow(missing_docs)]
pub enum Message {
Open(Open),
Options(Options),
Status(Status),
Have(Have),
Unhave(Unhave),
Want(Want),
Unwant(Unwant),
Request(Request),
Cancel(Cancel),
Data(Data),
Close(Close),
Extension(ExtensionMessage),
}
impl Message {
pub fn decode(buf: &[u8], typ: u64) -> io::Result<Self> {
match typ {
0 => Ok(Self::Open(Open::decode(buf)?)),
1 => Ok(Self::Options(Options::decode(buf)?)),
2 => Ok(Self::Status(Status::decode(buf)?)),
3 => Ok(Self::Have(Have::decode(buf)?)),
4 => Ok(Self::Unhave(Unhave::decode(buf)?)),
5 => Ok(Self::Want(Want::decode(buf)?)),
6 => Ok(Self::Unwant(Unwant::decode(buf)?)),
7 => Ok(Self::Request(Request::decode(buf)?)),
8 => Ok(Self::Cancel(Cancel::decode(buf)?)),
9 => Ok(Self::Data(Data::decode(buf)?)),
10 => Ok(Self::Close(Close::decode(buf)?)),
15 => Ok(Self::Extension(ExtensionMessage::decode(buf)?)),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid message type",
)),
}
}
pub fn typ(&self) -> u64 {
match self {
Self::Open(_) => 0,
Self::Options(_) => 1,
Self::Status(_) => 2,
Self::Have(_) => 3,
Self::Unhave(_) => 4,
Self::Want(_) => 5,
Self::Unwant(_) => 6,
Self::Request(_) => 7,
Self::Cancel(_) => 8,
Self::Data(_) => 9,
Self::Close(_) => 10,
Self::Extension(_) => 15,
}
}
}
impl Encoder for Message {
fn encoded_len(&self) -> usize {
match self {
Self::Open(ref message) => message.encoded_len(),
Self::Options(ref message) => message.encoded_len(),
Self::Status(ref message) => message.encoded_len(),
Self::Have(ref message) => message.encoded_len(),
Self::Unhave(ref message) => message.encoded_len(),
Self::Want(ref message) => message.encoded_len(),
Self::Unwant(ref message) => message.encoded_len(),
Self::Request(ref message) => message.encoded_len(),
Self::Cancel(ref message) => message.encoded_len(),
Self::Data(ref message) => message.encoded_len(),
Self::Close(ref message) => message.encoded_len(),
Self::Extension(ref message) => message.encoded_len(),
}
}
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
match self {
Self::Open(ref message) => encode_prost_message(message, buf),
Self::Options(ref message) => encode_prost_message(message, buf),
Self::Status(ref message) => encode_prost_message(message, buf),
Self::Have(ref message) => encode_prost_message(message, buf),
Self::Unhave(ref message) => encode_prost_message(message, buf),
Self::Want(ref message) => encode_prost_message(message, buf),
Self::Unwant(ref message) => encode_prost_message(message, buf),
Self::Request(ref message) => encode_prost_message(message, buf),
Self::Cancel(ref message) => encode_prost_message(message, buf),
Self::Data(ref message) => encode_prost_message(message, buf),
Self::Close(ref message) => encode_prost_message(message, buf),
Self::Extension(ref message) => message.encode(buf),
}
}
}
fn encode_prost_message(
msg: &impl prost::Message,
mut buf: &mut [u8],
) -> Result<usize, EncodeError> {
let len = msg.encoded_len();
msg.encode(&mut buf)?;
Ok(len)
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Open(msg) => write!(
f,
"Open(discovery_key: {}, capability <{}>)",
pretty_fmt(&msg.discovery_key).unwrap(),
msg.capability.as_ref().map_or(0, |c| c.len())
),
Self::Data(msg) => write!(
f,
"Data(index {}, value: <{}>, nodes: {}, signature <{}>)",
msg.index,
msg.value.as_ref().map_or(0, |d| d.len()),
msg.nodes.len(),
msg.signature.as_ref().map_or(0, |d| d.len()),
),
_ => write!(f, "{:?}", &self),
}
}
}
#[derive(Clone, PartialEq)]
pub struct ChannelMessage {
pub channel: u64,
pub message: Message,
}
impl fmt::Debug for ChannelMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ChannelMessage({}, {})", self.channel, self.message)
}
}
impl ChannelMessage {
pub fn new(channel: u64, message: Message) -> Self {
Self { channel, message }
}
pub fn into_split(self) -> (u64, Message) {
(self.channel, self.message)
}
pub fn decode(buf: &[u8]) -> io::Result<Self> {
if buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"received empty message",
));
}
let mut header = 0u64;
let headerlen = varinteger::decode(&buf, &mut header);
let channel = header >> 4;
let typ = header & 0b1111;
let message = Message::decode(&buf[headerlen..], typ)?;
let channel_message = Self { channel, message };
Ok(channel_message)
}
fn header(&self) -> u64 {
let typ = self.message.typ();
self.channel << 4 | typ
}
}
impl Encoder for ChannelMessage {
fn encoded_len(&self) -> usize {
let header_len = varinteger::length(self.header());
let body_len = self.message.encoded_len();
header_len + body_len
}
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
let header = self.header();
let header_len = varinteger::length(header);
let body_len = self.message.encoded_len();
let len = header_len + body_len;
if buf.len() < len || len > MAX_MESSAGE_SIZE as usize {
return Err(EncodeError::new(len));
}
varinteger::encode(header, &mut buf[..header_len]);
self.message.encode(&mut buf[header_len..len])?;
Ok(len)
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct ExtensionMessage {
pub id: u64,
pub message: Vec<u8>,
}
impl ExtensionMessage {
pub fn new(id: u64, message: Vec<u8>) -> Self {
Self { id, message }
}
fn decode(buf: &[u8]) -> io::Result<Self> {
if buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Extension message may not be empty",
));
}
let mut id: u64 = 0;
let id_len = varinteger::decode(&buf, &mut id);
Ok(Self {
id,
message: buf[id_len..].to_vec(),
})
}
}
impl Encoder for ExtensionMessage {
fn encoded_len(&self) -> usize {
let id_len = varinteger::length(self.id);
id_len + self.message.len()
}
fn encode(&self, buf: &mut [u8]) -> Result<usize, EncodeError> {
let id_len = varinteger::length(self.id);
let len = self.message.len() + id_len;
if buf.len() < len {
return Err(EncodeError::new(len));
}
varinteger::encode(self.id, &mut buf[..id_len]);
buf[id_len..len].copy_from_slice(&self.message[..]);
Ok(len)
}
}
#[cfg(test)]
mod tests {
use super::*;
macro_rules! message_enc_dec {
($( $msg:expr ),*) => {
$(
let channel = rand::random::<u8>() as u64;
let channel_message = ChannelMessage::new(channel, $msg);
let mut buf = vec![0u8; channel_message.encoded_len()];
let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message");
let decoded = ChannelMessage::decode(&buf[..n]).expect("Failed to decode message").into_split();
assert_eq!(channel, decoded.0);
assert_eq!($msg, decoded.1);
)*
}
}
#[test]
fn encode_decode() {
message_enc_dec! {
Message::Open(Open{
discovery_key: vec![2u8; 20],
capability: None
}),
Message::Options(Options {
extensions: vec!["test ext".to_string()],
ack: None
}),
Message::Status(Status {
uploading: Some(true),
downloading: Some(false)
}),
Message::Have(Have {
start: 0,
length: Some(100),
bitfield: None,
ack: Some(true)
}),
Message::Unhave(Unhave {
start: 0,
length: Some(100),
}),
Message::Want(Want {
start: 0,
length: Some(100),
}),
Message::Request(Request {
index: 0,
bytes: None,
hash: Some(true),
nodes: None
}),
Message::Cancel(Cancel{
index: 10,
bytes: Some(10),
hash: Some(true)
}),
Message::Data(Data {
index: 1,
value: None,
nodes: vec![],
signature: None
}),
Message::Close(Close {
discovery_key: Some(vec![1u8; 10])
})
};
}
}