use std::{
array::TryFromSliceError,
convert::TryInto,
fmt::{self, Debug, Formatter},
ops::Deref,
};
use aes::{Aes128, Aes192, Aes256};
use hmac::Hmac;
use rand::{rngs::OsRng, RngCore};
use sha1::Sha1;
use crate::{
packet::SeqNumber,
settings::{KeySettings, KeySize, Passphrase},
};
use super::wrap;
#[derive(Clone, Eq, PartialEq)]
pub struct Salt([u8; 16]);
impl Salt {
pub fn new_random() -> Self {
let mut salt = [0; 16];
OsRng.fill_bytes(&mut salt[..]);
Self(salt)
}
pub fn try_from(bytes: &[u8]) -> Result<Salt, TryFromSliceError> {
Ok(Salt(bytes[..].try_into()?))
}
pub fn generate_strean_iv_for(&self, seq_number: SeqNumber) -> StreamInitializationVector {
let salt = self.0;
let mut out = [0; 16];
out[0..14].copy_from_slice(&salt[..14]);
for (i, b) in seq_number.0.to_be_bytes().iter().enumerate() {
out[i + 10] ^= *b;
}
StreamInitializationVector(out)
}
pub fn as_slice(&self) -> &[u8] {
&self.0
}
}
impl Debug for Salt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "Salt(0x{})", hex::encode_upper(self.0))
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct StreamInitializationVector([u8; 16]);
impl StreamInitializationVector {
pub fn try_from(slice: &[u8]) -> Result<Self, TryFromSliceError> {
Ok(StreamInitializationVector(slice[..].try_into()?))
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
}
impl Debug for StreamInitializationVector {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "StreamIV(0x{})", hex::encode_upper(self.0))
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct WrapInitializationVector([u8; 8]);
impl WrapInitializationVector {
pub fn try_from(slice: &[u8]) -> Result<Self, TryFromSliceError> {
Ok(WrapInitializationVector(slice[..].try_into()?))
}
}
impl Debug for WrapInitializationVector {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "KeyIV(0x{})", hex::encode_upper(self.0))
}
}
#[derive(Clone, Eq, PartialEq)]
pub enum EncryptionKey {
Bytes16([u8; 16]),
Bytes24([u8; 24]),
Bytes32([u8; 32]),
}
impl EncryptionKey {
pub fn new_random(size: KeySize) -> Self {
use EncryptionKey::*;
fn new_key<const N: usize>() -> [u8; N] {
let mut key = [0u8; N];
OsRng.fill_bytes(&mut key[..]);
key
}
match size {
KeySize::AES128 => Bytes16(new_key()),
KeySize::AES192 => Bytes24(new_key()),
KeySize::AES256 => Bytes32(new_key()),
KeySize::Unspecified => Bytes16(new_key()),
}
}
pub fn try_from(bytes: &[u8]) -> Result<EncryptionKey, TryFromSliceError> {
use EncryptionKey::*;
match bytes.len() {
16 => Ok(Bytes16(bytes[..].try_into()?)),
24 => Ok(Bytes24(bytes[..].try_into()?)),
_ => Ok(Bytes32(bytes[..].try_into()?)),
}
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
use EncryptionKey::*;
match self {
Bytes16(key) => key.len(),
Bytes24(key) => key.len(),
Bytes32(key) => key.len(),
}
}
pub fn as_bytes(&self) -> &[u8] {
use EncryptionKey::*;
match self {
Bytes16(key) => &key[..],
Bytes24(key) => &key[..],
Bytes32(key) => &key[..],
}
}
}
impl fmt::Debug for EncryptionKey {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
use EncryptionKey::*;
match self {
Bytes16(_) => f.debug_struct("EncryptionKey::Bytes16"),
Bytes24(_) => f.debug_struct("EncryptionKey::Bytes24"),
Bytes32(_) => f.debug_struct("EncryptionKey::Bytes32"),
}
.finish()
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct KeyEncryptionKey(EncryptionKey);
impl KeyEncryptionKey {
pub fn new(key_settings: &KeySettings, salt: &Salt) -> Self {
let key_size = key_settings.key_size;
let passphrase = &key_settings.passphrase;
fn calculate_pbkdf2(passphrase: &Passphrase, salt: &Salt, key: &mut [u8]) {
let salt = salt.0;
const ROUNDS: u32 = 2048;
let salt_len = usize::min(8, salt.len());
pbkdf2::pbkdf2::<Hmac<Sha1>>(
passphrase.as_bytes(),
&salt[salt.len() - salt_len..], ROUNDS,
&mut *key,
);
}
fn new_key<const N: usize>(passphrase: &Passphrase, salt: &Salt) -> [u8; N] {
let mut key = [0u8; N];
calculate_pbkdf2(passphrase, salt, &mut key);
key
}
use EncryptionKey::*;
let key = match key_size {
KeySize::AES128 => Bytes16(new_key(passphrase, salt)),
KeySize::AES192 => Bytes24(new_key(passphrase, salt)),
KeySize::AES256 => Bytes32(new_key(passphrase, salt)),
KeySize::Unspecified => Bytes16(new_key(passphrase, salt)),
};
KeyEncryptionKey(key)
}
pub fn encrypt_wrapped_keys(&self, keys: &[u8]) -> Vec<u8> {
let mut encrypted_keys = vec![0; keys.len() + 8];
use aes::NewBlockCipher;
use EncryptionKey::*;
match &self.0 {
Bytes16(key) => wrap::aes_wrap(
&Aes128::new(key[..].into()),
None,
&mut encrypted_keys,
keys,
),
Bytes24(key) => wrap::aes_wrap(
&Aes192::new(key[..].into()),
None,
&mut encrypted_keys,
keys,
),
Bytes32(key) => wrap::aes_wrap(
&Aes256::new(key[..].into()),
None,
&mut encrypted_keys,
keys,
),
}
encrypted_keys
}
pub fn decrypt_wrapped_keys(
&self,
wrapped_keys: &[u8],
) -> Result<Vec<u8>, WrapInitializationVector> {
use aes::NewBlockCipher;
use EncryptionKey::*;
let mut keys = vec![0; wrapped_keys.len() - 8];
let mut iv = [0; 8];
match &self.0 {
Bytes16(key) => wrap::aes_unwrap(
&Aes128::new(key[..].into()),
&mut iv,
&mut keys,
wrapped_keys,
),
Bytes24(key) => wrap::aes_unwrap(
&Aes192::new(key[..].into()),
&mut iv,
&mut keys,
wrapped_keys,
),
Bytes32(key) => wrap::aes_unwrap(
&Aes256::new(key[..].into()),
&mut iv,
&mut keys,
wrapped_keys,
),
}
if iv != wrap::DEFAULT_IV {
return Err(WrapInitializationVector(iv));
}
Ok(keys)
}
}
impl Deref for KeyEncryptionKey {
type Target = EncryptionKey;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl fmt::Debug for KeyEncryptionKey {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
use EncryptionKey::*;
match &self.0 {
Bytes16(_) => f.debug_struct("KeyEncryptionKey::Bytes16"),
Bytes24(_) => f.debug_struct("KeyEncryptionKey::Bytes24"),
Bytes32(_) => f.debug_struct("KeyEncryptionKey::Bytes32"),
}
.finish()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn kek_generate() {
let key_settings = KeySettings {
key_size: KeySize::AES128,
passphrase: "password123".into(),
};
let expected_kek = &hex::decode(b"08F2758F41E4244D00057C9CEBEB95FC").unwrap()[..];
let salt =
Salt::try_from(&hex::decode(b"7D59759C2B1A3F0B06C7028790C81C7D").unwrap()[..]).unwrap();
let kek = KeyEncryptionKey::new(&key_settings, &salt);
assert_eq!(kek.0.as_bytes(), expected_kek);
assert_eq!(format!("{:?}", kek), "KeyEncryptionKey::Bytes16");
assert_eq!(format!("{:?}", kek.deref()), "EncryptionKey::Bytes16");
assert_ne!(Salt::new_random(), Salt::new_random());
}
#[test]
fn generate_iv() {
let salt =
Salt::try_from(&hex::decode(b"87647f8a2361fb1a9e692de576985949").unwrap()[..]).unwrap();
let expected_iv = StreamInitializationVector::try_from(
&hex::decode(b"87647f8a2361fb1a9e6907af1b810000").unwrap()[..],
)
.unwrap();
let iv = salt.generate_strean_iv_for(SeqNumber(709520665));
assert_eq!(iv, expected_iv);
assert_eq!(
format!("{:?}", iv),
"StreamIV(0x87647F8A2361FB1A9E6907AF1B810000)"
);
assert_eq!(
format!("{:?}", salt),
"Salt(0x87647F8A2361FB1A9E692DE576985949)"
);
assert_ne!(Salt::new_random(), Salt::new_random());
}
}