metalssh 0.0.1

Experimental SSH implementation
//! SSH binary packet

use core::cmp::max;
use core::fmt;

use bstr::ByteSlice;

use crate::types::Result;

// TODO: Include the peeked, decrypted packet len as an optional field so that
// ciphers like `chacha20poly1305` do not have re-decrypt it.
/// A read/write wrapper around an SSH packet buffer.
#[derive(PartialEq, Eq, Clone)]
pub struct Packet<B: AsRef<[u8]>> {
    buffer: B,
    mac_length: u8,
}

impl<B: AsRef<[u8]> + ?Sized> fmt::Debug for Packet<&B> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Packet")
            .field("buffer", &self.buffer.as_ref().as_bstr())
            .field("mac_length", &self.mac_length)
            .field("packet_length", &self.packet_length())
            .field("padding_length", &self.padding_length())
            .field("payload", &self.payload().map(ByteSlice::as_bstr))
            .field("padding", &self.padding().map(ByteSlice::as_bstr))
            .field("mac", &self.mac().map(ByteSlice::as_bstr))
            .finish()
    }
}

impl<B: AsRef<[u8]>> Packet<B> {
    /// Imbues a byte buffer with SSH packet structure.
    ///
    /// The `mac_length` parameter specifies the length of the MAC field in the
    /// packet buffer, which depends on the MAC algorithm in use.
    pub const fn new(buffer: B, mac_length: u8) -> Packet<B> {
        Packet { buffer, mac_length }
    }

    /// Creates a packet by copying a payload to a new owned byte buffer. Takes
    /// care of sizing the buffer properly for future use (encryption, mac,
    /// etc).
    ///
    /// Needs the following parameters to accurately size the packet buffer:
    /// - `payload`: the payload as bytes
    /// - `block_size`: the cipher block size
    /// - `mac_length`: the MAC algorithm tag length
    ///
    /// The result will be a packet whose:
    /// - `packet_length` is set (via calculation)
    /// - `padding_length` is set (via calculation)
    /// - `payload` is set (via copy)
    /// - `padding` is properly sized, but unset (ie, all zeroes)
    /// - `mac` is properly sized, but unset (ie, all zeroes)
    #[cfg(feature = "alloc")]
    pub fn from_payload(payload: B, block_size: u8, mac_length: u8) -> Packet<alloc::vec::Vec<u8>> {
        let payload_len = payload.as_ref().len();
        let padding_len = calculate_padding_length(payload_len as u32, block_size);
        let packet_len = 1 + payload_len as u32 + u32::from(padding_len);

        let buffer_size = 4 + 1 + payload_len + padding_len as usize + mac_length as usize;
        let mut buffer = alloc::vec![0u8; buffer_size];

        let packet_len_bytes = packet_len.to_be_bytes();
        let padding_len_bytes = padding_len.to_be_bytes();

        buffer[..4].copy_from_slice(&packet_len_bytes);
        buffer[4..5].copy_from_slice(&padding_len_bytes);
        buffer[5..5 + payload_len].copy_from_slice(payload.as_ref());

        Packet { buffer, mac_length }
    }

    /// Consumes the packet, returning the underlying byte buffer.
    pub fn into_inner(self) -> B {
        self.buffer
    }

    /// Gets the packet length as a byte array.
    pub fn packet_length_bytes(&self) -> Result<[u8; 4]> {
        let array = self.buffer.as_ref()[0..4].try_into()?;
        Ok(array)
    }

    /// Gets the packet length.
    ///
    /// This consists of the length of the following fields concatenated:
    /// - padding length (1 byte)
    /// - payload (variable)
    /// - padding (variable)
    ///
    /// Notably, it does **not** include:
    /// - packet length (ie, this field itself)
    /// - mac
    pub fn packet_length(&self) -> Result<u32> {
        let array = self.packet_length_bytes()?;
        Ok(u32::from_be_bytes(array))
    }

    /// Gets the padding length.
    pub fn padding_length(&self) -> Result<u8> {
        let array = self.buffer.as_ref()[4..5].try_into()?;
        Ok(u8::from_be_bytes(array))
    }

