use aes_gcm::aead::Aead;
use aes_gcm::{Aes256Gcm, KeyInit};
use chacha20poly1305::{ChaCha20Poly1305, Nonce};
use crate::config;
use crate::error::{CryptoError, SrxError};
pub enum AeadCipher {
ChaCha(ChaCha20Poly1305),
Aes(Box<Aes256Gcm>),
}
impl AeadCipher {
pub fn new(variant: config::AeadCipher, key: &[u8]) -> crate::error::Result<Self> {
if key.len() != 32 {
return Err(SrxError::Crypto(CryptoError::EncryptionFailed(format!(
"key must be 32 bytes, got {}",
key.len()
))));
}
match variant {
config::AeadCipher::ChaCha20Poly1305 => {
let cipher = ChaCha20Poly1305::new_from_slice(key)
.map_err(|e| SrxError::Crypto(CryptoError::EncryptionFailed(e.to_string())))?;
Ok(AeadCipher::ChaCha(cipher))
}
config::AeadCipher::Aes256Gcm => {
let cipher = Aes256Gcm::new_from_slice(key)
.map_err(|e| SrxError::Crypto(CryptoError::EncryptionFailed(e.to_string())))?;
Ok(AeadCipher::Aes(Box::new(cipher)))
}
}
}
pub fn encrypt(&self, nonce: &[u8; 12], plaintext: &[u8]) -> crate::error::Result<Vec<u8>> {
let n: Nonce = (*nonce).into();
match self {
AeadCipher::ChaCha(cipher) => cipher
.encrypt(&n, plaintext)
.map_err(|e| SrxError::Crypto(CryptoError::EncryptionFailed(e.to_string()))),
AeadCipher::Aes(cipher) => cipher
.encrypt(&n, plaintext)
.map_err(|e| SrxError::Crypto(CryptoError::EncryptionFailed(e.to_string()))),
}
}
pub fn decrypt(&self, nonce: &[u8; 12], ciphertext: &[u8]) -> crate::error::Result<Vec<u8>> {
let n: Nonce = (*nonce).into();
match self {
AeadCipher::ChaCha(cipher) => cipher
.decrypt(&n, ciphertext)
.map_err(|_| SrxError::Crypto(CryptoError::MacVerificationFailed)),
AeadCipher::Aes(cipher) => cipher
.decrypt(&n, ciphertext)
.map_err(|_| SrxError::Crypto(CryptoError::MacVerificationFailed)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::AeadCipher as AeadVariant;
fn roundtrip(variant: AeadVariant) {
let key = [0x42u8; 32];
let nonce = [0x01u8; 12];
let plaintext = b"Hello, SRX protocol!";
let cipher = AeadCipher::new(variant, &key).unwrap();
let ciphertext = cipher.encrypt(&nonce, plaintext.as_slice()).unwrap();
assert_ne!(ciphertext.as_slice(), plaintext.as_slice());
let decrypted = cipher.decrypt(&nonce, &ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_chacha20_roundtrip() {
roundtrip(AeadVariant::ChaCha20Poly1305);
}
#[test]
fn test_aes256gcm_roundtrip() {
roundtrip(AeadVariant::Aes256Gcm);
}
#[test]
fn test_wrong_key_fails_decrypt() {
let key1 = [0x42u8; 32];
let key2 = [0x43u8; 32];
let nonce = [0x01u8; 12];
let plaintext = b"secret data";
let c1 = AeadCipher::new(AeadVariant::ChaCha20Poly1305, &key1).unwrap();
let c2 = AeadCipher::new(AeadVariant::ChaCha20Poly1305, &key2).unwrap();
let ct = c1.encrypt(&nonce, plaintext.as_slice()).unwrap();
assert!(c2.decrypt(&nonce, &ct).is_err());
}
#[test]
fn test_wrong_nonce_fails_decrypt() {
let key = [0x42u8; 32];
let nonce1 = [0x01u8; 12];
let nonce2 = [0x02u8; 12];
let plaintext = b"secret data";
let cipher = AeadCipher::new(AeadVariant::ChaCha20Poly1305, &key).unwrap();
let ct = cipher.encrypt(&nonce1, plaintext.as_slice()).unwrap();
assert!(cipher.decrypt(&nonce2, &ct).is_err());
}
#[test]
fn test_invalid_key_length() {
let short_key = [0u8; 16];
assert!(AeadCipher::new(AeadVariant::ChaCha20Poly1305, &short_key).is_err());
}
#[test]
fn test_tampered_ciphertext() {
let key = [0x42u8; 32];
let nonce = [0x01u8; 12];
let cipher = AeadCipher::new(AeadVariant::Aes256Gcm, &key).unwrap();
let ct = cipher.encrypt(&nonce, b"data").unwrap();
let mut tampered = ct.clone();
tampered[0] ^= 0xFF;
assert!(cipher.decrypt(&nonce, &tampered).is_err());
}
#[test]
fn test_empty_plaintext() {
let key = [0x42u8; 32];
let nonce = [0x01u8; 12];
let cipher = AeadCipher::new(AeadVariant::ChaCha20Poly1305, &key).unwrap();
let ct = cipher.encrypt(&nonce, b"").unwrap();
let pt = cipher.decrypt(&nonce, &ct).unwrap();
assert!(pt.is_empty());
}
}