synta 0.1.3

ASN.1 parser, decoder, and encoder library with DER/BER support and C FFI
Documentation
#!/usr/bin/env python3
"""
ML-DSA (FIPS 204) certificate tests using LAMPS WG test vectors

Parses certificates from the LAMPS Working Group dilithium-certificates
repository:  https://github.com/lamps-wg/dilithium-certificates

Looks for pre-fetched vectors in:
  <workspace-root>/tests/vectors/dilithium-certificates/examples/

If not present, clones the repository on first run (sparse checkout,
examples/ only) into the same location.  Pre-fetch with:
  bash tests/vectors/fetch.sh

All tests are skipped automatically when the repository is unavailable (no
network access, git not installed, etc.).  Run them explicitly with:

  python -m pytest tests/python/test_mldsa.py -v
  python3 tests/python/test_mldsa.py          # manual runner
"""

import base64
import subprocess
import sys
from pathlib import Path
from typing import Optional

import pytest
import synta
import synta.oids as oids

# ── ML-DSA OIDs (FIPS 204 / draft-ietf-lamps-dilithium-certificates) ────────
ML_DSA_OID_44 = str(oids.ML_DSA_44)
ML_DSA_OID_65 = str(oids.ML_DSA_65)
ML_DSA_OID_87 = str(oids.ML_DSA_87)
ML_DSA_OIDS = {ML_DSA_OID_44, ML_DSA_OID_65, ML_DSA_OID_87}
ML_DSA_NAMES = {
    ML_DSA_OID_44: "ML-DSA-44",
    ML_DSA_OID_65: "ML-DSA-65",
    ML_DSA_OID_87: "ML-DSA-87",
}
# Set of friendly algorithm names that synta returns for ML-DSA certificates.
ML_DSA_FRIENDLY_NAMES = set(ML_DSA_NAMES.values())

# ── Repository configuration (mirrors tests/test_utils/repo.rs) ──────────────
_LAMPS_REPO_URL  = "https://github.com/lamps-wg/dilithium-certificates.git"
_LAMPS_DIR_NAME  = "dilithium-certificates"
_LAMPS_CERT_PATH = "examples"

# Workspace root (two directories up from tests/python/)
_WORKSPACE  = Path(__file__).resolve().parent.parent.parent
# Canonical location: tests/vectors/<dir_name>/<cert_path>
# Mirrors the path used by tests/test_utils/repo.rs and tests/vectors/fetch.sh
_REPO_DIR   = _WORKSPACE / "tests" / "vectors" / _LAMPS_DIR_NAME
_CERT_DIR   = _REPO_DIR / _LAMPS_CERT_PATH


# ── Repository setup ─────────────────────────────────────────────────────────

def _setup_repository() -> bool:
    """Make the LAMPS dilithium-certificates repo available.

    Checks the canonical location (tests/vectors/dilithium-certificates/)
    first — populated by ``bash tests/vectors/fetch.sh``.  If not present,
    clones into the same directory using the same shallow + sparse-checkout
    strategy as the Rust test harness.
    Returns True when the certificate directory is available.
    """
    if _CERT_DIR.exists():
        return True

    _REPO_DIR.parent.mkdir(parents=True, exist_ok=True)

    try:
        result = subprocess.run(
            [
                "git", "clone",
                "--depth=1", "--filter=blob:none", "--sparse",
                _LAMPS_REPO_URL, _LAMPS_DIR_NAME,
            ],
            cwd=_REPO_DIR.parent,
            capture_output=True,
            timeout=120,
        )
        if result.returncode != 0:
            return False

        subprocess.run(
            ["git", "sparse-checkout", "set", _LAMPS_CERT_PATH],
            cwd=_REPO_DIR,
            capture_output=True,
            timeout=30,
        )
        return _CERT_DIR.exists()

    except (FileNotFoundError, subprocess.TimeoutExpired, OSError):
        return False


_repo_available = _setup_repository()

# Skip every test in this module when the repository cannot be set up.
pytestmark = pytest.mark.skipif(
    not _repo_available,
    reason=(
        "LAMPS dilithium-certificates repository not available "
        "(requires git and network access)"
    ),
)


# ── Helpers ──────────────────────────────────────────────────────────────────

