from __future__ import annotations
import argparse
import re
import sys
from dataclasses import dataclass
from pathlib import Path
SEMVER_RE = re.compile(
r"^v?(?P<version>\d+\.\d+\.\d+(?:-[0-9A-Za-z.-]+)?(?:\+[0-9A-Za-z.-]+)?)$"
)
@dataclass(frozen=True)
class VersionFiles:
cargo_toml: Path
readme: Path
openapi: Path
def normalize_version(raw: str) -> str:
match = SEMVER_RE.fullmatch(raw.strip())
if not match:
raise ValueError(
f"Invalid version '{raw}'. Expected semver like 1.2.3 or v1.2.3."
)
return match.group("version")
def replace_first_line(lines: list[str], predicate, replacement: str, path: Path) -> list[str]:
for index, line in enumerate(lines):
if predicate(line):
lines[index] = replacement
return lines
raise ValueError(f"Could not find expected version line in {path}")
def update_cargo_toml(text: str, version: str, path: Path) -> str:
lines = text.splitlines()
in_package = False
for index, line in enumerate(lines):
stripped = line.strip()
if stripped == "[package]":
in_package = True
continue
if in_package and stripped.startswith("[") and stripped.endswith("]"):
break
if in_package and stripped.startswith("version = "):
lines[index] = f'version = "{version}"'
return "\n".join(lines) + "\n"
raise ValueError(f"Could not find package version in {path}")
def update_readme(text: str, version: str, path: Path) -> str:
lines = text.splitlines()
replace_first_line(
lines,
lambda line: line.startswith("# Athena RS "),
f"# Athena RS {version}",
path,
)
return "\n".join(lines) + "\n"
def update_openapi(text: str, version: str, path: Path) -> str:
lines = text.splitlines()
replace_first_line(
lines,
lambda line: line.startswith(" version: "),
f" version: {version}",
path,
)
return "\n".join(lines) + "\n"
def read_cargo_version(path: Path) -> str:
text = path.read_text()
in_package = False
for line in text.splitlines():
stripped = line.strip()
if stripped == "[package]":
in_package = True
continue
if in_package and stripped.startswith("[") and stripped.endswith("]"):
break
if in_package and stripped.startswith("version = "):
match = re.match(r'^version = "([^"]+)"$', stripped)
if match:
return match.group(1)
raise ValueError(f"Could not find package version in {path}")
def read_readme_version(path: Path) -> str:
first_line = path.read_text().splitlines()[0]
match = re.match(r"^# Athena RS (.+)$", first_line)
if not match:
raise ValueError(f"Could not find README version heading in {path}")
return match.group(1)
def read_openapi_version(path: Path) -> str:
for line in path.read_text().splitlines():
if line.startswith(" version: "):
return line.removeprefix(" version: ").strip()
raise ValueError(f"Could not find OpenAPI version in {path}")
def get_version_files(root: Path) -> VersionFiles:
return VersionFiles(
cargo_toml=root / "Cargo.toml",
readme=root / "README.md",
openapi=root / "openapi.yaml",
)
def collect_versions(files: VersionFiles) -> dict[str, str]:
return {
"Cargo.toml": read_cargo_version(files.cargo_toml),
"README.md": read_readme_version(files.readme),
"openapi.yaml": read_openapi_version(files.openapi),
}
def write_if_changed(path: Path, updated_text: str) -> bool:
current_text = path.read_text()
if current_text == updated_text:
return False
path.write_text(updated_text)
return True
def sync_versions(root: Path, target_version: str) -> list[str]:
files = get_version_files(root)
updated_paths: list[str] = []
cargo_updated = update_cargo_toml(files.cargo_toml.read_text(), target_version, files.cargo_toml)
if write_if_changed(files.cargo_toml, cargo_updated):
updated_paths.append("Cargo.toml")
readme_updated = update_readme(files.readme.read_text(), target_version, files.readme)
if write_if_changed(files.readme, readme_updated):
updated_paths.append("README.md")
openapi_updated = update_openapi(files.openapi.read_text(), target_version, files.openapi)
if write_if_changed(files.openapi, openapi_updated):
updated_paths.append("openapi.yaml")
return updated_paths
def run_check(root: Path, expected_version: str | None) -> int:
files = get_version_files(root)
versions = collect_versions(files)
unique_versions = set(versions.values())
if expected_version is not None:
mismatches = {
path: version for path, version in versions.items() if version != expected_version
}
if mismatches:
print("Version mismatch against expected release version:")
for path, version in mismatches.items():
print(f" {path}: {version} (expected {expected_version})")
return 1
if len(unique_versions) != 1:
print("Version drift detected:")
for path, version in versions.items():
print(f" {path}: {version}")
return 1
resolved_version = next(iter(unique_versions))
print(f"Version metadata is in sync at {resolved_version}")
return 0
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Sync the Athena release version across Cargo.toml, README.md, and openapi.yaml."
)
parser.add_argument(
"version",
nargs="?",
help="Target version to write. Accepts 1.2.3 or v1.2.3. Defaults to the version already in Cargo.toml.",
)
parser.add_argument(
"--check",
action="store_true",
help="Validate that all tracked files are in sync instead of writing changes.",
)
parser.add_argument(
"--root",
default=Path(__file__).resolve().parents[1],
type=Path,
help=argparse.SUPPRESS,
)
return parser.parse_args()
def main() -> int:
args = parse_args()
root = args.root.resolve()
try:
target_version = normalize_version(args.version) if args.version else None
if args.check:
return run_check(root, target_version)
if target_version is None:
target_version = read_cargo_version(get_version_files(root).cargo_toml)
updated_paths = sync_versions(root, target_version)
except ValueError as error:
print(error, file=sys.stderr)
return 1
if updated_paths:
print(f"Updated version metadata to {target_version}: {', '.join(updated_paths)}")
else:
print(f"Version metadata already set to {target_version}")
return 0
if __name__ == "__main__":
raise SystemExit(main())