import argparse
import enum
import os
from typing import TextIO, Optional, Union, Any, Callable, Iterable, List
from pathlib import Path
from cryptography import x509
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, ec, ed25519, padding
from cryptography.hazmat.primitives.serialization import Encoding
from cryptography.hazmat.backends import default_backend
from cryptography.x509.oid import NameOID, ExtendedKeyUsageOID
import ipaddress
import datetime
import subprocess
ROOT_PRIVATE_KEY: rsa.RSAPrivateKey = rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
ROOT_PUBLIC_KEY: rsa.RSAPublicKey = ROOT_PRIVATE_KEY.public_key()
NOT_BEFORE: datetime.datetime = datetime.datetime.utcfromtimestamp(0x1FEDF00D - 30)
NOT_AFTER: datetime.datetime = datetime.datetime.utcfromtimestamp(0x1FEDF00D + 30)
ANY_PRIV_KEY = Union[
ed25519.Ed25519PrivateKey | ec.EllipticCurvePrivateKey | rsa.RSAPrivateKey
]
ANY_PUB_KEY = Union[
ed25519.Ed25519PublicKey | ec.EllipticCurvePublicKey | rsa.RSAPublicKey
]
SIGNER = Callable[
[Any, bytes], Any
]
def trim_top(file_name: str) -> TextIO:
with open(file_name, "r") as f:
top = f.readlines()
top = top[: top.index("// DO NOT EDIT BELOW: generated by tests/generate.py\n") + 1]
output = open(file_name, "w")
for line in top:
output.write(line)
return output
def key_or_generate(key: Optional[ANY_PRIV_KEY] = None) -> ANY_PRIV_KEY:
return (
key
if key is not None
else rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
backend=default_backend(),
)
)
def write_der(path: str, content: bytes, force: bool) -> None:
out_path = Path(path)
if out_path.exists() and not force:
return None
with out_path.open("wb") as f:
f.write(content)
def end_entity_cert(
*,
subject_name: x509.Name,
issuer_name: x509.Name,
issuer_key: Optional[ANY_PRIV_KEY] = None,
subject_key: Optional[ANY_PRIV_KEY] = None,
sans: Optional[Iterable[x509.GeneralName]] = None,
ekus: Optional[Iterable[x509.ObjectIdentifier]] = None,
serial: Optional[int] = None,
cert_dps: Optional[x509.DistributionPoint] = None,
) -> x509.Certificate:
subject_priv_key = key_or_generate(subject_key)
subject_key_pub: ANY_PUB_KEY = subject_priv_key.public_key()
ee_builder: x509.CertificateBuilder = x509.CertificateBuilder()
ee_builder = ee_builder.subject_name(subject_name)
ee_builder = ee_builder.issuer_name(issuer_name)
ee_builder = ee_builder.not_valid_before(NOT_BEFORE)
ee_builder = ee_builder.not_valid_after(NOT_AFTER)
ee_builder = ee_builder.serial_number(
x509.random_serial_number() if serial is None else serial
)
ee_builder = ee_builder.public_key(subject_key_pub)
if sans:
ee_builder = ee_builder.add_extension(
x509.SubjectAlternativeName(sans), critical=False
)
if ekus:
ee_builder = ee_builder.add_extension(
x509.ExtendedKeyUsage(ekus), critical=False
)
if cert_dps:
ee_builder = ee_builder.add_extension(
x509.CRLDistributionPoints([cert_dps]), critical=False
)
ee_builder = ee_builder.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
)
return ee_builder.sign(
private_key=issuer_key if issuer_key is not None else ROOT_PRIVATE_KEY,
algorithm=hashes.SHA256(),
backend=default_backend(),
)
def subject_name_for_test(subject_cn: str, test_name: str) -> x509.Name:
return x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, subject_cn),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, test_name),
]
)
def issuer_name_for_test(test_name: str) -> x509.Name:
return x509.Name(
[
x509.NameAttribute(NameOID.COMMON_NAME, "issuer.example.com"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, test_name),
]
)
def ca_cert(
*,
subject_name: x509.Name,
subject_key: Optional[ANY_PRIV_KEY] = None,
issuer_name: Optional[x509.Name] = None,
issuer_key: Optional[ANY_PRIV_KEY] = None,
permitted_subtrees: Optional[Iterable[x509.GeneralName]] = None,
excluded_subtrees: Optional[Iterable[x509.GeneralName]] = None,
key_usage: Optional[x509.KeyUsage] = None,
cert_dps: Optional[x509.DistributionPoint] = None,
) -> x509.Certificate:
subject_priv_key = key_or_generate(subject_key)
subject_key_pub: ANY_PUB_KEY = subject_priv_key.public_key()
ca_builder: x509.CertificateBuilder = x509.CertificateBuilder()
ca_builder = ca_builder.subject_name(subject_name)
ca_builder = ca_builder.issuer_name(issuer_name if issuer_name else subject_name)
ca_builder = ca_builder.not_valid_before(NOT_BEFORE)
ca_builder = ca_builder.not_valid_after(NOT_AFTER)
ca_builder = ca_builder.serial_number(x509.random_serial_number())
ca_builder = ca_builder.public_key(subject_key_pub)
ca_builder = ca_builder.add_extension(
x509.BasicConstraints(ca=True, path_length=None),
critical=True,
)
if permitted_subtrees is not None or excluded_subtrees is not None:
ca_builder = ca_builder.add_extension(
x509.NameConstraints(permitted_subtrees, excluded_subtrees), critical=True
)
if key_usage is not None:
ca_builder = ca_builder.add_extension(
key_usage,
critical=True,
)
if cert_dps:
ca_builder = ca_builder.add_extension(
x509.CRLDistributionPoints([cert_dps]), critical=False
)
return ca_builder.sign(
private_key=issuer_key if issuer_key else subject_priv_key,
algorithm=hashes.SHA256(),
backend=default_backend(),
)
def generate_tls_server_cert_test(
output: TextIO,
test_name: str,
expected_error: Optional[str] = None,
subject_common_name: Optional[str] = None,
extra_subject_names: Optional[List[x509.NameAttribute]] = None,
valid_names: Optional[List[str]] = None,
invalid_names: Optional[List[str]] = None,
sans: Optional[Iterable[x509.GeneralName]] = None,
permitted_subtrees: Optional[Iterable[x509.GeneralName]] = None,
excluded_subtrees: Optional[Iterable[x509.GeneralName]] = None,
force: bool = False,
) -> None:
if invalid_names is None:
invalid_names = []
if valid_names is None:
valid_names = []
if extra_subject_names is None:
extra_subject_names = []
issuer_name: x509.Name = issuer_name_for_test(test_name)
ee_subject = x509.Name(
(
[x509.NameAttribute(NameOID.COMMON_NAME, subject_common_name)]
if subject_common_name
else []
)
+ [x509.NameAttribute(NameOID.ORGANIZATION_NAME, test_name)]
+ extra_subject_names
)
ee_certificate: x509.Certificate = end_entity_cert(
subject_name=ee_subject,
issuer_name=issuer_name,
sans=sans,
)
output_dir: str = "tls_server_certs"
ee_cert_path: str = os.path.join(output_dir, f"{test_name}.ee.der")
ca_cert_path: str = os.path.join(output_dir, f"{test_name}.ca.der")
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
write_der(ee_cert_path, ee_certificate.public_bytes(Encoding.DER), force)
ca: x509.Certificate = ca_cert(
subject_name=issuer_name,
subject_key=ROOT_PRIVATE_KEY,
permitted_subtrees=permitted_subtrees,
excluded_subtrees=excluded_subtrees,
)
write_der(ca_cert_path, ca.public_bytes(Encoding.DER), force)
expected: str = ""
if expected_error is None:
expected = "Ok(())"
else:
expected = "Err(webpki::Error::" + expected_error + ")"
valid_names_str: str = ", ".join('"' + name + '"' for name in valid_names)
invalid_names_str: str = ", ".join('"' + name + '"' for name in invalid_names)
print(
"""
#[test]
fn %(test_name)s() {
let ee = include_bytes!("%(ee_cert_path)s");
let ca = include_bytes!("%(ca_cert_path)s");
assert_eq!(
check_cert(ee, ca, &[%(valid_names_str)s], &[%(invalid_names_str)s]),
%(expected)s
);
}"""
% locals(),
file=output,
)
def tls_server_certs(force: bool) -> None:
with trim_top("tls_server_certs.rs") as output:
generate_tls_server_cert_test(
output,
"no_name_constraints",
subject_common_name="subject.example.com",
valid_names=["dns.example.com"],
invalid_names=["subject.example.com"],
sans=[x509.DNSName("dns.example.com")],
)
generate_tls_server_cert_test(
output,
"additional_dns_labels",
subject_common_name="subject.example.com",
valid_names=["host1.example.com", "host2.example.com"],
invalid_names=["subject.example.com"],
sans=[x509.DNSName("host1.example.com"), x509.DNSName("host2.example.com")],
permitted_subtrees=[x509.DNSName(".example.com")],
)
generate_tls_server_cert_test(
output,
"disallow_dns_san",
expected_error="NameConstraintViolation",
sans=[x509.DNSName("disallowed.example.com")],
excluded_subtrees=[x509.DNSName("disallowed.example.com")],
)
generate_tls_server_cert_test(
output,
"allow_subject_common_name",
subject_common_name="allowed.example.com",
invalid_names=["allowed.example.com"],
permitted_subtrees=[x509.DNSName("allowed.example.com")],
)
generate_tls_server_cert_test(
output,
"allow_dns_san",
valid_names=["allowed.example.com"],
sans=[x509.DNSName("allowed.example.com")],
permitted_subtrees=[x509.DNSName("allowed.example.com")],
)
generate_tls_server_cert_test(
output,
"allow_dns_san_and_subject_common_name",
valid_names=["allowed-san.example.com"],
invalid_names=["allowed-cn.example.com"],
sans=[x509.DNSName("allowed-san.example.com")],
subject_common_name="allowed-cn.example.com",
permitted_subtrees=[
x509.DNSName("allowed-san.example.com"),
x509.DNSName("allowed-cn.example.com"),
],
)
generate_tls_server_cert_test(
output,
"disallow_dns_san_and_allow_subject_common_name",
expected_error="NameConstraintViolation",
sans=[
x509.DNSName("allowed-san.example.com"),
x509.DNSName("disallowed-san.example.com"),
],
subject_common_name="allowed-cn.example.com",
permitted_subtrees=[
x509.DNSName("allowed-san.example.com"),
x509.DNSName("allowed-cn.example.com"),
],
excluded_subtrees=[x509.DNSName("disallowed-san.example.com")],
)
generate_tls_server_cert_test(
output,
"we_incorrectly_ignore_name_constraints_on_name_in_subject",
extra_subject_names=[
x509.NameAttribute(NameOID.EMAIL_ADDRESS, "joe@notexample.com")
],
permitted_subtrees=[x509.RFC822Name("example.com")],
)
generate_tls_server_cert_test(
output,
"reject_constraints_on_unimplemented_names",
expected_error="NameConstraintViolation",
sans=[x509.RFC822Name("joe@example.com")],
permitted_subtrees=[x509.RFC822Name("example.com")],
)
generate_tls_server_cert_test(
output,
"we_ignore_constraints_on_names_that_do_not_appear_in_cert",
sans=[x509.DNSName("notexample.com")],
valid_names=["notexample.com"],
invalid_names=["example.com"],
permitted_subtrees=[x509.RFC822Name("example.com")],
)
generate_tls_server_cert_test(
output,
"wildcard_san_accepted_if_in_subtree",
sans=[x509.DNSName("*.example.com")],
valid_names=["bob.example.com", "jane.example.com"],
invalid_names=["example.com", "uh.oh.example.com"],
permitted_subtrees=[x509.DNSName("example.com")],
)
generate_tls_server_cert_test(
output,
"wildcard_san_rejected_if_in_excluded_subtree",
expected_error="NameConstraintViolation",
sans=[x509.DNSName("*.example.com")],
excluded_subtrees=[x509.DNSName("example.com")],
)
generate_tls_server_cert_test(
output,
"ip4_address_san_rejected_if_in_excluded_subtree",
expected_error="NameConstraintViolation",
sans=[x509.IPAddress(ipaddress.ip_address("12.34.56.78"))],
excluded_subtrees=[x509.IPAddress(ipaddress.ip_network("12.34.56.0/24"))],
)
generate_tls_server_cert_test(
output,
"ip4_address_san_allowed_if_outside_excluded_subtree",
valid_names=["12.34.56.78"],
sans=[x509.IPAddress(ipaddress.ip_address("12.34.56.78"))],
excluded_subtrees=[x509.IPAddress(ipaddress.ip_network("12.34.56.252/30"))],
)
sparse_net_addr = ipaddress.ip_network("12.34.56.78/24", strict=False)
sparse_net_addr.netmask = ipaddress.ip_address("255.255.255.1")
generate_tls_server_cert_test(
output,
"ip4_address_san_rejected_if_excluded_is_sparse_cidr_mask",
expected_error="InvalidNetworkMaskConstraint",
sans=[
x509.IPAddress(ipaddress.ip_address("12.34.56.79")),
],
excluded_subtrees=[x509.IPAddress(sparse_net_addr)],
)
generate_tls_server_cert_test(
output,
"ip4_address_san_allowed",
valid_names=["12.34.56.78"],
invalid_names=[
"12.34.56.77",
"12.34.56.79",
"0000:0000:0000:0000:0000:ffff:0c22:384e",
],
sans=[x509.IPAddress(ipaddress.ip_address("12.34.56.78"))],
permitted_subtrees=[x509.IPAddress(ipaddress.ip_network("12.34.56.0/24"))],
)
generate_tls_server_cert_test(
output,
"ip6_address_san_rejected_if_in_excluded_subtree",
expected_error="NameConstraintViolation",
sans=[x509.IPAddress(ipaddress.ip_address("2001:db8::1"))],
excluded_subtrees=[x509.IPAddress(ipaddress.ip_network("2001:db8::/48"))],
)
generate_tls_server_cert_test(
output,
"ip6_address_san_allowed_if_outside_excluded_subtree",
valid_names=["2001:0db9:0000:0000:0000:0000:0000:0001"],
sans=[x509.IPAddress(ipaddress.ip_address("2001:db9::1"))],
excluded_subtrees=[x509.IPAddress(ipaddress.ip_network("2001:db8::/48"))],
)
generate_tls_server_cert_test(
output,
"ip6_address_san_allowed",
valid_names=["2001:0db9:0000:0000:0000:0000:0000:0001"],
invalid_names=["12.34.56.78"],
sans=[x509.IPAddress(ipaddress.ip_address("2001:db9::1"))],
permitted_subtrees=[x509.IPAddress(ipaddress.ip_network("2001:db9::/48"))],
)
generate_tls_server_cert_test(
output,
"ip46_mixed_address_san_allowed",
valid_names=["12.34.56.78", "2001:0db9:0000:0000:0000:0000:0000:0001"],
invalid_names=[
"12.34.56.77",
"12.34.56.79",
"0000:0000:0000:0000:0000:ffff:0c22:384e",
],
sans=[
x509.IPAddress(ipaddress.ip_address("12.34.56.78")),
x509.IPAddress(ipaddress.ip_address("2001:db9::1")),
],
permitted_subtrees=[
x509.IPAddress(ipaddress.ip_network("12.34.56.0/24")),
x509.IPAddress(ipaddress.ip_network("2001:db9::/48")),
],
)
generate_tls_server_cert_test(
output,
"permit_directory_name_not_implemented",
expected_error="NameConstraintViolation",
permitted_subtrees=[
x509.DirectoryName(
x509.Name([x509.NameAttribute(NameOID.COUNTRY_NAME, "CN")])
)
],
)
generate_tls_server_cert_test(
output,
"exclude_directory_name_not_implemented",
expected_error="NameConstraintViolation",
excluded_subtrees=[
x509.DirectoryName(
x509.Name([x509.NameAttribute(NameOID.COUNTRY_NAME, "CN")])
)
],
)
generate_tls_server_cert_test(
output,
"invalid_dns_name_matching",
valid_names=["dns.example.com"],
subject_common_name="subject.example.com",
sans=[
x509.DNSName("{invalid}.example.com"),
x509.DNSName("dns.example.com"),
],
)
def signatures(force: bool) -> None:
rsa_pub_exponent: int = 0x10001
backend: Any = default_backend()
all_key_types: dict[str, ANY_PRIV_KEY] = {
"ed25519": ed25519.Ed25519PrivateKey.generate(),
"ecdsa_p256": ec.generate_private_key(ec.SECP256R1(), backend),
"ecdsa_p384": ec.generate_private_key(ec.SECP384R1(), backend),
"ecdsa_p521": ec.generate_private_key(ec.SECP521R1(), backend),
"rsa_1024_not_supported": rsa.generate_private_key(
rsa_pub_exponent, 1024, backend
),
"rsa_2048": rsa.generate_private_key(rsa_pub_exponent, 2048, backend),
"rsa_3072": rsa.generate_private_key(rsa_pub_exponent, 3072, backend),
"rsa_4096": rsa.generate_private_key(rsa_pub_exponent, 4096, backend),
}
feature_gates = {
"ECDSA_P521_SHA512": 'all(not(feature = "ring"), feature = "aws_lc_rs")',
}
rsa_types: list[str] = [
"RSA_PKCS1_2048_8192_SHA256",
"RSA_PKCS1_2048_8192_SHA384",
"RSA_PKCS1_2048_8192_SHA512",
"RSA_PSS_2048_8192_SHA256_LEGACY_KEY",
"RSA_PSS_2048_8192_SHA384_LEGACY_KEY",
"RSA_PSS_2048_8192_SHA512_LEGACY_KEY",
]
webpki_algs: dict[str, Iterable[str]] = {
"ed25519": ["ED25519"],
"ecdsa_p256": ["ECDSA_P256_SHA384", "ECDSA_P256_SHA256"],
"ecdsa_p384": ["ECDSA_P384_SHA384", "ECDSA_P384_SHA256"],
"ecdsa_p521": ["ECDSA_P521_SHA512"],
"rsa_2048": rsa_types,
"rsa_3072": rsa_types + ["RSA_PKCS1_3072_8192_SHA384"],
"rsa_4096": rsa_types + ["RSA_PKCS1_3072_8192_SHA384"],
}
pss_sha256: padding.PSS = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=32
)
pss_sha384: padding.PSS = padding.PSS(
mgf=padding.MGF1(hashes.SHA384()), salt_length=48
)
pss_sha512: padding.PSS = padding.PSS(
mgf=padding.MGF1(hashes.SHA512()), salt_length=64
)
how_to_sign: dict[str, SIGNER] = {
"ED25519": lambda key, message: key.sign(message),
"ECDSA_P256_SHA256": lambda key, message: key.sign(
message, ec.ECDSA(hashes.SHA256())
),
"ECDSA_P256_SHA384": lambda key, message: key.sign(
message, ec.ECDSA(hashes.SHA384())
),
"ECDSA_P384_SHA256": lambda key, message: key.sign(
message, ec.ECDSA(hashes.SHA256())
),
"ECDSA_P384_SHA384": lambda key, message: key.sign(
message, ec.ECDSA(hashes.SHA384())
),
"ECDSA_P521_SHA512": lambda key, message: key.sign(
message, ec.ECDSA(hashes.SHA512())
),
"RSA_PKCS1_2048_8192_SHA256": lambda key, message: key.sign(
message, padding.PKCS1v15(), hashes.SHA256()
),
"RSA_PKCS1_2048_8192_SHA384": lambda key, message: key.sign(
message, padding.PKCS1v15(), hashes.SHA384()
),
"RSA_PKCS1_2048_8192_SHA512": lambda key, message: key.sign(
message, padding.PKCS1v15(), hashes.SHA512()
),
"RSA_PKCS1_3072_8192_SHA384": lambda key, message: key.sign(
message, padding.PKCS1v15(), hashes.SHA384()
),
"RSA_PSS_2048_8192_SHA256_LEGACY_KEY": lambda key, message: key.sign(
message, pss_sha256, hashes.SHA256()
),
"RSA_PSS_2048_8192_SHA384_LEGACY_KEY": lambda key, message: key.sign(
message, pss_sha384, hashes.SHA384()
),
"RSA_PSS_2048_8192_SHA512_LEGACY_KEY": lambda key, message: key.sign(
message, pss_sha512, hashes.SHA512()
),
}
output_dir: str = "signatures"
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
message = b"hello world!"
message_path: str = os.path.join(output_dir, "message.bin")
write_der(message_path, message, force)
def _cert_path(cert_type: str) -> str:
return os.path.join(output_dir, f"{cert_type}.ee.der")
for name, private_key in all_key_types.items():
ee_subject = x509.Name(
[x509.NameAttribute(NameOID.ORGANIZATION_NAME, name + " test")]
)
issuer_subject = x509.Name(
[x509.NameAttribute(NameOID.ORGANIZATION_NAME, name + " issuer")]
)
certificate: x509.Certificate = end_entity_cert(
subject_name=ee_subject,
subject_key=private_key,
issuer_name=issuer_subject,
)
write_der(_cert_path(name), certificate.public_bytes(Encoding.DER), force)
def _test(
test_name: str, cert_type: str, algorithm: str, signature: bytes, expected: str
) -> None:
nonlocal message_path
cert_path: str = _cert_path(cert_type)
lower_test_name: str = test_name.lower()
sig_path: str = os.path.join(output_dir, f"{lower_test_name}.sig.bin")
write_der(sig_path, signature, force)
feature_gate = feature_gates.get(algorithm, 'feature = "alloc"')
print(
"""
#[test]
#[cfg(%(feature_gate)s)]
fn %(lower_test_name)s() {
let ee = include_bytes!("%(cert_path)s");
let message = include_bytes!("%(message_path)s");
let signature = include_bytes!("%(sig_path)s");
assert_eq!(
check_sig(ee, %(algorithm)s, message, signature),
%(expected)s
);
}"""
% locals(),
file=output,
)
def good_signature(
test_name: str, cert_type: str, algorithm: str, signer: SIGNER
) -> None:
signature: bytes = signer(all_key_types[cert_type], message)
_test(test_name, cert_type, algorithm, signature, expected="Ok(())")
def good_signature_but_rejected(
test_name: str, cert_type: str, algorithm: str, signer: SIGNER
) -> None:
signature: bytes = signer(all_key_types[cert_type], message)
_test(
test_name,
cert_type,
algorithm,
signature,
expected="Err(webpki::Error::InvalidSignatureForPublicKey)",
)
def bad_signature(
test_name: str, cert_type: str, algorithm: str, signer: SIGNER
) -> None:
signature: bytes = signer(all_key_types[cert_type], message + b"?")
_test(
test_name,
cert_type,
algorithm,
signature,
expected="Err(webpki::Error::InvalidSignatureForPublicKey)",
)
def bad_algorithms_for_key(
test_name: str, cert_type: str, unusable_algs: set[str]
) -> None:
cert_path: str = _cert_path(cert_type)
test_name_lower: str = test_name.lower()
unusable_algs_str: str = ", ".join(alg for alg in sorted(unusable_algs))
print(
"""
#[test]
#[cfg(feature = "alloc")]
fn %(test_name_lower)s() {
let ee = include_bytes!("%(cert_path)s");
for algorithm in &[ %(unusable_algs_str)s ] {
assert_eq!(
check_sig(ee, *algorithm, b"", b""),
Err(webpki::Error::UnsupportedSignatureAlgorithmForPublicKey)
);
}
}"""
% locals(),
file=output,
)
with trim_top("signatures.rs") as output:
all_webpki_algs: set[str] = set(
[item for algs in webpki_algs.values() for item in algs]
)
for type, algs in webpki_algs.items():
for alg in algs:
signer: SIGNER = how_to_sign[alg]
good_signature(
type + "_key_and_" + alg + "_good_signature",
cert_type=type,
algorithm=alg,
signer=signer,
)
bad_signature(
type + "_key_and_" + alg + "_detects_bad_signature",
cert_type=type,
algorithm=alg,
signer=signer,
)
unusable_algs = set(all_webpki_algs)
for alg in algs:
unusable_algs.remove(alg)
if type == "rsa_2048":
unusable_algs.remove("RSA_PKCS1_3072_8192_SHA384")
unusable_algs = {
(
"#[cfg(%s)] %s" % (feature_gates[alg], alg)
if alg in feature_gates
else alg
)
for alg in unusable_algs
}
bad_algorithms_for_key(
type + "_key_rejected_by_other_algorithms",
cert_type=type,
unusable_algs=unusable_algs,
)
good_signature_but_rejected(
"rsa_2048_key_rejected_by_RSA_PKCS1_3072_8192_SHA384",
cert_type="rsa_2048",
algorithm="RSA_PKCS1_3072_8192_SHA384",
signer=signer,
)
def generate_client_auth_test(
output: TextIO,
test_name: str,
ekus: Optional[Iterable[x509.ObjectIdentifier]],
expected_error: Optional[str] = None,
force: bool = False,
) -> None:
issuer_name: x509.Name = issuer_name_for_test(test_name)
ee_subject: x509.Name = x509.Name(
[x509.NameAttribute(NameOID.ORGANIZATION_NAME, test_name)]
)
ee_certificate: x509.Certificate = end_entity_cert(
subject_name=ee_subject,
ekus=ekus,
issuer_name=issuer_name,
)
output_dir: str = "client_auth"
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
ee_cert_path: str = os.path.join(output_dir, f"{test_name}.ee.der")
write_der(ee_cert_path, ee_certificate.public_bytes(Encoding.DER), force)
ca: x509.Certificate = ca_cert(
subject_name=issuer_name, subject_key=ROOT_PRIVATE_KEY
)
ca_cert_path: str = os.path.join(output_dir, f"{test_name}.ca.der")
write_der(ca_cert_path, ca.public_bytes(Encoding.DER), force)
expected: str = ""
if expected_error is None:
expected = "Ok(())"
else:
expected = "Err(webpki::Error::" + expected_error + ")"
print(
"""
#[test]
fn %(test_name)s() {
let ee = include_bytes!("%(ee_cert_path)s");
let ca = include_bytes!("%(ca_cert_path)s");
assert_eq!(
check_cert(ee, ca),
%(expected)s
);
}"""
% locals(),
file=output,
)
def client_auth(force: bool) -> None:
with trim_top("client_auth.rs") as output:
generate_client_auth_test(
output, "cert_with_no_eku_accepted_for_client_auth", ekus=None
)
generate_client_auth_test(
output,
"cert_with_clientauth_eku_accepted_for_client_auth",
ekus=[ExtendedKeyUsageOID.CLIENT_AUTH],
)
generate_client_auth_test(
output,
"cert_with_both_ekus_accepted_for_client_auth",
ekus=[ExtendedKeyUsageOID.CLIENT_AUTH, ExtendedKeyUsageOID.SERVER_AUTH],
)
generate_client_auth_test(
output,
"cert_with_serverauth_eku_rejected_for_client_auth",
ekus=[ExtendedKeyUsageOID.SERVER_AUTH],
expected_error="RequiredEkuNotFound",
)
def client_auth_revocation(force: bool) -> None:
output_dir: str = "client_auth_revocation"
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
crl_sign_ku = x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True, content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
)
no_crl_sign_ku = x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=False, content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
)
valid_cert_crl_dp = x509.DistributionPoint(
full_name=[
x509.DNSName("example.com"),
x509.UniformResourceIdentifier("http://example.com/another.crl"),
x509.UniformResourceIdentifier("http://example.com/valid.crl"),
],
crl_issuer=None,
relative_name=None,
reasons=None,
)
valid_crl_idp = x509.IssuingDistributionPoint(
full_name=[
x509.UniformResourceIdentifier("http://example.com/yet.another.crl"),
x509.UniformResourceIdentifier("http://example.com/valid.crl"),
],
indirect_crl=False,
only_contains_user_certs=False,
only_contains_ca_certs=False,
only_contains_attribute_certs=False,
only_some_reasons=None,
relative_name=None,
)
class ChainDepth(enum.Enum):
END_ENTITY = enum.auto()
CHAIN = enum.auto()
class StatusRequirement(enum.Enum):
ALLOW_UNKNOWN = enum.auto()
FORBID_UNKNOWN = enum.auto()
class ExpirationPolicy(enum.Enum):
ENFORCE = enum.auto()
IGNORE = enum.auto()
def _chain(
*,
chain_name: str,
key_usage: Optional[x509.KeyUsage],
cert_dps: Optional[x509.DistributionPoint],
) -> list[tuple[x509.Certificate, str, ANY_PRIV_KEY]]:
ee_subj: x509.Name = subject_name_for_test("test.example.com", chain_name)
int_a_subj: x509.Name = issuer_name_for_test(f"int.a.{chain_name}")
int_b_subj: x509.Name = issuer_name_for_test(f"int.b.{chain_name}")
ca_subj: x509.Name = issuer_name_for_test(f"ca.{chain_name}")
ee_key: ec.EllipticCurvePrivateKey = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)
int_a_key: ec.EllipticCurvePrivateKey = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)
int_b_key: ec.EllipticCurvePrivateKey = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)
root_key: ec.EllipticCurvePrivateKey = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)
ee_cert: x509.Certificate = end_entity_cert(
subject_name=ee_subj,
issuer_name=int_a_subj,
issuer_key=int_a_key,
cert_dps=cert_dps,
)
ee_cert_path: str = os.path.join(output_dir, f"{chain_name}.ee.der")
write_der(ee_cert_path, ee_cert.public_bytes(Encoding.DER), force)
ee_cert_topbit: x509.Certificate = end_entity_cert(
subject_name=ee_subj,
issuer_name=int_a_subj,
issuer_key=int_a_key,
serial=0x80DEADBEEFF00D,
cert_dps=cert_dps,
)
ee_cert_topbit_path: str = os.path.join(
output_dir, f"{chain_name}.topbit.ee.der"
)
write_der(ee_cert_topbit_path, ee_cert_topbit.public_bytes(Encoding.DER), force)
int_a_cert: x509.Certificate = ca_cert(
subject_name=int_a_subj,
subject_key=int_a_key,
issuer_name=int_b_subj,
issuer_key=int_b_key,
key_usage=key_usage,
cert_dps=cert_dps,
)
int_a_cert_path: str = os.path.join(output_dir, f"{chain_name}.int.a.ca.der")
write_der(int_a_cert_path, int_a_cert.public_bytes(Encoding.DER), force)
int_b_cert: x509.Certificate = ca_cert(
subject_name=int_b_subj,
subject_key=int_b_key,
issuer_name=ca_subj,
issuer_key=root_key,
key_usage=key_usage,
cert_dps=cert_dps,
)
int_b_cert_path: str = os.path.join(output_dir, f"{chain_name}.int.b.ca.der")
write_der(int_b_cert_path, int_b_cert.public_bytes(Encoding.DER), force)
root_cert: x509.Certificate = ca_cert(
subject_name=ca_subj,
subject_key=root_key,
key_usage=key_usage,
cert_dps=cert_dps,
)
root_cert_path: str = os.path.join(output_dir, f"{chain_name}.root.ca.der")
write_der(root_cert_path, root_cert.public_bytes(Encoding.DER), force)
return [
(ee_cert, ee_cert_path, ee_key),
(int_a_cert, int_a_cert_path, int_a_key),
(int_b_cert, int_b_cert_path, int_b_key),
(root_cert, root_cert_path, root_key),
(ee_cert_topbit, ee_cert_topbit_path, ee_key),
]
def _crl(
*,
serials: Iterable[int],
issuer_name: x509.Name,
issuer_key: Optional[ANY_PRIV_KEY],
issuing_dp: Optional[x509.IssuingDistributionPoint] = None,
not_after: Optional[datetime.datetime] = None,
) -> x509.CertificateRevocationList:
issuer_priv_key: ANY_PRIV_KEY = key_or_generate(issuer_key)
crl_builder: x509.CertificateRevocationListBuilder = (
x509.CertificateRevocationListBuilder()
)
crl_builder = crl_builder.issuer_name(issuer_name)
crl_builder = crl_builder.last_update(NOT_BEFORE)
if not_after is None:
not_after = NOT_AFTER
crl_builder = crl_builder.next_update(not_after)
for serial in serials:
revoked_cert_builder: x509.RevokedCertificateBuilder = (
x509.RevokedCertificateBuilder()
)
revoked_cert_builder = revoked_cert_builder.serial_number(serial)
revoked_cert_builder = revoked_cert_builder.revocation_date(NOT_BEFORE)
revoked_cert_builder = revoked_cert_builder.add_extension(
x509.CRLReason(x509.ReasonFlags.key_compromise), critical=False
)
crl_builder = crl_builder.add_revoked_certificate(
revoked_cert_builder.build()
)
if issuing_dp is not None:
crl_builder = crl_builder.add_extension(issuing_dp, critical=True)
crl_builder = crl_builder.add_extension(
x509.CRLNumber(x509.random_serial_number()), critical=False
)
return crl_builder.sign(
private_key=issuer_priv_key,
algorithm=hashes.SHA256(),
)
def _revocation_test(
*,
test_name: str,
chain: list[tuple[x509.Certificate, str, ANY_PRIV_KEY]],
crl_paths: list[str],
depth: ChainDepth,
policy: StatusRequirement,
expiration: ExpirationPolicy,
expected_error: Optional[str],
ee_topbit_serial: bool = False,
) -> None:
if len(chain) != 5:
raise RuntimeError("invalid chain length")
ee_cert, ee_cert_path, _ = chain[4] if ee_topbit_serial else chain[0]
int_a_cert, int_a_cert_path, _ = chain[1]
int_b_cert, int_b_cert_path, _ = chain[2]
root_cert, root_cert_path, _ = chain[3]
int_a_str = f'include_bytes!("{int_a_cert_path}").as_slice()'
int_b_str = f'include_bytes!("{int_b_cert_path}").as_slice()'
intermediates_str: str = f"&[{int_a_str}, {int_b_str}]"
def _write_revocation_test(*, owned: bool) -> None:
nonlocal crl_paths, expected_error, intermediates_str, test_name, ee_cert_path, root_cert_path
test_name = test_name if not owned else test_name + "_owned"
crl_includes: str = ""
if not owned:
crl_includes = "\n".join(
[
f"""
&webpki::CertRevocationList::Borrowed(
webpki::BorrowedCertRevocationList::from_der(include_bytes!("{path}").as_slice())
.unwrap()
),
"""
for path in crl_paths
]
)
else:
crl_includes = "\n".join(
[
f"""
&webpki::CertRevocationList::Owned(
webpki::OwnedCertRevocationList::from_der(include_bytes!("{path}").as_slice())
.unwrap()
),
"""
for path in crl_paths
]
)
if len(crl_paths) == 0:
revocation_setup = "let revocation = None;"
else:
revocation_setup = f"""
let crls = &[{crl_includes}];
let builder = RevocationOptionsBuilder::new(crls).unwrap();
"""
if depth == ChainDepth.END_ENTITY:
revocation_setup += """
let builder = builder.with_depth(RevocationCheckDepth::EndEntity);
"""
if policy == StatusRequirement.ALLOW_UNKNOWN:
revocation_setup += """
let builder = builder.with_status_policy(UnknownStatusPolicy::Allow);
"""
if expiration == ExpirationPolicy.ENFORCE:
revocation_setup += """
let builder = builder.with_expiration_policy(webpki::ExpirationPolicy::Enforce);
"""
revocation_setup += "let revocation = Some(builder.build());"
expected: str = (
f"Err(webpki::Error::{expected_error})" if expected_error else "Ok(())"
)
feature_gate = '#[cfg(feature = "alloc")]' if owned else ""
print(
"""
%(feature_gate)s
#[test]
fn %(test_name)s() {
let ee = include_bytes!("%(ee_cert_path)s");
let intermediates = %(intermediates_str)s;
let ca = include_bytes!("%(root_cert_path)s");
%(revocation_setup)s
assert_eq!(check_cert(ee, intermediates, ca, revocation), %(expected)s);
}
"""
% locals(),
file=output,
)
_write_revocation_test(owned=False)
_write_revocation_test(owned=True)
no_ku_chain = _chain(chain_name="no_ku_chain", key_usage=None, cert_dps=None)
no_crl_ku_chain = _chain(
chain_name="no_crl_ku_chain", key_usage=no_crl_sign_ku, cert_dps=None
)
crl_ku_chain = _chain(chain_name="ku_chain", key_usage=crl_sign_ku, cert_dps=None)
dp_chain = _chain(chain_name="dp_chain", key_usage=None, cert_dps=valid_cert_crl_dp)
def _ee_no_crls_test() -> None:
_revocation_test(
test_name="no_crls_test",
chain=no_ku_chain,
crl_paths=[],
depth=ChainDepth.END_ENTITY, policy=StatusRequirement.ALLOW_UNKNOWN, expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _no_relevant_crl_ee_depth_allow_unknown() -> None:
test_name = "no_relevant_crl_ee_depth_allow_unknown"
ee_cert = no_ku_chain[0][0]
no_match_crl = _crl(
serials=[ee_cert.serial_number],
issuer_name=subject_name_for_test("whatev", test_name),
issuer_key=None,
)
no_match_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(no_match_crl_path, no_match_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[no_match_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _no_relevant_crl_ee_depth_forbid_unknown() -> None:
test_name = "no_relevant_crl_ee_depth_forbid_unknown"
ee_cert = no_ku_chain[0][0]
no_match_crl = _crl(
serials=[ee_cert.serial_number],
issuer_name=subject_name_for_test("whatev", test_name),
issuer_key=None,
)
no_match_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(no_match_crl_path, no_match_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[no_match_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _ee_not_revoked_ee_depth() -> None:
test_name = "ee_not_revoked_ee_depth"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_not_revoked_crl = _crl(
serials=[12345], issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_not_revoked_crl_path,
ee_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_not_revoked_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _ee_not_revoked_chain_depth() -> None:
test_name = "ee_not_revoked_chain_depth"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_not_revoked_crl = _crl(
serials=[12345], issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_not_revoked_crl_path,
ee_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_not_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _ee_revoked_badsig_ee_depth() -> None:
test_name = "ee_revoked_badsig_ee_depth"
ee_cert = no_ku_chain[0][0]
rand_key: ec.EllipticCurvePrivateKey = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)
ee_revoked_badsig = _crl(
serials=[ee_cert.serial_number],
issuer_name=ee_cert.issuer,
issuer_key=rand_key, )
ee_revoked_badsig_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_revoked_badsig_path, ee_revoked_badsig.public_bytes(Encoding.DER), force
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_revoked_badsig_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="InvalidCrlSignatureForPublicKey",
)
def _ee_revoked_wrong_ku_ee_depth() -> None:
test_name = "ee_revoked_wrong_ku_ee_depth"
ee_cert = no_crl_ku_chain[0][0]
int_a_key = no_crl_ku_chain[1][2]
ee_revoked_crl = _crl(
serials=[ee_cert.serial_number],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_revoked_crl_path, ee_revoked_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_crl_ku_chain,
crl_paths=[ee_revoked_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="IssuerNotCrlSigner",
)
def _ee_not_revoked_wrong_ku_ee_depth() -> None:
test_name = "ee_not_revoked_wrong_ku_ee_depth"
ee_cert = no_crl_ku_chain[0][0]
int_a_key = no_crl_ku_chain[1][2]
ee_not_revoked_crl = _crl(
serials=[12345], issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_not_revoked_crl_path,
ee_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_crl_ku_chain,
crl_paths=[ee_not_revoked_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="IssuerNotCrlSigner",
)
def _ee_revoked_no_ku_ee_depth() -> None:
test_name = "ee_revoked_no_ku_ee_depth"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_revoked_crl = _crl(
serials=[ee_cert.serial_number],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_revoked_crl_path, ee_revoked_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_revoked_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="CertRevoked",
)
def _ee_revoked_crl_ku_ee_depth() -> None:
test_name = "ee_revoked_crl_ku_ee_depth"
ee_cert = crl_ku_chain[0][0]
int_a_key = crl_ku_chain[1][2]
ee_revoked_crl = _crl(
serials=[ee_cert.serial_number],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_revoked_crl_path, ee_revoked_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=crl_ku_chain,
crl_paths=[ee_revoked_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="CertRevoked",
)
def _no_crls_test_chain_depth() -> None:
_revocation_test(
test_name="no_crls_test_chain_depth",
chain=no_ku_chain,
crl_paths=[],
depth=ChainDepth.CHAIN, policy=StatusRequirement.ALLOW_UNKNOWN, expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _no_relevant_crl_chain_depth_allow_unknown() -> None:
test_name = "no_relevant_crl_chain_depth_allow_unknown"
int_a_cert = no_ku_chain[1][0]
no_match_crl = _crl(
serials=[int_a_cert.serial_number],
issuer_name=subject_name_for_test("whatev", test_name),
issuer_key=None,
)
no_match_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(no_match_crl_path, no_match_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[no_match_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _no_relevant_crl_chain_depth_forbid_unknown() -> None:
test_name = "no_relevant_crl_chain_depth_forbid_unknown"
int_a_cert = no_ku_chain[1][0]
no_match_crl = _crl(
serials=[int_a_cert.serial_number],
issuer_name=subject_name_for_test("whatev", test_name),
issuer_key=None,
)
no_match_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(no_match_crl_path, no_match_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[no_match_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _int_not_revoked_chain_depth_allow_unknown() -> None:
test_name = "int_not_revoked_chain_depth"
int_a_cert = no_ku_chain[1][0]
int_b_key = no_ku_chain[2][2]
int_not_revoked_crl = _crl(
serials=[12345], issuer_name=int_a_cert.issuer,
issuer_key=int_b_key,
)
int_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
int_not_revoked_crl_path,
int_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[int_not_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _int_not_revoked_chain_depth_forbid_unknown() -> None:
test_name = "int_not_revoked_chain_depth_forbid_unknown"
ee_cert = no_ku_chain[0][0]
int_a_cert = no_ku_chain[1][0]
int_a_key = no_ku_chain[1][2]
int_b_key = no_ku_chain[2][2]
int_b_cert = no_ku_chain[2][0]
root_key = no_ku_chain[3][2]
ee_not_revoked_crl = _crl(
serials=[9999],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}_ee.crl.der")
write_der(
ee_not_revoked_crl_path,
ee_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
int_not_revoked_crl = _crl(
serials=[12345], issuer_name=int_a_cert.issuer,
issuer_key=int_b_key,
)
int_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
int_not_revoked_crl_path,
int_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
int_not_revoked_crl_b = _crl(
serials=[12345], issuer_name=int_b_cert.issuer,
issuer_key=root_key,
)
int_not_revoked_crl_b_path = os.path.join(output_dir, f"{test_name}_b.crl.der")
write_der(
int_not_revoked_crl_b_path,
int_not_revoked_crl_b.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[
ee_not_revoked_crl_path,
int_not_revoked_crl_path,
int_not_revoked_crl_b_path,
],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _int_revoked_badsig_chain_depth() -> None:
test_name = "int_revoked_badsig_chain_depth"
int_a_cert = no_ku_chain[1][0]
rand_key: ec.EllipticCurvePrivateKey = ec.generate_private_key(
ec.SECP256R1(), default_backend()
)
int_revoked_badsig = _crl(
serials=[int_a_cert.serial_number],
issuer_name=int_a_cert.issuer,
issuer_key=rand_key, )
int_revoked_badsig_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
int_revoked_badsig_path,
int_revoked_badsig.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[int_revoked_badsig_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="InvalidCrlSignatureForPublicKey",
)
def _int_revoked_wrong_ku_chain_depth() -> None:
test_name = "int_revoked_wrong_ku_chain_depth"
int_a_cert = no_crl_ku_chain[1][0]
int_b_key = no_crl_ku_chain[2][2]
int_revoked_crl = _crl(
serials=[int_a_cert.serial_number],
issuer_name=int_a_cert.issuer,
issuer_key=int_b_key,
)
int_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
int_revoked_crl_path, int_revoked_crl.public_bytes(Encoding.DER), force
)
_revocation_test(
test_name=test_name,
chain=no_crl_ku_chain,
crl_paths=[int_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="IssuerNotCrlSigner",
)
def _ee_revoked_chain_depth() -> None:
test_name = "ee_revoked_chain_depth"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_revoked_crl = _crl(
serials=[ee_cert.serial_number],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_revoked_crl_path, ee_revoked_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="CertRevoked",
)
def _int_revoked_no_ku_chain_depth() -> None:
test_name = "int_revoked_no_ku_chain_depth"
int_a_cert = no_ku_chain[1][0]
int_b_key = no_ku_chain[2][2]
int_revoked_crl = _crl(
serials=[int_a_cert.serial_number],
issuer_name=int_a_cert.issuer,
issuer_key=int_b_key,
)
int_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
int_revoked_crl_path, int_revoked_crl.public_bytes(Encoding.DER), force
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[int_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="CertRevoked",
)
def _int_revoked_crl_ku_chain_depth() -> None:
test_name = "int_revoked_crl_ku_chain_depth"
int_a_cert = crl_ku_chain[1][0]
int_b_key = crl_ku_chain[2][2]
int_revoked_crl = _crl(
serials=[int_a_cert.serial_number],
issuer_name=int_a_cert.issuer,
issuer_key=int_b_key,
)
int_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
int_revoked_crl_path, int_revoked_crl.public_bytes(Encoding.DER), force
)
_revocation_test(
test_name=test_name,
chain=crl_ku_chain,
crl_paths=[int_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="CertRevoked",
)
def _ee_with_top_bit_set_serial_revoked() -> None:
test_name = "ee_with_top_bit_set_serial_revoked"
ee_cert_topbit = crl_ku_chain[4][0]
int_a_key = crl_ku_chain[1][2]
ee_revoked_crl = _crl(
serials=[ee_cert_topbit.serial_number],
issuer_name=ee_cert_topbit.issuer,
issuer_key=int_a_key,
)
ee_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_revoked_crl_path, ee_revoked_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=crl_ku_chain,
crl_paths=[ee_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
ee_topbit_serial=True,
expected_error="CertRevoked",
)
def _ee_no_dp_crl_idp() -> None:
test_name = "ee_no_dp_crl_idp"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_idp_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=valid_crl_idp,
)
ee_idp_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_idp_crl_path, ee_idp_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_idp_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _ee_crl_no_idp_unknown_status() -> None:
test_name = "ee_crl_no_idp_unknown_status"
ee_cert = dp_chain[0][0]
int_a_key = dp_chain[1][2]
ee_no_idp_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
)
ee_no_idp_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(ee_no_idp_crl_path, ee_no_idp_crl.public_bytes(Encoding.DER), force)
_revocation_test(
test_name=test_name,
chain=dp_chain,
crl_paths=[ee_no_idp_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _ee_crl_mismatched_idp_unknown_status() -> None:
test_name = "ee_crl_mismatched_idp_unknown_status"
ee_cert = dp_chain[0][0]
int_a_key = dp_chain[1][2]
ee_wrong_idp_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=x509.IssuingDistributionPoint(
full_name=[
x509.UniformResourceIdentifier("http://does.not.match.example.com")
],
indirect_crl=False,
relative_name=None,
only_contains_attribute_certs=False,
only_contains_ca_certs=False,
only_contains_user_certs=False,
only_some_reasons=None,
),
)
ee_wrong_idp_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_wrong_idp_crl_path, ee_wrong_idp_crl.public_bytes(Encoding.DER), force
)
_revocation_test(
test_name=test_name,
chain=dp_chain,
crl_paths=[ee_wrong_idp_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _ee_indirect_dp_unknown_status() -> None:
test_name = "ee_indirect_dp_unknown_status"
indirect_dp_chain = _chain(
chain_name="indirect_dp_chain",
key_usage=None,
cert_dps=x509.DistributionPoint(
full_name=valid_cert_crl_dp.full_name,
relative_name=None,
reasons=None,
crl_issuer=[x509.DNSName("indirect.example.com")],
),
)
ee_cert = indirect_dp_chain[0][0]
int_a_key = indirect_dp_chain[1][2]
ee_indirect_dp_unknown_status_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=x509.IssuingDistributionPoint(
full_name=valid_cert_crl_dp.full_name,
indirect_crl=False,
relative_name=None,
only_contains_attribute_certs=False,
only_contains_ca_certs=False,
only_contains_user_certs=False,
only_some_reasons=None,
),
)
ee_indirect_dp_unknown_status_crl_path = os.path.join(
output_dir, f"{test_name}.crl.der"
)
write_der(
ee_indirect_dp_unknown_status_crl_path,
ee_indirect_dp_unknown_status_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=indirect_dp_chain,
crl_paths=[ee_indirect_dp_unknown_status_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _ee_reasons_dp_unknown_status() -> None:
test_name = "ee_reasons_dp_unknown_status"
reasons_dp_chain = _chain(
chain_name="reasons_dp_chain",
key_usage=None,
cert_dps=x509.DistributionPoint(
full_name=valid_cert_crl_dp.full_name,
relative_name=None,
reasons=frozenset([x509.ReasonFlags.key_compromise]),
crl_issuer=None,
),
)
ee_cert = reasons_dp_chain[0][0]
int_a_key = reasons_dp_chain[1][2]
ee_reasons_dp_unknown_status_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=x509.IssuingDistributionPoint(
full_name=valid_cert_crl_dp.full_name,
indirect_crl=False,
relative_name=None,
only_contains_attribute_certs=False,
only_contains_ca_certs=False,
only_contains_user_certs=False,
only_some_reasons=None,
),
)
ee_reasons_dp_unknown_status_crl_path = os.path.join(
output_dir, f"{test_name}.crl.der"
)
write_der(
ee_reasons_dp_unknown_status_crl_path,
ee_reasons_dp_unknown_status_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=reasons_dp_chain,
crl_paths=[ee_reasons_dp_unknown_status_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _ee_nofullname_dp_unknown_status() -> None:
test_name = "ee_nofullname_dp_unknown_status"
nofullname_dp_chain = _chain(
chain_name="nofullname_dp_chain",
key_usage=None,
cert_dps=x509.DistributionPoint(
full_name=None,
relative_name=x509.RelativeDistinguishedName(
[x509.NameAttribute(NameOID.COMMON_NAME, "example.com")]
),
reasons=None,
crl_issuer=None,
),
)
ee_cert = nofullname_dp_chain[0][0]
int_a_key = nofullname_dp_chain[1][2]
ee_nofullname_dp_unknown_status_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=x509.IssuingDistributionPoint(
full_name=valid_cert_crl_dp.full_name,
indirect_crl=False,
relative_name=None,
only_contains_attribute_certs=False,
only_contains_ca_certs=False,
only_contains_user_certs=False,
only_some_reasons=None,
),
)
ee_nofullname_dp_unknown_status_crl_path = os.path.join(
output_dir, f"{test_name}.crl.der"
)
write_der(
ee_nofullname_dp_unknown_status_crl_path,
ee_nofullname_dp_unknown_status_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=nofullname_dp_chain,
crl_paths=[ee_nofullname_dp_unknown_status_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _ee_dp_idp_match() -> None:
test_name = "ee_dp_idp_match"
ee_cert = dp_chain[0][0]
int_a_key = dp_chain[1][2]
ee_dp_idp_match_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=x509.IssuingDistributionPoint(
full_name=valid_cert_crl_dp.full_name,
indirect_crl=False,
relative_name=None,
only_contains_attribute_certs=False,
only_contains_ca_certs=False,
only_contains_user_certs=False,
only_some_reasons=None,
),
)
ee_dp_idp_match_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_dp_idp_match_crl_path,
ee_dp_idp_match_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=dp_chain,
crl_paths=[ee_dp_idp_match_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _ee_dp_invalid() -> None:
test_name = "ee_dp_invalid"
bad_dp = x509.DistributionPoint(
full_name=valid_cert_crl_dp.full_name,
relative_name=None,
reasons=frozenset([x509.ReasonFlags.key_compromise]),
crl_issuer=None,
)
bad_dp._full_name = None
invalid_dp_chain = _chain(
chain_name="invalid_dp_chain",
key_usage=None,
cert_dps=bad_dp,
)
ee_cert = invalid_dp_chain[0][0]
int_a_key = invalid_dp_chain[1][2]
invalid_dp_chain_crl = _crl(
serials=[0xFFFF],
issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
issuing_dp=valid_crl_idp,
)
invalid_dp_chain_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
invalid_dp_chain_crl_path,
invalid_dp_chain_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=invalid_dp_chain,
crl_paths=[invalid_dp_chain_crl_path],
depth=ChainDepth.END_ENTITY,
policy=StatusRequirement.FORBID_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error="UnknownRevocationStatus",
)
def _expired_crl_ignore_expiration() -> None:
test_name = "expired_crl_ignore_expiration"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_not_revoked_crl = _crl(
serials=[12345], issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
not_after=datetime.datetime.utcfromtimestamp(0x1FEDF00D - 10),
)
ee_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_not_revoked_crl_path,
ee_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_not_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.IGNORE,
expected_error=None,
)
def _expired_crl_enforce_expiration() -> None:
test_name = "expired_crl_enforce_expiration"
ee_cert = no_ku_chain[0][0]
int_a_key = no_ku_chain[1][2]
ee_not_revoked_crl = _crl(
serials=[12345], issuer_name=ee_cert.issuer,
issuer_key=int_a_key,
not_after=datetime.datetime.utcfromtimestamp(0x1FEDF00D - 10),
)
ee_not_revoked_crl_path = os.path.join(output_dir, f"{test_name}.crl.der")
write_der(
ee_not_revoked_crl_path,
ee_not_revoked_crl.public_bytes(Encoding.DER),
force,
)
_revocation_test(
test_name=test_name,
chain=no_ku_chain,
crl_paths=[ee_not_revoked_crl_path],
depth=ChainDepth.CHAIN,
policy=StatusRequirement.ALLOW_UNKNOWN,
expiration=ExpirationPolicy.ENFORCE,
expected_error="CrlExpired",
)
with trim_top("client_auth_revocation.rs") as output:
_ee_no_crls_test()
_no_relevant_crl_ee_depth_allow_unknown()
_no_relevant_crl_ee_depth_forbid_unknown()
_ee_not_revoked_ee_depth()
_ee_not_revoked_chain_depth()
_ee_revoked_badsig_ee_depth()
_ee_revoked_wrong_ku_ee_depth()
_ee_not_revoked_wrong_ku_ee_depth()
_ee_revoked_no_ku_ee_depth()
_ee_revoked_crl_ku_ee_depth()
_no_crls_test_chain_depth()
_no_relevant_crl_chain_depth_allow_unknown()
_no_relevant_crl_chain_depth_forbid_unknown()
_int_not_revoked_chain_depth_allow_unknown()
_int_not_revoked_chain_depth_forbid_unknown()
_int_revoked_badsig_chain_depth()
_int_revoked_wrong_ku_chain_depth()
_ee_revoked_chain_depth()
_int_revoked_no_ku_chain_depth()
_int_revoked_crl_ku_chain_depth()
_ee_with_top_bit_set_serial_revoked()
_ee_no_dp_crl_idp()
_ee_crl_no_idp_unknown_status()
_ee_crl_mismatched_idp_unknown_status()
_ee_indirect_dp_unknown_status()
_ee_reasons_dp_unknown_status()
_ee_nofullname_dp_unknown_status()
_ee_dp_idp_match()
_ee_dp_invalid()
_expired_crl_ignore_expiration()
_expired_crl_enforce_expiration()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tls-server-certs",
action=argparse.BooleanOptionalAction,
default=True,
help="Generate TLS server certificate testcases",
)
parser.add_argument(
"--signatures",
action=argparse.BooleanOptionalAction,
default=True,
help="Generate signature testcases",
)
parser.add_argument(
"--client-auth",
action=argparse.BooleanOptionalAction,
default=True,
help="Generate client auth testcases",
)
parser.add_argument(
"--client-auth-revocation",
action=argparse.BooleanOptionalAction,
default=True,
help="Generate client auth revocation testcases",
)
parser.add_argument(
"--format",
action=argparse.BooleanOptionalAction,
default=True,
help="Run cargo fmt post-generation",
)
parser.add_argument(
"--test",
action=argparse.BooleanOptionalAction,
default=True,
help="Run cargo test post-generation",
)
parser.add_argument(
"--force",
action=argparse.BooleanOptionalAction,
default=False,
help="Overwrite existing test keys/certs",
)
args = parser.parse_args()
if args.tls_server_certs:
tls_server_certs(args.force)
if args.signatures:
signatures(args.force)
if args.client_auth:
client_auth(args.force)
if args.client_auth_revocation:
client_auth_revocation(args.force)
if args.format:
subprocess.run("cargo fmt", shell=True, check=True)
if args.test:
subprocess.run("cargo test", shell=True, check=True)