networky 0.1.1

networking library for indigo with NaCl (Curve25519) encrypted connections and an async progress monitor.
Documentation
pub use tweetnacly::*;
pub use expry::*;

use bytes::{BytesMut, Buf};
use tokio_util::codec::{Decoder, Framed, FramedRead};
use tokio::io::{AsyncWriteExt as _, AsyncWrite, AsyncRead};
use tokio_stream::Stream;

pub const MAX_NACL_RECEIVE_BUFFER: u64 = 65536;
pub const MAX_NACL_SEND_BUFFER: usize = 32768;

pub struct NaClCodec {
    recv_cache: CryptoBoxCache,
    recv_nonce: Nonce,
    server: bool,
}

impl NaClCodec {
    pub fn new_server(sk: SecretBoxKey, public_key_of_session: &PublicBoxKey) -> Self {
        Self {
            recv_cache: crypto_box_prepare(&sk, public_key_of_session),
            recv_nonce: Nonce{data: [0u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]},
            server: true,
        }
    }
    pub fn new_client(sk: SecretBoxKey, public_key_of_session: &PublicBoxKey) -> Self {
        Self {
            recv_cache: crypto_box_prepare(&sk, public_key_of_session),
            recv_nonce: Nonce{data: [255u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]},
            server: false,
        }
    }
}

impl tokio_util::codec::Decoder for NaClCodec {
    type Item = BytesMut;
    type Error = std::io::Error;

    fn decode(
        &mut self,
        src: &mut BytesMut
    ) -> Result<Option<Self::Item>, Self::Error> {
        let mut reader = RawReader::with(src);
        if let Ok(frame_size) = reader.read_var_u64() {
            let remaining = reader.len();
            if remaining >= frame_size as usize {
                src.advance(src.len() - remaining); // skip the frame size
                let mut data = src.split_to(frame_size as usize);

                if data.len() < tweetnacly::bindings::crypto_box_MACBYTES as usize {
                    return Err(std::io::Error::new(std::io::ErrorKind::Other, "NaCl packet did not contain enough bytes for MAC"));
                }
                let mut raw_data = [0u8; tweetnacly::bindings::crypto_box_MACBYTES as usize];
                raw_data.copy_from_slice(&data[0..tweetnacly::bindings::crypto_box_MACBYTES as usize]);
                let tag = AuthenticationTag{data: raw_data};
                data.advance(tweetnacly::bindings::crypto_box_MACBYTES as usize);

                crypto_box_open_in_place(&mut data, &tag, &self.recv_nonce, &self.recv_cache).map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, "crypto error"))?;
                if self.server {
                    increase_nonce(&mut self.recv_nonce.data);
                } else {
                    decrease_nonce(&mut self.recv_nonce.data);
                }
                assert_eq!(frame_size as usize, tweetnacly::bindings::crypto_box_MACBYTES as usize + data.len());

                return Ok(Some(data));
            } else {
                if frame_size > MAX_NACL_RECEIVE_BUFFER {
                    return Err(std::io::Error::new(std::io::ErrorKind::Other, "max frame size exceeded"));
                }
                if src.capacity() < MAX_NACL_RECEIVE_BUFFER as usize{
                    src.reserve(MAX_NACL_RECEIVE_BUFFER as usize - src.len());
                }
                src.reserve(frame_size as usize - remaining);
            }
        }
        Ok(None)
    }
}

pub struct FramedReadWrapper<T,U> 
where
    T: AsyncRead,
    U: Decoder,
{
    framed: FramedRead<T,U>,
    current_read: Option<BytesMut>,
}
impl<T,U> FramedReadWrapper<T,U>
where
    T: AsyncRead,
    U: Decoder,
{
    pub fn new(framed: FramedRead<T,U>) -> Self {
        Self {
            framed,
            current_read: None,
        }
    }
    pub fn get_mut(&mut self) -> &mut T {
        self.framed.get_mut()
    }
}

impl<T,U> AsyncRead for FramedReadWrapper<T,U>
where
    T: AsyncRead,
    U: Decoder<Item=BytesMut,Error=std::io::Error>,
{
    fn poll_read(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        let mut retval = std::task::Poll::Pending;
        let this = unsafe { self.get_unchecked_mut() };
        let current_read = &mut this.current_read;
        if let Some(current_data) = current_read {
            let len = std::cmp::min(current_data.len(),buf.remaining());
            buf.put_slice(&current_data[0..len]);
            if len < current_data.len() {
                current_data.advance(len);
                return std::task::Poll::Ready(Ok(()));
            }
            retval = std::task::Poll::Ready(Ok(()));
            *current_read = None;
        }
        while let std::task::Poll::Ready(Some(result)) = Stream::poll_next(unsafe { std::pin::Pin::new_unchecked(&mut this.framed) }, cx) {
            match result {
                Ok(mut current_data) => {
                    let len = std::cmp::min(current_data.len(),buf.remaining());
                    buf.put_slice(&current_data[0..len]);
                    retval = std::task::Poll::Ready(Ok(()));
                    if len < current_data.len() {
                        current_data.advance(len);
                        this.current_read = Some(current_data);
                        return retval;
                    }
                },
                Err(err) => {
                    return std::task::Poll::Ready(Err(err));
                },
            }
        }
        retval
    }
}

