use bstr::ByteSlice;
use scroll::NETWORK;
use scroll::Pread;
use scroll::Pwrite;
use crate::types::Error;
use crate::types::Result;
pub trait SshDecode<'buf> {
fn read_byte(&self, offset: &mut usize) -> Result<u8>;
fn read_boolean(&self, offset: &mut usize) -> Result<bool>;
fn read_uint32(&self, offset: &mut usize) -> Result<u32>;
fn read_uint64(&self, offset: &mut usize) -> Result<u64>;
fn read_byte_string(&'buf self, offset: &mut usize) -> Result<&'buf [u8]>;
fn read_bytes_exact(&'buf self, offset: &mut usize, len: usize) -> Result<&'buf [u8]>;
fn read_bytes_until(&'buf self, offset: &mut usize, stop_char: u8) -> Result<&'buf [u8]>;
fn read_mpint(&'buf self, offset: &mut usize) -> Result<&'buf [u8]>;
fn read_name_list(&'buf self, offset: &mut usize) -> Result<impl Iterator<Item = &'buf [u8]>>;
}
impl<'buf, B: AsRef<[u8]> + ?Sized> SshDecode<'buf> for B {
fn read_byte(&self, offset: &mut usize) -> Result<u8> {
self.as_ref()
.gread_with(offset, NETWORK)
.map_err(Into::into)
}
fn read_boolean(&self, offset: &mut usize) -> Result<bool> {
self.read_byte(offset).map(|byte| byte != 0)
}
fn read_uint32(&self, offset: &mut usize) -> Result<u32> {
self.as_ref()
.gread_with(offset, NETWORK)
.map_err(Into::into)
}
fn read_uint64(&self, offset: &mut usize) -> Result<u64> {
self.as_ref()
.gread_with(offset, NETWORK)
.map_err(Into::into)
}
fn read_byte_string(&'buf self, offset: &mut usize) -> Result<&'buf [u8]> {
let len = self.read_uint32(offset)?;
self.as_ref()
.gread_with(offset, len as usize)
.map_err(Into::into)
}
fn read_bytes_exact(&'buf self, offset: &mut usize, len: usize) -> Result<&'buf [u8]> {
self.as_ref().gread_with(offset, len).map_err(Into::into)
}
fn read_bytes_until(&'buf self, offset: &mut usize, stop_char: u8) -> Result<&'buf [u8]> {
let buf = self.as_ref();
let mut pos = *offset;
while pos < buf.len() {
if buf[pos] == stop_char {
break;
}
if pos == buf.len() - 1 {
break;
}
pos += 1;
}
let read = &buf[*offset..=pos];
*offset = pos + 1;
Ok(read)
}
fn read_mpint(&'buf self, offset: &mut usize) -> Result<&'buf [u8]> {
self.read_byte_string(offset)
}
fn read_name_list(&'buf self, offset: &mut usize) -> Result<impl Iterator<Item = &'buf [u8]>> {
let string = self.read_byte_string(offset)?;
let mut iter = string.split_str(",");
if string.is_empty() {
iter.next();
}
Ok(iter)
}
}
pub trait SshEncode {
fn write_byte(&mut self, data: u8, offset: &mut usize) -> Result<usize>;
fn write_boolean(&mut self, data: bool, offset: &mut usize) -> Result<usize>;
fn write_uint32(&mut self, data: u32, offset: &mut usize) -> Result<usize>;
fn write_uint64(&mut self, data: u64, offset: &mut usize) -> Result<usize>;
fn write_byte_string(&mut self, data: &[u8], offset: &mut usize) -> Result<usize>;
fn write_bytes_exact(&mut self, data: &[u8], offset: &mut usize) -> Result<usize>;
fn write_mpint(&mut self, data: &[u8], offset: &mut usize) -> Result<usize>;
}
impl<B: AsRef<[u8]> + AsMut<[u8]>> SshEncode for B {
fn write_byte(&mut self, data: u8, offset: &mut usize) -> Result<usize> {
self.as_mut()
.gwrite_with(data, offset, NETWORK)
.map_err(Into::into)
}
fn write_boolean(&mut self, data: bool, offset: &mut usize) -> Result<usize> {
let byte = u8::from(data);
self.write_byte(byte, offset)
}
fn write_uint32(&mut self, data: u32, offset: &mut usize) -> Result<usize> {
self.as_mut()
.gwrite_with(data, offset, NETWORK)
.map_err(Into::into)
}
fn write_uint64(&mut self, data: u64, offset: &mut usize) -> Result<usize> {
self.as_mut()
.gwrite_with(data, offset, NETWORK)
.map_err(Into::into)
}
fn write_byte_string(&mut self, data: &[u8], offset: &mut usize) -> Result<usize> {
let mut wrote = self.write_uint32(data.len() as u32, offset)?;
wrote += self.as_mut().gwrite(data, offset).map_err(Error::from)?;
Ok(wrote)
}
fn write_bytes_exact(&mut self, data: &[u8], offset: &mut usize) -> Result<usize> {
self.as_mut().gwrite(data, offset).map_err(Into::into)
}
fn write_mpint(&mut self, data: &[u8], offset: &mut usize) -> Result<usize> {
self.write_byte_string(data, offset)
}
}
#[cfg(test)]
mod tests {
use bstr::B;
use rstest::rstest;
use super::*;
#[rstest]
#[case("00", u8::MIN)]
#[case("ff", u8::MAX)]
fn byte_roundtrip(#[case] input: &str, #[case] output: u8) {
let buf = hex::decode(input).unwrap();
let got = buf.read_byte(&mut 0).unwrap();
assert_eq!(got, output);
let mut buf = vec![0u8; buf.len()];
let _wrote = buf.write_byte(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), input);
}
#[rstest]
#[case("00", false)]
#[case("01", true)]
#[case("ff", true)]
fn boolean_roundtrip(#[case] input: &str, #[case] output: bool) {
let buf = hex::decode(input).unwrap();
let got = buf.read_boolean(&mut 0).unwrap();
assert_eq!(got, output);
let should = match output {
false => "00",
true => "01",
};
let mut buf = vec![0u8; buf.len()];
let mut _wrote = buf.write_boolean(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), should);
}
#[rstest]
#[case("00000000", u32::MIN)]
#[case("ffffffff", u32::MAX)]
#[case("29b7f4aa", 699921578)]
fn uint32_roundtrip(#[case] input: &str, #[case] output: u32) {
let buf = hex::decode(input).unwrap();
let got = buf.read_uint32(&mut 0).unwrap();
assert_eq!(got, output);
let mut buf = vec![0u8; buf.len()];
let _wrote = buf.write_uint32(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), input);
}
#[rstest]
#[case("0000000000000000", u64::MIN)]
#[case("ffffffffffffffff", u64::MAX)]
fn uint64_roundtrip(#[case] input: &str, #[case] output: u64) {
let buf = hex::decode(input).unwrap();
let got = buf.read_uint64(&mut 0).unwrap();
assert_eq!(got, output);
let mut buf = vec![0u8; buf.len()];
let _wrote = buf.write_uint64(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), input);
}
#[rstest]
#[case("00000000", &[])]
#[case("0000000809a378f9b2e332a7", &[0x09, 0xa3, 0x78, 0xf9, 0xb2, 0xe3, 0x32, 0xa7])]
#[case("000000020080", &[0x00, 0x80])]
#[case("00000002edcc", &[0xed, 0xcc])]
#[case("00000005ff21524111", &[0xff, 0x21, 0x52, 0x41, 0x11])]
fn mpint_roundtrip(#[case] input: &str, #[case] output: &[u8]) {
let buf = hex::decode(input).unwrap();
let got = buf.read_mpint(&mut 0).unwrap();
assert_eq!(got, output);
let mut buf = vec![0u8; buf.len()];
let _wrote = buf.write_mpint(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), input);
}
#[rstest]
#[case("00000000", B(""))]
#[case("0000000774657374696e67", B("testing"))]
fn byte_string_roundtrop(#[case] input: &str, #[case] output: &[u8]) {
let buf = hex::decode(input).unwrap();
let got = buf.read_byte_string(&mut 0).unwrap();
assert_eq!(got, output);
let mut buf = vec![0u8; buf.len()];
let _wrote = buf.write_byte_string(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), input);
}
#[rstest]
#[case("", B(""))]
#[case("74657374696e67", B("testing"))]
fn bytes_exact_roundtrip(#[case] input: &str, #[case] output: &[u8]) {
let buf = hex::decode(input).unwrap();
let got = buf.read_bytes_exact(&mut 0, output.len()).unwrap();
assert_eq!(got, output);
let mut buf = vec![0u8; buf.len()];
let _wrote = buf.write_bytes_exact(output, &mut 0).unwrap();
assert_eq!(hex::encode(&buf), input);
}
#[rstest]
#[case("74657374696e67", b't', B("t"))]
#[case("74657374696e67", b'i', B("testi"))]
#[case("74657374696e67", b'z', B("testing"))]
fn bytes_until_read(#[case] input: &str, #[case] stop_char: u8, #[case] output: &[u8]) {
let buf = hex::decode(input).unwrap();
let got = buf.read_bytes_until(&mut 0, stop_char).unwrap();
assert_eq!(got, output);
}
#[rstest]
#[case("00000000", &[])]
#[case("000000047a6c6962", &[B("zlib")])]
#[case("000000097a6c69622c6e6f6e65", &[B("zlib"), B("none")])]
fn name_list_read(#[case] input: &str, #[case] output: &[&[u8]]) {
let buf = hex::decode(input).unwrap();
let got = buf.read_name_list(&mut 0).unwrap().collect::<Vec<_>>();
assert_eq!(got, output);
}
}