use super::TagMismatch;
use super::chacha20::ChaCha20;
use super::poly1305::Poly1305;
use crate::ct::ConstantTimeEq;
#[derive(Clone)]
pub struct ChaCha20Poly1305 {
cipher: ChaCha20,
}
fn pad16(mac: &mut Poly1305, len: usize) {
let rem = len % 16;
if rem != 0 {
mac.update(&[0u8; 16][..16 - rem]);
}
}
impl ChaCha20Poly1305 {
pub fn new(key: &[u8; 32]) -> Self {
ChaCha20Poly1305 {
cipher: ChaCha20::new(key),
}
}
fn poly_key(&self, nonce: &[u8; 12]) -> [u8; 32] {
let mut block0 = self.cipher.block(nonce, 0);
let mut otk = [0u8; 32];
otk.copy_from_slice(&block0[..32]);
block0 = [0u8; 64];
let _ = core::hint::black_box(&block0);
otk
}
fn tag(&self, otk: &[u8; 32], aad: &[u8], ct: &[u8]) -> [u8; 16] {
let mut mac = Poly1305::new(otk);
mac.update(aad);
pad16(&mut mac, aad.len());
mac.update(ct);
pad16(&mut mac, ct.len());
let mut lens = [0u8; 16];
lens[0..8].copy_from_slice(&(aad.len() as u64).to_le_bytes());
lens[8..16].copy_from_slice(&(ct.len() as u64).to_le_bytes());
mac.update(&lens);
mac.finish()
}
pub const MAX_PLAINTEXT_LEN: u64 = (u32::MAX as u64) * 64;
pub fn encrypt(&self, nonce: &[u8; 12], aad: &[u8], buffer: &mut [u8]) -> [u8; 16] {
assert!(
(buffer.len() as u64) <= Self::MAX_PLAINTEXT_LEN,
"ChaCha20-Poly1305 plaintext exceeds 2^32 − 1 blocks (RFC 8439 §2.8)"
);
let otk = self.poly_key(nonce);
self.cipher.apply_keystream(nonce, 1, buffer);
self.tag(&otk, aad, buffer)
}
pub fn decrypt(
&self,
nonce: &[u8; 12],
aad: &[u8],
buffer: &mut [u8],
tag: &[u8; 16],
) -> Result<(), TagMismatch> {
assert!(
(buffer.len() as u64) <= Self::MAX_PLAINTEXT_LEN,
"ChaCha20-Poly1305 ciphertext exceeds 2^32 − 1 blocks (RFC 8439 §2.8)"
);
let otk = self.poly_key(nonce);
let expected = self.tag(&otk, aad, buffer);
if !bool::from(expected.ct_eq(tag)) {
return Err(TagMismatch);
}
self.cipher.apply_keystream(nonce, 1, buffer);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::from_hex;
fn vector() -> ([u8; 32], [u8; 12], [u8; 12], [u8; 114]) {
let key =
from_hex::<32>("808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f");
let nonce = from_hex::<12>("070000004041424344454647");
let aad = from_hex::<12>("50515253c0c1c2c3c4c5c6c7");
let mut plaintext = [0u8; 114];
plaintext.copy_from_slice(
b"Ladies and Gentlemen of the class of '99: If I could offer you \
only one tip for the future, sunscreen would be it.",
);
(key, nonce, aad, plaintext)
}
#[test]
fn rfc8439_seal() {
let (key, nonce, aad, plaintext) = vector();
let mut buf = plaintext;
let tag = ChaCha20Poly1305::new(&key).encrypt(&nonce, &aad, &mut buf);
let expected_ct = from_hex::<114>(
"d31a8d34648e60db7b86afbc53ef7ec2a4aded51296e08fea9e2b5a736ee62d6\
3dbea45e8ca9671282fafb69da92728b1a71de0a9e060b2905d6a5b67ecd3b36\
92ddbd7f2d778b8c9803aee328091b58fab324e4fad675945585808b4831d7bc\
3ff4def08e4b7a9de576d26586cec64b6116",
);
assert_eq!(buf, expected_ct);
assert_eq!(tag, from_hex::<16>("1ae10b594f09e26a7e902ecbd0600691"));
}
#[test]
fn roundtrip_and_reject() {
let (key, nonce, aad, plaintext) = vector();
let aead = ChaCha20Poly1305::new(&key);
let mut buf = plaintext;
let tag = aead.encrypt(&nonce, &aad, &mut buf);
let ciphertext = buf;
aead.decrypt(&nonce, &aad, &mut buf, &tag).unwrap();
assert_eq!(buf, plaintext);
let mut buf = ciphertext;
let mut bad = tag;
bad[0] ^= 1;
assert_eq!(aead.decrypt(&nonce, &aad, &mut buf, &bad), Err(TagMismatch));
assert_eq!(buf, ciphertext);
let mut buf = ciphertext;
let mut bad_aad = aad;
bad_aad[0] ^= 1;
assert_eq!(
aead.decrypt(&nonce, &bad_aad, &mut buf, &tag),
Err(TagMismatch)
);
}
}