use crate::schema::*;
use crate::util::{stat_uint24_le, write_uint24_le};
use hypercore::encoding::{
CompactEncoding, EncodingError, EncodingErrorKind, HypercoreState, State,
};
use pretty_hash::fmt as pretty_fmt;
use std::fmt;
use std::io;
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum FrameType {
Raw,
Message,
}
pub(crate) trait Encoder: Sized + fmt::Debug {
fn encoded_len(&mut self) -> Result<usize, EncodingError>;
fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError>;
}
impl Encoder for &[u8] {
fn encoded_len(&mut self) -> Result<usize, EncodingError> {
Ok(self.len())
}
fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError> {
let len = self.encoded_len()?;
if len > buf.len() {
return Err(EncodingError::new(
EncodingErrorKind::Overflow,
&format!("Length does not fit buffer, {} > {}", len, buf.len()),
));
}
buf[..len].copy_from_slice(&self[..]);
Ok(len)
}
}
#[derive(Clone, PartialEq)]
pub(crate) enum Frame {
RawBatch(Vec<Vec<u8>>),
MessageBatch(Vec<ChannelMessage>),
}
impl fmt::Debug for Frame {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Frame::RawBatch(batch) => write!(f, "Frame(RawBatch <{}>)", batch.len()),
Frame::MessageBatch(messages) => write!(f, "Frame({messages:?})"),
}
}
}
impl From<ChannelMessage> for Frame {
fn from(m: ChannelMessage) -> Self {
Self::MessageBatch(vec![m])
}
}
impl From<Vec<u8>> for Frame {
fn from(m: Vec<u8>) -> Self {
Self::RawBatch(vec![m])
}
}
impl Frame {
pub(crate) fn decode_multiple(buf: &[u8], frame_type: &FrameType) -> Result<Self, io::Error> {
match frame_type {
FrameType::Raw => {
let mut index = 0;
let mut raw_batch: Vec<Vec<u8>> = vec![];
while index < buf.len() {
if buf[index] == 0 {
index += 1;
continue;
}
let stat = stat_uint24_le(&buf[index..]);
if let Some((header_len, body_len)) = stat {
raw_batch.push(
buf[index + header_len..index + header_len + body_len as usize]
.to_vec(),
);
index += header_len + body_len as usize;
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"received invalid data in raw batch",
));
}
}
Ok(Frame::RawBatch(raw_batch))
}
FrameType::Message => {
let mut index = 0;
let mut combined_messages: Vec<ChannelMessage> = vec![];
while index < buf.len() {
if buf[index] == 0 {
index += 1;
continue;
}
let stat = stat_uint24_le(&buf[index..]);
if let Some((header_len, body_len)) = stat {
let (frame, length) = Self::decode_message(
&buf[index + header_len..index + header_len + body_len as usize],
)?;
if length != body_len as usize {
tracing::warn!(
"Did not know what to do with all the bytes, got {} but decoded {}. \
This may be because the peer implements a newer protocol version \
that has extra fields.",
body_len,
length
);
}
if let Frame::MessageBatch(messages) = frame {
for message in messages {
combined_messages.push(message);
}
} else {
unreachable!("Can not get Raw messages");
}
index += header_len + body_len as usize;
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"received invalid data in multi-message chunk",
));
}
}
Ok(Frame::MessageBatch(combined_messages))
}
}
}
pub(crate) fn decode(buf: &[u8], frame_type: &FrameType) -> Result<Self, io::Error> {
match frame_type {
FrameType::Raw => Ok(Frame::RawBatch(vec![buf.to_vec()])),
FrameType::Message => {
let (frame, _) = Self::decode_message(buf)?;
Ok(frame)
}
}
}
fn decode_message(buf: &[u8]) -> Result<(Self, usize), io::Error> {
if buf.len() >= 3 && buf[0] == 0x00 {
if buf[1] == 0x00 {
let mut messages: Vec<ChannelMessage> = vec![];
let mut state = State::new_with_start_and_end(2, buf.len());
let mut current_channel: u64 = state.decode(buf)?;
while state.start() < state.end() {
let channel_message_length: usize = state.decode(buf)?;
if state.start() + channel_message_length > state.end() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"received invalid message length, {} + {} > {}",
state.start(),
channel_message_length,
state.end()
),
));
}
let (channel_message, _) = ChannelMessage::decode(
&buf[state.start()..state.start() + channel_message_length],
current_channel,
)?;
messages.push(channel_message);
state.add_start(channel_message_length)?;
if state.start() < state.end() && buf[state.start()] == 0x00 {
state.add_start(1)?;
current_channel = state.decode(buf)?;
}
}
Ok((Frame::MessageBatch(messages), state.start()))
} else if buf[1] == 0x01 {
let (channel_message, length) = ChannelMessage::decode_open_message(&buf[2..])?;
Ok((Frame::MessageBatch(vec![channel_message]), length + 2))
} else if buf[1] == 0x03 {
let (channel_message, length) = ChannelMessage::decode_close_message(&buf[2..])?;
Ok((Frame::MessageBatch(vec![channel_message]), length + 2))
} else {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"received invalid special message",
))
}
} else if buf.len() >= 2 {
let mut state = State::from_buffer(buf);
let channel: u64 = state.decode(buf)?;
let (channel_message, length) = ChannelMessage::decode(&buf[state.start()..], channel)?;
Ok((
Frame::MessageBatch(vec![channel_message]),
state.start() + length,
))
} else {
Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("received too short message, {buf:02X?}"),
))
}
}
fn preencode(&mut self, state: &mut State) -> Result<usize, EncodingError> {
match self {
Self::RawBatch(raw_batch) => {
for raw in raw_batch {
state.add_end(raw.as_slice().encoded_len()?)?;
}
}
#[allow(clippy::comparison_chain)]
Self::MessageBatch(messages) => {
if messages.len() == 1 {
if let Message::Open(_) = &messages[0].message {
state.add_end(2 + &messages[0].encoded_len()?)?;
} else if let Message::Close(_) = &messages[0].message {
state.add_end(2 + &messages[0].encoded_len()?)?;
} else {
(*state).preencode(&messages[0].channel)?;
state.add_end(messages[0].encoded_len()?)?;
}
} else if messages.len() > 1 {
state.add_end(2)?;
let mut current_channel: u64 = messages[0].channel;
state.preencode(¤t_channel)?;
for message in messages.iter_mut() {
if message.channel != current_channel {
state.add_end(1)?;
state.preencode(&message.channel)?;
current_channel = message.channel;
}
let message_length = message.encoded_len()?;
state.preencode(&message_length)?;
state.add_end(message_length)?;
}
}
}
}
Ok(state.end())
}
}
impl Encoder for Frame {
fn encoded_len(&mut self) -> Result<usize, EncodingError> {
let body_len = self.preencode(&mut State::new())?;
match self {
Self::RawBatch(_) => Ok(body_len),
Self::MessageBatch(_) => Ok(3 + body_len),
}
}
fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError> {
let mut state = State::new();
let header_len = if let Self::RawBatch(_) = self { 0 } else { 3 };
let body_len = self.preencode(&mut state)?;
let len = body_len + header_len;
if buf.len() < len {
return Err(EncodingError::new(
EncodingErrorKind::Overflow,
&format!("Length does not fit buffer, {} > {}", len, buf.len()),
));
}
match self {
Self::RawBatch(ref raw_batch) => {
for raw in raw_batch {
raw.as_slice().encode(buf)?;
}
}
#[allow(clippy::comparison_chain)]
Self::MessageBatch(ref mut messages) => {
write_uint24_le(body_len, buf);
let buf = buf.get_mut(3..).expect("Buffer should be over 3 bytes");
if messages.len() == 1 {
if let Message::Open(_) = &messages[0].message {
state.encode(&(0_u8), buf)?;
state.encode(&(1_u8), buf)?;
state.add_start(messages[0].encode(&mut buf[state.start()..])?)?;
} else if let Message::Close(_) = &messages[0].message {
state.encode(&(0_u8), buf)?;
state.encode(&(3_u8), buf)?;
state.add_start(messages[0].encode(&mut buf[state.start()..])?)?;
} else {
state.encode(&messages[0].channel, buf)?;
state.add_start(messages[0].encode(&mut buf[state.start()..])?)?;
}
} else if messages.len() > 1 {
state.set_slice_to_buffer(&[0_u8, 0_u8], buf)?;
let mut current_channel: u64 = messages[0].channel;
state.encode(¤t_channel, buf)?;
for message in messages.iter_mut() {
if message.channel != current_channel {
state.encode(&(0_u8), buf)?;
state.encode(&message.channel, buf)?;
current_channel = message.channel;
}
let message_length = message.encoded_len()?;
state.encode(&message_length, buf)?;
state.add_start(message.encode(&mut buf[state.start()..])?)?;
}
}
}
};
Ok(len)
}
}
#[derive(Debug, Clone, PartialEq)]
#[allow(missing_docs)]
pub enum Message {
Open(Open),
Close(Close),
Synchronize(Synchronize),
Request(Request),
Cancel(Cancel),
Data(Data),
NoData(NoData),
Want(Want),
Unwant(Unwant),
Bitfield(Bitfield),
Range(Range),
Extension(Extension),
LocalSignal((String, Vec<u8>)),
}
impl Message {
pub(crate) fn typ(&self) -> u64 {
match self {
Self::Synchronize(_) => 0,
Self::Request(_) => 1,
Self::Cancel(_) => 2,
Self::Data(_) => 3,
Self::NoData(_) => 4,
Self::Want(_) => 5,
Self::Unwant(_) => 6,
Self::Bitfield(_) => 7,
Self::Range(_) => 8,
Self::Extension(_) => 9,
value => unimplemented!("{} does not have a type", value),
}
}
pub(crate) fn decode(buf: &[u8], typ: u64) -> Result<(Self, usize), EncodingError> {
let mut state = HypercoreState::from_buffer(buf);
let message = match typ {
0 => Ok(Self::Synchronize((*state).decode(buf)?)),
1 => Ok(Self::Request(state.decode(buf)?)),
2 => Ok(Self::Cancel((*state).decode(buf)?)),
3 => Ok(Self::Data(state.decode(buf)?)),
4 => Ok(Self::NoData((*state).decode(buf)?)),
5 => Ok(Self::Want((*state).decode(buf)?)),
6 => Ok(Self::Unwant((*state).decode(buf)?)),
7 => Ok(Self::Bitfield((*state).decode(buf)?)),
8 => Ok(Self::Range((*state).decode(buf)?)),
9 => Ok(Self::Extension((*state).decode(buf)?)),
_ => Err(EncodingError::new(
EncodingErrorKind::InvalidData,
&format!("Invalid message type to decode: {typ}"),
)),
}?;
Ok((message, state.start()))
}
pub(crate) fn preencode(&self, state: &mut HypercoreState) -> Result<usize, EncodingError> {
match self {
Self::Open(ref message) => state.0.preencode(message)?,
Self::Close(ref message) => state.0.preencode(message)?,
Self::Synchronize(ref message) => state.0.preencode(message)?,
Self::Request(ref message) => state.preencode(message)?,
Self::Cancel(ref message) => state.0.preencode(message)?,
Self::Data(ref message) => state.preencode(message)?,
Self::NoData(ref message) => state.0.preencode(message)?,
Self::Want(ref message) => state.0.preencode(message)?,
Self::Unwant(ref message) => state.0.preencode(message)?,
Self::Bitfield(ref message) => state.0.preencode(message)?,
Self::Range(ref message) => state.0.preencode(message)?,
Self::Extension(ref message) => state.0.preencode(message)?,
Self::LocalSignal(_) => 0,
};
Ok(state.end())
}
pub(crate) fn encode(
&self,
state: &mut HypercoreState,
buf: &mut [u8],
) -> Result<usize, EncodingError> {
match self {
Self::Open(ref message) => state.0.encode(message, buf)?,
Self::Close(ref message) => state.0.encode(message, buf)?,
Self::Synchronize(ref message) => state.0.encode(message, buf)?,
Self::Request(ref message) => state.encode(message, buf)?,
Self::Cancel(ref message) => state.0.encode(message, buf)?,
Self::Data(ref message) => state.encode(message, buf)?,
Self::NoData(ref message) => state.0.encode(message, buf)?,
Self::Want(ref message) => state.0.encode(message, buf)?,
Self::Unwant(ref message) => state.0.encode(message, buf)?,
Self::Bitfield(ref message) => state.0.encode(message, buf)?,
Self::Range(ref message) => state.0.encode(message, buf)?,
Self::Extension(ref message) => state.0.encode(message, buf)?,
Self::LocalSignal(_) => 0,
};
Ok(state.start())
}
}
impl fmt::Display for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Open(msg) => write!(
f,
"Open(discovery_key: {}, capability <{}>)",
pretty_fmt(&msg.discovery_key).unwrap(),
msg.capability.as_ref().map_or(0, |c| c.len())
),
Self::Data(msg) => write!(
f,
"Data(request: {}, fork: {}, block: {}, hash: {}, seek: {}, upgrade: {})",
msg.request,
msg.fork,
msg.block.is_some(),
msg.hash.is_some(),
msg.seek.is_some(),
msg.upgrade.is_some(),
),
_ => write!(f, "{:?}", &self),
}
}
}
#[derive(Clone)]
pub(crate) struct ChannelMessage {
pub(crate) channel: u64,
pub(crate) message: Message,
state: Option<HypercoreState>,
}
impl PartialEq for ChannelMessage {
fn eq(&self, other: &Self) -> bool {
self.channel == other.channel && self.message == other.message
}
}
impl fmt::Debug for ChannelMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ChannelMessage({}, {})", self.channel, self.message)
}
}
impl ChannelMessage {
pub(crate) fn new(channel: u64, message: Message) -> Self {
Self {
channel,
message,
state: None,
}
}
pub(crate) fn into_split(self) -> (u64, Message) {
(self.channel, self.message)
}
pub(crate) fn decode_open_message(buf: &[u8]) -> io::Result<(Self, usize)> {
if buf.len() <= 5 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"received too short Open message",
));
}
let mut state = State::new_with_start_and_end(0, buf.len());
let open_msg: Open = state.decode(buf)?;
Ok((
Self {
channel: open_msg.channel,
message: Message::Open(open_msg),
state: None,
},
state.start(),
))
}
pub(crate) fn decode_close_message(buf: &[u8]) -> io::Result<(Self, usize)> {
if buf.is_empty() {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"received too short Close message",
));
}
let mut state = State::new_with_start_and_end(0, buf.len());
let close_msg: Close = state.decode(buf)?;
Ok((
Self {
channel: close_msg.channel,
message: Message::Close(close_msg),
state: None,
},
state.start(),
))
}
pub(crate) fn decode(buf: &[u8], channel: u64) -> io::Result<(Self, usize)> {
if buf.len() <= 1 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"received empty message",
));
}
let mut state = State::from_buffer(buf);
let typ: u64 = state.decode(buf)?;
let (message, length) = Message::decode(&buf[state.start()..], typ)?;
Ok((
Self {
channel,
message,
state: None,
},
state.start() + length,
))
}
fn prepare_state(&mut self) -> Result<(), EncodingError> {
if self.state.is_none() {
let state = if let Message::Open(_) = self.message {
let mut state = HypercoreState::new();
self.message.preencode(&mut state)?;
state
} else if let Message::Close(_) = self.message {
let mut state = HypercoreState::new();
self.message.preencode(&mut state)?;
state
} else {
let mut state = HypercoreState::new();
let typ = self.message.typ();
(*state).preencode(&typ)?;
self.message.preencode(&mut state)?;
state
};
self.state = Some(state);
}
Ok(())
}
}
impl Encoder for ChannelMessage {
fn encoded_len(&mut self) -> Result<usize, EncodingError> {
self.prepare_state()?;
Ok(self.state.as_ref().unwrap().end())
}
fn encode(&mut self, buf: &mut [u8]) -> Result<usize, EncodingError> {
self.prepare_state()?;
let state = self.state.as_mut().unwrap();
if let Message::Open(_) = self.message {
self.message.encode(state, buf)?;
} else if let Message::Close(_) = self.message {
self.message.encode(state, buf)?;
} else {
let typ = self.message.typ();
state.0.encode(&typ, buf)?;
self.message.encode(state, buf)?;
}
Ok(state.start())
}
}
#[cfg(test)]
mod tests {
use super::*;
use hypercore::{
DataBlock, DataHash, DataSeek, DataUpgrade, Node, RequestBlock, RequestSeek, RequestUpgrade,
};
macro_rules! message_enc_dec {
($( $msg:expr ),*) => {
$(
let channel = rand::random::<u8>() as u64;
let mut channel_message = ChannelMessage::new(channel, $msg);
let encoded_len = channel_message.encoded_len().expect("Failed to get encoded length");
let mut buf = vec![0u8; encoded_len];
let n = channel_message.encode(&mut buf[..]).expect("Failed to encode message");
let decoded = ChannelMessage::decode(&buf[..n], channel).expect("Failed to decode message").0.into_split();
assert_eq!(channel, decoded.0);
assert_eq!($msg, decoded.1);
)*
}
}
#[test]
fn message_encode_decode() {
message_enc_dec! {
Message::Synchronize(Synchronize{
fork: 0,
can_upgrade: true,
downloading: true,
uploading: true,
length: 5,
remote_length: 0,
}),
Message::Request(Request {
id: 1,
fork: 1,
block: Some(RequestBlock {
index: 5,
nodes: 10,
}),
hash: Some(RequestBlock {
index: 20,
nodes: 0
}),
seek: Some(RequestSeek {
bytes: 10
}),
upgrade: Some(RequestUpgrade {
start: 0,
length: 10
})
}),
Message::Cancel(Cancel {
request: 1,
}),
Message::Data(Data{
request: 1,
fork: 5,
block: Some(DataBlock {
index: 5,
nodes: vec![Node::new(1, vec![0x01; 32], 100)],
value: vec![0xFF; 10]
}),
hash: Some(DataHash {
index: 20,
nodes: vec![Node::new(2, vec![0x02; 32], 200)],
}),
seek: Some(DataSeek {
bytes: 10,
nodes: vec![Node::new(3, vec![0x03; 32], 300)],
}),
upgrade: Some(DataUpgrade {
start: 0,
length: 10,
nodes: vec![Node::new(4, vec![0x04; 32], 400)],
additional_nodes: vec![Node::new(5, vec![0x05; 32], 500)],
signature: vec![0xAB; 32]
})
}),
Message::NoData(NoData {
request: 2,
}),
Message::Want(Want {
start: 0,
length: 100,
}),
Message::Unwant(Unwant {
start: 10,
length: 2,
}),
Message::Bitfield(Bitfield {
start: 20,
bitfield: vec![0x89ABCDEF, 0x00, 0xFFFFFFFF],
}),
Message::Range(Range {
drop: true,
start: 12345,
length: 100000
}),
Message::Extension(Extension {
name: "custom_extension/v1/open".to_string(),
message: vec![0x44, 20]
})
};
}
}