use constant_time_eq::constant_time_eq;
use core::marker::PhantomData;
use crate::{alloc::Vec, Cipher, CipherOutput, MacMismatch};
pub trait UnauthenticatedCipher: 'static {
const KEY_LEN: usize;
const NONCE_LEN: usize;
fn seal_or_open(message: &mut [u8], nonce: &[u8], key: &[u8]);
}
pub trait Mac: 'static {
const KEY_LEN: usize;
const MAC_LEN: usize;
fn digest(key: &[u8], message: &[u8]) -> Vec<u8>;
}
#[derive(Debug)]
pub struct CipherWithMac<C, M> {
_cipher: PhantomData<C>,
_mac: PhantomData<M>,
}
impl<C, M> Cipher for CipherWithMac<C, M>
where
C: UnauthenticatedCipher,
M: Mac,
{
const KEY_LEN: usize = C::KEY_LEN + M::KEY_LEN;
const NONCE_LEN: usize = C::NONCE_LEN;
const MAC_LEN: usize = M::MAC_LEN;
fn seal(message: &[u8], nonce: &[u8], key: &[u8]) -> CipherOutput {
let (cipher_key, mac_key) = (&key[..C::KEY_LEN], &key[C::KEY_LEN..]);
let mut ciphertext = message.to_vec();
C::seal_or_open(&mut ciphertext, nonce, cipher_key);
CipherOutput {
mac: M::digest(mac_key, &ciphertext),
ciphertext,
}
}
fn open(
output: &mut [u8],
enc: &CipherOutput,
nonce: &[u8],
key: &[u8],
) -> Result<(), MacMismatch> {
debug_assert_eq!(key.len(), Self::KEY_LEN);
debug_assert_eq!(enc.mac.len(), Self::MAC_LEN);
debug_assert_eq!(output.len(), enc.ciphertext.len());
let (cipher_key, mac_key) = (&key[..C::KEY_LEN], &key[C::KEY_LEN..]);
if !constant_time_eq(&M::digest(mac_key, &enc.ciphertext), &enc.mac) {
return Err(MacMismatch);
}
output.copy_from_slice(&enc.ciphertext);
C::seal_or_open(output, nonce, cipher_key);
Ok(())
}
}