use bytes::{BufMut, Bytes, BytesMut};
use asynchronous_codec::{Decoder, Encoder};
use libp2p_core::Endpoint;
use std::{fmt, hash::{Hash, Hasher}, io, mem};
use unsigned_varint::{codec, encode};
pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub struct LocalStreamId {
num: u32,
role: Endpoint,
}
impl fmt::Display for LocalStreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.role {
Endpoint::Dialer => write!(f, "({}/initiator)", self.num),
Endpoint::Listener => write!(f, "({}/receiver)", self.num),
}
}
}
impl Hash for LocalStreamId {
#![allow(clippy::derive_hash_xor_eq)]
fn hash<H: Hasher>(&self, state: &mut H) {
state.write_u32(self.num);
}
}
impl nohash_hasher::IsEnabled for LocalStreamId {}
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub struct RemoteStreamId {
num: u32,
role: Endpoint,
}
impl LocalStreamId {
pub fn dialer(num: u32) -> Self {
Self { num, role: Endpoint::Dialer }
}
#[cfg(test)]
pub fn listener(num: u32) -> Self {
Self { num, role: Endpoint::Listener }
}
pub fn next(self) -> Self {
Self {
num: self.num.checked_add(1).expect("Mplex substream ID overflowed"),
.. self
}
}
#[cfg(test)]
pub fn into_remote(self) -> RemoteStreamId {
RemoteStreamId {
num: self.num,
role: !self.role,
}
}
}
impl RemoteStreamId {
fn dialer(num: u32) -> Self {
Self { num, role: Endpoint::Dialer }
}
fn listener(num: u32) -> Self {
Self { num, role: Endpoint::Listener }
}
pub fn into_local(self) -> LocalStreamId {
LocalStreamId {
num: self.num,
role: !self.role,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Frame<T> {
Open { stream_id: T },
Data { stream_id: T, data: Bytes },
Close { stream_id: T },
Reset { stream_id: T },
}
impl Frame<RemoteStreamId> {
pub fn remote_id(&self) -> RemoteStreamId {
match *self {
Frame::Open { stream_id } => stream_id,
Frame::Data { stream_id, .. } => stream_id,
Frame::Close { stream_id, .. } => stream_id,
Frame::Reset { stream_id, .. } => stream_id,
}
}
}
pub struct Codec {
varint_decoder: codec::Uvi<u32>,
decoder_state: CodecDecodeState,
}
#[derive(Debug, Clone)]
enum CodecDecodeState {
Begin,
HasHeader(u32),
HasHeaderAndLen(u32, usize),
Poisoned,
}
impl Codec {
pub fn new() -> Codec {
Codec {
varint_decoder: codec::Uvi::default(),
decoder_state: CodecDecodeState::Begin,
}
}
}
impl Decoder for Codec {
type Item = Frame<RemoteStreamId>;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop {
match mem::replace(&mut self.decoder_state, CodecDecodeState::Poisoned) {
CodecDecodeState::Begin => {
match self.varint_decoder.decode(src)? {
Some(header) => {
self.decoder_state = CodecDecodeState::HasHeader(header);
},
None => {
self.decoder_state = CodecDecodeState::Begin;
return Ok(None);
},
}
},
CodecDecodeState::HasHeader(header) => {
match self.varint_decoder.decode(src)? {
Some(len) => {
if len as usize > MAX_FRAME_SIZE {
let msg = format!("Mplex frame length {} exceeds maximum", len);
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
}
self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len as usize);
},
None => {
self.decoder_state = CodecDecodeState::HasHeader(header);
return Ok(None);
},
}
},
CodecDecodeState::HasHeaderAndLen(header, len) => {
if src.len() < len {
self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len);
let to_reserve = len - src.len();
src.reserve(to_reserve);
return Ok(None);
}
let buf = src.split_to(len);
let num = (header >> 3) as u32;
let out = match header & 7 {
0 => Frame::Open { stream_id: RemoteStreamId::dialer(num) },
1 => Frame::Data { stream_id: RemoteStreamId::listener(num), data: buf.freeze() },
2 => Frame::Data { stream_id: RemoteStreamId::dialer(num), data: buf.freeze() },
3 => Frame::Close { stream_id: RemoteStreamId::listener(num) },
4 => Frame::Close { stream_id: RemoteStreamId::dialer(num) },
5 => Frame::Reset { stream_id: RemoteStreamId::listener(num) },
6 => Frame::Reset { stream_id: RemoteStreamId::dialer(num) },
_ => {
let msg = format!("Invalid mplex header value 0x{:x}", header);
return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
},
};
self.decoder_state = CodecDecodeState::Begin;
return Ok(Some(out));
},
CodecDecodeState::Poisoned => {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Mplex codec poisoned"));
}
}
}
}
}
impl Encoder for Codec {
type Item = Frame<LocalStreamId>;
type Error = io::Error;
fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> {
let (header, data) = match item {
Frame::Open { stream_id } => {
(u64::from(stream_id.num) << 3, Bytes::new())
},
Frame::Data { stream_id: LocalStreamId { num, role: Endpoint::Listener }, data } => {
(u64::from(num) << 3 | 1, data)
},
Frame::Data { stream_id: LocalStreamId { num, role: Endpoint::Dialer }, data } => {
(u64::from(num) << 3 | 2, data)
},
Frame::Close { stream_id: LocalStreamId { num, role: Endpoint::Listener } } => {
(u64::from(num) << 3 | 3, Bytes::new())
},
Frame::Close { stream_id: LocalStreamId { num, role: Endpoint::Dialer } } => {
(u64::from(num) << 3 | 4, Bytes::new())
},
Frame::Reset { stream_id: LocalStreamId { num, role: Endpoint::Listener } } => {
(u64::from(num) << 3 | 5, Bytes::new())
},
Frame::Reset { stream_id: LocalStreamId { num, role: Endpoint::Dialer } } => {
(u64::from(num) << 3 | 6, Bytes::new())
},
};
let mut header_buf = encode::u64_buffer();
let header_bytes = encode::u64(header, &mut header_buf);
let data_len = data.as_ref().len();
let mut data_buf = encode::usize_buffer();
let data_len_bytes = encode::usize(data_len, &mut data_buf);
if data_len > MAX_FRAME_SIZE {
return Err(io::Error::new(io::ErrorKind::InvalidData, "data size exceed maximum"));
}
dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len);
dst.put(header_bytes);
dst.put(data_len_bytes);
dst.put(data);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_large_messages_fails() {
let mut enc = Codec::new();
let role = Endpoint::Dialer;
let data = Bytes::from(&[123u8; MAX_FRAME_SIZE + 1][..]);
let bad_msg = Frame::Data { stream_id: LocalStreamId { num: 123, role }, data };
let mut out = BytesMut::new();
match enc.encode(bad_msg, &mut out) {
Err(e) => assert_eq!(e.to_string(), "data size exceed maximum"),
_ => panic!("Can't send a message bigger than MAX_FRAME_SIZE")
}
let data = Bytes::from(&[123u8; MAX_FRAME_SIZE][..]);
let ok_msg = Frame::Data { stream_id: LocalStreamId { num: 123, role }, data };
assert!(enc.encode(ok_msg, &mut out).is_ok());
}
}