nephele 0.0.2

A high performance asynchronous programming runtime for Rust.
use bytes::{Buf, BufMut, Bytes, BytesMut};
use cynthia::future::swap::{AsyncRead, AsyncWrite};
use std::error::Error as StdError;
use std::io::{self, Cursor};
use std::{cmp, fmt};

use crate::common::codec::{Decoder, Encoder, Framed, FramedRead, FramedWrite};

#[derive(Debug, Clone, Copy)]
pub struct Builder {
    max_frame_len: usize,

    length_field_len: usize,

    length_field_offset: usize,

    length_adjustment: isize,

    num_skip: Option<usize>,

    length_field_is_big_endian: bool,
}

pub struct LengthDelimitedCodecError {
    _priv: (),
}

#[derive(Debug)]
pub struct LengthDelimitedCodec {
    builder: Builder,
    state: DecodeState,
}

#[derive(Debug, Clone, Copy)]
enum DecodeState {
    Head,
    Data(usize),
}

impl LengthDelimitedCodec {
    pub fn new() -> Self {
        Self {
            builder: Builder::new(),
            state: DecodeState::Head,
        }
    }

    pub fn builder() -> Builder {
        Builder::new()
    }

    pub fn max_frame_length(&self) -> usize {
        self.builder.max_frame_len
    }

    pub fn set_max_frame_length(&mut self, val: usize) {
        self.builder.max_frame_length(val);
    }

    fn decode_head(&mut self, src: &mut BytesMut) -> io::Result<Option<usize>> {
        let head_len = self.builder.num_head_bytes();
        let field_len = self.builder.length_field_len;

        if src.len() < head_len {
            return Ok(None);
        }

        let n = {
            let mut src = Cursor::new(&mut *src);

            src.advance(self.builder.length_field_offset);

            let n = if self.builder.length_field_is_big_endian {
                src.get_uint(field_len)
            } else {
                src.get_uint_le(field_len)
            };

            if n > self.builder.max_frame_len as u64 {
                return Err(io::Error::new(
                    io::ErrorKind::InvalidData,
                    LengthDelimitedCodecError { _priv: () },
                ));
            }

            let n = n as usize;

            let n = if self.builder.length_adjustment < 0 {
                n.checked_sub(-self.builder.length_adjustment as usize)
            } else {
                n.checked_add(self.builder.length_adjustment as usize)
            };

            match n {
                Some(n) => n,
                None => {
                    return Err(io::Error::new(
                        io::ErrorKind::InvalidInput,
                        "provided length would overflow after adjustment",
                    ));
                }
            }
        };

        let num_skip = self.builder.get_num_skip();

        if num_skip > 0 {
            src.advance(num_skip);
        }

        src.reserve(n);

        Ok(Some(n))
    }

    fn decode_data(&self, n: usize, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
        if src.len() < n {
            return Ok(None);
        }

        Ok(Some(src.split_to(n)))
    }
}

impl Decoder for LengthDelimitedCodec {
    type Item = BytesMut;
    type Error = io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> io::Result<Option<BytesMut>> {
        let n = match self.state {
            DecodeState::Head => match self.decode_head(src)? {
                Some(n) => {
                    self.state = DecodeState::Data(n);
                    n
                }
                None => return Ok(None),
            },
            DecodeState::Data(n) => n,
        };

        match self.decode_data(n, src)? {
            Some(data) => {
                self.state = DecodeState::Head;

                src.reserve(self.builder.num_head_bytes());

                Ok(Some(data))
            }
            None => Ok(None),
        }
    }
}

impl Encoder<Bytes> for LengthDelimitedCodec {
    type Error = io::Error;

    fn encode(&mut self, data: Bytes, dst: &mut BytesMut) -> Result<(), io::Error> {
        let n = data.len();

        if n > self.builder.max_frame_len {
            return Err(io::Error::new(
                io::ErrorKind::InvalidInput,
                LengthDelimitedCodecError { _priv: () },
            ));
        }

        let n = if self.builder.length_adjustment < 0 {
            n.checked_add(-self.builder.length_adjustment as usize)
        } else {
            n.checked_sub(self.builder.length_adjustment as usize)
        };

        let n = n.ok_or_else(|| {
            io::Error::new(
                io::ErrorKind::InvalidInput,
                "provided length would overflow after adjustment",
            )
        })?;

        dst.reserve(self.builder.length_field_len + n);

        if self.builder.length_field_is_big_endian {
            dst.put_uint(n as u64, self.builder.length_field_len);
        } else {
            dst.put_uint_le(n as u64, self.builder.length_field_len);
        }

        dst.extend_from_slice(&data[..]);

        Ok(())
    }
}

impl Default for LengthDelimitedCodec {
    fn default() -> Self {
        Self::new()
    }
}

impl Builder {
    pub fn new() -> Builder {
        Builder {
            max_frame_len: 8 * 1_024 * 1_024,

            length_field_len: 4,

            length_field_offset: 0,

            length_adjustment: 0,

            num_skip: None,

            length_field_is_big_endian: true,
        }
    }

    pub fn big_endian(&mut self) -> &mut Self {
        self.length_field_is_big_endian = true;
        self
    }

    pub fn little_endian(&mut self) -> &mut Self {
        self.length_field_is_big_endian = false;
        self
    }

    pub fn native_endian(&mut self) -> &mut Self {
        if cfg!(target_endian = "big") {
            self.big_endian()
        } else {
            self.little_endian()
        }
    }

    pub fn max_frame_length(&mut self, val: usize) -> &mut Self {
        self.max_frame_len = val;
        self
    }

    pub fn length_field_length(&mut self, val: usize) -> &mut Self {
        assert!(val > 0 && val <= 8, "invalid length field length");
        self.length_field_len = val;
        self
    }

    pub fn length_field_offset(&mut self, val: usize) -> &mut Self {
        self.length_field_offset = val;
        self
    }

    pub fn length_adjustment(&mut self, val: isize) -> &mut Self {
        self.length_adjustment = val;
        self
    }

    pub fn num_skip(&mut self, val: usize) -> &mut Self {
        self.num_skip = Some(val);
        self
    }

    pub fn new_codec(&self) -> LengthDelimitedCodec {
        LengthDelimitedCodec {
            builder: *self,
            state: DecodeState::Head,
        }
    }

    pub fn new_read<T>(&self, upstream: T) -> FramedRead<T, LengthDelimitedCodec>
    where
        T: AsyncRead,
    {
        FramedRead::new(upstream, self.new_codec())
    }

    pub fn new_write<T>(&self, inner: T) -> FramedWrite<T, LengthDelimitedCodec>
    where
        T: AsyncWrite,
    {
        FramedWrite::new(inner, self.new_codec())
    }

    pub fn new_framed<T>(&self, inner: T) -> Framed<T, LengthDelimitedCodec>
    where
        T: AsyncRead + AsyncWrite,
    {
        Framed::new(inner, self.new_codec())
    }

    fn num_head_bytes(&self) -> usize {
        let num = self.length_field_offset + self.length_field_len;
        cmp::max(num, self.num_skip.unwrap_or(0))
    }

    fn get_num_skip(&self) -> usize {
        self.num_skip
            .unwrap_or(self.length_field_offset + self.length_field_len)
    }
}

impl Default for Builder {
    fn default() -> Self {
        Self::new()
    }
}

impl fmt::Debug for LengthDelimitedCodecError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("LengthDelimitedCodecError").finish()
    }
}

impl fmt::Display for LengthDelimitedCodecError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str("frame size too big")
    }
}

impl StdError for LengthDelimitedCodecError {}