erl_dist 0.5.1

Rust Implementation of Erlang Distribution Protocol
Documentation
use byteorder::{BigEndian, ByteOrder as _, WriteBytesExt};
use eetf::{DecodeError, EncodeError, FixInteger, Term, Tuple};
use futures::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
use std::io::{Read, Write};

#[derive(Debug)]
pub struct Connection<T> {
    inner: T,
}

impl<T> Connection<T>
where
    T: AsyncRead + AsyncWrite + Unpin,
{
    pub fn new(inner: T) -> Self {
        Self { inner }
    }

    pub fn into_inner(self) -> T {
        self.inner
    }

    pub fn handshake_message_writer(&mut self) -> HandshakeMessageWriter<T> {
        HandshakeMessageWriter {
            connection: self,
            buf: Vec::new(),
        }
    }

    #[allow(clippy::needless_lifetimes)]
    pub async fn handshake_message_reader<'a>(
        &'a mut self,
    ) -> std::io::Result<HandshakeMessageReader<'a, T>> {
        let size = self.read_u16().await? as usize;
        Ok(HandshakeMessageReader {
            connection: self,
            size,
        })
    }

    pub async fn write_u8(&mut self, v: u8) -> std::io::Result<()> {
        self.inner.write_all(&[v]).await
    }

    pub async fn write_u16(&mut self, v: u16) -> std::io::Result<()> {
        let mut buf = [0; 2];
        BigEndian::write_u16(&mut buf, v);
        self.inner.write_all(&buf).await
    }

    pub async fn write_u32(&mut self, v: u32) -> std::io::Result<()> {
        let mut buf = [0; 4];
        BigEndian::write_u32(&mut buf, v);
        self.inner.write_all(&buf).await
    }

    pub async fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
        self.inner.write_all(buf).await
    }

    pub async fn flush(&mut self) -> std::io::Result<()> {
        self.inner.flush().await
    }

    pub async fn read_u8(&mut self) -> std::io::Result<u8> {
        let mut buf = [0; 1];
        self.inner.read_exact(&mut buf).await?;
        Ok(buf[0])
    }

    pub async fn read_u16(&mut self) -> std::io::Result<u16> {
        let mut buf = [0; 2];
        self.inner.read_exact(&mut buf).await?;
        Ok(BigEndian::read_u16(&buf))
    }

    pub async fn read_u32(&mut self) -> std::io::Result<u32> {
        let mut buf = [0; 4];
        self.inner.read_exact(&mut buf).await?;
        Ok(BigEndian::read_u32(&buf))
    }

    pub async fn read_u64(&mut self) -> std::io::Result<u64> {
        let mut buf = [0; 8];
        self.inner.read_exact(&mut buf).await?;
        Ok(BigEndian::read_u64(&buf))
    }

    pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
        self.inner.read_exact(buf).await
    }

    pub async fn read_string(&mut self) -> std::io::Result<String> {
        let mut buf = String::new();
        self.inner.read_to_string(&mut buf).await?;
        Ok(buf)
    }

    pub async fn read_stringn(&mut self, size: usize) -> std::io::Result<String> {
        let mut buf = vec![0; size];
        self.inner.read_exact(&mut buf).await?;
        String::from_utf8(buf).map_err(|_| {
            std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                "stream did not contain valid UTF-8",
            )
        })
    }

    pub async fn read_u16_bytes(&mut self) -> std::io::Result<Vec<u8>> {
        let mut buf = vec![0; usize::from(self.read_u16().await?)];
        self.inner.read_exact(&mut buf).await?;
        Ok(buf)
    }

    pub async fn read_u16_string(&mut self) -> std::io::Result<String> {
        let buf = self.read_u16_bytes().await?;
        String::from_utf8(buf).map_err(|_| {
            std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                "stream did not contain valid UTF-8",
            )
        })
    }
}

#[derive(Debug)]
pub struct HandshakeMessageWriter<'a, T> {
    connection: &'a mut Connection<T>,
    buf: Vec<u8>,
}