    /// Gets the MAC length.
    pub const fn mac_length(&self) -> u8 {
        self.mac_length
    }
}

impl<'b, B: AsRef<[u8]> + ?Sized> Packet<&'b B> {
    /// Gets the payload bytes.
    pub fn payload(&self) -> Result<&'b [u8]> {
        let payload_len = self.packet_length()? as usize - self.padding_length()? as usize - 1;
        Ok(&self.buffer.as_ref()[5..5 + payload_len])
    }

    /// Gets the padding bytes.
    pub fn padding(&self) -> Result<&'b [u8]> {
        let padding_offset = 4 + 1 + self.payload()?.len();
        let padding_len = self.padding_length()? as usize;
        Ok(&self.buffer.as_ref()[padding_offset..padding_offset + padding_len])
    }

    /// Gets the MAC bytes.
    pub fn mac(&self) -> Result<&'b [u8]> {
        let mac_offset = 4 + self.packet_length()? as usize;
        let mac_length = self.mac_length as usize;
        Ok(&self.buffer.as_ref()[mac_offset..mac_offset + mac_length])
    }
}

impl<'b, B: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'b mut B> {
    /// Gets the packet as mutable bytes, including the MAC.
    pub fn packet_mut(&'b mut self) -> Result<&'b mut [u8]> {
        Ok(&mut self.buffer.as_mut()[..])
    }
}

impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
    fn as_ref(&self) -> &[u8] {
        self.buffer.as_ref()
    }
}

impl<T: AsRef<[u8]> + AsMut<[u8]>> AsMut<[u8]> for Packet<T> {
    fn as_mut(&mut self) -> &mut [u8] {
        self.buffer.as_mut()
    }
}

/// Arbitrary-length padding, such that the total length of `(packet_length ||
/// padding_length || payload || random padding)` is a multiple of the cipher
/// block size or 8, whichever is larger. There MUST be at least four bytes of
/// padding. The padding SHOULD consist of random bytes. The maximum amount of
/// padding is 255 bytes.
fn calculate_padding_length(payload_len: u32, block_size: u8) -> u8 {
    let unpadded_len = 4 /* packet_len size */ + 1 /* padding_len size */ + payload_len;
    let block_size = max(block_size, 8);
    let padding_len = block_size - (unpadded_len % u32::from(block_size)) as u8;
    if padding_len < 4 {
        padding_len + block_size
    } else {
        padding_len
    }
}

#[cfg(test)]
mod tests {
    use rstest::rstest;

    use super::*;
    use crate::constants::msg::*;

