from __future__ import annotations
import argparse
import json
import os
import subprocess
import sys
import tomllib
from collections.abc import Set
from dataclasses import InitVar, dataclass, field
from pathlib import Path
from typing import IO, Any, ClassVar, cast
class CargoInheritanceError(Exception):
pass
class MissingInheritedFieldError(CargoInheritanceError):
package: str
field_names: list[str]
def __init__(self, *, package: str, field_names: list[str]) -> None:
tail = ", ".join(field_names[:3])
if len(field_names) > 3:
tail += ", ..."
self.field_names = field_names.copy()
self.package = package
super().__init__(f"Package {package} does not inherit required fields from workspace root: {tail}")
class MetadataError(CargoInheritanceError, ValueError):
pass
@dataclass
class OurMetadata:
KEY: ClassVar[str] = "workspace.metadata.cargo-inheritance"
required_fields: list[str]
valid_crates: InitVar[Set[str]]
ignore_crates: list[str] = field(default_factory=list)
def __post_init__(self, valid_crates: Set[str]) -> None:
if len(set(self.required_fields)) != len(self.required_fields):
raise MetadataError(f"Duplicate entries in {OurMetadata.KEY}.required-fields")
for required_field in self.required_fields:
try:
_get_field({}, required_field)
except ValueError as cause:
raise MetadataError(f"Invalid field {field!r} in {OurMetadata.KEY}.required-fields") from cause
except KeyError:
pass if invalid_ignore_crates := set(self.ignore_crates) - valid_crates:
raise MetadataError(
f"Invalid crates in {OurMetadata.KEY}.ignore-crates: {', '.join(invalid_ignore_crates)}"
)
@classmethod
def deserialize(cls, *, cargo_metadata: dict[str, Any]) -> OurMetadata:
if "cargo-inheritance" not in cargo_metadata["metadata"]:
raise MetadataError(f"Missing {OurMetadata.KEY} workspace configuration")
valid_crates = set({raw_pkg["name"] for raw_pkg in cargo_metadata["packages"]})
raw_data: dict[str, object] = cargo_metadata["metadata"]["cargo-inheritance"]
if not isinstance(raw_data, dict):
raise MetadataError(f"Metadata entry {cls.KEY} must be a dictionary")
known_fields = {"required-fields", "ignore-crates"}
if unknown_fields := (set(raw_data.keys()) - known_fields):
raise MetadataError(f"Unknown config fields for {cls.KEY}: {', '.join(unknown_fields)}")
def _check_str_list(value: object, field_name: str) -> list[str]:
full_field_name = f"{cls.KEY}.{field_name}"
if not isinstance(value, list):
raise MetadataError(f"Expected {full_field_name} to be a list[str], got {type(value)}")
for index, item in enumerate(value):
if not isinstance(item, str):
raise MetadataError(
f"Expected entries of {full_field_name} to be a str, but found {type(value)} at index {index}"
)
return cast(list[str], value)
def _require_field(name: str) -> object:
try:
return raw_data[name]
except KeyError:
raise MetadataError(f"Missing required field {name} in {cls.KEY}") from None
return OurMetadata(
required_fields=_check_str_list(_require_field("required-fields"), "required-fields"),
ignore_crates=_check_str_list(raw_data.get("ignore-crates", []), "ignore-crates"),
valid_crates=valid_crates,
)
def _get_field(data: dict[str, Any], key: str, /) -> object:
assert isinstance(key, str)
key_parts = key.split(".")
assert key_parts, key
for part in key_parts:
if not part.isalnum():
raise ValueError(f"Invalid key {key!r} (bad part {part!r})")
current_value = data
for current_index in range(len(key_parts)):
parent_key = ".".join(key_parts[:current_index]) or "."
current_part = key_parts[current_index]
if not isinstance(current_value, dict):
raise KeyError(f"Failed to load {key!r}: {parent_key} is not a dictionary")
try:
current_value = current_value[current_part]
except KeyError:
raise KeyError(f"Failed to load {key!r}: {parent_key} is missing field {current_part!r}") from None
return current_value
def verify_cargo_inheritance(
*,
workspace: Path | None = None,
) -> None:
def validate_package(name: str, manifest_path: Path) -> None:
with open(manifest_path, "rb") as f:
manifest_data = tomllib.load(f)
missing_fields: list[str] = []
for field_name in our_metadata.required_fields:
try:
workspace_value = _get_field(manifest_data, field_name + ".workspace")
except KeyError:
missing_fields.append(field_name)
continue
if workspace_value is not True:
missing_fields.append(field_name)
if name in our_metadata.ignore_crates:
return
elif missing_fields:
raise MissingInheritedFieldError(package=name, field_names=missing_fields)
raw_metadata = json.loads(
subprocess.check_output(
["cargo", "metadata", "--no-deps", "--format-version=1"],
cwd=workspace,
encoding="utf-8",
)
)
try:
our_metadata = OurMetadata.deserialize(cargo_metadata=raw_metadata)
except MetadataError as e:
raise CargoInheritanceError("Failed to load cargo-inheritance metadata from root Cargo.toml") from e
failures = []
for raw_pkg in raw_metadata["packages"]:
name = raw_pkg["name"]
manifest_path = Path(raw_pkg["manifest_path"])
try:
validate_package(name, manifest_path)
except MissingInheritedFieldError as e:
failures.append(e)
if failures:
raise ExceptionGroup("Some cargo packages do not inherit required fields", failures)
def main(raw_args: list[str] | None = None, /) -> None:
parser = argparse.ArgumentParser(
prog="verify_cargo_inheritance",
description="Verify cargo packages inherit required fields from the workspace",
)
parser.add_argument(
"--workspace",
help="The workspace root directory",
default=str(Path.cwd()),
type=str,
)
args = parser.parse_args(raw_args)
workspace = Path(args.workspace)
if not workspace.is_dir() or not (workspace / "Cargo.toml").is_file():
print(f"ERROR: Does not look like a workspace: {workspace}", file=sys.stderr)
sys.exit(1)
try:
verify_cargo_inheritance(workspace=workspace)
except* MissingInheritedFieldError as group:
for entry in group.exceptions:
print_error(entry)
sys.exit(1)
def should_color(file: IO[str]) -> bool:
if os.getenv("NO_COLOR"):
return False
elif os.getenv("CLICOLOR_FORCE"):
return True
else:
return file.isatty()
def print_error(msg: Any) -> None:
prefix: str = "ERROR"
if should_color(sys.stderr):
prefix = f"\x1b[1;31m{prefix}\x1b[0m"
print(prefix, ": ", msg, sep="", file=sys.stderr)
if __name__ == "__main__":
main()
__all__ = (
"MissingInheritedFieldError",
"verify_cargo_inheritance",
)