pub struct FramedWrapper<T,U> 
where
    T: AsyncRead + AsyncWrite,
    U: Decoder,
{
    framed: Framed<T,U>,
    current_read: Option<BytesMut>,
}
impl<T,U> FramedWrapper<T,U>
where
    T: AsyncRead + AsyncWrite,
    U: Decoder,
{
    pub fn new(framed: Framed<T,U>) -> Self {
        Self {
            framed,
            current_read: None,
        }
    }
    pub fn get_mut(&mut self) -> &mut T {
        self.framed.get_mut()
    }
}

impl<T,U> AsyncRead for FramedWrapper<T,U>
where
    T: AsyncRead + AsyncWrite,
    U: Decoder<Item=BytesMut,Error=std::io::Error>,
{
    fn poll_read(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        let mut retval = std::task::Poll::Pending;
        let this = unsafe { self.get_unchecked_mut() };
        let current_read = &mut this.current_read;
        if let Some(current_data) = current_read {
            let len = std::cmp::min(current_data.len(),buf.remaining());
            buf.put_slice(&current_data[0..len]);
            if len < current_data.len() {
                current_data.advance(len);
                return std::task::Poll::Ready(Ok(()));
            }
            retval = std::task::Poll::Ready(Ok(()));
            *current_read = None;
        }
        while let std::task::Poll::Ready(Some(result)) = Stream::poll_next(unsafe { std::pin::Pin::new_unchecked(&mut this.framed) }, cx) {
            match result {
                Ok(mut current_data) => {
                    let len = std::cmp::min(current_data.len(),buf.remaining());
                    buf.put_slice(&current_data[0..len]);
                    retval = std::task::Poll::Ready(Ok(()));
                    if len < current_data.len() {
                        current_data.advance(len);
                        this.current_read = Some(current_data);
                        return retval;
                    }
                },
                Err(err) => {
                    return std::task::Poll::Ready(Err(err));
                },
            }
        }
        retval
    }
}

impl<T,U> AsyncWrite for FramedWrapper<T,U>
where
    T: AsyncRead + AsyncWrite,
    U: Decoder<Item=BytesMut>,
{
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
        let pin_framed = unsafe {
            std::pin::Pin::new_unchecked(std::pin::Pin::into_inner_unchecked(self).framed.get_mut())
        };
        AsyncWrite::poll_write(pin_framed, cx, buf)
    }

    fn poll_flush(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
        let pin_framed = unsafe {
            std::pin::Pin::new_unchecked(std::pin::Pin::into_inner_unchecked(self).framed.get_mut())
        };
        AsyncWrite::poll_flush(pin_framed, cx)
    }

    fn poll_shutdown(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), std::io::Error>> {
        let pin_framed = unsafe {
            std::pin::Pin::new_unchecked(std::pin::Pin::into_inner_unchecked(self).framed.get_mut())
        };
        AsyncWrite::poll_shutdown(pin_framed, cx)
    }
}

pub fn nonce_for_client() -> Nonce {
    Nonce{data: [0u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]}
}
pub fn nonce_for_server() -> Nonce {
    Nonce{data: [255u8; tweetnacly::bindings::crypto_box_curve25519xsalsa20poly1305_NONCEBYTES as usize]}
}

pub fn decrease_nonce<const N: usize>(numbers: &mut [u8; N]) {
    for v in numbers.as_mut_slice() {
        *v = (core::num::Wrapping(*v) - core::num::Wrapping(1u8)).0;
        if *v != u8::MAX {
            return;
        }
    }
}
pub fn increase_nonce<const N: usize>(numbers: &mut [u8; N]) {
    for v in numbers.as_mut_slice() {
        *v = (core::num::Wrapping(*v) + core::num::Wrapping(1u8)).0;
        if *v != 0 {
            return;
        }
    }
}

pub async fn send_frame<Out: AsyncWrite + std::marker::Unpin>(data: &mut [u8], send_nonce: &mut Nonce, send_cache: &CryptoBoxCache, server: bool, stream: &mut Out) -> Result<(),std::io::Error> {
    for data in data.chunks_mut(MAX_NACL_SEND_BUFFER) {
        let tag = crypto_box_in_place(data, send_nonce, send_cache);
        let mut header = [0u8; 9 + tweetnacly::bindings::crypto_box_MACBYTES as usize];
        let mut header_writer = RawWriter::with(&mut header);
        header_writer.write_var_u64((tag.data.len() + data.len()) as u64).unwrap();
        header_writer.write_bytes(&tag.data).unwrap();
        let header = header_writer.build();
        stream.write_all(header).await?;
        stream.write_all(data).await?;
        if server {
            decrease_nonce(&mut send_nonce.data);
        } else {
            increase_nonce(&mut send_nonce.data);
        }
    }
    Ok(())
}

pub struct WireCodec {
    max: usize,
}

impl WireCodec {
    pub fn new(max: usize) -> Self { Self { max, } }
}

impl tokio_util::codec::Decoder for WireCodec {
    type Item = BytesMut;
    type Error = std::io::Error;

    fn decode(
        &mut self,
        src: &mut BytesMut
    ) -> Result<Option<Self::Item>, Self::Error> {
        let mut reader = RawReader::with(src);
        if let Ok(frame_size) = reader.read_var_u64() {
            let remaining = reader.len();
            if remaining >= frame_size as usize {
                src.advance(src.len() - remaining); // skip the frame size
                return Ok(Some(src.split_to(frame_size as usize)));
            }
            if frame_size > self.max as u64 {
                return Err(std::io::Error::new(std::io::ErrorKind::Other, format!("max wire frame size exceeded: {}", frame_size)));
            }
            if src.capacity() < 4096 {
                src.reserve(4096 - src.len());
            }
            src.reserve(frame_size as usize - remaining);
        }
        Ok(None)
    }
}