def _parse_pem(data: bytes) -> bytes:
    """Extract DER bytes from a PEM-encoded certificate."""
    text = data.decode("ascii", errors="ignore")
    b64_lines = []
    in_cert = False
    for line in text.splitlines():
        if line.startswith("-----BEGIN"):
            in_cert = True
            continue
        if line.startswith("-----END"):
            break
        if in_cert:
            b64_lines.append(line.strip())
    return base64.b64decode("".join(b64_lines))


def _load_cert_der(path: Path) -> bytes:
    """Load a certificate file as DER bytes (handles both PEM and DER)."""
    data = path.read_bytes()
    if data.startswith(b"-----"):
        return _parse_pem(data)
    return data


def _find_cert_files():
    """Return all certificate files in the LAMPS ML-DSA repo."""
    files = []
    for ext in ("*.pem", "*.der", "*.crt"):
        files.extend(_CERT_DIR.rglob(ext))
    return sorted(files)


def _assert_mldsa_cert(cert, variant: Optional[str] = None):
    """Common assertions for any ML-DSA X.509 certificate."""
    sig_name = str(cert.signature_algorithm)
    assert sig_name in ML_DSA_FRIENDLY_NAMES, \
        f"Expected an ML-DSA signature algorithm name, got {sig_name!r}"
    if variant is not None:
        assert sig_name == ML_DSA_NAMES[variant], \
            f"Expected {ML_DSA_NAMES[variant]} ({variant}), got {sig_name!r}"

    assert cert.version == 2, \
        f"Expected X.509 v3 (version field = 2), got {cert.version!r}"
    assert isinstance(cert.serial_number, int), \
        "serial_number must be a Python int"
    assert cert.not_before, "not_before must be a non-empty string"
    assert cert.not_after,  "not_after must be a non-empty string"
    assert cert.subject,    "subject must be a non-empty string"
    assert cert.issuer,     "issuer must be a non-empty string"
    assert len(cert.public_key) > 0, "public_key bytes must not be empty"


# ── Tests ────────────────────────────────────────────────────────────────────

def test_parse_all_mldsa_certificates():
    """Parse every certificate file in the LAMPS ML-DSA repo.

    Counts certificates per ML-DSA variant.  Fails if any file that is
    successfully read as DER cannot be decoded as an X.509 certificate.
    Public-key-only files (.pub / bare SubjectPublicKeyInfo) that are not
    full certificates are counted as skipped rather than failures.
    """
    cert_files = _find_cert_files()
    assert cert_files, f"No certificate files found under {_CERT_DIR}"

    parsed_ok = []
    skipped   = []   # not a full certificate (e.g. bare public key)
    failed    = []   # genuine parse errors

    ml_dsa_counts = {"ML-DSA-44": 0, "ML-DSA-65": 0, "ML-DSA-87": 0}

    for path in cert_files:
        try:
            der = _load_cert_der(path)
        except Exception as exc:
            skipped.append((path.name, f"unreadable: {exc}"))
            continue

        try:
            cert = synta.Certificate.from_der(der)
        except Exception as exc:
            # A bare public-key file or other non-certificate ASN.1 blob will
            # fail here; treat it as a skip rather than a hard failure.
            skipped.append((path.name, f"not a certificate: {exc}"))
            continue

        sig_name = str(cert.signature_algorithm)
        if sig_name in ml_dsa_counts:
            ml_dsa_counts[sig_name] += 1

        parsed_ok.append(path.name)

    total_attempted = len(parsed_ok) + len(failed)
    print(f"\nML-DSA certificate parsing: {len(parsed_ok)}/{total_attempted} OK"
          f"  ({len(skipped)} skipped as non-certificate)")
    for name, count in ml_dsa_counts.items():
        if count:
            print(f"  {name}: {count} certificate(s)")

    if failed:
        details = "\n".join(f"  {n}: {e}" for n, e in failed)
        raise AssertionError(
            f"Failed to parse {len(failed)} certificate(s):\n{details}"
        )

    assert parsed_ok, \
        "Should have successfully parsed at least one certificate"
    assert any(ml_dsa_counts.values()), \
        f"No recognised ML-DSA signature OIDs found; counts: {ml_dsa_counts}"


