#![cfg_attr(fuzzing, allow(dead_code))]
#![cfg_attr(fuzzing, allow(unused_variables))]
#[allow(unused_imports)]
use {
crate::error::{Error, Result, TrapBug},
log::{debug, error, info, log, trace, warn},
};
use core::fmt;
use core::fmt::Debug;
use core::num::Wrapping;
use aes::cipher::{BlockSizeUser, KeyIvInit, KeySizeUser, StreamCipher};
use hmac::Mac;
use sha2::Digest as Sha2DigestForTrait;
use zeroize::ZeroizeOnDrop;
use crate::*;
use kex::{self, SessId};
use ssh_chapoly::SSHChaPoly;
use sshnames::*;
type Aes256Ctr32BE = ctr::Ctr32BE<aes::Aes256>;
type HmacSha256 = hmac::Hmac<sha2::Sha256>;
const SSH_MIN_PADLEN: usize = 4;
const SSH_MIN_BLOCK: usize = 8;
pub const SSH_LENGTH_SIZE: usize = 4;
pub const SSH_PAYLOAD_START: usize = SSH_LENGTH_SIZE + 1;
const MAX_IV_LEN: usize = 32;
const MAX_KEY_LEN: usize = 64;
#[derive(Debug)]
pub(crate) struct KeyState {
enc: KeysSend,
dec: KeysRecv,
pub seq_encrypt: Wrapping<u32>,
pub seq_decrypt: Wrapping<u32>,
strict_kex: bool,
done_first_kex: bool,
}
impl KeyState {
pub fn new_cleartext() -> Self {
KeyState {
enc: KeysSend::new_cleartext(),
dec: KeysRecv::new_cleartext(),
seq_encrypt: Wrapping(0),
seq_decrypt: Wrapping(0),
strict_kex: false,
done_first_kex: false,
}
}
pub fn is_send_cleartext(&self) -> bool {
matches!(self.enc.cipher, EncKey::NoCipher)
}
pub fn rekey_send(&mut self, keys: KeysSend, enable_strict: bool) {
self.enc = keys;
if enable_strict && !self.done_first_kex {
self.strict_kex = true;
}
self.done_first_kex = true;
if self.strict_kex {
self.seq_encrypt = Wrapping(0);
}
}
pub fn rekey_recv(&mut self, keys: KeysRecv) {
debug_assert!(
!matches!(self.enc.cipher, EncKey::NoCipher),
"Should have already performed rekey_enc"
);
self.dec = keys;
self.done_first_kex = true;
if self.strict_kex {
self.seq_decrypt = Wrapping(0);
}
}
pub fn recv_seq(&self) -> u32 {
self.seq_decrypt.0
}
pub fn decrypt_first_block(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
self.dec.decrypt_first_block(buf, self.seq_decrypt.0)
}
pub fn decrypt(&mut self, buf: &mut [u8]) -> Result<usize, Error> {
let e = self.dec.decrypt(buf, self.seq_decrypt.0);
self.seq_decrypt += 1;
e
}
pub fn encrypt(
&mut self,
payload_len: usize,
buf: &mut [u8],
) -> Result<usize, Error> {
let e = self.enc.encrypt(payload_len, buf, self.seq_encrypt.0);
self.seq_encrypt += 1;
e
}
pub fn size_block_dec(&self) -> usize {
self.dec.cipher.size_block()
}
pub fn max_enc_payload(&self, total_avail: usize) -> usize {
self.enc.max_enc_payload(total_avail)
}
}
#[derive(Debug)]
pub(crate) struct KeysSend {
cipher: EncKey,
integ: IntegKey,
}
impl KeysSend {
fn new_cleartext() -> Self {
Self { cipher: EncKey::NoCipher, integ: IntegKey::NoInteg }
}
pub fn new<CS: CliServ>(
kex_out: &kex::KexOutput,
sess_id: &SessId,
algos: &kex::Algos<CS>,
) -> Self {
let mut key = [0u8; MAX_KEY_LEN];
let mut iv = [0u8; MAX_IV_LEN];
let [iv_e, k_e, i_e] =
if CS::is_client() { ['A', 'C', 'E'] } else { ['B', 'D', 'F'] };
let cipher = {
let ci = kex_out.compute_key(
iv_e,
algos.cipher_enc.iv_len(),
&mut iv,
sess_id,
);
let ck = kex_out.compute_key(
k_e,
algos.cipher_enc.key_len(),
&mut key,
sess_id,
);
EncKey::from_cipher(&algos.cipher_enc, ck, ci).unwrap()
};
let integ = {
let ck = kex_out.compute_key(
i_e,
algos.integ_enc.key_len(),
&mut key,
sess_id,
);
IntegKey::from_integ(&algos.integ_enc, ck).unwrap()
};
Self { cipher, integ }
}
fn calc_encrypt_pad(&self, payload_len: usize) -> usize {
let size_block = self.cipher.size_block();
let len = 1
+ payload_len
+ if self.cipher.is_aead() { 0 } else { SSH_LENGTH_SIZE };
let mut padlen = size_block - len % size_block;
if padlen < SSH_MIN_PADLEN {
padlen += size_block
}
padlen
}
pub fn max_enc_payload(&self, total_avail: usize) -> usize {
let total_avail = total_avail.saturating_sub(self.integ.size_out());
let overhead = SSH_LENGTH_SIZE + 1 + SSH_MIN_PADLEN;
let mut space = total_avail;
let enc_len = if self.cipher.is_aead() {
total_avail.saturating_sub(SSH_LENGTH_SIZE)
} else {
total_avail
};
let extra_block = enc_len % self.cipher.size_block();
if extra_block != 0 {
space = space.saturating_sub(extra_block);
}
space.saturating_sub(overhead)
}
fn encrypt(
&mut self,
payload_len: usize,
buf: &mut [u8],
seq: u32,
) -> Result<usize, Error> {
let size_block = self.cipher.size_block();
let size_integ = self.integ.size_out();
let padlen = self.calc_encrypt_pad(payload_len);
let len = SSH_LENGTH_SIZE + 1 + payload_len + padlen;
if self.cipher.is_aead() {
debug_assert_eq!((len - SSH_LENGTH_SIZE) % size_block, 0);
} else {
debug_assert_eq!(len % size_block, 0);
};
if len + size_integ > buf.len() {
error!("Output buffer {} is too small for packet", buf.len());
return error::NoRoom.fail();
}
let blen = ((len - SSH_LENGTH_SIZE) as u32).to_be_bytes();
buf[..SSH_LENGTH_SIZE].copy_from_slice(&blen);
buf[SSH_LENGTH_SIZE] = padlen as u8;
let pad_start = SSH_LENGTH_SIZE + 1 + payload_len;
debug_assert_eq!(pad_start + padlen, len);
random::fill_random(&mut buf[pad_start..pad_start + padlen])?;
let (enc, rest) = buf.split_at_mut(len);
let (mac, _) = rest.split_at_mut(size_integ);
match self.integ {
IntegKey::ChaPoly => {}
IntegKey::NoInteg => {}
IntegKey::HmacSha256(k) => {
let mut h = HmacSha256::new_from_slice(&k).trap()?;
h.update(&seq.to_be_bytes());
h.update(enc);
let result = h.finalize();
mac.copy_from_slice(&result.into_bytes());
}
}
match &mut self.cipher {
EncKey::ChaPoly(k) => k.encrypt(seq, enc, mac).trap()?,
EncKey::Aes256Ctr(a) => {
a.apply_keystream(enc);
}
EncKey::NoCipher => {}
}
Ok(len + size_integ)
}
}
#[derive(Debug)]
pub(crate) struct KeysRecv {
cipher: DecKey,
integ: IntegKey,
}
impl KeysRecv {
fn new_cleartext() -> Self {
Self { cipher: DecKey::NoCipher, integ: IntegKey::NoInteg }
}
pub fn new<CS: CliServ>(
kex_out: &kex::KexOutput,
sess_id: &SessId,
algos: &kex::Algos<CS>,
) -> Self {
let mut key = [0u8; MAX_KEY_LEN];
let mut iv = [0u8; MAX_IV_LEN];
let [iv_d, k_d, i_d] =
if CS::is_client() { ['B', 'D', 'F'] } else { ['A', 'C', 'E'] };
let cipher = {
let ci = kex_out.compute_key(
iv_d,
algos.cipher_dec.iv_len(),
&mut iv,
sess_id,
);
let ck = kex_out.compute_key(
k_d,
algos.cipher_dec.key_len(),
&mut key,
sess_id,
);
DecKey::from_cipher(&algos.cipher_dec, ck, ci).unwrap()
};
let integ = {
let ck = kex_out.compute_key(
i_d,
algos.integ_dec.key_len(),
&mut key,
sess_id,
);
IntegKey::from_integ(&algos.integ_dec, ck).unwrap()
};
Self { cipher, integ }
}
fn decrypt_first_block(
&mut self,
buf: &mut [u8],
seq: u32,
) -> Result<usize, Error> {
if buf.len() < self.cipher.size_block() {
return Err(Error::bug());
}
#[cfg(fuzzing)]
let len = u32::from_be_bytes(buf[..SSH_LENGTH_SIZE].try_into().unwrap());
#[cfg(not(fuzzing))]
let len = match &mut self.cipher {
DecKey::ChaPoly(k) => k.packet_length(seq, buf).trap()?,
DecKey::Aes256Ctr(a) => {
a.apply_keystream(&mut buf[..16]);
u32::from_be_bytes(buf[..SSH_LENGTH_SIZE].try_into().unwrap())
}
DecKey::NoCipher => {
u32::from_be_bytes(buf[..SSH_LENGTH_SIZE].try_into().unwrap())
}
};
let total_len = len
.checked_add((SSH_LENGTH_SIZE + self.integ.size_out()) as u32)
.ok_or(Error::BadDecrypt)?;
Ok(total_len as usize)
}
fn decrypt(&mut self, buf: &mut [u8], seq: u32) -> Result<usize, Error> {
let size_block = self.cipher.size_block();
let size_integ = self.integ.size_out();
if buf.len() < size_block + size_integ {
debug!("Bad packet, {} smaller than block size", buf.len());
return error::SSHProto.fail();
}
let sublength = if self.cipher.is_aead() { SSH_LENGTH_SIZE } else { 0 };
let len = buf.len() - size_integ - sublength;
if !len.is_multiple_of(size_block) {
debug!("Bad packet, not multiple of block size");
return error::SSHProto.fail();
}
let (data, mac) = buf.split_at_mut(buf.len() - size_integ);
debug_assert!(data.len() >= size_block);
#[cfg(not(fuzzing))]
match &mut self.cipher {
DecKey::ChaPoly(k) => {
k.decrypt(seq, data, mac).map_err(|_| Error::BadDecrypt)?;
}
DecKey::Aes256Ctr(a) => {
a.apply_keystream(&mut data[16..]);
}
DecKey::NoCipher => {}
}
#[cfg(not(fuzzing))]
match self.integ {
IntegKey::ChaPoly => {}
IntegKey::NoInteg => {}
IntegKey::HmacSha256(k) => {
let mut h = HmacSha256::new_from_slice(&k).trap()?;
h.update(&seq.to_be_bytes());
h.update(data);
h.verify_slice(mac).map_err(|_| Error::BadDecrypt)?;
}
}
let padlen = data[SSH_LENGTH_SIZE] as usize;
if padlen < SSH_MIN_PADLEN {
debug!("Packet padding too short");
return error::SSHProto.fail();
}
let payload_len = buf
.len()
.checked_sub(SSH_LENGTH_SIZE + 1 + size_integ + padlen)
.ok_or_else(|| {
debug!("Bad padding length");
error::SSHProto.build()
})?;
Ok(payload_len)
}
}
#[derive(Debug, Clone)]
pub(crate) enum Cipher {
ChaPoly,
Aes256Ctr,
}
impl fmt::Display for Cipher {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let n = match self {
Self::ChaPoly => SSH_NAME_CHAPOLY,
Self::Aes256Ctr => SSH_NAME_AES256_CTR,
};
write!(f, "{n}")
}
}
impl Cipher {
pub fn from_name(name: &'static str) -> Result<Self, Error> {
match name {
SSH_NAME_CHAPOLY => Ok(Cipher::ChaPoly),
SSH_NAME_AES256_CTR => Ok(Cipher::Aes256Ctr),
_ => Err(Error::bug()),
}
}
pub fn key_len(&self) -> usize {
match self {
Cipher::ChaPoly => SSHChaPoly::KEY_LEN,
Cipher::Aes256Ctr => aes::Aes256::key_size(),
}
}
pub fn iv_len(&self) -> usize {
match self {
Cipher::ChaPoly => 0,
Cipher::Aes256Ctr => aes::Aes256::block_size(),
}
}
pub fn integ(&self) -> Option<Integ> {
match self {
Cipher::ChaPoly => Some(Integ::ChaPoly),
Cipher::Aes256Ctr => None,
}
}
}
#[derive(Clone, ZeroizeOnDrop)]
pub(crate) enum EncKey {
ChaPoly(SSHChaPoly),
Aes256Ctr(Aes256Ctr32BE),
NoCipher,
}
impl Debug for EncKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let n = match self {
Self::ChaPoly(_) => "ChaPoly",
Self::Aes256Ctr(_) => "Aes256Ctr",
Self::NoCipher => "NoCipher",
};
f.write_fmt(format_args!("EncKey::{n}"))
}
}
impl EncKey {
pub fn from_cipher<'a>(
cipher: &Cipher,
key: &'a [u8],
iv: &'a [u8],
) -> Result<Self, Error> {
match cipher {
Cipher::ChaPoly => {
Ok(EncKey::ChaPoly(SSHChaPoly::new_from_slice(key).trap()?))
}
Cipher::Aes256Ctr => Ok(EncKey::Aes256Ctr(
Aes256Ctr32BE::new_from_slices(key, iv).trap()?,
)),
}
}
pub fn is_aead(&self) -> bool {
match self {
EncKey::ChaPoly(_) => true,
EncKey::Aes256Ctr(_a) => false,
EncKey::NoCipher => false,
}
}
pub fn size_block(&self) -> usize {
match self {
EncKey::ChaPoly(_) => SSH_MIN_BLOCK,
EncKey::Aes256Ctr(_) => aes::Aes256::block_size(),
EncKey::NoCipher => SSH_MIN_BLOCK,
}
}
}
#[derive(Clone, ZeroizeOnDrop)]
pub(crate) enum DecKey {
ChaPoly(SSHChaPoly),
Aes256Ctr(Aes256Ctr32BE),
NoCipher,
}
impl Debug for DecKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let n = match self {
Self::ChaPoly(_) => "ChaPoly",
Self::Aes256Ctr(_) => "Aes256Ctr",
Self::NoCipher => "NoCipher",
};
f.write_fmt(format_args!("DecKey::{n}"))
}
}
impl DecKey {
pub fn from_cipher<'a>(
cipher: &Cipher,
key: &'a [u8],
iv: &'a [u8],
) -> Result<Self, Error> {
match cipher {
Cipher::ChaPoly => {
Ok(DecKey::ChaPoly(SSHChaPoly::new_from_slice(key).trap()?))
}
Cipher::Aes256Ctr => Ok(DecKey::Aes256Ctr(
Aes256Ctr32BE::new_from_slices(key, iv).trap()?,
)),
}
}
pub fn is_aead(&self) -> bool {
match self {
DecKey::ChaPoly(_) => true,
DecKey::Aes256Ctr(_a) => false,
DecKey::NoCipher => false,
}
}
pub fn size_block(&self) -> usize {
match self {
DecKey::ChaPoly(_) => SSH_MIN_BLOCK,
DecKey::Aes256Ctr(_) => aes::Aes256::block_size(),
DecKey::NoCipher => SSH_MIN_BLOCK,
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum Integ {
ChaPoly,
HmacSha256,
}
impl Integ {
pub fn from_name(name: &'static str) -> Result<Self, Error> {
match name {
SSH_NAME_HMAC_SHA256 => Ok(Integ::HmacSha256),
_ => Err(Error::bug()),
}
}
fn key_len(&self) -> usize {
match self {
Integ::ChaPoly => 0,
Integ::HmacSha256 => 32,
}
}
}
impl fmt::Display for Integ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let n = match self {
Self::ChaPoly => SSH_NAME_CHAPOLY,
Self::HmacSha256 => SSH_NAME_HMAC_SHA256,
};
write!(f, "{n}")
}
}
#[derive(Clone)]
pub(crate) enum IntegKey {
ChaPoly,
HmacSha256([u8; 32]),
NoInteg,
}
impl Debug for IntegKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let n = match self {
Self::ChaPoly => "ChaPoly",
Self::HmacSha256(_) => "HmacSha256",
Self::NoInteg => "NoInteg",
};
f.write_fmt(format_args!("IntegKey::{n}"))
}
}
impl IntegKey {
pub fn from_integ(integ: &Integ, key: &[u8]) -> Result<Self, Error> {
match integ {
Integ::ChaPoly => Ok(IntegKey::ChaPoly),
Integ::HmacSha256 => Ok(IntegKey::HmacSha256(key.try_into().trap()?)),
}
}
pub fn size_out(&self) -> usize {
match self {
IntegKey::ChaPoly => SSHChaPoly::TAG_LEN,
IntegKey::HmacSha256(_) => sha2::Sha256::output_size(),
IntegKey::NoInteg => 0,
}
}
}
#[cfg(test)]
mod tests {
use core::marker::PhantomData;
use crate::encrypt::*;
use crate::error::Error;
use crate::kex::KexOutput;
use crate::sshnames::SSH_NAME_CURVE25519;
use crate::sunsetlog::*;
#[allow(unused_imports)]
use pretty_hex::PrettyHex;
use sha2::Sha256;
fn do_roundtrips(
keys_enc: &mut KeyState,
keys_dec: &mut KeyState,
corrupt: bool,
) {
for i in 0usize..80 {
let mut v: std::vec::Vec<u8> = (0u8..i as u8 + 60).collect();
let orig_payload = v[SSH_PAYLOAD_START..SSH_PAYLOAD_START + i].to_vec();
let written = keys_enc.encrypt(i, v.as_mut_slice()).unwrap();
v.truncate(written);
if corrupt {
v[SSH_PAYLOAD_START] ^= 4;
}
let l = keys_dec.decrypt_first_block(v.as_mut_slice()).unwrap();
assert_eq!(l, v.len());
let dec = keys_dec.decrypt(v.as_mut_slice());
if corrupt {
assert!(matches!(dec, Err(Error::BadDecrypt)));
return;
}
let payload_len = dec.unwrap();
assert_eq!(payload_len, i);
let dec_payload = v[SSH_PAYLOAD_START..SSH_PAYLOAD_START + i].to_vec();
assert_eq!(orig_payload, dec_payload);
}
}
#[test]
fn roundtrip_nocipher() {
let mut ke = KeyState::new_cleartext();
let mut kd = KeyState::new_cleartext();
do_roundtrips(&mut ke, &mut kd, false);
}
#[test]
#[should_panic]
fn roundtrip_nocipher_corrupt() {
let mut ke = KeyState::new_cleartext();
let mut kd = KeyState::new_cleartext();
do_roundtrips(&mut ke, &mut kd, true);
}
fn algo_combos() -> impl Iterator<Item = Option<kex::Algos<Client>>> {
const COMBOS: [(Cipher, Integ, Cipher, Integ); 4] = [
(
Cipher::Aes256Ctr,
Integ::HmacSha256,
Cipher::Aes256Ctr,
Integ::HmacSha256,
),
(Cipher::ChaPoly, Integ::ChaPoly, Cipher::ChaPoly, Integ::ChaPoly),
(Cipher::Aes256Ctr, Integ::HmacSha256, Cipher::ChaPoly, Integ::ChaPoly),
(Cipher::ChaPoly, Integ::ChaPoly, Cipher::Aes256Ctr, Integ::HmacSha256),
];
COMBOS
.iter()
.map(|(ce, ie, cd, id)| {
Some(kex::Algos {
kex: kex::SharedSecret::from_name(SSH_NAME_CURVE25519).unwrap(),
hostsig: sign::SigType::Ed25519,
cipher_enc: ce.clone(),
cipher_dec: cd.clone(),
integ_enc: ie.clone(),
integ_dec: id.clone(),
discard_next: false,
send_ext_info: true,
strict_kex: false,
_cs: PhantomData,
})
})
.chain(core::iter::once(None))
}
#[test]
fn algo_roundtrips() {
init_test_log();
for mut algos in algo_combos() {
let mut keys_enc = KeyState::new_cleartext();
let mut keys_dec = KeyState::new_cleartext();
if let Some(algos) = algos.take() {
let h = SessId::from_slice(&Sha256::digest(
"some exchange hash".as_bytes(),
))
.unwrap();
let sess_id =
SessId::from_slice(&Sha256::digest("some sessid".as_bytes()))
.unwrap();
let sharedkey = b"hello";
let ko = KexOutput::new_test(sharedkey, &h);
let ko_b = KexOutput::new_test(sharedkey, &h);
trace!("algos enc {algos:?}");
let enc = KeysSend::new(&ko, &sess_id, &algos);
keys_enc.rekey_send(enc, algos.strict_kex);
let algos = algos.test_swap_to_server();
trace!("algos dec {algos:?}");
let e = KeysSend::new(&ko, &sess_id, &algos);
keys_dec.rekey_send(e, algos.strict_kex);
let dec = KeysRecv::new(&ko_b, &sess_id, &algos);
keys_dec.rekey_recv(dec);
} else {
trace!("Trying cleartext");
}
do_roundtrips(&mut keys_enc, &mut keys_dec, false);
if algos.is_some() {
do_roundtrips(&mut keys_enc, &mut keys_dec, true);
}
}
}
#[test]
fn max_enc_payload() {
init_test_log();
for algos in algo_combos() {
let mut keys = KeyState::new_cleartext();
if let Some(algos) = algos {
let h = SessId::from_slice(&Sha256::digest(b"some exchange hash"))
.unwrap();
let sess_id =
SessId::from_slice(&Sha256::digest(b"some sessid")).unwrap();
let sharedkey = b"hello";
let ko = KexOutput::new_test(sharedkey, &h);
let enc = KeysSend::new(&ko, &sess_id, &algos);
let dec = KeysRecv::new(&ko, &sess_id, &algos);
keys.rekey_send(enc, algos.strict_kex);
keys.rekey_recv(dec);
trace!("algos {algos:?}");
trace!("integ {}", keys.enc.integ.size_out());
} else {
trace!("cleartext");
}
let mut buf = [0u8; 100];
for i in 1..80 {
let p = keys.max_enc_payload(i);
trace!("i {i} p {p}");
if p > 0 {
let l = keys.encrypt(p, &mut buf).unwrap();
trace!("i {i} p {p} l {l}");
assert!(l <= i);
assert!(l >= i.saturating_sub(keys.enc.cipher.size_block()));
let l = keys.encrypt(p + 1, &mut buf).unwrap();
assert!(l > i);
}
}
}
}
}