#![allow(deprecated)]
use aes_gcm::aead::{Aead, KeyInit, Payload};
use aes_gcm::{Aes256Gcm, Nonce};
use aes_kw::Kek;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use p256::PublicKey;
use rand::RngCore;
use rsa::{Oaep, RsaPublicKey};
use serde_json::Value;
use sha1::Sha1;
use crate::errors::WebexError;
pub fn encrypt_rsa_oaep_a256gcm(
plaintext: &[u8],
rsa_jwk: &Value,
) -> Result<String, WebexError> {
let rsa_key = parse_rsa_public_key(rsa_jwk)?;
let mut cek = [0u8; 32];
rand::thread_rng().fill_bytes(&mut cek);
let mut iv = [0u8; 12];
rand::thread_rng().fill_bytes(&mut iv);
let mut header = serde_json::json!({"alg": "RSA-OAEP", "enc": "A256GCM"});
if let Some(kid) = rsa_jwk.get("kid").and_then(|v| v.as_str()) {
header["kid"] = Value::String(kid.to_string());
}
let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
let padding = Oaep::new::<Sha1>();
let encrypted_key = rsa_key
.encrypt(&mut rand::thread_rng(), padding, &cek)
.map_err(|e| WebexError::kms(format!("RSA-OAEP encryption failed: {e}")))?;
let cipher = Aes256Gcm::new_from_slice(&cek)
.map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
let nonce = Nonce::from_slice(&iv);
let aad = header_b64.as_bytes();
let ciphertext_with_tag = cipher
.encrypt(nonce, Payload { msg: plaintext, aad })
.map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
Ok(format!(
"{}.{}.{}.{}.{}",
header_b64,
URL_SAFE_NO_PAD.encode(&encrypted_key),
URL_SAFE_NO_PAD.encode(iv),
URL_SAFE_NO_PAD.encode(ciphertext),
URL_SAFE_NO_PAD.encode(tag),
))
}
pub fn encrypt_dir_a256gcm(
plaintext: &[u8],
cek: &[u8; 32],
kid: &str,
) -> Result<String, WebexError> {
let mut iv = [0u8; 12];
rand::thread_rng().fill_bytes(&mut iv);
let mut header = serde_json::json!({"alg": "dir", "enc": "A256GCM"});
if !kid.is_empty() {
header["kid"] = Value::String(kid.to_string());
}
let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
let cipher = Aes256Gcm::new_from_slice(cek)
.map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
let nonce = Nonce::from_slice(&iv);
let aad = header_b64.as_bytes();
let ciphertext_with_tag = cipher
.encrypt(nonce, Payload { msg: plaintext, aad })
.map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
Ok(format!(
"{}.{}.{}.{}.{}",
header_b64,
"", URL_SAFE_NO_PAD.encode(iv),
URL_SAFE_NO_PAD.encode(ciphertext),
URL_SAFE_NO_PAD.encode(tag),
))
}
pub fn encrypt_a256kw_a256gcm(
plaintext: &[u8],
wrapping_key: &[u8; 32],
) -> Result<String, WebexError> {
let mut cek = [0u8; 32];
rand::thread_rng().fill_bytes(&mut cek);
let mut iv = [0u8; 12];
rand::thread_rng().fill_bytes(&mut iv);
let header = serde_json::json!({"alg": "A256KW", "enc": "A256GCM"});
let header_b64 = URL_SAFE_NO_PAD.encode(header.to_string().as_bytes());
let kek = Kek::from(*wrapping_key);
let mut wrapped_key = vec![0u8; cek.len() + 8]; kek.wrap(&cek, &mut wrapped_key)
.map_err(|e| WebexError::kms(format!("AES key wrap failed: {e}")))?;
let cipher = Aes256Gcm::new_from_slice(&cek)
.map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
let nonce = Nonce::from_slice(&iv);
let aad = header_b64.as_bytes();
let ciphertext_with_tag = cipher
.encrypt(nonce, Payload { msg: plaintext, aad })
.map_err(|e| WebexError::kms(format!("AES-GCM encryption failed: {e}")))?;
let (ciphertext, tag) = ciphertext_with_tag.split_at(ciphertext_with_tag.len() - 16);
Ok(format!(
"{}.{}.{}.{}.{}",
header_b64,
URL_SAFE_NO_PAD.encode(&wrapped_key),
URL_SAFE_NO_PAD.encode(iv),
URL_SAFE_NO_PAD.encode(ciphertext),
URL_SAFE_NO_PAD.encode(tag),
))
}
pub fn decrypt_a256kw_a256gcm(
token: &str,
wrapping_key: &[u8; 32],
) -> Result<Vec<u8>, WebexError> {
let parts = parse_jwe_compact(token)?;
let kek = Kek::from(*wrapping_key);
let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
kek.unwrap(&parts.encrypted_key, &mut cek)
.map_err(|e| WebexError::kms(format!("AES key unwrap failed: {e}")))?;
decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
}
pub fn decrypt_dir_a256gcm(
token: &str,
cek: &[u8; 32],
) -> Result<Vec<u8>, WebexError> {
let parts = parse_jwe_compact(token)?;
decrypt_a256gcm(cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
}
pub fn decrypt_message_jwe(
token: &str,
key: &[u8; 32],
) -> Result<Vec<u8>, WebexError> {
let parts = parse_jwe_compact(token)?;
let header_json: Value = serde_json::from_slice(&parts.header_bytes)
.map_err(|e| WebexError::kms(format!("Failed to parse JWE header: {e}")))?;
let alg = header_json
.get("alg")
.and_then(|v| v.as_str())
.unwrap_or("");
match alg {
"dir" => {
decrypt_a256gcm(key, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
}
"A256KW" => {
let kek = Kek::from(*key);
let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
kek.unwrap(&parts.encrypted_key, &mut cek)
.map_err(|e| WebexError::kms(format!("AES key unwrap failed: {e}")))?;
decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
}
_ => Err(WebexError::kms(format!(
"Unsupported message JWE algorithm: {alg}"
))),
}
}
pub fn decrypt_ecdh_es(
token: &str,
local_private_key: &p256::SecretKey,
) -> Result<Vec<u8>, WebexError> {
let parts = parse_jwe_compact(token)?;
let header_json: Value = serde_json::from_slice(&parts.header_bytes)
.map_err(|e| WebexError::kms(format!("Failed to parse JWE header: {e}")))?;
let alg = header_json
.get("alg")
.and_then(|v| v.as_str())
.unwrap_or("");
let enc = header_json
.get("enc")
.and_then(|v| v.as_str())
.unwrap_or("A256GCM");
let epk = header_json
.get("epk")
.ok_or_else(|| WebexError::kms("No epk in ECDH-ES JWE header"))?;
let server_public = parse_ec_public_key(epk)?;
let shared_secret = p256::ecdh::diffie_hellman(
local_private_key.to_nonzero_scalar(),
server_public.as_affine(),
);
let apu = header_json
.get("apu")
.and_then(|v| v.as_str())
.map(|s| URL_SAFE_NO_PAD.decode(s).unwrap_or_default())
.unwrap_or_default();
let apv = header_json
.get("apv")
.and_then(|v| v.as_str())
.map(|s| URL_SAFE_NO_PAD.decode(s).unwrap_or_default())
.unwrap_or_default();
match alg {
"ECDH-ES" => {
let key_len = enc_key_length(enc);
let cek = concat_kdf(
shared_secret.raw_secret_bytes(),
enc, &apu,
&apv,
(key_len * 8) as u32,
)?;
decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
}
"ECDH-ES+A256KW" => {
let kek_bytes = concat_kdf(
shared_secret.raw_secret_bytes(),
"A256KW",
&apu,
&apv,
256,
)?;
let kek_arr: [u8; 32] = kek_bytes
.try_into()
.map_err(|_| WebexError::kms("Derived KEK is not 32 bytes"))?;
let kek = Kek::from(kek_arr);
let mut cek = vec![0u8; parts.encrypted_key.len() - 8];
kek.unwrap(&parts.encrypted_key, &mut cek)
.map_err(|e| WebexError::kms(format!("ECDH-ES+A256KW unwrap failed: {e}")))?;
decrypt_a256gcm(&cek, &parts.iv, &parts.ciphertext, &parts.tag, &parts.header_b64)
}
_ => Err(WebexError::kms(format!("Unsupported ECDH algorithm: {alg}"))),
}
}
pub fn decrypt_jwe(token: &str, key: &JweKey) -> Result<Vec<u8>, WebexError> {
match key {
JweKey::Symmetric(k) => decrypt_message_jwe(token, k),
JweKey::EcdhPrivate(k) => decrypt_ecdh_es(token, k),
}
}
pub fn unwrap_kms_response(token: &str, key: &JweKey) -> Result<Vec<u8>, WebexError> {
let dot_count = token.chars().filter(|&c| c == '.').count();
match dot_count {
4 => decrypt_jwe(token, key),
2 => {
let parts: Vec<&str> = token.split('.').collect();
URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| WebexError::kms(format!("Failed to decode JWS payload: {e}")))
}
_ => Err(WebexError::kms(format!(
"Invalid KMS response format: expected 3 or 5 parts, got {} dots",
dot_count
))),
}
}
pub enum JweKey {
Symmetric([u8; 32]),
EcdhPrivate(p256::SecretKey),
}
struct JweParts {
header_b64: String,
header_bytes: Vec<u8>,
encrypted_key: Vec<u8>,
iv: Vec<u8>,
ciphertext: Vec<u8>,
tag: Vec<u8>,
}
fn parse_jwe_compact(token: &str) -> Result<JweParts, WebexError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 5 {
return Err(WebexError::kms(format!(
"Invalid JWE compact: expected 5 parts, got {}",
parts.len()
)));
}
let header_b64 = parts[0].to_string();
let header_bytes = URL_SAFE_NO_PAD
.decode(parts[0])
.map_err(|e| WebexError::kms(format!("Failed to decode JWE header: {e}")))?;
let encrypted_key = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| WebexError::kms(format!("Failed to decode encrypted key: {e}")))?;
let iv = URL_SAFE_NO_PAD
.decode(parts[2])
.map_err(|e| WebexError::kms(format!("Failed to decode IV: {e}")))?;
let ciphertext = URL_SAFE_NO_PAD
.decode(parts[3])
.map_err(|e| WebexError::kms(format!("Failed to decode ciphertext: {e}")))?;
let tag = URL_SAFE_NO_PAD
.decode(parts[4])
.map_err(|e| WebexError::kms(format!("Failed to decode tag: {e}")))?;
Ok(JweParts {
header_b64,
header_bytes,
encrypted_key,
iv,
ciphertext,
tag,
})
}
fn decrypt_a256gcm(
cek: &[u8],
iv: &[u8],
ciphertext: &[u8],
tag: &[u8],
aad: &str,
) -> Result<Vec<u8>, WebexError> {
let cipher = Aes256Gcm::new_from_slice(cek)
.map_err(|e| WebexError::kms(format!("AES-GCM key init failed: {e}")))?;
let nonce = Nonce::from_slice(iv);
let mut ct_with_tag = ciphertext.to_vec();
ct_with_tag.extend_from_slice(tag);
let plaintext = cipher
.decrypt(
nonce,
Payload {
msg: &ct_with_tag,
aad: aad.as_bytes(),
},
)
.map_err(|e| WebexError::kms(format!("AES-GCM decryption failed: {e}")))?;
Ok(plaintext)
}
fn concat_kdf(
shared_secret: &[u8],
algorithm_id: &str,
apu: &[u8],
apv: &[u8],
key_data_len_bits: u32,
) -> Result<Vec<u8>, WebexError> {
use sha2::{Digest, Sha256};
let key_data_len = (key_data_len_bits / 8) as usize;
let reps = key_data_len.div_ceil(32);
let mut derived = Vec::with_capacity(key_data_len);
for counter in 1..=reps as u32 {
let mut hasher = Sha256::new();
hasher.update(counter.to_be_bytes());
hasher.update(shared_secret);
hasher.update((algorithm_id.len() as u32).to_be_bytes());
hasher.update(algorithm_id.as_bytes());
hasher.update((apu.len() as u32).to_be_bytes());
hasher.update(apu);
hasher.update((apv.len() as u32).to_be_bytes());
hasher.update(apv);
hasher.update(key_data_len_bits.to_be_bytes());
derived.extend_from_slice(&hasher.finalize());
}
derived.truncate(key_data_len);
Ok(derived)
}
fn enc_key_length(enc: &str) -> usize {
match enc {
"A128GCM" => 16,
"A192GCM" => 24,
"A256GCM" => 32,
"A128CBC-HS256" => 32,
"A256CBC-HS512" => 64,
_ => 32, }
}
fn parse_rsa_public_key(jwk: &Value) -> Result<RsaPublicKey, WebexError> {
let n = jwk
.get("n")
.and_then(|v| v.as_str())
.ok_or_else(|| WebexError::kms("Missing 'n' in RSA JWK"))?;
let e = jwk
.get("e")
.and_then(|v| v.as_str())
.ok_or_else(|| WebexError::kms("Missing 'e' in RSA JWK"))?;
let n_bytes = URL_SAFE_NO_PAD
.decode(n)
.map_err(|e| WebexError::kms(format!("Failed to decode RSA n: {e}")))?;
let e_bytes = URL_SAFE_NO_PAD
.decode(e)
.map_err(|e| WebexError::kms(format!("Failed to decode RSA e: {e}")))?;
let n_uint = rsa::BigUint::from_bytes_be(&n_bytes);
let e_uint = rsa::BigUint::from_bytes_be(&e_bytes);
RsaPublicKey::new(n_uint, e_uint)
.map_err(|e| WebexError::kms(format!("Invalid RSA public key: {e}")))
}
fn parse_ec_public_key(jwk: &Value) -> Result<PublicKey, WebexError> {
let x = jwk
.get("x")
.and_then(|v| v.as_str())
.ok_or_else(|| WebexError::kms("Missing 'x' in EC JWK"))?;
let y = jwk
.get("y")
.and_then(|v| v.as_str())
.ok_or_else(|| WebexError::kms("Missing 'y' in EC JWK"))?;
let x_bytes = URL_SAFE_NO_PAD
.decode(x)
.map_err(|e| WebexError::kms(format!("Failed to decode EC x: {e}")))?;
let y_bytes = URL_SAFE_NO_PAD
.decode(y)
.map_err(|e| WebexError::kms(format!("Failed to decode EC y: {e}")))?;
let mut uncompressed = vec![0x04];
uncompressed.extend_from_slice(&x_bytes);
uncompressed.extend_from_slice(&y_bytes);
PublicKey::from_sec1_bytes(&uncompressed)
.map_err(|e| WebexError::kms(format!("Invalid EC public key: {e}")))
}