use sha2::{Digest, Sha256};
use tracing::{debug, error, info, warn};
use zeroize::Zeroize;
use crate::config::TeeKmsConfig;
use crate::error::{AppError, tee_attestation_error};
pub struct BootstrappedSecrets {
pub seed: Vec<u8>,
pub jwt_signing_key: [u8; 32],
pub storage_key: [u8; 32],
pub entropy: Option<[u8; 32]>,
pub is_first_boot: bool,
}
impl Drop for BootstrappedSecrets {
fn drop(&mut self) {
self.seed.zeroize();
self.jwt_signing_key.zeroize();
self.storage_key.zeroize();
if let Some(ref mut e) = self.entropy {
e.zeroize();
}
}
}
const BOOTSTRAP_DK_CT_KEY: &str = "bootstrap:data_key_ciphertext";
const BOOTSTRAP_SEED_CT_KEY: &str = "bootstrap:seed_ciphertext";
const BOOTSTRAP_JWT_CT_KEY: &str = "bootstrap:jwt_ciphertext";
const BOOTSTRAP_JWT_FINGERPRINT_KEY: &str = "bootstrap:jwt_fingerprint";
pub async fn bootstrap_secrets(
kms_config: &TeeKmsConfig,
storage_key_salt: &str,
store: &crate::store::Store,
) -> Result<BootstrappedSecrets, AppError> {
let bs_ks = store.keyspace("bootstrap")?;
let dk_ct = bs_ks.get_raw(BOOTSTRAP_DK_CT_KEY).await?;
let seed_ct = bs_ks.get_raw(BOOTSTRAP_SEED_CT_KEY).await?;
let jwt_ct = bs_ks.get_raw(BOOTSTRAP_JWT_CT_KEY).await?;
if let (Some(dk_ciphertext), Some(seed_ciphertext), Some(jwt_ciphertext)) =
(dk_ct, seed_ct, jwt_ct)
{
info!("found existing secret ciphertexts in store — decrypting via KMS");
match kms_decrypt_data_key(kms_config, &dk_ciphertext).await {
Ok(data_key) => {
let seed = aes_gcm_decrypt(&data_key, &seed_ciphertext)?;
let jwt_bytes = aes_gcm_decrypt(&data_key, &jwt_ciphertext)?;
let jwt_key: [u8; 32] = jwt_bytes
.try_into()
.map_err(|_| tee_attestation_error("JWT key must be exactly 32 bytes"))?;
verify_jwt_fingerprint(&bs_ks, &jwt_key).await?;
info!("secrets decrypted from KMS — subsequent boot");
let storage_key = derive_storage_key(&seed, storage_key_salt);
return Ok(BootstrappedSecrets {
seed,
jwt_signing_key: jwt_key,
storage_key,
entropy: None,
is_first_boot: false,
});
}
Err(e) => {
warn!(
error = %e,
"KMS decrypt of existing ciphertexts failed — clearing stale \
bootstrap data and starting fresh. This is expected after an \
image rebuild with a new PCR0. The VTA will generate a new \
identity."
);
bs_ks.remove(BOOTSTRAP_DK_CT_KEY).await?;
bs_ks.remove(BOOTSTRAP_SEED_CT_KEY).await?;
bs_ks.remove(BOOTSTRAP_JWT_CT_KEY).await?;
bs_ks.remove(BOOTSTRAP_JWT_FINGERPRINT_KEY).await?;
store.persist().await?;
}
}
}
info!("first boot — generating new secrets in TEE");
let mut entropy = [0u8; 32];
rand::fill(&mut entropy);
let mnemonic = bip39::Mnemonic::from_entropy(&entropy)
.map_err(|e| tee_attestation_error(format!("failed to generate mnemonic: {e}")))?;
info!("master seed generated inside TEE (mnemonic NOT displayed)");
info!("to export the mnemonic, restart with VTA_MNEMONIC_EXPORT_WINDOW=<seconds>");
let full_seed = mnemonic.to_seed("").to_vec();
let seed = full_seed[..32].to_vec();
let mut jwt_key_bytes = [0u8; 32];
rand::fill(&mut jwt_key_bytes);
let jwt_key = jwt_key_bytes;
let (dk_ciphertext, data_key) = kms_generate_data_key(kms_config).await?;
let seed_ciphertext = aes_gcm_encrypt(&data_key, &seed)?;
let jwt_ciphertext = aes_gcm_encrypt(&data_key, &jwt_key)?;
bs_ks.insert_raw(BOOTSTRAP_DK_CT_KEY, dk_ciphertext).await?;
bs_ks
.insert_raw(BOOTSTRAP_SEED_CT_KEY, seed_ciphertext)
.await?;
bs_ks
.insert_raw(BOOTSTRAP_JWT_CT_KEY, jwt_ciphertext)
.await?;
store_jwt_fingerprint(&bs_ks, &jwt_key).await?;
store.persist().await?;
info!("secrets generated and encrypted to KMS — ciphertexts stored");
Ok(BootstrappedSecrets {
storage_key: derive_storage_key(&seed, storage_key_salt),
seed,
jwt_signing_key: jwt_key,
entropy: Some(entropy),
is_first_boot: true,
})
}
pub async fn re_encrypt_bootstrap_secrets(
kms_config: &TeeKmsConfig,
store: &crate::store::Store,
seed: &[u8],
jwt_key: &[u8; 32],
) -> Result<(), AppError> {
let bs_ks = store.keyspace("bootstrap")?;
let _ = bs_ks.remove(BOOTSTRAP_DK_CT_KEY).await;
let _ = bs_ks.remove(BOOTSTRAP_SEED_CT_KEY).await;
let _ = bs_ks.remove(BOOTSTRAP_JWT_CT_KEY).await;
let _ = bs_ks.remove(BOOTSTRAP_JWT_FINGERPRINT_KEY).await;
let (dk_ciphertext, data_key) = kms_generate_data_key(kms_config).await?;
let seed_ciphertext = aes_gcm_encrypt(&data_key, seed)?;
let jwt_ciphertext = aes_gcm_encrypt(&data_key, jwt_key)?;
bs_ks.insert_raw(BOOTSTRAP_DK_CT_KEY, dk_ciphertext).await?;
bs_ks
.insert_raw(BOOTSTRAP_SEED_CT_KEY, seed_ciphertext)
.await?;
bs_ks
.insert_raw(BOOTSTRAP_JWT_CT_KEY, jwt_ciphertext)
.await?;
store_jwt_fingerprint(&bs_ks, jwt_key).await?;
store.persist().await?;
info!("imported secrets re-encrypted to KMS — stored in bootstrap keyspace");
Ok(())
}
fn jwt_fingerprint(key: &[u8; 32]) -> String {
let hash = Sha256::digest(key);
hex::encode(&hash[..16]) }
async fn store_jwt_fingerprint(
bs_ks: &crate::store::KeyspaceHandle,
key: &[u8; 32],
) -> Result<(), AppError> {
let fingerprint = jwt_fingerprint(key);
bs_ks
.insert_raw(
BOOTSTRAP_JWT_FINGERPRINT_KEY,
fingerprint.as_bytes().to_vec(),
)
.await?;
debug!(fingerprint = %fingerprint, "JWT key fingerprint stored");
Ok(())
}
async fn verify_jwt_fingerprint(
bs_ks: &crate::store::KeyspaceHandle,
key: &[u8; 32],
) -> Result<(), AppError> {
let stored_bytes = match bs_ks.get_raw(BOOTSTRAP_JWT_FINGERPRINT_KEY).await? {
Some(bytes) => bytes,
None => {
warn!("no JWT fingerprint found — storing one now (first boot after upgrade)");
return store_jwt_fingerprint(bs_ks, key).await;
}
};
let stored = String::from_utf8_lossy(&stored_bytes);
let computed = jwt_fingerprint(key);
if stored.trim() != computed {
error!(
stored = %stored.trim(),
computed = %computed,
"JWT key fingerprint MISMATCH — possible key tampering or KMS key rotation"
);
return Err(tee_attestation_error(
"JWT key fingerprint mismatch — the decrypted JWT key does not match the key \
used on first boot. This could indicate tampering with the ciphertext \
or a KMS key change. If this is intentional (e.g., disaster recovery), \
clear the bootstrap keyspace and restart.",
));
}
debug!(fingerprint = %computed, "JWT key fingerprint verified");
Ok(())
}
pub(crate) fn derive_storage_key(seed: &[u8], salt: &str) -> [u8; 32] {
use hmac::{Hmac, Mac};
type HmacSha256 = Hmac<Sha256>;
let mut mac = HmacSha256::new_from_slice(salt.as_bytes()).expect("HMAC accepts any key length");
mac.update(seed);
let prk = mac.finalize().into_bytes();
let info = b"aes-256-gcm-storage";
let mut mac = HmacSha256::new_from_slice(&prk).expect("HMAC accepts any key length");
mac.update(info);
mac.update(&[0x01]);
let okm = mac.finalize().into_bytes();
let mut key = [0u8; 32];
key.copy_from_slice(&okm);
key
}
async fn kms_client(config: &TeeKmsConfig) -> aws_sdk_kms::Client {
let sdk_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(aws_config::Region::new(config.region.clone()))
.load()
.await;
aws_sdk_kms::Client::new(&sdk_config)
}
fn nsm_attested_recipient()
-> Result<(rsa::RsaPrivateKey, aws_sdk_kms::types::RecipientInfo), AppError> {
use rsa::pkcs8::EncodePublicKey;
let mut rng = rsa::rand_core::OsRng;
let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048)
.map_err(|e| tee_attestation_error(format!("RSA key generation failed: {e}")))?;
let public_key_der = private_key
.to_public_key()
.to_public_key_der()
.map_err(|e| tee_attestation_error(format!("RSA public key DER encoding failed: {e}")))?;
let attestation_doc = super::nitro::request_nsm_attestation_for_kms(public_key_der.as_ref())?;
let recipient = aws_sdk_kms::types::RecipientInfo::builder()
.attestation_document(aws_sdk_kms::primitives::Blob::new(attestation_doc))
.key_encryption_algorithm(aws_sdk_kms::types::KeyEncryptionMechanism::RsaesOaepSha256)
.build();
Ok((private_key, recipient))
}
fn unwrap_cms_response(
cms_blob: Option<&aws_sdk_kms::primitives::Blob>,
private_key: &rsa::RsaPrivateKey,
) -> Result<Vec<u8>, AppError> {
let cms_bytes = cms_blob.ok_or_else(|| {
tee_attestation_error(
"KMS response missing CiphertextForRecipient — \
the KMS key may not support attestation-based operations",
)
})?;
decrypt_cms_envelope(cms_bytes.as_ref(), private_key)
}
async fn kms_decrypt_data_key(
config: &TeeKmsConfig,
ciphertext: &[u8],
) -> Result<Vec<u8>, AppError> {
if std::path::Path::new("/dev/nsm").exists() {
match kms_decrypt_attested(config, ciphertext).await {
Ok(plaintext) => {
info!("KMS Decrypt succeeded with Nitro attestation");
return Ok(plaintext);
}
Err(e) => {
warn!(
error = %e,
"attestation-based KMS Decrypt failed — falling back to direct Decrypt"
);
}
}
}
kms_decrypt_direct(config, ciphertext).await
}
async fn kms_decrypt_attested(
config: &TeeKmsConfig,
ciphertext: &[u8],
) -> Result<Vec<u8>, AppError> {
let (private_key, recipient) = nsm_attested_recipient()?;
let client = kms_client(config).await;
let resp = client
.decrypt()
.ciphertext_blob(aws_sdk_kms::primitives::Blob::new(ciphertext))
.key_id(&config.key_arn)
.recipient(recipient)
.send()
.await
.map_err(|e| classify_kms_error("Decrypt(attested)", e))?;
unwrap_cms_response(resp.ciphertext_for_recipient(), &private_key)
}
async fn kms_decrypt_direct(config: &TeeKmsConfig, ciphertext: &[u8]) -> Result<Vec<u8>, AppError> {
let client = kms_client(config).await;
let resp = client
.decrypt()
.ciphertext_blob(aws_sdk_kms::primitives::Blob::new(ciphertext))
.key_id(&config.key_arn)
.send()
.await
.map_err(|e| classify_kms_error("Decrypt", e))?;
resp.plaintext()
.map(|b| b.as_ref().to_vec())
.ok_or_else(|| tee_attestation_error("KMS Decrypt returned no plaintext"))
}
async fn kms_generate_data_key(config: &TeeKmsConfig) -> Result<(Vec<u8>, [u8; 32]), AppError> {
if std::path::Path::new("/dev/nsm").exists() {
match kms_generate_data_key_attested(config).await {
Ok(result) => {
info!("KMS GenerateDataKey succeeded with Nitro attestation");
return Ok(result);
}
Err(e) => {
warn!(
error = %e,
"attestation-based GenerateDataKey failed — falling back to direct"
);
}
}
}
kms_generate_data_key_direct(config).await
}
async fn kms_generate_data_key_attested(
config: &TeeKmsConfig,
) -> Result<(Vec<u8>, [u8; 32]), AppError> {
let (private_key, recipient) = nsm_attested_recipient()?;
let client = kms_client(config).await;
let resp = client
.generate_data_key()
.key_id(&config.key_arn)
.key_spec(aws_sdk_kms::types::DataKeySpec::Aes256)
.recipient(recipient)
.send()
.await
.map_err(|e| classify_kms_error("GenerateDataKey(attested)", e))?;
let kms_ciphertext = resp
.ciphertext_blob()
.ok_or_else(|| tee_attestation_error("GenerateDataKey returned no CiphertextBlob"))?
.as_ref()
.to_vec();
let data_key_vec = unwrap_cms_response(resp.ciphertext_for_recipient(), &private_key)?;
let data_key: [u8; 32] = data_key_vec
.try_into()
.map_err(|_| tee_attestation_error("data key is not 32 bytes"))?;
debug!(
kms_ct_len = kms_ciphertext.len(),
"obtained attested data key"
);
Ok((kms_ciphertext, data_key))
}
async fn kms_generate_data_key_direct(
config: &TeeKmsConfig,
) -> Result<(Vec<u8>, [u8; 32]), AppError> {
let client = kms_client(config).await;
let resp = client
.generate_data_key()
.key_id(&config.key_arn)
.key_spec(aws_sdk_kms::types::DataKeySpec::Aes256)
.send()
.await
.map_err(|e| classify_kms_error("GenerateDataKey", e))?;
let kms_ciphertext = resp
.ciphertext_blob()
.ok_or_else(|| tee_attestation_error("GenerateDataKey returned no CiphertextBlob"))?
.as_ref()
.to_vec();
let plaintext = resp
.plaintext()
.ok_or_else(|| tee_attestation_error("GenerateDataKey returned no Plaintext"))?;
let data_key: [u8; 32] = plaintext
.as_ref()
.try_into()
.map_err(|_| tee_attestation_error("data key is not 32 bytes"))?;
Ok((kms_ciphertext, data_key))
}
fn aes_gcm_encrypt(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>, AppError> {
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
let mut nonce_bytes = [0u8; 12];
rand::fill(&mut nonce_bytes);
let nonce = GenericArray::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, plaintext)
.map_err(|e| tee_attestation_error(format!("AES-GCM encryption failed: {e}")))?;
let mut out = Vec::with_capacity(12 + ciphertext.len());
out.extend_from_slice(&nonce_bytes);
out.extend_from_slice(&ciphertext);
Ok(out)
}
fn aes_gcm_decrypt(key: &[u8], blob: &[u8]) -> Result<Vec<u8>, AppError> {
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
if key.len() != 32 {
return Err(tee_attestation_error(format!(
"data key is {} bytes, expected 32",
key.len()
)));
}
if blob.len() < 12 + 1 {
return Err(tee_attestation_error("AES-GCM blob too short"));
}
let nonce = GenericArray::from_slice(&blob[..12]);
let ciphertext = &blob[12..];
let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
cipher
.decrypt(nonce, ciphertext)
.map_err(|e| tee_attestation_error(format!("AES-GCM decryption failed: {e}")))
}
fn decrypt_cms_envelope(
cms_bytes: &[u8],
private_key: &rsa::RsaPrivateKey,
) -> Result<Vec<u8>, AppError> {
debug!(
cms_len = cms_bytes.len(),
cms_hex_head = %hex::encode(cms_bytes),
"raw CMS envelope"
);
let fields = cms_der::parse_enveloped_data(cms_bytes)?;
use rsa::Oaep;
let cek = private_key
.decrypt(Oaep::new::<sha2::Sha256>(), &fields.encrypted_key)
.or_else(|_| {
let padding = Oaep::new_with_mgf_hash::<sha2::Sha256, sha1::Sha1>();
private_key.decrypt(padding, &fields.encrypted_key)
})
.map_err(|e| tee_attestation_error(format!("RSA-OAEP decryption of CEK failed: {e}")))?;
debug!(
cek_len = cek.len(),
oid_hex = %hex::encode(&fields.content_encryption_oid),
iv_hex = %hex::encode(&fields.iv),
ciphertext_len = fields.ciphertext.len(),
encrypted_key_len = fields.encrypted_key.len(),
"CMS envelope fields extracted"
);
if cek.len() != 32 {
return Err(tee_attestation_error(format!(
"unexpected CEK length: {} (expected 32 for AES-256)",
cek.len()
)));
}
let aes_256_cbc_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x01, 0x2a];
let aes_256_gcm_oid: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x01, 0x2e];
let plaintext = if fields.content_encryption_oid == aes_256_cbc_oid {
use cbc::cipher::{BlockDecryptMut, KeyIvInit};
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
if fields.iv.len() != 16 {
return Err(tee_attestation_error(format!(
"AES-256-CBC IV must be 16 bytes, got {}",
fields.iv.len()
)));
}
let mut buf = fields.ciphertext.clone();
let decryptor = Aes256CbcDec::new_from_slices(&cek, &fields.iv)
.map_err(|e| tee_attestation_error(format!("AES-256-CBC init failed: {e}")))?;
let plaintext = decryptor
.decrypt_padded_mut::<cbc::cipher::block_padding::Pkcs7>(&mut buf)
.map_err(|e| {
tee_attestation_error(format!("AES-256-CBC decryption of CMS content failed: {e}"))
})?;
plaintext.to_vec()
} else if fields.content_encryption_oid == aes_256_gcm_oid {
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::{AesGcm, KeyInit, aead::Aead};
match fields.iv.len() {
12 => {
let cipher = AesGcm::<aes_gcm::aes::Aes256, aes_gcm::aead::consts::U12>::new(
GenericArray::from_slice(&cek),
);
cipher
.decrypt(
GenericArray::from_slice(&fields.iv),
fields.ciphertext.as_ref(),
)
.map_err(|e| tee_attestation_error(format!("AES-GCM decryption failed: {e}")))?
}
16 => {
let cipher = AesGcm::<aes_gcm::aes::Aes256, aes_gcm::aead::consts::U16>::new(
GenericArray::from_slice(&cek),
);
cipher
.decrypt(
GenericArray::from_slice(&fields.iv),
fields.ciphertext.as_ref(),
)
.map_err(|e| tee_attestation_error(format!("AES-GCM decryption failed: {e}")))?
}
n => {
return Err(tee_attestation_error(format!(
"unsupported GCM nonce length: {n}"
)));
}
}
} else {
return Err(tee_attestation_error(format!(
"unsupported content encryption algorithm OID: {}",
hex::encode(&fields.content_encryption_oid)
)));
};
debug!(
plaintext_len = plaintext.len(),
"CMS envelope decrypted successfully"
);
Ok(plaintext)
}
mod cms_der {
use crate::error::{AppError, tee_attestation_error};
pub(super) struct CmsFields {
pub encrypted_key: Vec<u8>,
pub content_encryption_oid: Vec<u8>,
pub iv: Vec<u8>,
pub ciphertext: Vec<u8>,
}
pub(super) fn parse_enveloped_data(data: &[u8]) -> Result<CmsFields, AppError> {
let mut pos = 0;
let (_, ci_body) = read_tlv(data, &mut pos, "ContentInfo")?;
let mut ci_pos = 0;
let _ = read_tlv(ci_body, &mut ci_pos, "contentType OID")?;
let (_, ctx0_body) = read_tlv(ci_body, &mut ci_pos, "[0] content")?;
let mut env_pos = 0;
let (_, env_body) = read_tlv(ctx0_body, &mut env_pos, "EnvelopedData")?;
let mut ed_pos = 0;
let _ = read_tlv(env_body, &mut ed_pos, "EnvelopedData version")?;
let (_, ri_set) = read_tlv(env_body, &mut ed_pos, "recipientInfos SET")?;
let (_, eci_body) = read_tlv(env_body, &mut ed_pos, "encryptedContentInfo")?;
let encrypted_key = parse_key_trans_ri(ri_set)?;
let (oid, iv, ciphertext) = parse_encrypted_content_info(eci_body)?;
Ok(CmsFields {
encrypted_key,
content_encryption_oid: oid,
iv,
ciphertext,
})
}
fn parse_key_trans_ri(set_data: &[u8]) -> Result<Vec<u8>, AppError> {
let mut pos = 0;
let (_, ktri_body) = read_tlv(set_data, &mut pos, "KeyTransRI")?;
let mut kp = 0;
let _ = read_tlv(ktri_body, &mut kp, "KeyTransRI version")?;
let _ = read_tlv(ktri_body, &mut kp, "KeyTransRI rid")?;
let _ = read_tlv(ktri_body, &mut kp, "KeyTransRI keyEncAlg")?;
let (_, ek_value) = read_tlv(ktri_body, &mut kp, "encryptedKey")?;
Ok(ek_value.to_vec())
}
type EncryptedContentParts = (Vec<u8>, Vec<u8>, Vec<u8>);
fn parse_encrypted_content_info(eci_data: &[u8]) -> Result<EncryptedContentParts, AppError> {
let mut pos = 0;
let _ = read_tlv(eci_data, &mut pos, "ECI contentType")?;
let (_, alg_body) = read_tlv(eci_data, &mut pos, "ECI algorithm")?;
if pos >= eci_data.len() {
return Err(tee_attestation_error(
"CMS: missing encryptedContent in EncryptedContentInfo",
));
}
let _tag = eci_data[pos];
pos += 1;
if pos >= eci_data.len() {
return Err(tee_attestation_error(
"CMS: truncated encryptedContent length",
));
}
let first_len = eci_data[pos];
pos += 1;
let ct_value = if first_len < 0x80 {
let len = first_len as usize;
&eci_data[pos..pos + len]
} else if first_len == 0x80 {
let remaining = &eci_data[pos..];
if remaining.len() >= 2
&& remaining[remaining.len() - 2] == 0x00
&& remaining[remaining.len() - 1] == 0x00
{
&remaining[..remaining.len() - 2]
} else {
remaining
}
} else {
let num_bytes = (first_len & 0x7F) as usize;
let mut len = 0usize;
for i in 0..num_bytes {
len = (len << 8) | (eci_data[pos + i] as usize);
}
pos += num_bytes;
&eci_data[pos..pos + len]
};
let (oid, iv) = parse_content_encryption_params(alg_body)?;
let ciphertext = if ct_value.len() > 2 && ct_value[0] == 0x04 {
let mut inner_pos = 0;
let (_, inner) = read_tlv(ct_value, &mut inner_pos, "inner encryptedContent")?;
inner.to_vec()
} else {
ct_value.to_vec()
};
Ok((oid, iv, ciphertext))
}
fn parse_content_encryption_params(alg_data: &[u8]) -> Result<(Vec<u8>, Vec<u8>), AppError> {
let mut pos = 0;
let (_, oid_bytes) = read_tlv(alg_data, &mut pos, "algorithm OID")?;
let (param_tag, params_body) = read_tlv(alg_data, &mut pos, "algorithm parameters")?;
let iv = if param_tag == 0x04 {
params_body.to_vec()
} else {
let mut pp = 0;
let (_, nonce_value) = read_tlv(params_body, &mut pp, "GCM nonce")?;
nonce_value.to_vec()
};
Ok((oid_bytes.to_vec(), iv))
}
fn read_tlv<'a>(
data: &'a [u8],
pos: &mut usize,
context: &str,
) -> Result<(u8, &'a [u8]), AppError> {
if *pos >= data.len() {
return Err(tee_attestation_error(format!(
"CMS: unexpected end of data reading {context}"
)));
}
let tag = data[*pos];
*pos += 1;
if *pos >= data.len() {
return Err(tee_attestation_error(format!(
"CMS: truncated length for {context}"
)));
}
let first_len = data[*pos];
*pos += 1;
let len: usize = if first_len < 0x80 {
first_len as usize
} else if first_len == 0x80 {
let content_start = *pos;
while *pos + 1 < data.len() {
if data[*pos] == 0x00 && data[*pos + 1] == 0x00 {
let value = &data[content_start..*pos];
*pos += 2; return Ok((tag, value));
}
skip_ber_tlv(data, pos).map_err(|_| {
tee_attestation_error(format!(
"CMS: malformed BER child element inside {context}"
))
})?;
}
return Err(tee_attestation_error(format!(
"CMS: no EOC marker found for indefinite-length {context}"
)));
} else {
let num_bytes = (first_len & 0x7F) as usize;
if *pos + num_bytes > data.len() {
return Err(tee_attestation_error(format!(
"CMS: truncated length bytes for {context}"
)));
}
let mut len: usize = 0;
for i in 0..num_bytes {
len = (len << 8) | (data[*pos + i] as usize);
}
*pos += num_bytes;
len
};
if *pos + len > data.len() {
return Err(tee_attestation_error(format!(
"CMS: value overflows buffer for {context} (need {len} bytes at offset {pos}, have {})",
data.len()
)));
}
let value = &data[*pos..*pos + len];
*pos += len;
Ok((tag, value))
}
fn skip_ber_tlv(data: &[u8], pos: &mut usize) -> Result<(), ()> {
if *pos + 1 >= data.len() {
return Err(());
}
*pos += 1;
let first_len = data[*pos];
*pos += 1;
if first_len < 0x80 {
let len = first_len as usize;
if *pos + len > data.len() {
return Err(());
}
*pos += len;
} else if first_len == 0x80 {
while *pos + 1 < data.len() {
if data[*pos] == 0x00 && data[*pos + 1] == 0x00 {
*pos += 2; return Ok(());
}
skip_ber_tlv(data, pos)?;
}
return Err(());
} else {
let num_bytes = (first_len & 0x7F) as usize;
if *pos + num_bytes > data.len() {
return Err(());
}
let mut len: usize = 0;
for i in 0..num_bytes {
len = (len << 8) | (data[*pos + i] as usize);
}
*pos += num_bytes;
if *pos + len > data.len() {
return Err(());
}
*pos += len;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_tlv_short_form() {
let data = [0x04, 0x03, 0x01, 0x02, 0x03];
let mut pos = 0;
let (tag, value) = read_tlv(&data, &mut pos, "test").unwrap();
assert_eq!(tag, 0x04);
assert_eq!(value, &[0x01, 0x02, 0x03]);
assert_eq!(pos, 5);
}
#[test]
fn test_read_tlv_long_form() {
let mut data = vec![0x04, 0x81, 0x80];
data.extend_from_slice(&[0xAA; 128]);
let mut pos = 0;
let (tag, value) = read_tlv(&data, &mut pos, "test").unwrap();
assert_eq!(tag, 0x04);
assert_eq!(value.len(), 128);
assert_eq!(pos, 131);
}
#[test]
fn test_read_tlv_truncated() {
let data = [0x04, 0x05, 0x01]; let mut pos = 0;
assert!(read_tlv(&data, &mut pos, "test").is_err());
}
}
}
fn classify_kms_error<E: std::error::Error>(operation: &str, err: E) -> AppError {
let mut full_msg = format!("{err}");
let mut source = std::error::Error::source(&err);
while let Some(cause) = source {
full_msg.push_str(&format!("\n caused by: {cause}"));
source = cause.source();
}
let classification = if full_msg.contains("AccessDeniedException") {
"ACCESS_DENIED — check KMS key policy allows this action and PCR conditions match"
} else if full_msg.contains("NotFoundException") || full_msg.contains("not found") {
"KEY_NOT_FOUND — verify the KMS key ARN in config.toml"
} else if full_msg.contains("InvalidCiphertextException") {
"INVALID_CIPHERTEXT — ciphertext may be corrupt or encrypted with a different key"
} else if full_msg.contains("KMSInternalException") {
"KMS_INTERNAL — transient AWS error, retry may help"
} else if full_msg.contains("connect") || full_msg.contains("timeout") {
"NETWORK — cannot reach KMS endpoint, check vsock proxy and allowlist"
} else {
"UNKNOWN"
};
let msg = format!("KMS {operation} failed [{classification}]: {full_msg}");
error!(operation, classification, "KMS error");
tee_attestation_error(msg)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cms_envelope_roundtrip() {
use aes_gcm::aead::generic_array::GenericArray;
use aes_gcm::{Aes256Gcm, KeyInit, aead::Aead};
use rsa::pkcs8::EncodePublicKey;
use rsa::{Oaep, RsaPrivateKey};
let mut rng = rsa::rand_core::OsRng;
let private_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
let original_plaintext = b"this is a secret seed value!!!!!";
let mut cek = [0u8; 32];
rand::fill(&mut cek);
let mut nonce_bytes = [0u8; 12];
rand::fill(&mut nonce_bytes);
let cipher = Aes256Gcm::new(GenericArray::from_slice(&cek));
let nonce = GenericArray::from_slice(&nonce_bytes);
let aes_ciphertext = cipher.encrypt(nonce, original_plaintext.as_ref()).unwrap();
let encrypted_cek = private_key
.to_public_key()
.encrypt(&mut rng, Oaep::new::<sha2::Sha256>(), &cek)
.unwrap();
let cms_bytes = build_test_cms_envelope(&encrypted_cek, &nonce_bytes, &aes_ciphertext);
let recovered = decrypt_cms_envelope(&cms_bytes, &private_key).unwrap();
assert_eq!(recovered, original_plaintext);
}
fn build_test_cms_envelope(
encrypted_cek: &[u8],
nonce: &[u8],
aes_ciphertext: &[u8],
) -> Vec<u8> {
let enveloped_data_oid = &[
0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x07, 0x03,
];
let data_oid = &[
0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x07, 0x01,
];
let aes_256_gcm_oid = &[
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x01, 0x2E,
];
let rsaes_oaep_oid = &[
0x06, 0x09, 0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01, 0x07,
];
let nonce_tlv = der_octet_string(nonce);
let gcm_params = der_sequence(&nonce_tlv);
let mut alg_id_content = Vec::new();
alg_id_content.extend_from_slice(aes_256_gcm_oid);
alg_id_content.extend_from_slice(&gcm_params);
let alg_id = der_sequence(&alg_id_content);
let encrypted_content = der_context_implicit(0, aes_ciphertext);
let mut eci_content = Vec::new();
eci_content.extend_from_slice(data_oid);
eci_content.extend_from_slice(&alg_id);
eci_content.extend_from_slice(&encrypted_content);
let eci = der_sequence(&eci_content);
let fake_rid = der_sequence(&[0x30, 0x00, 0x02, 0x01, 0x01]);
let key_enc_alg = der_sequence(rsaes_oaep_oid);
let mut ktri_content = Vec::new();
ktri_content.extend_from_slice(&[0x02, 0x01, 0x00]); ktri_content.extend_from_slice(&fake_rid);
ktri_content.extend_from_slice(&key_enc_alg);
ktri_content.extend_from_slice(&der_octet_string(encrypted_cek));
let ktri = der_sequence(&ktri_content);
let ri_set = der_set(&ktri);
let mut env_content = Vec::new();
env_content.extend_from_slice(&[0x02, 0x01, 0x00]); env_content.extend_from_slice(&ri_set);
env_content.extend_from_slice(&eci);
let enveloped_data = der_sequence(&env_content);
let ctx0 = der_context_explicit(0, &enveloped_data);
let mut ci_content = Vec::new();
ci_content.extend_from_slice(enveloped_data_oid);
ci_content.extend_from_slice(&ctx0);
der_sequence(&ci_content)
}
fn der_sequence(content: &[u8]) -> Vec<u8> {
der_tlv(0x30, content)
}
fn der_set(content: &[u8]) -> Vec<u8> {
der_tlv(0x31, content)
}
fn der_octet_string(content: &[u8]) -> Vec<u8> {
der_tlv(0x04, content)
}
fn der_context_explicit(tag_num: u8, content: &[u8]) -> Vec<u8> {
der_tlv(0xA0 | tag_num, content) }
fn der_context_implicit(tag_num: u8, content: &[u8]) -> Vec<u8> {
der_tlv(0x80 | tag_num, content) }
fn der_tlv(tag: u8, content: &[u8]) -> Vec<u8> {
let mut buf = vec![tag];
let len = content.len();
if len < 0x80 {
buf.push(len as u8);
} else if len < 0x100 {
buf.push(0x81);
buf.push(len as u8);
} else if len < 0x10000 {
buf.push(0x82);
buf.push((len >> 8) as u8);
buf.push(len as u8);
} else {
buf.push(0x83);
buf.push((len >> 16) as u8);
buf.push((len >> 8) as u8);
buf.push(len as u8);
}
buf.extend_from_slice(content);
buf
}
#[test]
fn test_derive_storage_key_deterministic() {
let seed = [0x42u8; 32];
let key1 = derive_storage_key(&seed, "test-salt");
let key2 = derive_storage_key(&seed, "test-salt");
assert_eq!(key1, key2, "same seed + salt must produce same key");
}
#[test]
fn test_derive_storage_key_different_salts() {
let seed = [0x42u8; 32];
let key1 = derive_storage_key(&seed, "salt-a");
let key2 = derive_storage_key(&seed, "salt-b");
assert_ne!(key1, key2, "different salts must produce different keys");
}
#[test]
fn test_derive_storage_key_different_seeds() {
let key1 = derive_storage_key(&[0x01u8; 32], "same-salt");
let key2 = derive_storage_key(&[0x02u8; 32], "same-salt");
assert_ne!(key1, key2, "different seeds must produce different keys");
}
#[test]
fn test_jwt_fingerprint_deterministic() {
let key = [0xABu8; 32];
let fp1 = jwt_fingerprint(&key);
let fp2 = jwt_fingerprint(&key);
assert_eq!(fp1, fp2);
assert_eq!(fp1.len(), 32); }
}