use super::arrayref::{DecryptError, EncryptError, KeyGenError};
use libcrux_secrets::{Classify, U8};
pub trait Aead<const KEY_LEN: usize, const TAG_LEN: usize, const NONCE_LEN: usize> {
fn keygen(rand: &[U8; KEY_LEN]) -> Result<[U8; KEY_LEN], KeyGenError>;
fn encrypt<const MSG_LEN: usize>(
key: &[U8; KEY_LEN],
nonce: &[U8; NONCE_LEN],
aad: &[u8],
plaintext: &[U8; MSG_LEN],
) -> Result<([u8; MSG_LEN], [U8; TAG_LEN]), EncryptError>;
#[cfg(feature = "alloc")]
fn encrypt_to_vec(
key: &[u8; KEY_LEN],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
plaintext: &[u8],
) -> Result<(alloc::vec::Vec<u8>, [u8; TAG_LEN]), EncryptError>;
fn decrypt<const MSG_LEN: usize>(
key: &[U8; KEY_LEN],
nonce: &[U8; NONCE_LEN],
aad: &[u8],
ciphertext: &[u8; MSG_LEN],
tag: &[U8; TAG_LEN],
) -> Result<[U8; MSG_LEN], DecryptError>;
#[cfg(feature = "alloc")]
fn decrypt_to_vec(
key: &[U8; KEY_LEN],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8; TAG_LEN],
) -> Result<alloc::vec::Vec<U8>, DecryptError>;
}
impl<
const KEY_LEN: usize,
const TAG_LEN: usize,
const NONCE_LEN: usize,
T: super::arrayref::Aead<KEY_LEN, TAG_LEN, NONCE_LEN>,
> Aead<KEY_LEN, TAG_LEN, NONCE_LEN> for T
{
fn keygen(rand: &[U8; KEY_LEN]) -> Result<[U8; KEY_LEN], KeyGenError> {
let mut key = [0u8.classify(); KEY_LEN];
Self::keygen(&mut key, rand)?;
Ok(key)
}
fn encrypt<const MSG_LEN: usize>(
key: &[U8; KEY_LEN],
nonce: &[U8; NONCE_LEN],
aad: &[u8],
plaintext: &[U8; MSG_LEN],
) -> Result<([u8; MSG_LEN], [U8; TAG_LEN]), EncryptError> {
let mut ciphertext = [0u8; MSG_LEN];
let mut tag = [0u8.classify(); TAG_LEN];
Self::encrypt(&mut ciphertext, &mut tag, key, nonce, aad, plaintext)
.map(|()| (ciphertext, tag))
}
#[cfg(feature = "alloc")]
fn encrypt_to_vec(
key: &[U8; KEY_LEN],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
plaintext: &[U8],
) -> Result<(alloc::vec::Vec<u8>, [u8; TAG_LEN]), EncryptError> {
let mut ciphertext = alloc::vec![0u8; plaintext.len()];
let mut tag = [0u8; TAG_LEN];
Self::encrypt(&mut ciphertext, &mut tag, key, nonce, aad, plaintext)
.map(|()| (ciphertext, tag))
}
fn decrypt<const MSG_LEN: usize>(
key: &[U8; KEY_LEN],
nonce: &[U8; NONCE_LEN],
aad: &[u8],
ciphertext: &[u8; MSG_LEN],
tag: &[U8; TAG_LEN],
) -> Result<[U8; MSG_LEN], DecryptError> {
let mut plaintext = [0u8.classify(); MSG_LEN];
Self::decrypt(&mut plaintext, key, nonce, aad, ciphertext, tag).map(|()| plaintext)
}
#[cfg(feature = "alloc")]
fn decrypt_to_vec(
key: &[U8; KEY_LEN],
nonce: &[u8; NONCE_LEN],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8; TAG_LEN],
) -> Result<alloc::vec::Vec<U8>, DecryptError> {
let mut plaintext = alloc::vec![0u8; ciphertext.len()];
Self::decrypt(&mut plaintext, key, nonce, aad, ciphertext, tag).map(|()| plaintext)
}
}