use cryptovec::CryptoVec;
pub use openssl::bn::{BigNum, BigNumRef};
use std::io;
use std::io::{Read, Result, Write};
use std::str;
use zeroize::{Zeroize, Zeroizing};
const MAX_BIGNUM: usize = 16384 / 8;
#[derive(Debug, Default)]
pub struct SshBuf {
read_pos: usize,
buf: CryptoVec,
}
impl SshBuf {
pub fn new() -> SshBuf {
SshBuf {
read_pos: 0,
buf: CryptoVec::new(),
}
}
pub fn with_vec(v: CryptoVec) -> SshBuf {
SshBuf {
read_pos: 0,
buf: v,
}
}
pub fn position(&self) -> usize {
self.read_pos
}
pub fn set_position(&mut self, offset: usize) {
if offset > self.buf.len() {
panic!("Offset exceed length");
}
self.read_pos = offset;
}
pub fn into_inner(self) -> CryptoVec {
self.buf
}
pub fn get_ref(&self) -> &CryptoVec {
&self.buf
}
pub fn as_slice(&self) -> &[u8] {
&self.buf
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
}
impl Read for SshBuf {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if self.read_pos >= self.buf.len() {
return Ok(0);
}
let n = self.buf.write_all_from(self.read_pos, buf)?;
self.read_pos += n;
Ok(n)
}
}
impl Write for SshBuf {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
self.buf.extend(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
pub trait SshReadExt {
fn read_bool(&mut self) -> Result<bool>;
fn read_uint8(&mut self) -> io::Result<u8>;
fn read_uint32(&mut self) -> io::Result<u32>;
fn read_uint64(&mut self) -> io::Result<u64>;
fn read_string(&mut self) -> io::Result<Vec<u8>>;
fn read_utf8(&mut self) -> io::Result<String>;
fn read_mpint(&mut self) -> io::Result<BigNum>;
}
impl<R: io::Read + ?Sized> SshReadExt for R {
fn read_bool(&mut self) -> io::Result<bool> {
let i = Zeroizing::new(self.read_uint8()?);
Ok(*i != 0)
}
fn read_uint8(&mut self) -> io::Result<u8> {
let mut buf = Zeroizing::new([0u8; 1]);
self.read_exact(&mut *buf)?;
Ok(buf[0])
}
fn read_uint32(&mut self) -> io::Result<u32> {
let mut buf = Zeroizing::new([0u8; 4]);
self.read_exact(&mut *buf)?;
Ok(u32::from_be_bytes(*buf))
}
fn read_uint64(&mut self) -> io::Result<u64> {
let mut buf = Zeroizing::new([0u8; 8]);
self.read_exact(&mut *buf)?;
Ok(u64::from_be_bytes(*buf))
}
fn read_string(&mut self) -> io::Result<Vec<u8>> {
let len = self.read_uint32()? as usize;
let mut buf = vec![0u8; len];
match self.read_exact(buf.as_mut_slice()) {
Ok(_) => Ok(buf),
Err(e) => {
buf.zeroize();
Err(e)
}
}
}
fn read_utf8(&mut self) -> io::Result<String> {
let mut buf = self.read_string()?;
match str::from_utf8(&buf) {
Ok(_) => unsafe {
Ok(String::from_utf8_unchecked(buf))
},
Err(_) => {
buf.zeroize();
Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid UTF-8 sequence",
))
}
}
}
fn read_mpint(&mut self) -> io::Result<BigNum> {
let data = Zeroizing::new(self.read_string()?);
to_bignum(&data)
}
}
fn to_bignum(data: &[u8]) -> io::Result<BigNum> {
if !data.is_empty() && data[0] & 0x80 != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Negative Big Number",
));
}
if (data.len() > MAX_BIGNUM + 1) || (data.len() == MAX_BIGNUM + 1 && data[0] != 0) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Big Number Too Long",
));
}
let mut i = 0;
let mut iter = data.iter();
while let Some(0) = iter.next() {
i += 1;
}
match BigNum::from_slice(&data[i..]) {
Ok(bn) => Ok(bn),
Err(_) => Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid Big Number",
)),
}
}
pub trait SshWriteExt {
fn write_bool(&mut self, value: bool) -> io::Result<()>;
fn write_uint8(&mut self, value: u8) -> io::Result<()>;
fn write_uint32(&mut self, value: u32) -> io::Result<()>;
fn write_uint64(&mut self, value: u64) -> io::Result<()>;
fn write_string(&mut self, buf: &[u8]) -> io::Result<()>;
fn write_utf8(&mut self, value: &str) -> io::Result<()>;
fn write_mpint(&mut self, value: &BigNumRef) -> io::Result<()>;
}
impl<W: io::Write + ?Sized> SshWriteExt for W {
fn write_bool(&mut self, value: bool) -> io::Result<()> {
let i = if value { 1u8 } else { 0u8 };
self.write_uint8(i)?;
Ok(())
}
fn write_uint8(&mut self, value: u8) -> io::Result<()> {
self.write_all(&[value])?;
Ok(())
}
fn write_uint32(&mut self, value: u32) -> io::Result<()> {
let buf = Zeroizing::new(value.to_be_bytes());
self.write_all(&*buf)?;
Ok(())
}
fn write_uint64(&mut self, value: u64) -> io::Result<()> {
let buf = Zeroizing::new(value.to_be_bytes());
self.write_all(&*buf)?;
Ok(())
}
fn write_string(&mut self, buf: &[u8]) -> io::Result<()> {
self.write_uint32(buf.len() as u32)?;
self.write_all(buf)?;
Ok(())
}
fn write_utf8(&mut self, value: &str) -> io::Result<()> {
self.write_string(value.as_bytes())?;
Ok(())
}
fn write_mpint(&mut self, value: &BigNumRef) -> io::Result<()> {
let mut buf = Zeroizing::new(vec![0x00u8]);
let bnbuf = Zeroizing::new(value.to_vec());
buf.reserve(bnbuf.len());
buf.extend(bnbuf.as_slice());
if (buf[1] & 0x80) > 0 {
self.write_string(&buf[..])
} else {
self.write_string(&buf[1..])
}
}
}