#[cfg(any(feature = "pure-rust", target_arch = "wasm32", target_arch = "wasm64"))]
use superboring as boring;
use boring::symm::{Cipher, Crypter, Mode};
use rand::RngCore;
use zeroize::Zeroize;
use crate::error::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ContentEncryption {
#[default]
A256GCM,
A128GCM,
}
impl ContentEncryption {
pub fn alg_name(&self) -> &'static str {
match self {
ContentEncryption::A256GCM => "A256GCM",
ContentEncryption::A128GCM => "A128GCM",
}
}
pub fn from_alg_name(name: &str) -> Result<Self, Error> {
match name {
"A256GCM" => Ok(ContentEncryption::A256GCM),
"A128GCM" => Ok(ContentEncryption::A128GCM),
_ => bail!(JWTError::UnsupportedContentEncryption(name.to_string())),
}
}
pub fn key_size(&self) -> usize {
match self {
ContentEncryption::A256GCM => 32,
ContentEncryption::A128GCM => 16,
}
}
pub fn iv_size(&self) -> usize {
12 }
pub fn tag_size(&self) -> usize {
16 }
pub fn generate_cek(&self) -> Vec<u8> {
let mut cek = vec![0u8; self.key_size()];
rand::thread_rng().fill_bytes(&mut cek);
cek
}
pub fn generate_iv(&self) -> Vec<u8> {
let mut iv = vec![0u8; self.iv_size()];
rand::thread_rng().fill_bytes(&mut iv);
iv
}
fn cipher(&self) -> Cipher {
match self {
ContentEncryption::A256GCM => Cipher::aes_256_gcm(),
ContentEncryption::A128GCM => Cipher::aes_128_gcm(),
}
}
pub fn encrypt(
&self,
cek: &[u8],
iv: &[u8],
aad: &[u8],
plaintext: &[u8],
) -> Result<(Vec<u8>, Vec<u8>), Error> {
ensure!(cek.len() == self.key_size(), JWTError::InvalidEncryptionKey);
ensure!(iv.len() == self.iv_size(), JWTError::InvalidIV);
let cipher = self.cipher();
let mut crypter = Crypter::new(cipher, Mode::Encrypt, cek, Some(iv))?;
crypter.aad_update(aad)?;
let mut ciphertext = vec![0u8; plaintext.len() + cipher.block_size()];
let mut count = crypter.update(plaintext, &mut ciphertext)?;
count += crypter.finalize(&mut ciphertext[count..])?;
ciphertext.truncate(count);
let mut tag = vec![0u8; self.tag_size()];
crypter.get_tag(&mut tag)?;
Ok((ciphertext, tag))
}
pub fn decrypt(
&self,
cek: &[u8],
iv: &[u8],
aad: &[u8],
ciphertext: &[u8],
tag: &[u8],
) -> Result<Vec<u8>, Error> {
ensure!(cek.len() == self.key_size(), JWTError::InvalidEncryptionKey);
ensure!(iv.len() == self.iv_size(), JWTError::InvalidIV);
ensure!(tag.len() == self.tag_size(), JWTError::InvalidJWEAuthTag);
let cipher = self.cipher();
let mut crypter = Crypter::new(cipher, Mode::Decrypt, cek, Some(iv))?;
crypter.aad_update(aad)?;
crypter.set_tag(tag)?;
let mut plaintext = vec![0u8; ciphertext.len() + cipher.block_size()];
let mut count = crypter.update(ciphertext, &mut plaintext)?;
count += crypter
.finalize(&mut plaintext[count..])
.map_err(|_| JWTError::DecryptionFailed)?;
plaintext.truncate(count);
Ok(plaintext)
}
}
#[derive(Clone)]
pub struct CEK {
key: Vec<u8>,
}
impl CEK {
pub fn new(key: Vec<u8>) -> Self {
CEK { key }
}
pub fn as_bytes(&self) -> &[u8] {
&self.key
}
}
impl Drop for CEK {
fn drop(&mut self) {
self.key.zeroize();
}
}
impl AsRef<[u8]> for CEK {
fn as_ref(&self) -> &[u8] {
&self.key
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_a256gcm_roundtrip() {
let enc = ContentEncryption::A256GCM;
let cek = enc.generate_cek();
let iv = enc.generate_iv();
let aad = b"additional authenticated data";
let plaintext = b"Hello, World!";
let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
let decrypted = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_a128gcm_roundtrip() {
let enc = ContentEncryption::A128GCM;
let cek = enc.generate_cek();
let iv = enc.generate_iv();
let aad = b"additional authenticated data";
let plaintext = b"Hello, World!";
let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
let decrypted = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_tampered_ciphertext_fails() {
let enc = ContentEncryption::A256GCM;
let cek = enc.generate_cek();
let iv = enc.generate_iv();
let aad = b"additional authenticated data";
let plaintext = b"Hello, World!";
let (mut ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
ciphertext[0] ^= 0xff;
let result = enc.decrypt(&cek, &iv, aad, &ciphertext, &tag);
assert!(result.is_err());
}
#[test]
fn test_tampered_aad_fails() {
let enc = ContentEncryption::A256GCM;
let cek = enc.generate_cek();
let iv = enc.generate_iv();
let aad = b"additional authenticated data";
let plaintext = b"Hello, World!";
let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
let wrong_aad = b"wrong aad";
let result = enc.decrypt(&cek, &iv, wrong_aad, &ciphertext, &tag);
assert!(result.is_err());
}
#[test]
fn test_wrong_key_fails() {
let enc = ContentEncryption::A256GCM;
let cek = enc.generate_cek();
let wrong_cek = enc.generate_cek();
let iv = enc.generate_iv();
let aad = b"additional authenticated data";
let plaintext = b"Hello, World!";
let (ciphertext, tag) = enc.encrypt(&cek, &iv, aad, plaintext).unwrap();
let result = enc.decrypt(&wrong_cek, &iv, aad, &ciphertext, &tag);
assert!(result.is_err());
}
}