use core::cmp::max;
use core::fmt;
use bstr::ByteSlice;
use crate::types::Result;
#[derive(PartialEq, Eq, Clone)]
pub struct Packet<B: AsRef<[u8]>> {
buffer: B,
mac_length: u8,
}
impl<B: AsRef<[u8]> + ?Sized> fmt::Debug for Packet<&B> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Packet")
.field("buffer", &self.buffer.as_ref().as_bstr())
.field("mac_length", &self.mac_length)
.field("packet_length", &self.packet_length())
.field("padding_length", &self.padding_length())
.field("payload", &self.payload().map(ByteSlice::as_bstr))
.field("padding", &self.padding().map(ByteSlice::as_bstr))
.field("mac", &self.mac().map(ByteSlice::as_bstr))
.finish()
}
}
impl<B: AsRef<[u8]>> Packet<B> {
pub const fn new(buffer: B, mac_length: u8) -> Packet<B> {
Packet { buffer, mac_length }
}
#[cfg(feature = "alloc")]
pub fn from_payload(payload: B, block_size: u8, mac_length: u8) -> Packet<alloc::vec::Vec<u8>> {
let payload_len = payload.as_ref().len();
let padding_len = calculate_padding_length(payload_len as u32, block_size);
let packet_len = 1 + payload_len as u32 + u32::from(padding_len);
let buffer_size = 4 + 1 + payload_len + padding_len as usize + mac_length as usize;
let mut buffer = alloc::vec![0u8; buffer_size];
let packet_len_bytes = packet_len.to_be_bytes();
let padding_len_bytes = padding_len.to_be_bytes();
buffer[..4].copy_from_slice(&packet_len_bytes);
buffer[4..5].copy_from_slice(&padding_len_bytes);
buffer[5..5 + payload_len].copy_from_slice(payload.as_ref());
Packet { buffer, mac_length }
}
pub fn into_inner(self) -> B {
self.buffer
}
pub fn packet_length_bytes(&self) -> Result<[u8; 4]> {
let array = self.buffer.as_ref()[0..4].try_into()?;
Ok(array)
}
pub fn packet_length(&self) -> Result<u32> {
let array = self.packet_length_bytes()?;
Ok(u32::from_be_bytes(array))
}
pub fn padding_length(&self) -> Result<u8> {
let array = self.buffer.as_ref()[4..5].try_into()?;
Ok(u8::from_be_bytes(array))
}
pub const fn mac_length(&self) -> u8 {
self.mac_length
}
}
impl<'b, B: AsRef<[u8]> + ?Sized> Packet<&'b B> {
pub fn payload(&self) -> Result<&'b [u8]> {
let payload_len = self.packet_length()? as usize - self.padding_length()? as usize - 1;
Ok(&self.buffer.as_ref()[5..5 + payload_len])
}
pub fn padding(&self) -> Result<&'b [u8]> {
let padding_offset = 4 + 1 + self.payload()?.len();
let padding_len = self.padding_length()? as usize;
Ok(&self.buffer.as_ref()[padding_offset..padding_offset + padding_len])
}
pub fn mac(&self) -> Result<&'b [u8]> {
let mac_offset = 4 + self.packet_length()? as usize;
let mac_length = self.mac_length as usize;
Ok(&self.buffer.as_ref()[mac_offset..mac_offset + mac_length])
}
}
impl<'b, B: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'b mut B> {
pub fn packet_mut(&'b mut self) -> Result<&'b mut [u8]> {
Ok(&mut self.buffer.as_mut()[..])
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for Packet<T> {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref()
}
}
impl<T: AsRef<[u8]> + AsMut<[u8]>> AsMut<[u8]> for Packet<T> {
fn as_mut(&mut self) -> &mut [u8] {
self.buffer.as_mut()
}
}
fn calculate_padding_length(payload_len: u32, block_size: u8) -> u8 {
let unpadded_len = 4 + 1 + payload_len;
let block_size = max(block_size, 8);
let padding_len = block_size - (unpadded_len % u32::from(block_size)) as u8;
if padding_len < 4 {
padding_len + block_size
} else {
padding_len
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::constants::msg::*;
#[rstest]
#[case::newkeys(
"0000000c0a1500000000000000000000",
0,
[0x00, 0x00, 0x00, 0x0c],
12,
10,
"15",
"00000000000000000000",
""
)]
fn packet_read_works(
#[case] input_data: &str,
#[case] input_mac_length: u8,
#[case] should_packet_length_bytes: [u8; 4],
#[case] should_packet_length: u32,
#[case] should_padding_length: u8,
#[case] should_payload: &str,
#[case] should_padding: &str,
#[case] should_mac: &str,
) {
let buf = hex::decode(input_data).unwrap();
let packet = Packet::new(&buf, input_mac_length);
let got_packet_length_bytes = packet.packet_length_bytes().unwrap();
let got_packet_length = packet.packet_length().unwrap();
let got_padding_length = packet.padding_length().unwrap();
let got_payload = hex::encode(packet.payload().unwrap());
let got_padding = hex::encode(packet.padding().unwrap());
let got_mac = hex::encode(packet.mac().unwrap());
assert_eq!(got_packet_length_bytes, should_packet_length_bytes);
assert_eq!(got_packet_length, should_packet_length);
assert_eq!(got_padding_length, should_padding_length);
assert_eq!(got_payload, should_payload);
assert_eq!(got_padding, should_padding);
assert_eq!(got_mac, should_mac);
}
#[rustfmt::skip]
#[rstest]
#[case("testdata/none-exec/02-client-kexinit.bin", 0, 1356, 4, SSH_MSG_KEXINIT, "")]
#[case("testdata/none-exec/03-server-kexinit.bin", 0, 828, 9, SSH_MSG_KEXINIT, "")]
#[case("testdata/none-exec/04-client-kexdh_init.bin", 0, 1228, 6, SSH_MSG_KEX_ECDH_INIT, "")]
#[case("testdata/none-exec/05-server-kexdh_reply.bin", 0, 1276, 8, SSH_MSG_KEX_ECDH_REPLY, "")]
#[case("testdata/none-exec/06-server-newkeys.bin", 0, 12, 10, SSH_MSG_NEWKEYS, "")]
#[case("testdata/none-exec/07-server-ext_info.bin", 8, 264, 11, SSH_MSG_EXT_INFO, "e89b5df177489bb7")]
#[case("testdata/none-exec/08-client-newkeys.bin", 0, 12, 10, SSH_MSG_NEWKEYS, "")]
#[case("testdata/none-exec/09-client-ext_info.bin", 8, 48, 5, SSH_MSG_EXT_INFO, "e8e665da71a89a5a")]
#[case("testdata/none-exec/10-client-service_request.bin", 8, 24, 6, SSH_MSG_SERVICE_REQUEST, "9149e02f1eaeab17")]
#[case("testdata/none-exec/11-server-service_accept.bin", 8, 24, 6, SSH_MSG_SERVICE_ACCEPT, "e54fbf63bd58851c")]
#[case("testdata/none-exec/12-client-userauth_request.bin", 8, 40, 4, SSH_MSG_USERAUTH_REQUEST, "be423d89f3dc729d")]
#[case("testdata/none-exec/13-server-ext_info.bin", 8, 192, 4, SSH_MSG_EXT_INFO, "b20d9663faeca278")]
#[case("testdata/none-exec/14-server-userauth_failure.bin", 8, 48, 11, SSH_MSG_USERAUTH_FAILURE, "df6e2ff2d2273011")]
#[case("testdata/none-exec/15-client-userauth_request.bin", 8, 120, 8, SSH_MSG_USERAUTH_REQUEST, "11393a00d1dc7990")]
#[case("testdata/none-exec/16-server-userauth_pk_ok.bin", 8, 80, 8, SSH_MSG_USERAUTH_PK_OK, "60f47fe6e7d44b9e")]
#[case("testdata/none-exec/17-client-userauth_request.bin", 8, 288, 8, SSH_MSG_USERAUTH_REQUEST, "1d93895c9a72fd74")]
#[case("testdata/none-exec/18-server-userauth_success.bin", 8, 8, 6, SSH_MSG_USERAUTH_SUCCESS, "95b8b19e0b9770f8")]
#[case("testdata/none-exec/19-client-channel_open.bin", 8, 32, 7, SSH_MSG_CHANNEL_OPEN, "b02d1efdd5ef24a7")]
#[case("testdata/none-exec/20-client-global_request.bin", 8, 40, 5, SSH_MSG_GLOBAL_REQUEST, "bc37d0ed73b10b28")]
#[case("testdata/none-exec/21-server-global_request.bin", 8, 808, 8, SSH_MSG_GLOBAL_REQUEST, "3d7f905963e88a8f")]
#[case("testdata/none-exec/22-server-debug.bin", 8, 128, 11, SSH_MSG_DEBUG, "d5f9403eb96abb6b")]
#[case("testdata/none-exec/23-server-debug.bin", 8, 128, 11, SSH_MSG_DEBUG, "b6de49eedb1e5c05")]
#[case("testdata/none-exec/24-server-channel_open_confirmation.bin", 8, 24, 6, SSH_MSG_CHANNEL_OPEN_CONFIRMATION, "1642db3f071b3784")]
#[case("testdata/none-exec/25-client-channel_request.bin", 8, 48, 11, SSH_MSG_CHANNEL_REQUEST, "de9d702e8157d450")]
#[case("testdata/none-exec/26-client-channel_request.bin", 8, 32, 7, SSH_MSG_CHANNEL_REQUEST, "2f2443c64a7b0a25")]
#[case("testdata/none-exec/27-server-channel_window_adjust.bin", 8, 16, 6, SSH_MSG_CHANNEL_WINDOW_ADJUST, "9dd9ec9efd85682c")]
#[case("testdata/none-exec/28-server-channel_success.bin", 8, 16, 10, SSH_MSG_CHANNEL_SUCCESS, "09175fff7ae9d128")]
#[case("testdata/none-exec/29-server-channel_extended_data.bin", 8, 56, 8, SSH_MSG_CHANNEL_EXTENDED_DATA, "c543158bc98d37a8")]
#[case("testdata/none-exec/30-server-channel_extended_data.bin", 8, 200, 7, SSH_MSG_CHANNEL_EXTENDED_DATA, "f26806b090dcbe8d")]
#[case("testdata/none-exec/31-server-channel_extended_data.bin", 8, 104, 10, SSH_MSG_CHANNEL_EXTENDED_DATA, "37ba892ac53c5974")]
#[case("testdata/none-exec/32-server-channel_data.bin", 8, 24, 9, SSH_MSG_CHANNEL_DATA, "4305f7ae1c568618")]
#[case("testdata/none-exec/33-server-channel_eof.bin", 8, 16, 10, SSH_MSG_CHANNEL_EOF, "0f051f4ac868bc36")]
#[case("testdata/none-exec/34-server-channel_request.bin", 8, 32, 6, SSH_MSG_CHANNEL_REQUEST, "ca763139dbcfef97")]
#[case("testdata/none-exec/35-server-channel_request.bin", 8, 32, 6, SSH_MSG_CHANNEL_REQUEST, "47ab0063a6b851ed")]
#[case("testdata/none-exec/36-server-channel_close.bin", 8, 16, 10, SSH_MSG_CHANNEL_CLOSE, "d5cea55189965da4")]
#[case("testdata/none-exec/37-client-channel_close.bin", 8, 16, 10, SSH_MSG_CHANNEL_CLOSE, "6eedb67af8145c45")]
#[case("testdata/none-exec/38-client-disconnect.bin", 8, 40, 6, SSH_MSG_DISCONNECT, "b7371119d6e47da2")]
fn packet_read_file_works(
#[case] packet_file: &str,
#[case] mac_len: u8,
#[case] packet_len: u32,
#[case] padding_len: u8,
#[case] message_code: u8,
#[case] mac: &str,
) {
let bytes = std::fs::read(packet_file).unwrap();
let packet = Packet::new(&bytes, mac_len);
let got_packet_len = packet.packet_length().unwrap();
let got_padding_len = packet.padding_length().unwrap();
let got_message_code = packet.payload().unwrap()[0];
let got_mac = hex::encode(packet.mac().unwrap());
assert_eq!(got_packet_len, packet_len);
assert_eq!(got_padding_len, padding_len);
assert_eq!(got_message_code, message_code);
assert_eq!(got_mac, mac);
}
#[rstest]
#[case("15", 0, 0, "0000000c0a1500000000000000000000")]
fn packet_write_works(
#[case] input_payload: &str,
#[case] input_block_size: u8,
#[case] input_mac_length: u8,
#[case] should_bytes: &str,
) {
let input_payload = hex::decode(input_payload).unwrap();
let packet = Packet::from_payload(input_payload, input_block_size, input_mac_length);
let got_bytes = hex::encode(packet.into_inner());
assert_eq!(got_bytes, should_bytes);
}
#[rstest]
#[case(1, 0, 10)]
#[case(37, 0, 6)]
#[case(179, 0, 8)]
#[case(699, 0, 8)]
#[case(1038, 0, 5)]
fn calculate_padding_length_works(
#[case] payload_len: u32,
#[case] block_size: u8,
#[case] should: u8,
) {
let got = calculate_padding_length(payload_len, block_size);
assert_eq!(got, should);
}
}