#[allow(deprecated)]
use aes_gcm_siv::aead::generic_array::GenericArray;
use aes_gcm_siv::aead::AeadMutInPlace;
use aes_gcm_siv::{Aes256GcmSiv, KeyInit};
use elements::bitcoin::hashes::{sha256t_hash_newtype, Hash};
use rand::{thread_rng, Rng};
const NONCE_LEN: usize = 12;
sha256t_hash_newtype! {
struct DeterministicNonceTag = hash_str("LWK-Deterministic-Nonce/1.0");
#[hash_newtype(forward)]
struct DeterministicNonceHash(_);
}
#[derive(Debug)]
pub enum EncryptError {
MissingNonce,
Aead(String),
}
impl std::fmt::Display for EncryptError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EncryptError::MissingNonce => write!(f, "Encrypted data too short - missing nonce"),
EncryptError::Aead(err) => write!(f, "Aead error: {err}"),
}
}
}
impl std::error::Error for EncryptError {}
#[allow(deprecated)]
pub fn cipher_from_key_bytes(key_bytes: [u8; 32]) -> Aes256GcmSiv {
let key = GenericArray::from_slice(&key_bytes);
Aes256GcmSiv::new(key)
}
#[allow(deprecated)]
fn encrypt_with_nonce(
cipher: &mut Aes256GcmSiv,
plaintext: &[u8],
nonce_bytes: [u8; NONCE_LEN],
) -> Result<Vec<u8>, EncryptError> {
let nonce = GenericArray::from_slice(&nonce_bytes);
let mut buffer = plaintext.to_vec();
cipher
.encrypt_in_place(nonce, b"", &mut buffer)
.map_err(|err| EncryptError::Aead(err.to_string()))?;
let mut result = Vec::with_capacity(nonce_bytes.len() + buffer.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&buffer);
Ok(result)
}
#[allow(deprecated)]
pub fn encrypt_with_random_nonce(
cipher: &mut Aes256GcmSiv,
plaintext: &[u8],
) -> Result<Vec<u8>, EncryptError> {
let mut nonce_bytes = [0u8; NONCE_LEN];
thread_rng().fill(&mut nonce_bytes);
encrypt_with_nonce(cipher, plaintext, nonce_bytes)
}
#[allow(deprecated)]
pub fn decrypt_with_nonce_prefix(
cipher: &mut Aes256GcmSiv,
ciphertext: &[u8],
) -> Result<Vec<u8>, EncryptError> {
if ciphertext.len() < NONCE_LEN {
return Err(EncryptError::MissingNonce);
}
let nonce_bytes: [u8; NONCE_LEN] = ciphertext[..NONCE_LEN]
.try_into()
.expect("nonce slice length validated");
let nonce = GenericArray::from_slice(&nonce_bytes);
let mut buffer = ciphertext[NONCE_LEN..].to_vec();
cipher
.decrypt_in_place(nonce, b"", &mut buffer)
.map_err(|err| EncryptError::Aead(err.to_string()))?;
Ok(buffer)
}
#[allow(deprecated)]
pub fn encrypt_with_deterministic_nonce(
cipher: &mut Aes256GcmSiv,
plaintext: &[u8],
) -> Result<Vec<u8>, EncryptError> {
let hash = DeterministicNonceHash::hash(plaintext);
let nonce_bytes: [u8; NONCE_LEN] = hash.as_byte_array()[..NONCE_LEN]
.try_into()
.expect("nonce slice length validated");
encrypt_with_nonce(cipher, plaintext, nonce_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_cipher() -> Aes256GcmSiv {
cipher_from_key_bytes([7u8; 32])
}
#[test]
fn random_nonce_roundtrip() {
let mut cipher = test_cipher();
let plaintext = b"example plaintext";
let encrypted = encrypt_with_random_nonce(&mut cipher, plaintext).unwrap();
assert!(encrypted.len() > NONCE_LEN);
let mut cipher = test_cipher();
let decrypted = decrypt_with_nonce_prefix(&mut cipher, &encrypted).unwrap();
assert_eq!(plaintext.to_vec(), decrypted);
}
#[test]
fn deterministic_nonce_is_stable() {
let plaintext = b"deterministic payload";
let mut cipher = test_cipher();
let encrypted1 = encrypt_with_deterministic_nonce(&mut cipher, plaintext).unwrap();
assert!(encrypted1.len() > NONCE_LEN);
assert_eq!(
&encrypted1[..NONCE_LEN],
&[109, 114, 166, 63, 192, 58, 90, 214, 13, 78, 153, 17]
);
let mut cipher = test_cipher();
let decrypted1 = decrypt_with_nonce_prefix(&mut cipher, &encrypted1).unwrap();
assert_eq!(&plaintext[..], &decrypted1[..]);
let mut cipher = test_cipher();
let encrypted2 = encrypt_with_deterministic_nonce(&mut cipher, plaintext).unwrap();
assert_eq!(encrypted1, encrypted2);
let mut cipher = test_cipher();
let decrypted2 = decrypt_with_nonce_prefix(&mut cipher, &encrypted2).unwrap();
assert_eq!(&plaintext[..], &decrypted2[..]);
}
#[test]
fn deterministic_nonce_hash_empty_input_regression() {
let got = DeterministicNonceHash::hash(b"").to_string();
let exp = "953d329ceafdac1fa531eefd38f397af65097c560d844c3451bcf376cb511ff7";
assert_eq!(got, exp);
}
}