impl<'a, T> HandshakeMessageWriter<'a, T>
where
    T: AsyncRead + AsyncWrite + Unpin,
{
    pub async fn finish(self) -> std::io::Result<()> {
        if self.buf.len() > u16::MAX as usize {
            return Err(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!(
                    "too large bytes: expected less then {}, but got {} bytes",
                    u16::MAX as usize + 1,
                    self.buf.len()
                ),
            ));
        }
        self.connection.write_u16(self.buf.len() as u16).await?;
        self.connection.write_all(&self.buf).await?;
        self.connection.flush().await?;
        Ok(())
    }

    pub fn write_u8(&mut self, v: u8) -> std::io::Result<()> {
        self.buf.write_u8(v)
    }

    pub fn write_u16(&mut self, v: u16) -> std::io::Result<()> {
        self.buf.write_u16::<BigEndian>(v)
    }

    pub fn write_u32(&mut self, v: u32) -> std::io::Result<()> {
        self.buf.write_u32::<BigEndian>(v)
    }

    pub fn write_u64(&mut self, v: u64) -> std::io::Result<()> {
        self.buf.write_u64::<BigEndian>(v)
    }

    pub fn write_all(&mut self, bytes: &[u8]) -> std::io::Result<()> {
        self.buf.extend_from_slice(bytes);
        Ok(())
    }
}

#[derive(Debug)]
pub struct HandshakeMessageReader<'a, T> {
    connection: &'a mut Connection<T>,
    size: usize,
}

impl<'a, T> HandshakeMessageReader<'a, T>
where
    T: AsyncRead + AsyncWrite + Unpin,
{
    pub async fn read_u8(&mut self) -> std::io::Result<u8> {
        self.size = self.size.checked_sub(1).ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof")
        })?;
        self.connection.read_u8().await
    }

    pub async fn read_u16(&mut self) -> std::io::Result<u16> {
        self.size = self.size.checked_sub(2).ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof")
        })?;
        self.connection.read_u16().await
    }

    pub async fn read_u32(&mut self) -> std::io::Result<u32> {
        self.size = self.size.checked_sub(4).ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof")
        })?;
        self.connection.read_u32().await
    }

    pub async fn read_u64(&mut self) -> std::io::Result<u64> {
        self.size = self.size.checked_sub(8).ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof")
        })?;
        self.connection.read_u64().await
    }

    pub async fn read_string(&mut self) -> std::io::Result<String> {
        let n = self.size;
        self.size = 0;
        self.connection.read_stringn(n).await
    }

    pub async fn read_bytes(&mut self) -> std::io::Result<Vec<u8>> {
        let n = self.size;
        let mut buf = vec![0; n];
        self.read_exact(&mut buf).await?;
        Ok(buf)
    }

    pub async fn read_exact(&mut self, buf: &mut [u8]) -> std::io::Result<()> {
        let n = buf.len();
        self.size = self.size.checked_sub(n).ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof")
        })?;
        self.connection.read_exact(buf).await
    }

    pub async fn read_u16_string(&mut self) -> std::io::Result<String> {
        let n = self.read_u16().await? as usize;
        self.size = self.size.checked_sub(n).ok_or_else(|| {
            std::io::Error::new(std::io::ErrorKind::UnexpectedEof, "unexpected eof")
        })?;
        self.connection.read_stringn(n).await
    }

    pub async fn consume_remaining_bytes(&mut self) -> std::io::Result<()> {
        let mut buf = vec![0; self.size];
        self.size = 0;
        self.connection.read_exact(&mut buf).await?;
        Ok(())
    }

    pub async fn finish(mut self) -> std::io::Result<()> {
        self.consume_remaining_bytes().await
    }
}

pub trait ReadTermExt: Read {
    fn read_tuple(&mut self) -> Result<Tuple, DecodeError> {
        let term = self.read_term()?;
        term.try_into()
            .map_err(|value| DecodeError::UnexpectedType {
                value,
                expected: "Tuple".to_owned(),
            })
    }

