#[cfg(any(feature = "pure-rust", target_arch = "wasm32", target_arch = "wasm64"))]
use superboring as boring;
use boring::aes::{unwrap_key, wrap_key, AesKey};
use rand::RngCore;
use serde::{de::DeserializeOwned, Serialize};
use zeroize::Zeroize;
use crate::claims::*;
use crate::error::*;
use crate::jwe_header::JWEHeader;
use crate::jwe_token::{DecryptionOptions, EncryptionOptions, JWEToken, JWETokenMetadata};
#[derive(Clone)]
pub struct A256KWKey {
key: Vec<u8>,
key_id: Option<String>,
}
impl std::fmt::Debug for A256KWKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("A256KWKey")
.field("key_id", &self.key_id)
.finish_non_exhaustive()
}
}
impl Drop for A256KWKey {
fn drop(&mut self) {
self.key.zeroize();
}
}
impl A256KWKey {
const KEY_SIZE: usize = 32;
const ALG_NAME: &'static str = "A256KW";
pub fn from_bytes(key: &[u8]) -> Result<Self, Error> {
ensure!(key.len() == Self::KEY_SIZE, JWTError::InvalidEncryptionKey);
Ok(A256KWKey {
key: key.to_vec(),
key_id: None,
})
}
pub fn generate() -> Self {
let mut key = vec![0u8; Self::KEY_SIZE];
rand::thread_rng().fill_bytes(&mut key);
A256KWKey { key, key_id: None }
}
pub fn to_bytes(&self) -> Vec<u8> {
self.key.clone()
}
pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
self.key_id = Some(key_id.into());
self
}
pub fn key_id(&self) -> Option<&str> {
self.key_id.as_deref()
}
pub(crate) fn wrap_key(&self, cek: &[u8]) -> Result<Vec<u8>, Error> {
let aes_key = AesKey::new_encrypt(&self.key).map_err(|_| JWTError::InvalidEncryptionKey)?;
let mut wrapped = vec![0u8; cek.len() + 8];
wrap_key(&aes_key, None, &mut wrapped, cek).map_err(|_| JWTError::InvalidEncryptionKey)?;
Ok(wrapped)
}
pub(crate) fn unwrap_key(&self, wrapped: &[u8]) -> Result<Vec<u8>, Error> {
ensure!(wrapped.len() >= 16, JWTError::KeyUnwrapFailed);
let aes_key = AesKey::new_decrypt(&self.key).map_err(|_| JWTError::InvalidEncryptionKey)?;
let mut cek = vec![0u8; wrapped.len() - 8];
unwrap_key(&aes_key, None, &mut cek, wrapped).map_err(|_| JWTError::KeyUnwrapFailed)?;
Ok(cek)
}
pub fn encrypt<CustomClaims: Serialize>(
&self,
claims: JWTClaims<CustomClaims>,
) -> Result<String, Error> {
self.encrypt_with_options(claims, &EncryptionOptions::default())
}
pub fn encrypt_with_options<CustomClaims: Serialize>(
&self,
claims: JWTClaims<CustomClaims>,
options: &EncryptionOptions,
) -> Result<String, Error> {
let content_encryption = options.content_encryption;
let mut header = JWEHeader::new(Self::ALG_NAME, content_encryption.alg_name());
if let Some(key_id) = &self.key_id {
header.key_id = Some(key_id.clone());
}
if let Some(key_id) = &options.key_id {
header.key_id = Some(key_id.clone());
}
if let Some(cty) = &options.content_type {
header.content_type = Some(cty.clone());
}
JWEToken::build_from_claims(&header, &claims, content_encryption, |cek| {
self.wrap_key(cek)
})
}
pub fn decrypt_token<CustomClaims: DeserializeOwned>(
&self,
token: &str,
options: Option<DecryptionOptions>,
) -> Result<JWTClaims<CustomClaims>, Error> {
JWEToken::decrypt(Self::ALG_NAME, token, options, |_header, encrypted_key| {
self.unwrap_key(encrypted_key)
})
}
pub fn decode_metadata(token: &str) -> Result<JWETokenMetadata, Error> {
JWEToken::decode_metadata(token)
}
}
#[derive(Clone)]
pub struct A128KWKey {
key: Vec<u8>,
key_id: Option<String>,
}
impl std::fmt::Debug for A128KWKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("A128KWKey")
.field("key_id", &self.key_id)
.finish_non_exhaustive()
}
}
impl Drop for A128KWKey {
fn drop(&mut self) {
self.key.zeroize();
}
}
impl A128KWKey {
const KEY_SIZE: usize = 16;
const ALG_NAME: &'static str = "A128KW";
pub fn from_bytes(key: &[u8]) -> Result<Self, Error> {
ensure!(key.len() == Self::KEY_SIZE, JWTError::InvalidEncryptionKey);
Ok(A128KWKey {
key: key.to_vec(),
key_id: None,
})
}
pub fn generate() -> Self {
let mut key = vec![0u8; Self::KEY_SIZE];
rand::thread_rng().fill_bytes(&mut key);
A128KWKey { key, key_id: None }
}
pub fn to_bytes(&self) -> Vec<u8> {
self.key.clone()
}
pub fn with_key_id(mut self, key_id: impl Into<String>) -> Self {
self.key_id = Some(key_id.into());
self
}
pub fn key_id(&self) -> Option<&str> {
self.key_id.as_deref()
}
pub(crate) fn wrap_key(&self, cek: &[u8]) -> Result<Vec<u8>, Error> {
let aes_key = AesKey::new_encrypt(&self.key).map_err(|_| JWTError::InvalidEncryptionKey)?;
let mut wrapped = vec![0u8; cek.len() + 8];
wrap_key(&aes_key, None, &mut wrapped, cek).map_err(|_| JWTError::InvalidEncryptionKey)?;
Ok(wrapped)
}
pub(crate) fn unwrap_key(&self, wrapped: &[u8]) -> Result<Vec<u8>, Error> {
ensure!(wrapped.len() >= 16, JWTError::KeyUnwrapFailed);
let aes_key = AesKey::new_decrypt(&self.key).map_err(|_| JWTError::InvalidEncryptionKey)?;
let mut cek = vec![0u8; wrapped.len() - 8];
unwrap_key(&aes_key, None, &mut cek, wrapped).map_err(|_| JWTError::KeyUnwrapFailed)?;
Ok(cek)
}
pub fn encrypt<CustomClaims: Serialize>(
&self,
claims: JWTClaims<CustomClaims>,
) -> Result<String, Error> {
self.encrypt_with_options(claims, &EncryptionOptions::default())
}
pub fn encrypt_with_options<CustomClaims: Serialize>(
&self,
claims: JWTClaims<CustomClaims>,
options: &EncryptionOptions,
) -> Result<String, Error> {
let content_encryption = options.content_encryption;
let mut header = JWEHeader::new(Self::ALG_NAME, content_encryption.alg_name());
if let Some(key_id) = &self.key_id {
header.key_id = Some(key_id.clone());
}
if let Some(key_id) = &options.key_id {
header.key_id = Some(key_id.clone());
}
if let Some(cty) = &options.content_type {
header.content_type = Some(cty.clone());
}
JWEToken::build_from_claims(&header, &claims, content_encryption, |cek| {
self.wrap_key(cek)
})
}
pub fn decrypt_token<CustomClaims: DeserializeOwned>(
&self,
token: &str,
options: Option<DecryptionOptions>,
) -> Result<JWTClaims<CustomClaims>, Error> {
JWEToken::decrypt(Self::ALG_NAME, token, options, |_header, encrypted_key| {
self.unwrap_key(encrypted_key)
})
}
pub fn decode_metadata(token: &str) -> Result<JWETokenMetadata, Error> {
JWEToken::decode_metadata(token)
}
}