use argon2::Argon2;
use blake3::Hasher;
use chacha20::{
cipher::{generic_array::GenericArray, NewCipher, StreamCipher},
XChaCha20, XNonce,
};
use rand::{Rng, RngCore};
use redacted::RedactedBytes;
use serde::{Deserialize, Serialize};
use snafu::{ensure, ResultExt};
use zeroize::{Zeroize, Zeroizing};
use crate::{
crypto::types::{CipherText, ClearText},
error::{Argon2Failure, BackendError, BadHMAC},
};
pub trait Key {
fn encryption_key(&self) -> &chacha20::Key;
fn hmac_key(&self) -> &[u8; 32];
}
#[derive(Debug, Hash, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Nonce(pub(crate) RedactedBytes<24>);
impl Nonce {
pub fn random() -> Self {
let mut data = [0_u8; 24];
rand::thread_rng().fill(&mut data[..]);
Self(data.into())
}
pub fn nonce(&self) -> &XNonce {
GenericArray::from_slice(&self.0)
}
}
#[derive(Hash, Clone, Serialize, Deserialize, Zeroize)]
#[zeroize(drop)]
pub struct RootKey {
encryption: RedactedBytes<32>,
hmac: RedactedBytes<32>,
entropy: RedactedBytes<256>,
}
impl RootKey {
pub fn random() -> Self {
let mut rand = rand::thread_rng();
let mut ret = Self::null();
rand.fill(&mut ret.encryption[..]);
rand.fill(&mut ret.hmac[..]);
rand.fill(&mut ret.entropy[..]);
ret
}
pub fn null() -> Self {
RootKey {
encryption: [0_u8; 32].into(),
hmac: [0_u8; 32].into(),
entropy: [0_u8; 256].into(),
}
}
pub fn encrypt(&self, password: &[u8]) -> Result<EncryptedRootKey, BackendError> {
let mut salt = [0_u8; 32];
rand::thread_rng().fill(&mut salt[..]);
let argon = Argon2::default();
let mut argon_output = Zeroizing::new([0_u8; 64]);
argon
.hash_password_into(password, &salt, &mut argon_output[..])
.context(Argon2Failure)?;
let encryption_key: &chacha20::Key = GenericArray::from_slice(&argon_output[0..32]);
let mut hmac_key = Zeroizing::new([0_u8; 32]);
hmac_key.copy_from_slice(&argon_output[32..]);
let nonce = Nonce::random();
let mut serial = serde_cbor::to_vec(self).expect("Infallible");
let mut chacha = XChaCha20::new(encryption_key, nonce.nonce());
chacha.apply_keystream(&mut serial[..]);
let hmac: [u8; 32] = blake3::keyed_hash(&hmac_key, &serial).into();
Ok(EncryptedRootKey {
nonce,
hmac: hmac.into(),
salt: salt.into(),
payload: serial,
})
}
pub fn derive(&self, namespace: &str) -> DerivedKey {
let mut nonce_array = [0_u8; 32];
let mut nonce_bytes = [0_u8; 16];
rand::thread_rng().fill_bytes(&mut nonce_bytes[..]);
for (index, byte) in nonce_bytes.into_iter().enumerate() {
let upper_nibble = match byte & 0xF0_u8 {
0x00 => 0x30_u8,
0x10 => 0x31_u8,
0x20 => 0x32_u8,
0x30 => 0x33_u8,
0x40 => 0x34_u8,
0x50 => 0x35_u8,
0x60 => 0x36_u8,
0x70 => 0x37_u8,
0x80 => 0x38_u8,
0x90 => 0x39_u8,
0xA0 => 0x41_u8,
0xB0 => 0x42_u8,
0xC0 => 0x43_u8,
0xD0 => 0x44_u8,
0xE0 => 0x45_u8,
0xF0 => 0x46_u8,
_ => unreachable!(),
};
let lower_nibble = match byte & 0x0F_u8 {
0x00 => 0x30_u8,
0x01 => 0x31_u8,
0x02 => 0x32_u8,
0x03 => 0x33_u8,
0x04 => 0x34_u8,
0x05 => 0x35_u8,
0x06 => 0x36_u8,
0x07 => 0x37_u8,
0x08 => 0x38_u8,
0x09 => 0x39_u8,
0x0A => 0x41_u8,
0x0B => 0x42_u8,
0x0C => 0x43_u8,
0x0D => 0x44_u8,
0x0E => 0x45_u8,
0x0F => 0x46_u8,
_ => unreachable!(),
};
nonce_array[index * 2] = upper_nibble;
nonce_array[index * 2 + 1] = lower_nibble;
}
let nonce = if let Ok(nonce) = std::str::from_utf8(&nonce_array[..]) {
nonce
} else {
unreachable!()
};
let context_string = format!("snapper-box nonce: {} namespace: {}", &*nonce, namespace);
self.derive_with_context(context_string)
}
pub fn derive_with_context(&self, context_string: String) -> DerivedKey {
let mut hasher = Hasher::new_derive_key(&context_string);
hasher.update(&self.entropy[..]);
let mut ret = DerivedKey {
encryption: [0_u8; 32].into(),
hmac: [0_u8; 32].into(),
context_string,
};
let mut output = hasher.finalize_xof();
output.fill(&mut ret.encryption[..]);
output.fill(&mut ret.hmac[..]);
ret
}
}
impl Key for RootKey {
fn encryption_key(&self) -> &chacha20::Key {
GenericArray::from_slice(&self.encryption)
}
fn hmac_key(&self) -> &[u8; 32] {
self.hmac.as_ref()
}
}
#[derive(Debug, Hash, Clone, Serialize, Deserialize)]
pub struct EncryptedRootKey {
nonce: Nonce,
hmac: RedactedBytes<32>,
salt: RedactedBytes<32>,
#[serde(with = "serde_bytes")]
payload: Vec<u8>,
}
impl EncryptedRootKey {
pub fn decrypt(&self, password: &[u8]) -> Result<RootKey, BackendError> {
let argon = Argon2::default();
let mut argon_output = Zeroizing::new([0_u8; 64]);
argon
.hash_password_into(password, &self.salt, &mut argon_output[..])
.context(Argon2Failure)?;
let encryption_key: &chacha20::Key = GenericArray::from_slice(&argon_output[0..32]);
let mut hmac_key = Zeroizing::new([0_u8; 32]);
hmac_key.copy_from_slice(&argon_output[32..]);
let hmac = blake3::keyed_hash(&hmac_key, &self.payload);
ensure!(hmac.eq(&*self.hmac), BadHMAC);
let mut data = Zeroizing::new(self.payload.clone());
let mut chacha = XChaCha20::new(encryption_key, self.nonce.nonce());
chacha.apply_keystream(&mut data[..]);
match serde_cbor::from_slice(&data) {
Ok(x) => Ok(x),
Err(_) => {
Err(BackendError::KeyDeserialization)
}
}
}
}
#[derive(Hash, Clone, Serialize, Deserialize, Zeroize)]
#[zeroize(drop)]
pub struct DerivedKey {
encryption: RedactedBytes<32>,
hmac: RedactedBytes<32>,
context_string: String,
}
impl Key for DerivedKey {
fn encryption_key(&self) -> &chacha20::Key {
GenericArray::from_slice(&self.encryption)
}
fn hmac_key(&self) -> &[u8; 32] {
self.hmac.as_ref()
}
}
impl DerivedKey {
pub fn encrypt(
&self,
root_key: &RootKey,
) -> Result<EncryptedDerivedKey<'static>, BackendError> {
let cleartext = ClearText::new(self)?;
let ciphertext = cleartext.encrypt(root_key, None)?;
Ok(EncryptedDerivedKey {
encrypted_key: ciphertext,
})
}
}
#[derive(Hash, Clone, Serialize, Deserialize)]
pub struct EncryptedDerivedKey<'a> {
encrypted_key: CipherText<'a>,
}
impl EncryptedDerivedKey<'_> {
pub fn decrypt(&self, root_key: &RootKey) -> Result<DerivedKey, BackendError> {
let cleartext = self.encrypted_key.decrypt(root_key)?;
cleartext.deserialize()
}
}
#[cfg(test)]
mod tests {
use super::*;
mod root_key {
use super::*;
#[test]
fn null_is_zeros() {
let key = RootKey::null();
assert_eq!(key.encryption, [0_u8; 32].into());
assert_eq!(key.hmac, [0_u8; 32].into());
assert_eq!(key.entropy, [0_u8; 256].into());
}
#[test]
fn random_is_not_zeros() {
let key = RootKey::random();
assert_ne!(key.encryption, [0_u8; 32].into());
assert_ne!(key.hmac, [0_u8; 32].into());
assert_ne!(key.entropy, [0_u8; 256].into());
}
}
mod nonce {
use super::*;
#[test]
fn non_zero() {
let nonce = Nonce::random();
assert_ne!(nonce.0, [0_u8; 24].into());
}
}
mod encrypted_root_key {
use super::*;
#[test]
fn round_trip() {
let key = RootKey::random();
let password = "password".as_bytes();
let encrypted = key.encrypt(password).expect("Failed to encrypt key");
let decrypted = encrypted.decrypt(password).expect("Failed to decrypt key");
assert_eq!(decrypted.encryption, key.encryption);
assert_eq!(decrypted.hmac, key.hmac);
assert_eq!(decrypted.entropy, key.entropy);
}
#[test]
fn bad_password_failure() {
let key = RootKey::random();
let password = "password".as_bytes();
let wrong_password = "wrong password".as_bytes();
let encrypted = key.encrypt(password).expect("Failed to encrypt key");
let decrypted = encrypted.decrypt(wrong_password);
assert!(decrypted.is_err());
}
#[test]
fn corruption_failure() {
let key = RootKey::random();
let password = "password".as_bytes();
let mut encrypted = key.encrypt(password).expect("Failed to encrypt key");
encrypted.payload[0] = encrypted.payload[0].wrapping_add(1_u8);
let decrypted = encrypted.decrypt(password);
match decrypted {
Ok(_) => panic!("Somehow decrypted corrupted data"),
Err(e) => assert!(matches!(e, BackendError::BadHMAC)),
}
}
}
mod derived_key {
use super::*;
#[test]
fn not_zero() {
let root_key = RootKey::random();
let derived_key = root_key.derive("namespace");
assert_ne!(derived_key.encryption, [0_u8; 32].into());
assert_ne!(derived_key.hmac, [0_u8; 32].into());
assert_ne!(derived_key.encryption, derived_key.hmac);
}
#[test]
fn non_repeatable() {
let root_key = RootKey::random();
let derived_key_1 = root_key.derive("namespace");
let derived_key_2 = root_key.derive("namespace");
assert_ne!(derived_key_1.encryption, derived_key_2.encryption);
assert_ne!(derived_key_1.hmac, derived_key_2.hmac);
assert_ne!(derived_key_1.context_string, derived_key_2.context_string);
}
#[test]
fn repeatable() {
let root_key = RootKey::random();
let derived_key_1 = root_key.derive_with_context("Some context goes here".to_string());
let derived_key_2 = root_key.derive_with_context("Some context goes here".to_string());
assert_eq!(derived_key_1.encryption, derived_key_2.encryption);
assert_eq!(derived_key_1.hmac, derived_key_2.hmac);
assert_eq!(derived_key_1.context_string, derived_key_2.context_string);
}
}
mod enc_derived_key {
use super::*;
#[test]
fn round_trip() {
let root_key = RootKey::random();
let derived_key_orig = root_key.derive("testing");
let enc_derived_key = derived_key_orig
.encrypt(&root_key)
.expect("Failed to encrypt key");
let derived_key_deser = enc_derived_key
.decrypt(&root_key)
.expect("Failed to decrypt key");
assert_eq!(derived_key_deser.encryption, derived_key_orig.encryption);
assert_eq!(derived_key_deser.hmac, derived_key_orig.hmac);
assert_eq!(
derived_key_deser.context_string,
derived_key_orig.context_string
);
}
}
}