    fn read_term(&mut self) -> Result<Term, DecodeError> {
        Term::decode(self)
    }
}

impl<T: Read> ReadTermExt for T {}

pub trait WriteTermExt: Write {
    fn write_tagged_tuple1(&mut self, tag: i32) -> Result<(), EncodeError> {
        let tuple = Tuple {
            elements: vec![Term::from(FixInteger { value: tag as i32 })],
        };
        self.write_term(tuple)
    }

    fn write_tagged_tuple3<T0, T1>(
        &mut self,
        tag: i32,
        term0: T0,
        term1: T1,
    ) -> Result<(), EncodeError>
    where
        Term: From<T0>,
        Term: From<T1>,
    {
        let tuple = Tuple {
            elements: vec![
                Term::from(FixInteger { value: tag as i32 }),
                Term::from(term0),
                Term::from(term1),
            ],
        };
        self.write_term(tuple)
    }

    fn write_tagged_tuple4<T0, T1, T2>(
        &mut self,
        tag: i32,
        term0: T0,
        term1: T1,
        term2: T2,
    ) -> Result<(), EncodeError>
    where
        Term: From<T0>,
        Term: From<T1>,
        Term: From<T2>,
    {
        let tuple = Tuple {
            elements: vec![
                Term::from(FixInteger { value: tag as i32 }),
                Term::from(term0),
                Term::from(term1),
                Term::from(term2),
            ],
        };
        self.write_term(tuple)
    }

    fn write_tagged_tuple5<T0, T1, T2, T3>(
        &mut self,
        tag: i32,
        term0: T0,
        term1: T1,
        term2: T2,
        term3: T3,
    ) -> Result<(), EncodeError>
    where
        Term: From<T0>,
        Term: From<T1>,
        Term: From<T2>,
        Term: From<T3>,
    {
        let tuple = Tuple {
            elements: vec![
                Term::from(FixInteger { value: tag as i32 }),
                Term::from(term0),
                Term::from(term1),
                Term::from(term2),
                Term::from(term3),
            ],
        };
        self.write_term(tuple)
    }

    fn write_tagged_tuple6<T0, T1, T2, T3, T4>(
        &mut self,
        tag: i32,
        term0: T0,
        term1: T1,
        term2: T2,
        term3: T3,
        term4: T4,
    ) -> Result<(), EncodeError>
    where
        Term: From<T0>,
        Term: From<T1>,
        Term: From<T2>,
        Term: From<T3>,
        Term: From<T4>,
    {
        let tuple = Tuple {
            elements: vec![
                Term::from(FixInteger { value: tag as i32 }),
                Term::from(term0),
                Term::from(term1),
                Term::from(term2),
                Term::from(term3),
                Term::from(term4),
            ],
        };
        self.write_term(tuple)
    }

    #[allow(clippy::too_many_arguments)]
    fn write_tagged_tuple7<T0, T1, T2, T3, T4, T5>(
        &mut self,
        tag: i32,
        term0: T0,
        term1: T1,
        term2: T2,
        term3: T3,
        term4: T4,
        term5: T5,
    ) -> Result<(), EncodeError>
    where
        Term: From<T0>,
        Term: From<T1>,
        Term: From<T2>,
        Term: From<T3>,
        Term: From<T4>,
        Term: From<T5>,
    {
        let tuple = Tuple {
            elements: vec![
                Term::from(FixInteger { value: tag as i32 }),
                Term::from(term0),
                Term::from(term1),
                Term::from(term2),
                Term::from(term3),
                Term::from(term4),
                Term::from(term5),
            ],
        };
        self.write_term(tuple)
    }

    fn write_term<T>(&mut self, term: T) -> Result<(), EncodeError>
    where
        Term: From<T>,
    {
        Term::from(term).encode(self)
    }
}

impl<T: Write> WriteTermExt for T {}