def test_mldsa_44_certificate():
    """At least one ML-DSA-44 certificate must parse correctly."""
    # Try files whose name suggests ML-DSA-44 first, then scan all files.
    candidates = [
        f for f in _find_cert_files()
        if any(kw in f.name.lower() for kw in ("44", "mldsa44", "dilithium2"))
    ] or _find_cert_files()

    found = False
    for path in candidates:
        try:
            cert = synta.Certificate.from_der(_load_cert_der(path))
        except Exception:
            continue
        if str(cert.signature_algorithm) == ML_DSA_NAMES[ML_DSA_OID_44]:
            _assert_mldsa_cert(cert, variant=ML_DSA_OID_44)
            print(f"\n  ML-DSA-44 certificate found: {path.name}")
            print(f"  Subject:    {cert.subject}")
            print(f"  Issuer:     {cert.issuer}")
            print(f"  Not before: {cert.not_before}")
            print(f"  Not after:  {cert.not_after}")
            found = True
            break

    assert found, \
        "No ML-DSA-44 certificate found or parseable in the LAMPS repo"


def test_mldsa_65_certificate():
    """At least one ML-DSA-65 certificate must parse correctly."""
    candidates = [
        f for f in _find_cert_files()
        if any(kw in f.name.lower() for kw in ("65", "mldsa65", "dilithium3"))
    ] or _find_cert_files()

    found = False
    for path in candidates:
        try:
            cert = synta.Certificate.from_der(_load_cert_der(path))
        except Exception:
            continue
        if str(cert.signature_algorithm) == ML_DSA_NAMES[ML_DSA_OID_65]:
            _assert_mldsa_cert(cert, variant=ML_DSA_OID_65)
            print(f"\n  ML-DSA-65 certificate found: {path.name}")
            print(f"  Subject:    {cert.subject}")
            print(f"  Issuer:     {cert.issuer}")
            print(f"  Not before: {cert.not_before}")
            print(f"  Not after:  {cert.not_after}")
            found = True
            break

    assert found, \
        "No ML-DSA-65 certificate found or parseable in the LAMPS repo"


def test_mldsa_87_certificate():
    """At least one ML-DSA-87 certificate must parse correctly."""
    candidates = [
        f for f in _find_cert_files()
        if any(kw in f.name.lower() for kw in ("87", "mldsa87", "dilithium5"))
    ] or _find_cert_files()

    found = False
    for path in candidates:
        try:
            cert = synta.Certificate.from_der(_load_cert_der(path))
        except Exception:
            continue
        if str(cert.signature_algorithm) == ML_DSA_NAMES[ML_DSA_OID_87]:
            _assert_mldsa_cert(cert, variant=ML_DSA_OID_87)
            print(f"\n  ML-DSA-87 certificate found: {path.name}")
            print(f"  Subject:    {cert.subject}")
            print(f"  Issuer:     {cert.issuer}")
            print(f"  Not before: {cert.not_before}")
            print(f"  Not after:  {cert.not_after}")
            found = True
            break

    assert found, \
        "No ML-DSA-87 certificate found or parseable in the LAMPS repo"


# ── Manual runner (python3 tests/python/test_mldsa.py) ──────────────────────

def main():
    if not _repo_available:
        print("SKIP: LAMPS dilithium-certificates repository not available.")
        print("      Requires git and network access.")
        print(f"      Expected path: {_CERT_DIR}")
        sys.exit(0)

    tests = [
        test_parse_all_mldsa_certificates,
        test_mldsa_44_certificate,
        test_mldsa_65_certificate,
        test_mldsa_87_certificate,
    ]

    print("=" * 60)
    print("ML-DSA Certificate Tests (LAMPS WG test vectors)")
    print("=" * 60)

    failed_count = 0
    for test in tests:
        print(f"\nRunning {test.__name__} ...")
        try:
            test()
            print(f"{test.__name__}: OK")
        except Exception as exc:
            print(f"{test.__name__} FAILED: {exc}")
            failed_count += 1

    print()
    print("=" * 60)
    if failed_count:
        print(f"{failed_count}/{len(tests)} tests FAILED")
        sys.exit(1)
    else:
        print(f"All {len(tests)} tests passed ✓")
    print("=" * 60)


if __name__ == "__main__":
    main()