use alloc::vec::Vec;
#[cfg_attr(
any(feature = "aead-chacha20", feature = "aead-aes-gcm"),
allow(unused_imports)
)]
use crate::error::{Error, Result};
#[cfg(feature = "aead-aes-gcm")]
mod aes_gcm;
#[cfg(feature = "aead-chacha20")]
mod chacha20;
pub const CHACHA20_NONCE_LEN: usize = 12;
pub const CHACHA20_TAG_LEN: usize = 16;
pub const AES_GCM_NONCE_LEN: usize = 12;
pub const AES_GCM_TAG_LEN: usize = 16;
pub const KEY_LEN: usize = 32;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum Algorithm {
#[default]
ChaCha20Poly1305,
Aes256Gcm,
}
impl Algorithm {
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::ChaCha20Poly1305 => "ChaCha20-Poly1305",
Self::Aes256Gcm => "AES-256-GCM",
}
}
#[must_use]
pub const fn key_len(self) -> usize {
match self {
Self::ChaCha20Poly1305 | Self::Aes256Gcm => KEY_LEN,
}
}
#[must_use]
pub const fn nonce_len(self) -> usize {
match self {
Self::ChaCha20Poly1305 => CHACHA20_NONCE_LEN,
Self::Aes256Gcm => AES_GCM_NONCE_LEN,
}
}
#[must_use]
pub const fn tag_len(self) -> usize {
match self {
Self::ChaCha20Poly1305 => CHACHA20_TAG_LEN,
Self::Aes256Gcm => AES_GCM_TAG_LEN,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Crypt {
algorithm: Algorithm,
}
impl Crypt {
#[must_use]
pub const fn new() -> Self {
Self {
algorithm: Algorithm::ChaCha20Poly1305,
}
}
#[must_use]
pub const fn with_algorithm(algorithm: Algorithm) -> Self {
Self { algorithm }
}
#[cfg(feature = "aead-aes-gcm")]
#[must_use]
pub const fn aes_256_gcm() -> Self {
Self {
algorithm: Algorithm::Aes256Gcm,
}
}
#[must_use]
pub const fn algorithm(&self) -> Algorithm {
self.algorithm
}
pub fn encrypt(&self, key: &[u8], plaintext: &[u8]) -> Result<Vec<u8>> {
self.encrypt_with_aad(key, plaintext, &[])
}
pub fn encrypt_with_aad(&self, key: &[u8], plaintext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
match self.algorithm {
Algorithm::ChaCha20Poly1305 => {
#[cfg(feature = "aead-chacha20")]
{
chacha20::encrypt(key, plaintext, aad)
}
#[cfg(not(feature = "aead-chacha20"))]
{
let _ = (key, plaintext, aad);
Err(Error::AlgorithmNotEnabled("aead-chacha20"))
}
}
Algorithm::Aes256Gcm => {
#[cfg(feature = "aead-aes-gcm")]
{
aes_gcm::encrypt(key, plaintext, aad)
}
#[cfg(not(feature = "aead-aes-gcm"))]
{
let _ = (key, plaintext, aad);
Err(Error::AlgorithmNotEnabled("aead-aes-gcm"))
}
}
}
}
pub fn decrypt(&self, key: &[u8], ciphertext: &[u8]) -> Result<Vec<u8>> {
self.decrypt_with_aad(key, ciphertext, &[])
}
pub fn decrypt_with_aad(&self, key: &[u8], ciphertext: &[u8], aad: &[u8]) -> Result<Vec<u8>> {
match self.algorithm {
Algorithm::ChaCha20Poly1305 => {
#[cfg(feature = "aead-chacha20")]
{
chacha20::decrypt(key, ciphertext, aad)
}
#[cfg(not(feature = "aead-chacha20"))]
{
let _ = (key, ciphertext, aad);
Err(Error::AlgorithmNotEnabled("aead-chacha20"))
}
}
Algorithm::Aes256Gcm => {
#[cfg(feature = "aead-aes-gcm")]
{
aes_gcm::decrypt(key, ciphertext, aad)
}
#[cfg(not(feature = "aead-aes-gcm"))]
{
let _ = (key, ciphertext, aad);
Err(Error::AlgorithmNotEnabled("aead-aes-gcm"))
}
}
}
}
}
impl Default for Crypt {
fn default() -> Self {
Self::new()
}
}
#[cfg(all(test, feature = "aead-chacha20"))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn algorithm_metadata_matches_constants() {
let a = Algorithm::default();
assert_eq!(a, Algorithm::ChaCha20Poly1305);
assert_eq!(a.key_len(), KEY_LEN);
assert_eq!(a.nonce_len(), CHACHA20_NONCE_LEN);
assert_eq!(a.tag_len(), CHACHA20_TAG_LEN);
assert_eq!(a.name(), "ChaCha20-Poly1305");
}
#[test]
fn crypt_defaults_to_chacha20() {
let c = Crypt::new();
assert_eq!(c.algorithm(), Algorithm::ChaCha20Poly1305);
let d = Crypt::default();
assert_eq!(d.algorithm(), Algorithm::ChaCha20Poly1305);
}
#[test]
fn round_trip_empty_plaintext() {
let crypt = Crypt::new();
let key = [0x11u8; 32];
let ciphertext = crypt.encrypt(&key, b"").unwrap();
assert_eq!(ciphertext.len(), CHACHA20_NONCE_LEN + CHACHA20_TAG_LEN);
let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
assert!(recovered.is_empty());
}
#[test]
fn round_trip_short_plaintext() {
let crypt = Crypt::new();
let key = [0x22u8; 32];
let plaintext = b"hello, world!";
let ciphertext = crypt.encrypt(&key, plaintext).unwrap();
let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
assert_eq!(&*recovered, plaintext);
}
#[test]
fn round_trip_one_megabyte() {
let crypt = Crypt::new();
let key = [0x33u8; 32];
let plaintext = vec![0xa5u8; 1024 * 1024];
let ciphertext = crypt.encrypt(&key, &plaintext).unwrap();
let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn two_encryptions_of_same_plaintext_differ() {
let crypt = Crypt::new();
let key = [0u8; 32];
let plaintext = b"deterministic? no.";
let a = crypt.encrypt(&key, plaintext).unwrap();
let b = crypt.encrypt(&key, plaintext).unwrap();
assert_ne!(a, b, "nonce-prepended outputs must differ across calls");
}
#[test]
fn wrong_key_fails_authentication() {
let crypt = Crypt::new();
let key = [0x44u8; 32];
let wrong = [0x55u8; 32];
let ciphertext = crypt.encrypt(&key, b"secret").unwrap();
let err = crypt.decrypt(&wrong, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn tampered_ciphertext_fails_authentication() {
let crypt = Crypt::new();
let key = [0x66u8; 32];
let mut ciphertext = crypt.encrypt(&key, b"hands off").unwrap();
let i = ciphertext.len() / 2;
ciphertext[i] ^= 0x01;
let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn tampered_tag_fails_authentication() {
let crypt = Crypt::new();
let key = [0x77u8; 32];
let mut ciphertext = crypt.encrypt(&key, b"sign me").unwrap();
let last = ciphertext.len() - 1;
ciphertext[last] ^= 0xff;
let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn truncated_ciphertext_is_rejected() {
let crypt = Crypt::new();
let key = [0u8; 32];
for len in 0..(CHACHA20_NONCE_LEN + CHACHA20_TAG_LEN) {
let err = crypt.decrypt(&key, &vec![0u8; len]).unwrap_err();
assert!(
matches!(err, Error::InvalidCiphertext(_)),
"len={len} should error"
);
}
}
#[test]
fn aad_round_trip() {
let crypt = Crypt::new();
let key = [0x88u8; 32];
let plaintext = b"plaintext";
let aad = b"associated";
let ciphertext = crypt.encrypt_with_aad(&key, plaintext, aad).unwrap();
let recovered = crypt.decrypt_with_aad(&key, &ciphertext, aad).unwrap();
assert_eq!(&*recovered, plaintext);
}
#[test]
fn aad_mismatch_fails_authentication() {
let crypt = Crypt::new();
let key = [0x99u8; 32];
let ciphertext = crypt
.encrypt_with_aad(&key, b"body", b"original-aad")
.unwrap();
let err = crypt
.decrypt_with_aad(&key, &ciphertext, b"tampered-aad")
.unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn encrypt_with_aad_then_decrypt_without_aad_fails() {
let crypt = Crypt::new();
let key = [0xaau8; 32];
let ciphertext = crypt.encrypt_with_aad(&key, b"body", b"required").unwrap();
let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn invalid_key_length_rejected_on_encrypt() {
let crypt = Crypt::new();
let err = crypt.encrypt(&[0u8; 16], b"x").unwrap_err();
assert_eq!(
err,
Error::InvalidKey {
expected: 32,
actual: 16
}
);
}
#[test]
fn invalid_key_length_rejected_on_decrypt() {
let crypt = Crypt::new();
let ciphertext = crypt.encrypt(&[0u8; 32], b"x").unwrap();
let err = crypt.decrypt(&[0u8; 16], &ciphertext).unwrap_err();
assert_eq!(
err,
Error::InvalidKey {
expected: 32,
actual: 16
}
);
}
}
#[cfg(all(test, feature = "aead-aes-gcm"))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod aes_gcm_tests {
use super::*;
use alloc::vec;
fn aes() -> Crypt {
Crypt::aes_256_gcm()
}
#[test]
fn algorithm_metadata_matches_constants() {
let a = Algorithm::Aes256Gcm;
assert_eq!(a.key_len(), KEY_LEN);
assert_eq!(a.nonce_len(), AES_GCM_NONCE_LEN);
assert_eq!(a.tag_len(), AES_GCM_TAG_LEN);
assert_eq!(a.name(), "AES-256-GCM");
}
#[test]
fn aes_256_gcm_constructor_selects_algorithm() {
let c = aes();
assert_eq!(c.algorithm(), Algorithm::Aes256Gcm);
let alt = Crypt::with_algorithm(Algorithm::Aes256Gcm);
assert_eq!(c, alt);
}
#[test]
fn round_trip_empty_plaintext() {
let crypt = aes();
let key = [0x11u8; 32];
let ciphertext = crypt.encrypt(&key, b"").unwrap();
assert_eq!(ciphertext.len(), AES_GCM_NONCE_LEN + AES_GCM_TAG_LEN);
let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
assert!(recovered.is_empty());
}
#[test]
fn round_trip_short_plaintext() {
let crypt = aes();
let key = [0x22u8; 32];
let plaintext = b"hello, world!";
let ciphertext = crypt.encrypt(&key, plaintext).unwrap();
let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
assert_eq!(&*recovered, plaintext);
}
#[test]
fn round_trip_one_megabyte() {
let crypt = aes();
let key = [0x33u8; 32];
let plaintext = vec![0xa5u8; 1024 * 1024];
let ciphertext = crypt.encrypt(&key, &plaintext).unwrap();
let recovered = crypt.decrypt(&key, &ciphertext).unwrap();
assert_eq!(recovered, plaintext);
}
#[test]
fn two_encryptions_of_same_plaintext_differ() {
let crypt = aes();
let key = [0u8; 32];
let plaintext = b"deterministic? no.";
let a = crypt.encrypt(&key, plaintext).unwrap();
let b = crypt.encrypt(&key, plaintext).unwrap();
assert_ne!(a, b, "nonce-prepended outputs must differ across calls");
}
#[test]
fn wrong_key_fails_authentication() {
let crypt = aes();
let key = [0x44u8; 32];
let wrong = [0x55u8; 32];
let ciphertext = crypt.encrypt(&key, b"secret").unwrap();
let err = crypt.decrypt(&wrong, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn tampered_ciphertext_fails_authentication() {
let crypt = aes();
let key = [0x66u8; 32];
let mut ciphertext = crypt.encrypt(&key, b"hands off").unwrap();
let i = ciphertext.len() / 2;
ciphertext[i] ^= 0x01;
let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn tampered_tag_fails_authentication() {
let crypt = aes();
let key = [0x77u8; 32];
let mut ciphertext = crypt.encrypt(&key, b"sign me").unwrap();
let last = ciphertext.len() - 1;
ciphertext[last] ^= 0xff;
let err = crypt.decrypt(&key, &ciphertext).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn truncated_ciphertext_is_rejected() {
let crypt = aes();
let key = [0u8; 32];
for len in 0..(AES_GCM_NONCE_LEN + AES_GCM_TAG_LEN) {
let err = crypt.decrypt(&key, &vec![0u8; len]).unwrap_err();
assert!(
matches!(err, Error::InvalidCiphertext(_)),
"len={len} should error"
);
}
}
#[test]
fn aad_round_trip() {
let crypt = aes();
let key = [0x88u8; 32];
let plaintext = b"plaintext";
let aad = b"associated";
let ciphertext = crypt.encrypt_with_aad(&key, plaintext, aad).unwrap();
let recovered = crypt.decrypt_with_aad(&key, &ciphertext, aad).unwrap();
assert_eq!(&*recovered, plaintext);
}
#[test]
fn aad_mismatch_fails_authentication() {
let crypt = aes();
let key = [0x99u8; 32];
let ciphertext = crypt
.encrypt_with_aad(&key, b"body", b"original-aad")
.unwrap();
let err = crypt
.decrypt_with_aad(&key, &ciphertext, b"tampered-aad")
.unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn invalid_key_length_rejected_on_encrypt() {
let crypt = aes();
let err = crypt.encrypt(&[0u8; 16], b"x").unwrap_err();
assert_eq!(
err,
Error::InvalidKey {
expected: 32,
actual: 16
}
);
}
}
#[cfg(all(test, feature = "aead-chacha20", feature = "aead-aes-gcm"))]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod cross_algorithm_tests {
use super::*;
#[test]
fn chacha_ciphertext_does_not_decrypt_as_aes() {
let key = [0xcdu8; 32];
let ct = Crypt::new().encrypt(&key, b"message").unwrap();
let err = Crypt::aes_256_gcm().decrypt(&key, &ct).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn aes_ciphertext_does_not_decrypt_as_chacha() {
let key = [0xefu8; 32];
let ct = Crypt::aes_256_gcm().encrypt(&key, b"message").unwrap();
let err = Crypt::new().decrypt(&key, &ct).unwrap_err();
assert_eq!(err, Error::AuthenticationFailed);
}
#[test]
fn algorithm_name_table_is_unique() {
let names = [
Algorithm::ChaCha20Poly1305.name(),
Algorithm::Aes256Gcm.name(),
];
for (i, a) in names.iter().enumerate() {
for (j, b) in names.iter().enumerate() {
if i != j {
assert_ne!(a, b, "algorithm names must be distinct");
}
}
}
}
}