use alloc::vec::Vec;
use chacha20poly1305::aead::{Aead, KeyInit, Payload};
use chacha20poly1305::{ChaCha20Poly1305, Key, Nonce};
use super::{CHACHA20_NONCE_LEN, CHACHA20_TAG_LEN, KEY_LEN};
use crate::error::{Error, Result};
pub(super) fn encrypt(key: &[u8], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
check_key_len(key)?;
let mut nonce_bytes = [0u8; CHACHA20_NONCE_LEN];
mod_rand::tier3::fill_bytes(&mut nonce_bytes)
.map_err(|_| Error::RandomFailure("mod_rand::tier3::fill_bytes"))?;
let cipher = ChaCha20Poly1305::new(Key::from_slice(key));
let nonce = Nonce::from_slice(&nonce_bytes);
let ct_and_tag = cipher
.encrypt(
nonce,
Payload {
msg: plaintext,
aad,
},
)
.map_err(|_| Error::AuthenticationFailed)?;
let mut out = Vec::with_capacity(CHACHA20_NONCE_LEN + ct_and_tag.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ct_and_tag);
Ok(out)
}
pub(super) fn decrypt(key: &[u8], wire: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
check_key_len(key)?;
if wire.len() < CHACHA20_NONCE_LEN + CHACHA20_TAG_LEN {
return Err(Error::InvalidCiphertext(alloc::format!(
"buffer too short ({} bytes, need at least {})",
wire.len(),
CHACHA20_NONCE_LEN + CHACHA20_TAG_LEN
)));
}
let (nonce_bytes, ct_and_tag) = wire.split_at(CHACHA20_NONCE_LEN);
let cipher = ChaCha20Poly1305::new(Key::from_slice(key));
let nonce = Nonce::from_slice(nonce_bytes);
cipher
.decrypt(
nonce,
Payload {
msg: ct_and_tag,
aad,
},
)
.map_err(|_| Error::AuthenticationFailed)
}
#[inline]
fn check_key_len(key: &[u8]) -> Result<()> {
if key.len() == KEY_LEN {
Ok(())
} else {
Err(Error::InvalidKey {
expected: KEY_LEN,
actual: key.len(),
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn rfc8439_section_2_8_2_known_answer() {
let key = hex_to_bytes("808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f");
let nonce = hex_to_bytes("070000004041424344454647");
let aad = hex_to_bytes("50515253c0c1c2c3c4c5c6c7");
let plaintext = b"Ladies and Gentlemen of the class of '99: \
If I could offer you only one tip for the future, sunscreen would be it.";
let expected_ciphertext_and_tag = hex_to_bytes(
"d31a8d34648e60db7b86afbc53ef7ec2a4aded51296e08fea9e2b5a736ee62d6\
3dbea45e8ca9671282fafb69da92728b1a71de0a9e060b2905d6a5b67ecd3b36\
92ddbd7f2d778b8c9803aee328091b58fab324e4fad675945585808b4831d7bc\
3ff4def08e4b7a9de576d26586cec64b61161ae10b594f09e26a7e902ecbd060\
0691",
);
let cipher = ChaCha20Poly1305::new(Key::from_slice(&key));
let n = Nonce::from_slice(&nonce);
let got = cipher
.encrypt(
n,
Payload {
msg: plaintext.as_ref(),
aad: &aad,
},
)
.unwrap();
assert_eq!(got, expected_ciphertext_and_tag);
let recovered = cipher
.decrypt(
n,
Payload {
msg: &got,
aad: &aad,
},
)
.unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn round_trip_via_module_wrapper() {
let key = [0xa1u8; 32];
let pt = b"the wrapper layers nonce-prepend on top of the upstream primitive";
let wire = encrypt(&key, pt, &[]).unwrap();
assert_eq!(wire.len(), CHACHA20_NONCE_LEN + pt.len() + CHACHA20_TAG_LEN);
let recovered = decrypt(&key, &wire, &[]).unwrap();
assert_eq!(recovered, pt);
}
#[test]
fn check_key_len_accepts_exactly_32() {
assert!(check_key_len(&[0u8; 32]).is_ok());
}
#[test]
fn check_key_len_rejects_off_by_one() {
assert!(check_key_len(&[0u8; 31]).is_err());
assert!(check_key_len(&[0u8; 33]).is_err());
}
fn hex_to_bytes(s: &str) -> alloc::vec::Vec<u8> {
hex::decode(s).expect("valid hex")
}
}