    #[rstest]
    #[case::newkeys(
        "0000000c0a1500000000000000000000",
        0,
        [0x00, 0x00, 0x00, 0x0c],
        12,
        10,
        "15",
        "00000000000000000000",
        ""
    )]
    fn packet_read_works(
        #[case] input_data: &str,
        #[case] input_mac_length: u8,
        #[case] should_packet_length_bytes: [u8; 4],
        #[case] should_packet_length: u32,
        #[case] should_padding_length: u8,
        #[case] should_payload: &str,
        #[case] should_padding: &str,
        #[case] should_mac: &str,
    ) {
        let buf = hex::decode(input_data).unwrap();
        let packet = Packet::new(&buf, input_mac_length);

        let got_packet_length_bytes = packet.packet_length_bytes().unwrap();
        let got_packet_length = packet.packet_length().unwrap();
        let got_padding_length = packet.padding_length().unwrap();
        let got_payload = hex::encode(packet.payload().unwrap());
        let got_padding = hex::encode(packet.padding().unwrap());
        let got_mac = hex::encode(packet.mac().unwrap());

        assert_eq!(got_packet_length_bytes, should_packet_length_bytes);
        assert_eq!(got_packet_length, should_packet_length);
        assert_eq!(got_padding_length, should_padding_length);
        assert_eq!(got_payload, should_payload);
        assert_eq!(got_padding, should_padding);
        assert_eq!(got_mac, should_mac);
    }

    #[rustfmt::skip]
    #[rstest]
    #[case("testdata/none-exec/02-client-kexinit.bin",                   0, 1356, 4,  SSH_MSG_KEXINIT,                      "")]
    #[case("testdata/none-exec/03-server-kexinit.bin",                   0, 828,  9,  SSH_MSG_KEXINIT,                      "")]
    #[case("testdata/none-exec/04-client-kexdh_init.bin",                0, 1228, 6,  SSH_MSG_KEX_ECDH_INIT,                "")]
    #[case("testdata/none-exec/05-server-kexdh_reply.bin",               0, 1276, 8,  SSH_MSG_KEX_ECDH_REPLY,               "")]
    #[case("testdata/none-exec/06-server-newkeys.bin",                   0, 12,   10, SSH_MSG_NEWKEYS,                      "")]
    #[case("testdata/none-exec/07-server-ext_info.bin",                  8, 264,  11, SSH_MSG_EXT_INFO,                     "e89b5df177489bb7")]
    #[case("testdata/none-exec/08-client-newkeys.bin",                   0, 12,   10, SSH_MSG_NEWKEYS,                      "")]
    #[case("testdata/none-exec/09-client-ext_info.bin",                  8, 48,   5,  SSH_MSG_EXT_INFO,                     "e8e665da71a89a5a")]
    #[case("testdata/none-exec/10-client-service_request.bin",           8, 24,   6,  SSH_MSG_SERVICE_REQUEST,              "9149e02f1eaeab17")]
    #[case("testdata/none-exec/11-server-service_accept.bin",            8, 24,   6,  SSH_MSG_SERVICE_ACCEPT,               "e54fbf63bd58851c")]
    #[case("testdata/none-exec/12-client-userauth_request.bin",          8, 40,   4,  SSH_MSG_USERAUTH_REQUEST,             "be423d89f3dc729d")]
    #[case("testdata/none-exec/13-server-ext_info.bin",                  8, 192,  4,  SSH_MSG_EXT_INFO,                     "b20d9663faeca278")]
    #[case("testdata/none-exec/14-server-userauth_failure.bin",          8, 48,   11, SSH_MSG_USERAUTH_FAILURE,             "df6e2ff2d2273011")]
    #[case("testdata/none-exec/15-client-userauth_request.bin",          8, 120,  8,  SSH_MSG_USERAUTH_REQUEST,             "11393a00d1dc7990")]
    #[case("testdata/none-exec/16-server-userauth_pk_ok.bin",            8, 80,   8,  SSH_MSG_USERAUTH_PK_OK,               "60f47fe6e7d44b9e")]
    #[case("testdata/none-exec/17-client-userauth_request.bin",          8, 288,  8,  SSH_MSG_USERAUTH_REQUEST,             "1d93895c9a72fd74")]
    #[case("testdata/none-exec/18-server-userauth_success.bin",          8, 8,    6,  SSH_MSG_USERAUTH_SUCCESS,             "95b8b19e0b9770f8")]
    #[case("testdata/none-exec/19-client-channel_open.bin",              8, 32,   7,  SSH_MSG_CHANNEL_OPEN,                 "b02d1efdd5ef24a7")]
    #[case("testdata/none-exec/20-client-global_request.bin",            8, 40,   5,  SSH_MSG_GLOBAL_REQUEST,               "bc37d0ed73b10b28")]
    #[case("testdata/none-exec/21-server-global_request.bin",            8, 808,  8,  SSH_MSG_GLOBAL_REQUEST,               "3d7f905963e88a8f")]
    #[case("testdata/none-exec/22-server-debug.bin",                     8, 128,  11, SSH_MSG_DEBUG,                        "d5f9403eb96abb6b")]
    #[case("testdata/none-exec/23-server-debug.bin",                     8, 128,  11, SSH_MSG_DEBUG,                        "b6de49eedb1e5c05")]
    #[case("testdata/none-exec/24-server-channel_open_confirmation.bin", 8, 24,   6,  SSH_MSG_CHANNEL_OPEN_CONFIRMATION,    "1642db3f071b3784")]
    #[case("testdata/none-exec/25-client-channel_request.bin",           8, 48,   11, SSH_MSG_CHANNEL_REQUEST,              "de9d702e8157d450")]
    #[case("testdata/none-exec/26-client-channel_request.bin",           8, 32,   7,  SSH_MSG_CHANNEL_REQUEST,              "2f2443c64a7b0a25")]
    #[case("testdata/none-exec/27-server-channel_window_adjust.bin",     8, 16,   6,  SSH_MSG_CHANNEL_WINDOW_ADJUST,        "9dd9ec9efd85682c")]
    #[case("testdata/none-exec/28-server-channel_success.bin",           8, 16,   10, SSH_MSG_CHANNEL_SUCCESS,              "09175fff7ae9d128")]
    #[case("testdata/none-exec/29-server-channel_extended_data.bin",     8, 56,   8,  SSH_MSG_CHANNEL_EXTENDED_DATA,        "c543158bc98d37a8")]
    #[case("testdata/none-exec/30-server-channel_extended_data.bin",     8, 200,  7,  SSH_MSG_CHANNEL_EXTENDED_DATA,        "f26806b090dcbe8d")]
    #[case("testdata/none-exec/31-server-channel_extended_data.bin",     8, 104,  10, SSH_MSG_CHANNEL_EXTENDED_DATA,        "37ba892ac53c5974")]
    #[case("testdata/none-exec/32-server-channel_data.bin",              8, 24,   9,  SSH_MSG_CHANNEL_DATA,                 "4305f7ae1c568618")]
    #[case("testdata/none-exec/33-server-channel_eof.bin",               8, 16,   10, SSH_MSG_CHANNEL_EOF,                  "0f051f4ac868bc36")]
    #[case("testdata/none-exec/34-server-channel_request.bin",           8, 32,   6,  SSH_MSG_CHANNEL_REQUEST,              "ca763139dbcfef97")]
    #[case("testdata/none-exec/35-server-channel_request.bin",           8, 32,   6,  SSH_MSG_CHANNEL_REQUEST,              "47ab0063a6b851ed")]
    #[case("testdata/none-exec/36-server-channel_close.bin",             8, 16,   10, SSH_MSG_CHANNEL_CLOSE,                "d5cea55189965da4")]
    #[case("testdata/none-exec/37-client-channel_close.bin",             8, 16,   10, SSH_MSG_CHANNEL_CLOSE,                "6eedb67af8145c45")]
    #[case("testdata/none-exec/38-client-disconnect.bin",                8, 40,   6,  SSH_MSG_DISCONNECT,                   "b7371119d6e47da2")]
    fn packet_read_file_works(
        #[case] packet_file: &str,
        #[case] mac_len: u8,
        #[case] packet_len: u32,
        #[case] padding_len: u8,
        #[case] message_code: u8,
        #[case] mac: &str,
    ) {
        let bytes = std::fs::read(packet_file).unwrap();
        let packet = Packet::new(&bytes, mac_len);

        let got_packet_len = packet.packet_length().unwrap();
        let got_padding_len = packet.padding_length().unwrap();
        let got_message_code = packet.payload().unwrap()[0];
        let got_mac = hex::encode(packet.mac().unwrap());

        assert_eq!(got_packet_len, packet_len);
        assert_eq!(got_padding_len, padding_len);
        assert_eq!(got_message_code, message_code);
        assert_eq!(got_mac, mac);
    }

    #[rstest]
    #[case("15", 0, 0, "0000000c0a1500000000000000000000")]
    fn packet_write_works(
        #[case] input_payload: &str,
        #[case] input_block_size: u8,
        #[case] input_mac_length: u8,
        #[case] should_bytes: &str,
    ) {
        let input_payload = hex::decode(input_payload).unwrap();
        let packet = Packet::from_payload(input_payload, input_block_size, input_mac_length);
        let got_bytes = hex::encode(packet.into_inner());
        assert_eq!(got_bytes, should_bytes);
    }

    #[rstest]
    #[case(1, 0, 10)]
    #[case(37, 0, 6)]
    #[case(179, 0, 8)]
    #[case(699, 0, 8)]
    #[case(1038, 0, 5)]
    fn calculate_padding_length_works(
        #[case] payload_len: u32,
        #[case] block_size: u8,
        #[case] should: u8,
    ) {
        let got = calculate_padding_length(payload_len, block_size);
        assert_eq!(got, should);
    }
}