use crate::mem::constant_time_eq;
use crate::mac::Poly1305;
mod chacha20;
use self::chacha20::Chacha20;
#[derive(Clone)]
pub struct Chacha20Poly1305 {
c1: Chacha20,
c2: Chacha20,
}
impl core::fmt::Debug for Chacha20Poly1305 {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
f.debug_struct("Chacha20Poly1305").finish()
}
}
impl Chacha20Poly1305 {
pub const KEY_LEN: usize = Chacha20::KEY_LEN * 2; pub const BLOCK_LEN: usize = Chacha20::BLOCK_LEN; pub const NONCE_LEN: usize = Chacha20::NONCE_LEN; pub const TAG_LEN: usize = Poly1305::TAG_LEN;
pub const P_MAX: usize = 274877906880; pub const C_MAX: usize = Self::P_MAX + Self::TAG_LEN; pub const N_MIN: usize = Self::NONCE_LEN;
pub const N_MAX: usize = Self::NONCE_LEN;
pub const PKT_OCTETS_LEN: usize = 4;
pub fn new(key: &[u8]) -> Self {
assert_eq!(Self::KEY_LEN / 2, Poly1305::KEY_LEN);
assert_eq!(key.len(), Self::KEY_LEN);
let k2 = &key[..Chacha20::KEY_LEN];
let k1 = &key[Chacha20::KEY_LEN..Self::KEY_LEN];
let c2 = Chacha20::new(k2);
let c1 = Chacha20::new(k1);
Self { c1, c2 }
}
pub fn encrypt_slice(&mut self, pkt_seq_num: u32, aead_pkt: &mut [u8]) {
debug_assert!(aead_pkt.len() >= Self::TAG_LEN + Self::PKT_OCTETS_LEN);
let (pkt_len, plaintext_and_tag) = aead_pkt.split_at_mut(Self::PKT_OCTETS_LEN);
let plen = plaintext_and_tag.len() - Self::TAG_LEN;
let (plaintext_in_ciphertext_out, tag_out) = plaintext_and_tag.split_at_mut(plen);
self.encrypt_slice_detached(pkt_seq_num, pkt_len, plaintext_in_ciphertext_out, tag_out)
}
pub fn decrypt_slice(&mut self, pkt_seq_num: u32, aead_pkt: &mut [u8]) -> bool {
debug_assert!(aead_pkt.len() >= Self::TAG_LEN + Self::PKT_OCTETS_LEN);
let (pkt_len, ciphertext_and_tag) = aead_pkt.split_at_mut(Self::PKT_OCTETS_LEN);
let clen = ciphertext_and_tag.len() - Self::TAG_LEN;
let (ciphertext_in_plaintext_out, tag_in) = ciphertext_and_tag.split_at_mut(clen);
self.decrypt_slice_detached(pkt_seq_num, pkt_len, ciphertext_in_plaintext_out, &tag_in)
}
pub fn encrypt_slice_detached(&mut self, pkt_seq_num: u32, pkt_len: &mut [u8], plaintext_in_ciphertext_out: &mut [u8], tag_out: &mut [u8]) {
let plen = plaintext_in_ciphertext_out.len();
let tlen = tag_out.len();
debug_assert_eq!(pkt_len.len(), Self::PKT_OCTETS_LEN);
debug_assert!(plen <= Self::P_MAX);
debug_assert!(tlen == Self::TAG_LEN);
let mut poly1305_key = [0u8; Poly1305::KEY_LEN];
self.c2.encrypt_slice(pkt_seq_num, 0, &mut poly1305_key);
self.c1.encrypt_slice(pkt_seq_num, 0, pkt_len);
self.c2.encrypt_slice(pkt_seq_num, 1, plaintext_in_ciphertext_out);
let mut poly1305 = Poly1305::new(&poly1305_key);
poly1305.update(&pkt_len);
poly1305.update(&plaintext_in_ciphertext_out);
let tag = poly1305.finalize();
tag_out.copy_from_slice(&tag[..Self::TAG_LEN]);
}
pub fn decrypt_slice_detached(&mut self, pkt_seq_num: u32, pkt_len: &mut [u8], ciphertext_in_plaintext_out: &mut [u8], tag_in: &[u8]) -> bool {
let clen = ciphertext_in_plaintext_out.len();
let tlen = tag_in.len();
debug_assert_eq!(pkt_len.len(), Self::PKT_OCTETS_LEN);
debug_assert!(clen <= Self::P_MAX);
debug_assert!(tlen == Self::TAG_LEN);
let mut poly1305_key = [0u8; Poly1305::KEY_LEN];
self.c2.encrypt_slice(pkt_seq_num, 0, &mut poly1305_key);
let mut poly1305 = Poly1305::new(&poly1305_key);
poly1305.update(&pkt_len);
poly1305.update(&ciphertext_in_plaintext_out);
let tag = poly1305.finalize();
let is_match = constant_time_eq(tag_in, &tag[..Self::TAG_LEN]);
if is_match {
self.c1.decrypt_slice(pkt_seq_num, 0, pkt_len);
self.c2.decrypt_slice(pkt_seq_num, 1, ciphertext_in_plaintext_out);
}
is_match
}
}