pub mod types;
pub use types::{KeyAttestationVerifyRequest, KeyAttestationVerifyResponse};
use base64::{engine::general_purpose::STANDARD, Engine};
use der_parser::ber::{BerObject, BerObjectContent};
use reqwest::Client;
use std::sync::Arc;
use tokio::sync::RwLock;
use x509_parser::prelude::*;
use super::error::{Result, SignError};
use types::*;
const ATTESTATION_EXTENSION_OID_COMPONENTS: &[u64] = &[1, 3, 6, 1, 4, 1, 11129, 2, 1, 17];
const GOOGLE_ATTESTATION_STATUS_URL: &str = "https://android.googleapis.com/attestation/status";
const GOOGLE_ROOT_2016: &[u8] = include_bytes!("google_hardware_attestation_root_2016.der");
const GOOGLE_ROOT_2019: &[u8] = include_bytes!("google_hardware_attestation_root_2019.der");
const GOOGLE_ROOT_2021: &[u8] = include_bytes!("google_hardware_attestation_root_2021.der");
const GOOGLE_ROOT_2022: &[u8] = include_bytes!("google_hardware_attestation_root_2022.der");
const GOOGLE_ROOT_2025_EC: &[u8] = include_bytes!("google_hardware_attestation_root_2025_ec.der");
#[async_trait::async_trait]
pub trait KeyAttestationStore: Send + Sync {
async fn validate_challenge(&self, challenge: &str) -> Result<bool>;
async fn consume_challenge(&self, challenge: &str) -> Result<()>;
async fn store_registration(
&self,
device_id: &str,
public_key_base64: &str,
security_level: &str,
attestation_chain: &str,
os_version: Option<&str>,
os_patch_level: Option<&str>,
) -> Result<()>;
async fn get_device_signing_info(&self, device_id: &str) -> Result<Option<(String, i64)>>;
async fn update_device_counter(&self, device_id: &str, new_counter: i64) -> Result<()>;
}
const NEGATIVE_CACHE_SECS: u64 = 300;
struct CachedCrl {
entries: serde_json::Value,
fetched_at: std::time::Instant,
is_negative: bool,
}
pub struct KeyAttestationVerifier {
client: Client,
google_roots: Vec<Vec<u8>>,
cached_crl: Arc<RwLock<Option<CachedCrl>>>,
min_security_level: SecurityLevel,
crl_cache_hours: u64,
}
impl KeyAttestationVerifier {
pub fn new(min_security_level: SecurityLevel) -> Result<Self> {
let client = Client::new();
let crl_cache_hours: u64 = std::env::var("KEY_ATTESTATION_CRL_CACHE_HOURS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(24);
Ok(Self {
client,
google_roots: vec![
GOOGLE_ROOT_2016.to_vec(),
GOOGLE_ROOT_2019.to_vec(),
GOOGLE_ROOT_2021.to_vec(),
GOOGLE_ROOT_2022.to_vec(),
GOOGLE_ROOT_2025_EC.to_vec(),
],
cached_crl: Arc::new(RwLock::new(None)),
min_security_level,
crl_cache_hours,
})
}
pub fn from_env() -> Result<Self> {
let min_level_str =
std::env::var("KEY_ATTESTATION_MIN_SECURITY_LEVEL").unwrap_or_else(|_| "TEE".into());
let min_level = SecurityLevel::parse(&min_level_str).ok_or_else(|| {
SignError::ConfigError(format!(
"Invalid KEY_ATTESTATION_MIN_SECURITY_LEVEL: {}",
min_level_str
))
})?;
Self::new(min_level)
}
pub async fn verify_attestation(
&self,
request: &KeyAttestationVerifyRequest,
store: &dyn KeyAttestationStore,
) -> Result<KeyAttestationVerifyResponse> {
if !store.validate_challenge(&request.challenge).await? {
return Err(SignError::KeyAttestation(
"Invalid or expired challenge".into(),
));
}
let cert_ders = parse_certificate_chain(&request.certificate_chain)?;
log::info!(
"Key attestation: received {} certificates in chain",
cert_ders.len()
);
let certs: Vec<X509Certificate> = cert_ders
.iter()
.enumerate()
.map(|(i, der)| {
let (_, cert) = X509Certificate::from_der(der).map_err(|e| {
SignError::KeyAttestation(format!("Failed to parse X.509 certificate: {}", e))
})?;
log::info!(
"Key attestation: cert[{}] subject={}, issuer={}, serial={}",
i,
cert.subject(),
cert.issuer(),
cert.serial.to_str_radix(16)
);
Ok(cert)
})
.collect::<Result<Vec<_>>>()?;
if certs.len() < 2 {
return Err(SignError::KeyAttestation(
"Certificate chain must contain at least 2 certificates".into(),
));
}
let ordered = order_chain(&certs)?;
log::info!(
"Key attestation: ordered chain has {} certs, root subject={}",
ordered.len(),
ordered.last().unwrap().subject()
);
validate_chain_signatures(&ordered)?;
self.verify_root_trust(ordered.last().unwrap())?;
self.check_revocation(&ordered).await?;
let ext = parse_attestation_extension(ordered.first().unwrap())?;
if ext.attestation_security_level < self.min_security_level {
return Err(SignError::KeyAttestation(format!(
"Attestation security level {} does not meet minimum requirement {}",
ext.attestation_security_level.as_str(),
self.min_security_level.as_str()
)));
}
let challenge_bytes = hex::decode(&request.challenge).map_err(|e| {
SignError::KeyAttestation(format!("Invalid challenge hex encoding: {}", e))
})?;
if ext.attestation_challenge != challenge_bytes {
return Err(SignError::KeyAttestation(
"Attestation challenge does not match server challenge".into(),
));
}
store.consume_challenge(&request.challenge).await?;
let leaf = ordered.first().unwrap();
let spki_der = leaf.public_key().raw;
let public_key_base64 = STANDARD.encode(spki_der);
let os_version = ext.os_version.map(|v| {
let major = v / 10000;
let minor = (v % 10000) / 100;
let patch = v % 100;
if patch == 0 {
format!("{}.{}", major, minor)
} else {
format!("{}.{}.{}", major, minor, patch)
}
});
let os_patch_level = ext.os_patch_level.map(|v| {
let year = v / 100;
let month = v % 100;
format!("{}-{:02}", year, month)
});
let security_level = ext.attestation_security_level.as_str().to_string();
let chain_json = serde_json::to_string(&request.certificate_chain)
.map_err(|e| SignError::KeyAttestation(format!("Failed to serialize chain: {}", e)))?;
store
.store_registration(
&request.device_id,
&public_key_base64,
&security_level,
&chain_json,
os_version.as_deref(),
os_patch_level.as_deref(),
)
.await?;
Ok(KeyAttestationVerifyResponse {
success: true,
message: "Key attestation verified successfully".into(),
security_level,
public_key_base64,
os_version,
os_patch_level,
})
}
fn verify_root_trust(&self, root: &X509Certificate) -> Result<()> {
let root_spki = root.public_key().raw;
let root_subject = root.subject();
let root_serial = root.serial.to_str_radix(16);
let root_is_self_signed = root.subject() == root.issuer();
log::info!(
"Key attestation: verifying root trust - subject={}, serial={}, self_signed={}, spki_len={}",
root_subject,
root_serial,
root_is_self_signed,
root_spki.len()
);
for (i, trusted_root_der) in self.google_roots.iter().enumerate() {
match X509Certificate::from_der(trusted_root_der) {
Ok((_, trusted_cert)) => {
let spki_match = root_spki == trusted_cert.public_key().raw;
let subject_match = root_subject == trusted_cert.subject();
log::info!(
"Key attestation: comparing with trusted root[{}] subject={}, serial={}, spki_match={}, subject_match={}",
i,
trusted_cert.subject(),
trusted_cert.serial.to_str_radix(16),
spki_match,
subject_match
);
if spki_match && subject_match {
log::info!(
"Key attestation: root trust verified via trusted root[{}]",
i
);
return Ok(());
}
}
Err(e) => {
log::warn!(
"Key attestation: failed to parse trusted root[{}]: {}",
i,
e
);
}
}
}
Err(SignError::KeyAttestation(
"Root certificate does not match any trusted Google attestation root".into(),
))
}
async fn check_revocation(&self, chain: &[&X509Certificate<'_>]) -> Result<()> {
{
let cache = self.cached_crl.read().await;
if let Some(ref cached) = *cache {
let ttl = if cached.is_negative {
std::time::Duration::from_secs(NEGATIVE_CACHE_SECS)
} else {
std::time::Duration::from_secs(self.crl_cache_hours * 3600)
};
if cached.fetched_at.elapsed() < ttl {
if cached.is_negative {
return Err(SignError::KeyAttestation(
"CRL temporarily unavailable (cached failure)".into(),
));
}
return check_serial_numbers_against_crl(&cached.entries, chain);
}
}
}
match self.fetch_crl_entries().await {
Ok(entries) => {
let result = check_serial_numbers_against_crl(&entries, chain);
let mut cache = self.cached_crl.write().await;
*cache = Some(CachedCrl {
entries,
fetched_at: std::time::Instant::now(),
is_negative: false,
});
result
}
Err(e) => {
let mut cache = self.cached_crl.write().await;
*cache = Some(CachedCrl {
entries: serde_json::Value::Object(serde_json::Map::new()),
fetched_at: std::time::Instant::now(),
is_negative: true,
});
Err(e)
}
}
}
async fn fetch_crl_entries(&self) -> Result<serde_json::Value> {
let response = self
.client
.get(GOOGLE_ATTESTATION_STATUS_URL)
.send()
.await
.map_err(|e| {
SignError::KeyAttestation(format!("Failed to fetch attestation CRL: {}", e))
})?;
if !response.status().is_success() {
return Err(SignError::KeyAttestation(format!(
"CRL fetch failed with status: {}",
response.status()
)));
}
let crl_json: serde_json::Value = response.json().await.map_err(|e| {
SignError::KeyAttestation(format!("Failed to parse CRL response: {}", e))
})?;
Ok(crl_json
.get("entries")
.cloned()
.unwrap_or(serde_json::Value::Object(serde_json::Map::new())))
}
}
fn check_serial_numbers_against_crl(
entries: &serde_json::Value,
chain: &[&X509Certificate],
) -> Result<()> {
if let Some(map) = entries.as_object() {
for cert in chain {
let serial = cert.serial.to_str_radix(16).to_lowercase();
if let Some(entry) = map.get(&serial) {
let status = entry
.get("status")
.and_then(|s| s.as_str())
.unwrap_or("REVOKED");
if status == "REVOKED" || status == "SUSPENDED" {
return Err(SignError::KeyAttestation(format!(
"Certificate with serial {} is {}",
serial, status
)));
}
}
}
}
Ok(())
}
fn parse_certificate_chain(certs_base64: &[String]) -> Result<Vec<Vec<u8>>> {
certs_base64
.iter()
.map(|b64| {
STANDARD.decode(b64).map_err(|e| {
SignError::KeyAttestation(format!("Failed to decode base64 certificate: {}", e))
})
})
.collect()
}
fn order_chain<'a>(certs: &'a [X509Certificate<'a>]) -> Result<Vec<&'a X509Certificate<'a>>> {
let mut leaf_idx = None;
let mut root_idx = None;
for (i, cert) in certs.iter().enumerate() {
if cert.subject() == cert.issuer() {
root_idx = Some(i);
}
}
for (i, cert) in certs.iter().enumerate() {
if Some(i) == root_idx {
continue;
}
let is_issuer_of_someone = certs
.iter()
.enumerate()
.any(|(j, other)| j != i && other.issuer() == cert.subject());
if !is_issuer_of_someone {
leaf_idx = Some(i);
break;
}
}
let leaf_idx = leaf_idx.ok_or_else(|| {
SignError::KeyAttestation("Could not identify leaf certificate in chain".into())
})?;
let mut ordered = Vec::with_capacity(certs.len());
let mut current_idx = leaf_idx;
let mut visited = std::collections::HashSet::new();
loop {
if visited.contains(¤t_idx) {
break;
}
visited.insert(current_idx);
ordered.push(&certs[current_idx]);
if ordered.len() == certs.len() {
break;
}
let current = &certs[current_idx];
let mut found_next = false;
for (i, cert) in certs.iter().enumerate() {
if !visited.contains(&i) && cert.subject() == current.issuer() {
current_idx = i;
found_next = true;
break;
}
}
if !found_next {
break;
}
}
if ordered.len() < 2 {
return Err(SignError::KeyAttestation(
"Could not build a valid certificate chain".into(),
));
}
Ok(ordered)
}
fn validate_chain_signatures(chain: &[&X509Certificate]) -> Result<()> {
for i in 0..chain.len() - 1 {
let child = chain[i];
let parent = chain[i + 1];
if child.issuer() != parent.subject() {
return Err(SignError::KeyAttestation(format!(
"Chain link {} -> {} has issuer/subject mismatch",
i,
i + 1
)));
}
child
.verify_signature(Some(parent.public_key()))
.map_err(|e| {
SignError::KeyAttestation(format!(
"Signature verification failed at chain position {}: {}",
i, e
))
})?;
}
let root = chain.last().unwrap();
root.verify_signature(None).map_err(|e| {
SignError::KeyAttestation(format!("Root certificate self-signature invalid: {}", e))
})?;
Ok(())
}
fn parse_attestation_extension(leaf: &X509Certificate) -> Result<AttestationExtension> {
let ext_oid = x509_parser::oid_registry::Oid::from(ATTESTATION_EXTENSION_OID_COMPONENTS)
.map_err(|_| SignError::KeyAttestation("Invalid attestation extension OID".into()))?;
let ext = leaf
.extensions()
.iter()
.find(|e| e.oid == ext_oid)
.ok_or_else(|| {
SignError::KeyAttestation(
"Leaf certificate does not contain attestation extension".into(),
)
})?;
let (_, parsed) = der_parser::ber::parse_ber(ext.value).map_err(|e| {
SignError::KeyAttestation(format!(
"Failed to parse attestation extension ASN.1: {}",
e
))
})?;
parse_attestation_sequence(&parsed)
}
fn parse_attestation_sequence(seq: &BerObject) -> Result<AttestationExtension> {
let items = seq.as_sequence().map_err(|e| {
SignError::KeyAttestation(format!("Attestation extension is not a SEQUENCE: {}", e))
})?;
if items.len() < 6 {
return Err(SignError::KeyAttestation(format!(
"Attestation extension SEQUENCE has {} items, expected at least 6",
items.len()
)));
}
let attestation_version = get_integer(&items[0], "attestationVersion")?;
let attestation_security_level_raw = get_enum_value(&items[1], "attestationSecurityLevel")?;
let keymaster_version = get_integer(&items[2], "keymasterVersion")?;
let keymaster_security_level_raw = get_enum_value(&items[3], "keymasterSecurityLevel")?;
let attestation_challenge = get_octet_string(&items[4], "attestationChallenge")?;
let unique_id = get_octet_string(&items[5], "uniqueId")?;
let attestation_security_level = SecurityLevel::from_u64(attestation_security_level_raw)
.ok_or_else(|| {
SignError::KeyAttestation(format!(
"Unknown attestation security level: {}",
attestation_security_level_raw
))
})?;
let keymaster_security_level = SecurityLevel::from_u64(keymaster_security_level_raw)
.ok_or_else(|| {
SignError::KeyAttestation(format!(
"Unknown keymaster security level: {}",
keymaster_security_level_raw
))
})?;
let mut os_version = None;
let mut os_patch_level = None;
if items.len() > 7 {
if let Ok(tee_items) = items[7].as_sequence() {
for item in tee_items {
if let BerObjectContent::Tagged(class, tag, _data) = &item.content {
if *class == der_parser::ber::Class::ContextSpecific {
match tag.0 {
705 => {
if let Ok(val) = extract_tagged_integer(item) {
os_version = Some(val as u32);
}
}
706 => {
if let Ok(val) = extract_tagged_integer(item) {
os_patch_level = Some(val as u32);
}
}
_ => {}
}
}
}
}
}
}
if items.len() > 6 && (os_version.is_none() || os_patch_level.is_none()) {
if let Ok(sw_items) = items[6].as_sequence() {
for item in sw_items {
if let BerObjectContent::Tagged(class, tag, _data) = &item.content {
if *class == der_parser::ber::Class::ContextSpecific {
match tag.0 {
705 if os_version.is_none() => {
if let Ok(val) = extract_tagged_integer(item) {
os_version = Some(val as u32);
}
}
706 if os_patch_level.is_none() => {
if let Ok(val) = extract_tagged_integer(item) {
os_patch_level = Some(val as u32);
}
}
_ => {}
}
}
}
}
}
}
Ok(AttestationExtension {
attestation_version,
attestation_security_level,
keymaster_version,
keymaster_security_level,
attestation_challenge,
unique_id,
os_version,
os_patch_level,
})
}
fn get_integer(obj: &BerObject, field: &str) -> Result<i64> {
match &obj.content {
BerObjectContent::Integer(bytes) => {
if bytes.is_empty() {
return Ok(0);
}
let mut val: i64 = if bytes[0] & 0x80 != 0 { -1 } else { 0 };
for &b in *bytes {
val = (val << 8) | (b as i64);
}
Ok(val)
}
BerObjectContent::Enum(v) => Ok(*v as i64),
_ => Err(SignError::KeyAttestation(format!(
"{} is not an INTEGER",
field
))),
}
}
fn get_enum_value(obj: &BerObject, field: &str) -> Result<u64> {
match &obj.content {
BerObjectContent::Enum(v) => Ok(*v),
BerObjectContent::Integer(bytes) => {
if bytes.is_empty() {
return Ok(0);
}
if bytes[0] & 0x80 != 0 {
return Err(SignError::KeyAttestation(format!(
"{} has negative value, expected non-negative ENUMERATED",
field
)));
}
let mut val: u64 = 0;
for &b in *bytes {
val = (val << 8) | (b as u64);
}
Ok(val)
}
_ => Err(SignError::KeyAttestation(format!(
"{} is not an ENUMERATED",
field
))),
}
}
fn get_octet_string(obj: &BerObject, field: &str) -> Result<Vec<u8>> {
match &obj.content {
BerObjectContent::OctetString(bytes) => Ok(bytes.to_vec()),
_ => Err(SignError::KeyAttestation(format!(
"{} is not an OCTET STRING",
field
))),
}
}
fn extract_tagged_integer(obj: &BerObject) -> Result<i64> {
if let BerObjectContent::Tagged(_class, _tag, data) = &obj.content {
return get_integer(data.as_ref(), "tagged_integer");
}
Err(SignError::KeyAttestation("Expected tagged value".into()))
}
#[cfg(test)]
mod tests {
use super::types::SecurityLevel;
use super::*;
use base64::{engine::general_purpose::STANDARD, Engine};
#[test]
fn test_security_level_ordering() {
assert!(SecurityLevel::Software < SecurityLevel::TrustedEnvironment);
assert!(SecurityLevel::TrustedEnvironment < SecurityLevel::StrongBox);
assert!(SecurityLevel::Software < SecurityLevel::StrongBox);
}
#[test]
fn test_security_level_from_u64() {
assert_eq!(SecurityLevel::from_u64(0), Some(SecurityLevel::Software));
assert_eq!(
SecurityLevel::from_u64(1),
Some(SecurityLevel::TrustedEnvironment)
);
assert_eq!(SecurityLevel::from_u64(2), Some(SecurityLevel::StrongBox));
assert_eq!(SecurityLevel::from_u64(3), None);
}
#[test]
fn test_security_level_from_str() {
assert_eq!(
SecurityLevel::parse("TEE"),
Some(SecurityLevel::TrustedEnvironment)
);
assert_eq!(
SecurityLevel::parse("StrongBox"),
Some(SecurityLevel::StrongBox)
);
assert_eq!(
SecurityLevel::parse("Software"),
Some(SecurityLevel::Software)
);
assert_eq!(SecurityLevel::parse("invalid"), None);
}
#[test]
fn test_security_level_validation() {
let min = SecurityLevel::TrustedEnvironment;
assert!(SecurityLevel::Software < min);
assert!(SecurityLevel::TrustedEnvironment >= min);
assert!(SecurityLevel::StrongBox >= min);
}
#[test]
fn test_parse_certificate_chain_valid() {
let der_bytes = vec![0x30, 0x82, 0x01, 0x00];
let b64 = STANDARD.encode(&der_bytes);
let result = parse_certificate_chain(&[b64]);
assert!(result.is_ok());
assert_eq!(result.unwrap()[0], der_bytes);
}
#[test]
fn test_parse_certificate_chain_invalid_base64() {
let result = parse_certificate_chain(&["not-valid-base64!!!".to_string()]);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("Failed to decode base64"));
}
#[test]
fn test_parse_certificate_chain_empty() {
let result = parse_certificate_chain(&[]);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_parse_certificate_chain_multiple() {
let cert1 = STANDARD.encode(&[1, 2, 3]);
let cert2 = STANDARD.encode(&[4, 5, 6]);
let result = parse_certificate_chain(&[cert1, cert2]).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0], vec![1, 2, 3]);
assert_eq!(result[1], vec![4, 5, 6]);
}
fn generate_self_signed_cert(serial: &[u8]) -> Vec<u8> {
let mut params = rcgen::CertificateParams::default();
params.serial_number = Some(rcgen::SerialNumber::from_slice(serial));
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "Test");
let key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let cert = params.self_signed(&key).unwrap();
cert.der().to_vec()
}
#[test]
fn test_crl_no_revocations() {
let der = generate_self_signed_cert(&[0x42]);
let (_, cert) = X509Certificate::from_der(&der).unwrap();
let entries = serde_json::json!({});
let result = check_serial_numbers_against_crl(&entries, &[&cert]);
assert!(result.is_ok());
}
#[test]
fn test_crl_revoked_certificate() {
let der = generate_self_signed_cert(&[0x42]);
let (_, cert) = X509Certificate::from_der(&der).unwrap();
let serial_hex = cert.serial.to_str_radix(16).to_lowercase();
let entries = serde_json::json!({ serial_hex: { "status": "REVOKED" } });
let result = check_serial_numbers_against_crl(&entries, &[&cert]);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("REVOKED"));
}
#[test]
fn test_crl_suspended_certificate() {
let der = generate_self_signed_cert(&[0x43]);
let (_, cert) = X509Certificate::from_der(&der).unwrap();
let serial_hex = cert.serial.to_str_radix(16).to_lowercase();
let entries = serde_json::json!({ serial_hex: { "status": "SUSPENDED" } });
let result = check_serial_numbers_against_crl(&entries, &[&cert]);
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("SUSPENDED"));
}
#[test]
fn test_crl_non_revoked_status_passes() {
let der = generate_self_signed_cert(&[0x44]);
let (_, cert) = X509Certificate::from_der(&der).unwrap();
let serial_hex = cert.serial.to_str_radix(16).to_lowercase();
let entries = serde_json::json!({ serial_hex: { "status": "VALID" } });
let result = check_serial_numbers_against_crl(&entries, &[&cert]);
assert!(result.is_ok());
}
#[test]
fn test_crl_not_an_object_passes() {
let der = generate_self_signed_cert(&[0x45]);
let (_, cert) = X509Certificate::from_der(&der).unwrap();
let entries = serde_json::json!("not an object");
let result = check_serial_numbers_against_crl(&entries, &[&cert]);
assert!(result.is_ok());
}
fn generate_chain() -> (Vec<u8>, Vec<u8>, Vec<u8>) {
let mut root_params = rcgen::CertificateParams::default();
root_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
root_params.serial_number = Some(rcgen::SerialNumber::from_slice(&[0x01]));
root_params.distinguished_name = rcgen::DistinguishedName::new();
root_params
.distinguished_name
.push(rcgen::DnType::CommonName, "Test Root CA");
let root_key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let root_cert = root_params.self_signed(&root_key).unwrap();
let mut int_params = rcgen::CertificateParams::default();
int_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
int_params.serial_number = Some(rcgen::SerialNumber::from_slice(&[0x02]));
int_params.distinguished_name = rcgen::DistinguishedName::new();
int_params
.distinguished_name
.push(rcgen::DnType::CommonName, "Test Intermediate CA");
let int_key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let int_cert = int_params
.signed_by(&int_key, &root_cert, &root_key)
.unwrap();
let mut leaf_params = rcgen::CertificateParams::default();
leaf_params.serial_number = Some(rcgen::SerialNumber::from_slice(&[0x03]));
leaf_params.distinguished_name = rcgen::DistinguishedName::new();
leaf_params
.distinguished_name
.push(rcgen::DnType::CommonName, "Test Leaf");
let leaf_key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let leaf_cert = leaf_params
.signed_by(&leaf_key, &int_cert, &int_key)
.unwrap();
(
root_cert.der().to_vec(),
int_cert.der().to_vec(),
leaf_cert.der().to_vec(),
)
}
#[test]
fn test_order_chain_already_ordered() {
let (root_der, int_der, leaf_der) = generate_chain();
let (_, root) = X509Certificate::from_der(&root_der).unwrap();
let (_, int) = X509Certificate::from_der(&int_der).unwrap();
let (_, leaf) = X509Certificate::from_der(&leaf_der).unwrap();
let certs = vec![leaf, int, root];
let ordered = order_chain(&certs).unwrap();
assert_eq!(ordered.len(), 3);
assert_ne!(ordered[0].subject(), ordered[0].issuer());
assert_eq!(ordered[2].subject(), ordered[2].issuer());
}
#[test]
fn test_order_chain_shuffled() {
let (root_der, int_der, leaf_der) = generate_chain();
let (_, root) = X509Certificate::from_der(&root_der).unwrap();
let (_, int) = X509Certificate::from_der(&int_der).unwrap();
let (_, leaf) = X509Certificate::from_der(&leaf_der).unwrap();
let certs = vec![root, int, leaf];
let ordered = order_chain(&certs).unwrap();
assert_eq!(ordered.len(), 3);
assert_ne!(ordered[0].subject(), ordered[0].issuer());
assert_eq!(ordered[2].subject(), ordered[2].issuer());
}
#[test]
fn test_order_chain_two_certs() {
let mut root_params = rcgen::CertificateParams::default();
root_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
root_params.serial_number = Some(rcgen::SerialNumber::from_slice(&[0x01]));
root_params.distinguished_name = rcgen::DistinguishedName::new();
root_params
.distinguished_name
.push(rcgen::DnType::CommonName, "Test Root CA");
let root_key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let root_cert = root_params.self_signed(&root_key).unwrap();
let mut leaf_params = rcgen::CertificateParams::default();
leaf_params.serial_number = Some(rcgen::SerialNumber::from_slice(&[0x10]));
leaf_params.distinguished_name = rcgen::DistinguishedName::new();
leaf_params
.distinguished_name
.push(rcgen::DnType::CommonName, "Direct Leaf");
let leaf_key = rcgen::KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256).unwrap();
let leaf_cert = leaf_params
.signed_by(&leaf_key, &root_cert, &root_key)
.unwrap();
let root_der = root_cert.der().to_vec();
let leaf_der = leaf_cert.der().to_vec();
let (_, r) = X509Certificate::from_der(&root_der).unwrap();
let (_, l) = X509Certificate::from_der(&leaf_der).unwrap();
let certs = vec![r, l];
let ordered = order_chain(&certs).unwrap();
assert_eq!(ordered.len(), 2);
assert_ne!(ordered[0].subject(), ordered[0].issuer());
assert_eq!(ordered[1].subject(), ordered[1].issuer());
}
#[test]
fn test_order_chain_single_cert_fails() {
let der = generate_self_signed_cert(&[0x01]);
let (_, cert) = X509Certificate::from_der(&der).unwrap();
let certs = vec![cert];
let result = order_chain(&certs);
assert!(result.is_err());
}
#[test]
fn test_validate_chain_signatures_valid() {
let (root_der, int_der, leaf_der) = generate_chain();
let (_, root) = X509Certificate::from_der(&root_der).unwrap();
let (_, int) = X509Certificate::from_der(&int_der).unwrap();
let (_, leaf) = X509Certificate::from_der(&leaf_der).unwrap();
let certs = vec![leaf, int, root];
let ordered = order_chain(&certs).unwrap();
let result = validate_chain_signatures(&ordered);
assert!(result.is_ok());
}
#[test]
fn test_validate_chain_signatures_tampered() {
let (root_der, _int_der, leaf_der) = generate_chain();
let (_, root) = X509Certificate::from_der(&root_der).unwrap();
let (_, leaf) = X509Certificate::from_der(&leaf_der).unwrap();
let chain: Vec<&X509Certificate> = vec![&leaf, &root];
let result = validate_chain_signatures(&chain);
assert!(result.is_err());
}
#[test]
fn test_get_integer_positive() {
let der = vec![0x02, 0x01, 0x03]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_integer(&obj, "test").unwrap(), 3);
}
#[test]
fn test_get_integer_positive_multibyte() {
let der = vec![0x02, 0x02, 0x01, 0x00]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_integer(&obj, "test").unwrap(), 256);
}
#[test]
fn test_get_integer_negative() {
let der = vec![0x02, 0x01, 0x80]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_integer(&obj, "test").unwrap(), -128);
}
#[test]
fn test_get_integer_negative_minus_one() {
let der = vec![0x02, 0x01, 0xFF]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_integer(&obj, "test").unwrap(), -1);
}
#[test]
fn test_get_integer_zero() {
let der = vec![0x02, 0x01, 0x00]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_integer(&obj, "test").unwrap(), 0);
}
#[test]
fn test_get_integer_large_positive() {
let der = vec![0x02, 0x02, 0x00, 0x80]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_integer(&obj, "test").unwrap(), 128);
}
#[test]
fn test_get_enum_value_zero() {
let der = vec![0x0a, 0x01, 0x00]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_enum_value(&obj, "test").unwrap(), 0);
}
#[test]
fn test_get_enum_value_one() {
let der = vec![0x0a, 0x01, 0x01]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_enum_value(&obj, "test").unwrap(), 1);
}
#[test]
fn test_get_enum_value_two() {
let der = vec![0x0a, 0x01, 0x02]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_enum_value(&obj, "test").unwrap(), 2);
}
#[test]
fn test_get_octet_string_valid() {
let der = vec![0x04, 0x05, b'h', b'e', b'l', b'l', b'o']; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_octet_string(&obj, "test").unwrap(), b"hello".to_vec());
}
#[test]
fn test_get_octet_string_empty() {
let der = vec![0x04, 0x00]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert_eq!(get_octet_string(&obj, "test").unwrap(), Vec::<u8>::new());
}
#[test]
fn test_get_octet_string_wrong_type() {
let der = vec![0x02, 0x01, 0x01]; let (_, obj) = der_parser::ber::parse_ber(&der).unwrap();
assert!(get_octet_string(&obj, "test").is_err());
}
}