use crate::mls::{MlsError, Result};
use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
ChaCha20Poly1305, Nonce,
};
#[derive(Debug, Clone)]
pub struct MlsCipher {
key: Vec<u8>,
base_nonce: Vec<u8>,
}
impl MlsCipher {
#[must_use]
pub fn new(key: Vec<u8>, base_nonce: Vec<u8>) -> Self {
Self { key, base_nonce }
}
pub fn encrypt(&self, plaintext: &[u8], aad: &[u8], counter: u64) -> Result<Vec<u8>> {
let nonce = self.derive_nonce(counter);
let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
.map_err(|e| MlsError::EncryptionError(format!("invalid key length: {}", e)))?;
let nonce_arr = Nonce::from_slice(&nonce[..12]);
let payload = Payload {
msg: plaintext,
aad,
};
cipher
.encrypt(nonce_arr, payload)
.map_err(|e| MlsError::EncryptionError(format!("encryption failed: {}", e)))
}
pub fn decrypt(&self, ciphertext: &[u8], aad: &[u8], counter: u64) -> Result<Vec<u8>> {
let nonce = self.derive_nonce(counter);
let cipher = ChaCha20Poly1305::new_from_slice(&self.key)
.map_err(|e| MlsError::EncryptionError(format!("invalid key length: {}", e)))?;
let nonce_arr = Nonce::from_slice(&nonce[..12]);
let payload = Payload {
msg: ciphertext,
aad,
};
cipher
.decrypt(nonce_arr, payload)
.map_err(|e| MlsError::DecryptionError(format!("decryption failed: {}", e)))
}
fn derive_nonce(&self, counter: u64) -> Vec<u8> {
let counter_bytes = counter.to_le_bytes();
let mut nonce = self.base_nonce.clone();
for (i, byte) in counter_bytes.iter().enumerate() {
if i + 4 < nonce.len() {
nonce[i + 4] ^= byte;
}
}
nonce
}
#[must_use]
pub fn key(&self) -> &[u8] {
&self.key
}
#[must_use]
pub fn base_nonce(&self) -> &[u8] {
&self.base_nonce
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_key() -> Vec<u8> {
vec![0u8; 32] }
fn test_nonce() -> Vec<u8> {
vec![0u8; 12] }
#[test]
fn test_encrypt_decrypt_roundtrip() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"Hello, MLS!";
let aad = b"additional data";
let counter = 1;
let ciphertext = cipher.encrypt(plaintext, aad, counter);
assert!(ciphertext.is_ok());
let ciphertext = ciphertext.unwrap();
assert_eq!(ciphertext.len(), plaintext.len() + 16);
let decrypted = cipher.decrypt(&ciphertext, aad, counter);
assert!(decrypted.is_ok());
let decrypted = decrypted.unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_authentication_tag_verification() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"secret message";
let aad = b"context";
let counter = 5;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
let mut tampered = ciphertext.clone();
tampered[0] ^= 0x01;
let result = cipher.decrypt(&tampered, aad, counter);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::DecryptionError(_)));
}
#[test]
fn test_wrong_aad_fails() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"secret";
let aad = b"original aad";
let wrong_aad = b"wrong aad";
let counter = 10;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
let result = cipher.decrypt(&ciphertext, wrong_aad, counter);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::DecryptionError(_)));
}
#[test]
fn test_wrong_counter_fails() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"data";
let aad = b"aad";
let counter = 42;
let wrong_counter = 43;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
let result = cipher.decrypt(&ciphertext, aad, wrong_counter);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MlsError::DecryptionError(_)));
}
#[test]
fn test_different_counters_produce_different_ciphertexts() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"same message";
let aad = b"same aad";
let ct1 = cipher.encrypt(plaintext, aad, 1).unwrap();
let ct2 = cipher.encrypt(plaintext, aad, 2).unwrap();
let ct3 = cipher.encrypt(plaintext, aad, 100).unwrap();
assert_ne!(ct1, ct2);
assert_ne!(ct2, ct3);
assert_ne!(ct1, ct3);
}
#[test]
fn test_empty_plaintext() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"";
let aad = b"aad";
let counter = 0;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
assert_eq!(ciphertext.len(), 16);
let decrypted = cipher.decrypt(&ciphertext, aad, counter).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_empty_aad() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"message";
let aad = b"";
let counter = 7;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
let decrypted = cipher.decrypt(&ciphertext, aad, counter).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_large_plaintext() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = vec![0x42u8; 10000]; let aad = b"large message aad";
let counter = 1000;
let ciphertext = cipher.encrypt(&plaintext, aad, counter).unwrap();
assert_eq!(ciphertext.len(), plaintext.len() + 16);
let decrypted = cipher.decrypt(&ciphertext, aad, counter).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_counter_zero() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"first message";
let aad = b"aad";
let counter = 0;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
let decrypted = cipher.decrypt(&ciphertext, aad, counter).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_counter_max() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let plaintext = b"last message";
let aad = b"aad";
let counter = u64::MAX;
let ciphertext = cipher.encrypt(plaintext, aad, counter).unwrap();
let decrypted = cipher.decrypt(&ciphertext, aad, counter).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_cipher_accessors() {
let key = test_key();
let nonce = test_nonce();
let cipher = MlsCipher::new(key.clone(), nonce.clone());
assert_eq!(cipher.key(), key.as_slice());
assert_eq!(cipher.base_nonce(), nonce.as_slice());
}
#[test]
fn test_nonce_derivation_deterministic() {
let cipher = MlsCipher::new(test_key(), test_nonce());
let nonce1 = cipher.derive_nonce(42);
let nonce2 = cipher.derive_nonce(42);
assert_eq!(nonce1, nonce2);
}
#[test]
fn test_different_keys_produce_different_ciphertexts() {
let key1 = vec![1u8; 32];
let key2 = vec![2u8; 32];
let nonce = test_nonce();
let cipher1 = MlsCipher::new(key1, nonce.clone());
let cipher2 = MlsCipher::new(key2, nonce);
let plaintext = b"test";
let aad = b"aad";
let counter = 1;
let ct1 = cipher1.encrypt(plaintext, aad, counter).unwrap();
let ct2 = cipher2.encrypt(plaintext, aad, counter).unwrap();
assert_ne!(ct1, ct2);
}
}