use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use futures::prelude::*;
use tetsy_libp2p_core::ProtocolName;
use minicbor::{Encode, Decode};
use std::io;
use super::RequestResponseCodec;
use unsigned_varint::{aio, io::ReadError};
#[derive(Debug, Default, Clone, PartialEq, Eq, Encode, Decode)]
#[cbor(map)]
pub struct Header {
#[n(0)] pub typ: Option<Type>,
#[n(1)] pub credit: Option<u16>,
#[n(2)] pub ident: Option<u64>
}
#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
pub enum Type {
#[n(0)] Request,
#[n(1)] Response,
#[n(2)] Credit,
#[n(3)] Ack
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Message<T> {
header: Header,
data: Option<T>
}
impl<T> Message<T> {
fn new(header: Header) -> Self {
Message { header, data: None }
}
pub fn request(data: T) -> Self {
let mut m = Message::new(Header { typ: Some(Type::Request), .. Header::default() });
m.data = Some(data);
m
}
pub fn response(data: T) -> Self {
let mut m = Message::new(Header { typ: Some(Type::Response), .. Header::default() });
m.data = Some(data);
m
}
pub fn credit(credit: u16, ident: u64) -> Self {
Message::new(Header { typ: Some(Type::Credit), credit: Some(credit), ident: Some(ident) })
}
pub fn ack(ident: u64) -> Self {
Message::new(Header { typ: Some(Type::Ack), credit: None, ident: Some(ident) })
}
pub fn header(&self) -> &Header {
&self.header
}
pub fn data(&self) -> Option<&T> {
self.data.as_ref()
}
pub fn into_parts(self) -> (Header, Option<T>) {
(self.header, self.data)
}
}
#[derive(Debug, Clone)]
pub struct ProtocolWrapper<P>(P, Bytes);
impl<P: ProtocolName> ProtocolWrapper<P> {
pub fn new(prefix: &[u8], p: P) -> Self {
let mut full = BytesMut::from(prefix);
full.extend_from_slice(p.protocol_name());
ProtocolWrapper(p, full.freeze())
}
}
impl<P> ProtocolName for ProtocolWrapper<P> {
fn protocol_name(&self) -> &[u8] {
self.1.as_ref()
}
}
#[derive(Debug, Clone)]
pub struct Codec<C> {
inner: C,
buffer: Vec<u8>,
max_header_len: u32
}
impl<C> Codec<C> {
pub fn new(c: C, max_header_len: u32) -> Self {
Codec { inner: c, buffer: Vec::new(), max_header_len }
}
async fn read_header<T, H>(&mut self, io: &mut T) -> io::Result<H>
where
T: AsyncRead + Unpin + Send,
H: for<'a> minicbor::Decode<'a>
{
let header_len = aio::read_u32(&mut *io).await
.map_err(|e| match e {
ReadError::Io(e) => e,
other => io::Error::new(io::ErrorKind::Other, other)
})?;
if header_len > self.max_header_len {
return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to read"))
}
self.buffer.resize(u32_to_usize(header_len), 0u8);
io.read_exact(&mut self.buffer).await?;
minicbor::decode(&self.buffer).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
}
async fn write_header<T, H>(&mut self, hdr: &H, io: &mut T) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send,
H: minicbor::Encode
{
self.buffer.clear();
minicbor::encode(hdr, &mut self.buffer).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
if self.buffer.len() > u32_to_usize(self.max_header_len) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to write"))
}
let mut b = unsigned_varint::encode::u32_buffer();
let header_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut b);
io.write_all(header_len).await?;
io.write_all(&self.buffer).await
}
}
#[async_trait]
impl<C> RequestResponseCodec for Codec<C>
where
C: RequestResponseCodec + Send,
C::Protocol: Sync
{
type Protocol = ProtocolWrapper<C::Protocol>;
type Request = Message<C::Request>;
type Response = Message<C::Response>;
async fn read_request<T>(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
where
T: AsyncRead + Unpin + Send
{
let mut msg = Message::new(self.read_header(io).await?);
match msg.header.typ {
Some(Type::Request) => {
msg.data = Some(self.inner.read_request(&p.0, io).await?);
Ok(msg)
}
Some(Type::Credit) => Ok(msg),
Some(Type::Response) | Some(Type::Ack) | None => {
log::debug!("unexpected {:?} when expecting request or credit grant", msg.header.typ);
Err(io::ErrorKind::InvalidData.into())
}
}
}
async fn read_response<T>(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result<Self::Response>
where
T: AsyncRead + Unpin + Send
{
let mut msg = Message::new(self.read_header(io).await?);
match msg.header.typ {
Some(Type::Response) => {
msg.data = Some(self.inner.read_response(&p.0, io).await?);
Ok(msg)
}
Some(Type::Ack) => Ok(msg),
Some(Type::Request) | Some(Type::Credit) | None => {
log::debug!("unexpected {:?} when expecting response or ack", msg.header.typ);
Err(io::ErrorKind::InvalidData.into())
}
}
}
async fn write_request<T>(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Request) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send
{
self.write_header(&r.header, io).await?;
if let Some(data) = r.data {
self.inner.write_request(&p.0, io, data).await?
}
Ok(())
}
async fn write_response<T>(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Response) -> io::Result<()>
where
T: AsyncWrite + Unpin + Send
{
self.write_header(&r.header, io).await?;
if let Some(data) = r.data {
self.inner.write_response(&p.0, io, data).await?
}
Ok(())
}
}
#[cfg(any(target_pointer_width = "64", target_pointer_width = "32"))]
fn u32_to_usize(n: u32) -> usize {
n as usize
}