metalssh 0.0.1

Experimental SSH implementation
use bstr::ByteSlice;
use scroll::NETWORK;
use scroll::Pread;
use scroll::Pwrite;

use crate::types::Error;
use crate::types::Result;

/// Extensions for reading SSH data types from byte buffers.
pub trait SshDecode<'buf> {
    /// Reads a byte.
    fn read_byte(&self, offset: &mut usize) -> Result<u8>;

    /// Reads a boolean. Any nonzero value is interpreted as `true`.
    fn read_boolean(&self, offset: &mut usize) -> Result<bool>;

    /// Reads a uint32.
    fn read_uint32(&self, offset: &mut usize) -> Result<u32>;

    /// Reads a uint64.
    fn read_uint64(&self, offset: &mut usize) -> Result<u64>;

    /// Reads a length-prefixed byte string.
    fn read_byte_string(&'buf self, offset: &mut usize) -> Result<&'buf [u8]>;

    /// Reads an exact number of bytes.
    fn read_bytes_exact(&'buf self, offset: &mut usize, len: usize) -> Result<&'buf [u8]>;

    /// Reads bytes up to and including a stop character.
    fn read_bytes_until(&'buf self, offset: &mut usize, stop_char: u8) -> Result<&'buf [u8]>;

    /// Reads an mpint.
    fn read_mpint(&'buf self, offset: &mut usize) -> Result<&'buf [u8]>;

    /// Reads a name-list.
    fn read_name_list(&'buf self, offset: &mut usize) -> Result<impl Iterator<Item = &'buf [u8]>>;
}

impl<'buf, B: AsRef<[u8]> + ?Sized> SshDecode<'buf> for B {
    fn read_byte(&self, offset: &mut usize) -> Result<u8> {
        self.as_ref()
            .gread_with(offset, NETWORK)
            .map_err(Into::into)
    }

    fn read_boolean(&self, offset: &mut usize) -> Result<bool> {
        self.read_byte(offset).map(|byte| byte != 0)
    }

    fn read_uint32(&self, offset: &mut usize) -> Result<u32> {
        self.as_ref()
            .gread_with(offset, NETWORK)
            .map_err(Into::into)
    }

    fn read_uint64(&self, offset: &mut usize) -> Result<u64> {
        self.as_ref()
            .gread_with(offset, NETWORK)
            .map_err(Into::into)
    }

    fn read_byte_string(&'buf self, offset: &mut usize) -> Result<&'buf [u8]> {
        let len = self.read_uint32(offset)?;
        self.as_ref()
            .gread_with(offset, len as usize)
            .map_err(Into::into)
    }

    fn read_bytes_exact(&'buf self, offset: &mut usize, len: usize) -> Result<&'buf [u8]> {
        self.as_ref().gread_with(offset, len).map_err(Into::into)
    }

    fn read_bytes_until(&'buf self, offset: &mut usize, stop_char: u8) -> Result<&'buf [u8]> {
        let buf = self.as_ref();
        let mut pos = *offset;

        while pos < buf.len() {
            if buf[pos] == stop_char {
                break;
            }
            if pos == buf.len() - 1 {
                break;
            }
            pos += 1;
        }

        let read = &buf[*offset..=pos];
        *offset = pos + 1;
        Ok(read)
    }

    fn read_mpint(&'buf self, offset: &mut usize) -> Result<&'buf [u8]> {
        self.read_byte_string(offset)
    }

    fn read_name_list(&'buf self, offset: &mut usize) -> Result<impl Iterator<Item = &'buf [u8]>> {
        let string = self.read_byte_string(offset)?;
        let mut iter = string.split_str(",");
        if string.is_empty() {
            iter.next();
        }
        Ok(iter)
    }
}

/// Extensions for writing SSH data types to byte buffers.
pub trait SshEncode {
    /// Writes a byte.
    fn write_byte(&mut self, data: u8, offset: &mut usize) -> Result<usize>;

    /// Writes a boolean.
    fn write_boolean(&mut self, data: bool, offset: &mut usize) -> Result<usize>;

    /// Writes a uint32.
    fn write_uint32(&mut self, data: u32, offset: &mut usize) -> Result<usize>;

    /// Writes a uint64.
    fn write_uint64(&mut self, data: u64, offset: &mut usize) -> Result<usize>;

    /// Writes a length-prefixed byte string.
    fn write_byte_string(&mut self, data: &[u8], offset: &mut usize) -> Result<usize>;

    /// Writes an exact byte buffer.
    fn write_bytes_exact(&mut self, data: &[u8], offset: &mut usize) -> Result<usize>;

    /// Writes an mpint.
    fn write_mpint(&mut self, data: &[u8], offset: &mut usize) -> Result<usize>;
}

impl<B: AsRef<[u8]> + AsMut<[u8]>> SshEncode for B {
    fn write_byte(&mut self, data: u8, offset: &mut usize) -> Result<usize> {
        self.as_mut()
            .gwrite_with(data, offset, NETWORK)
            .map_err(Into::into)
    }

    fn write_boolean(&mut self, data: bool, offset: &mut usize) -> Result<usize> {
        let byte = u8::from(data);
        self.write_byte(byte, offset)
    }

    fn write_uint32(&mut self, data: u32, offset: &mut usize) -> Result<usize> {
        self.as_mut()
            .gwrite_with(data, offset, NETWORK)
            .map_err(Into::into)
    }

    fn write_uint64(&mut self, data: u64, offset: &mut usize) -> Result<usize> {
        self.as_mut()
            .gwrite_with(data, offset, NETWORK)
            .map_err(Into::into)
    }

    fn write_byte_string(&mut self, data: &[u8], offset: &mut usize) -> Result<usize> {
        let mut wrote = self.write_uint32(data.len() as u32, offset)?;
        wrote += self.as_mut().gwrite(data, offset).map_err(Error::from)?;
        Ok(wrote)
    }

    fn write_bytes_exact(&mut self, data: &[u8], offset: &mut usize) -> Result<usize> {
        self.as_mut().gwrite(data, offset).map_err(Into::into)
    }

    fn write_mpint(&mut self, data: &[u8], offset: &mut usize) -> Result<usize> {
        self.write_byte_string(data, offset)
    }
}

#[cfg(test)]
mod tests {
    use bstr::B;
    use rstest::rstest;

    use super::*;

    #[rstest]
    #[case("00", u8::MIN)]
    #[case("ff", u8::MAX)]
    fn byte_roundtrip(#[case] input: &str, #[case] output: u8) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_byte(&mut 0).unwrap();
        assert_eq!(got, output);

        let mut buf = vec![0u8; buf.len()];
        let _wrote = buf.write_byte(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), input);
    }

    #[rstest]
    #[case("00", false)]
    #[case("01", true)]
    #[case("ff", true)]
    fn boolean_roundtrip(#[case] input: &str, #[case] output: bool) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_boolean(&mut 0).unwrap();
        assert_eq!(got, output);

        let should = match output {
            false => "00",
            true => "01",
        };
        let mut buf = vec![0u8; buf.len()];
        let mut _wrote = buf.write_boolean(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), should);
    }

    #[rstest]
    #[case("00000000", u32::MIN)]
    #[case("ffffffff", u32::MAX)]
    #[case("29b7f4aa", 699921578)]
    fn uint32_roundtrip(#[case] input: &str, #[case] output: u32) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_uint32(&mut 0).unwrap();
        assert_eq!(got, output);

        let mut buf = vec![0u8; buf.len()];
        let _wrote = buf.write_uint32(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), input);
    }

    #[rstest]
    #[case("0000000000000000", u64::MIN)]
    #[case("ffffffffffffffff", u64::MAX)]
    fn uint64_roundtrip(#[case] input: &str, #[case] output: u64) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_uint64(&mut 0).unwrap();
        assert_eq!(got, output);

        let mut buf = vec![0u8; buf.len()];
        let _wrote = buf.write_uint64(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), input);
    }

    #[rstest]
    #[case("00000000", &[])]
    #[case("0000000809a378f9b2e332a7", &[0x09, 0xa3, 0x78, 0xf9, 0xb2, 0xe3, 0x32, 0xa7])]
    #[case("000000020080", &[0x00, 0x80])]
    #[case("00000002edcc", &[0xed, 0xcc])]
    #[case("00000005ff21524111", &[0xff, 0x21, 0x52, 0x41, 0x11])]
    fn mpint_roundtrip(#[case] input: &str, #[case] output: &[u8]) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_mpint(&mut 0).unwrap();
        assert_eq!(got, output);

        let mut buf = vec![0u8; buf.len()];
        let _wrote = buf.write_mpint(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), input);
    }

    #[rstest]
    #[case("00000000", B(""))]
    #[case("0000000774657374696e67", B("testing"))]
    fn byte_string_roundtrop(#[case] input: &str, #[case] output: &[u8]) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_byte_string(&mut 0).unwrap();
        assert_eq!(got, output);

        let mut buf = vec![0u8; buf.len()];
        let _wrote = buf.write_byte_string(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), input);
    }

    #[rstest]
    #[case("", B(""))]
    #[case("74657374696e67", B("testing"))]
    fn bytes_exact_roundtrip(#[case] input: &str, #[case] output: &[u8]) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_bytes_exact(&mut 0, output.len()).unwrap();
        assert_eq!(got, output);

        let mut buf = vec![0u8; buf.len()];
        let _wrote = buf.write_bytes_exact(output, &mut 0).unwrap();
        assert_eq!(hex::encode(&buf), input);
    }

    #[rstest]
    #[case("74657374696e67", b't', B("t"))]
    #[case("74657374696e67", b'i', B("testi"))]
    #[case("74657374696e67", b'z', B("testing"))]
    fn bytes_until_read(#[case] input: &str, #[case] stop_char: u8, #[case] output: &[u8]) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_bytes_until(&mut 0, stop_char).unwrap();
        assert_eq!(got, output);
    }

    #[rstest]
    #[case("00000000", &[])]
    #[case("000000047a6c6962", &[B("zlib")])]
    #[case("000000097a6c69622c6e6f6e65", &[B("zlib"), B("none")])]
    fn name_list_read(#[case] input: &str, #[case] output: &[&[u8]]) {
        let buf = hex::decode(input).unwrap();
        let got = buf.read_name_list(&mut 0).unwrap().collect::<Vec<_>>();
        assert_eq!(got, output);
    }
}