athena_rs 2.9.1

Database gateway API
Documentation
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import re
import sys
from dataclasses import dataclass
from pathlib import Path

SEMVER_RE = re.compile(
    r"^v?(?P<version>\d+\.\d+\.\d+(?:-[0-9A-Za-z.-]+)?(?:\+[0-9A-Za-z.-]+)?)$"
)


@dataclass(frozen=True)
class VersionFiles:
    cargo_toml: Path
    readme: Path
    openapi: Path


def normalize_version(raw: str) -> str:
    match = SEMVER_RE.fullmatch(raw.strip())
    if not match:
        raise ValueError(
            f"Invalid version '{raw}'. Expected semver like 1.2.3 or v1.2.3."
        )
    return match.group("version")


def replace_first_line(lines: list[str], predicate, replacement: str, path: Path) -> list[str]:
    for index, line in enumerate(lines):
        if predicate(line):
            lines[index] = replacement
            return lines
    raise ValueError(f"Could not find expected version line in {path}")


def update_cargo_toml(text: str, version: str, path: Path) -> str:
    lines = text.splitlines()
    in_package = False

    for index, line in enumerate(lines):
        stripped = line.strip()
        if stripped == "[package]":
            in_package = True
            continue
        if in_package and stripped.startswith("[") and stripped.endswith("]"):
            break
        if in_package and stripped.startswith("version = "):
            lines[index] = f'version = "{version}"'
            return "\n".join(lines) + "\n"

    raise ValueError(f"Could not find package version in {path}")


def update_readme(text: str, version: str, path: Path) -> str:
    lines = text.splitlines()
    replace_first_line(
        lines,
        lambda line: line.startswith("# Athena RS "),
        f"# Athena RS {version}",
        path,
    )
    return "\n".join(lines) + "\n"


def update_openapi(text: str, version: str, path: Path) -> str:
    lines = text.splitlines()
    replace_first_line(
        lines,
        lambda line: line.startswith("  version: "),
        f"  version: {version}",
        path,
    )
    return "\n".join(lines) + "\n"


def read_cargo_version(path: Path) -> str:
    text = path.read_text()
    in_package = False

    for line in text.splitlines():
        stripped = line.strip()
        if stripped == "[package]":
            in_package = True
            continue
        if in_package and stripped.startswith("[") and stripped.endswith("]"):
            break
        if in_package and stripped.startswith("version = "):
            match = re.match(r'^version = "([^"]+)"$', stripped)
            if match:
                return match.group(1)

    raise ValueError(f"Could not find package version in {path}")


def read_readme_version(path: Path) -> str:
    first_line = path.read_text().splitlines()[0]
    match = re.match(r"^# Athena RS (.+)$", first_line)
    if not match:
        raise ValueError(f"Could not find README version heading in {path}")
    return match.group(1)


def read_openapi_version(path: Path) -> str:
    for line in path.read_text().splitlines():
        if line.startswith("  version: "):
            return line.removeprefix("  version: ").strip()
    raise ValueError(f"Could not find OpenAPI version in {path}")


def get_version_files(root: Path) -> VersionFiles:
    return VersionFiles(
        cargo_toml=root / "Cargo.toml",
        readme=root / "README.md",
        openapi=root / "openapi.yaml",
    )


def collect_versions(files: VersionFiles) -> dict[str, str]:
    return {
        "Cargo.toml": read_cargo_version(files.cargo_toml),
        "README.md": read_readme_version(files.readme),
        "openapi.yaml": read_openapi_version(files.openapi),
    }


def write_if_changed(path: Path, updated_text: str) -> bool:
    current_text = path.read_text()
    if current_text == updated_text:
        return False
    path.write_text(updated_text)
    return True


def sync_versions(root: Path, target_version: str) -> list[str]:
    files = get_version_files(root)
    updated_paths: list[str] = []

    cargo_updated = update_cargo_toml(files.cargo_toml.read_text(), target_version, files.cargo_toml)
    if write_if_changed(files.cargo_toml, cargo_updated):
        updated_paths.append("Cargo.toml")

    readme_updated = update_readme(files.readme.read_text(), target_version, files.readme)
    if write_if_changed(files.readme, readme_updated):
        updated_paths.append("README.md")

    openapi_updated = update_openapi(files.openapi.read_text(), target_version, files.openapi)
    if write_if_changed(files.openapi, openapi_updated):
        updated_paths.append("openapi.yaml")

    return updated_paths


def run_check(root: Path, expected_version: str | None) -> int:
    files = get_version_files(root)
    versions = collect_versions(files)
    unique_versions = set(versions.values())

    if expected_version is not None:
        mismatches = {
            path: version for path, version in versions.items() if version != expected_version
        }
        if mismatches:
            print("Version mismatch against expected release version:")
            for path, version in mismatches.items():
                print(f"  {path}: {version} (expected {expected_version})")
            return 1

    if len(unique_versions) != 1:
        print("Version drift detected:")
        for path, version in versions.items():
            print(f"  {path}: {version}")
        return 1

    resolved_version = next(iter(unique_versions))
    print(f"Version metadata is in sync at {resolved_version}")
    return 0


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Sync the Athena release version across Cargo.toml, README.md, and openapi.yaml."
    )
    parser.add_argument(
        "version",
        nargs="?",
        help="Target version to write. Accepts 1.2.3 or v1.2.3. Defaults to the version already in Cargo.toml.",
    )
    parser.add_argument(
        "--check",
        action="store_true",
        help="Validate that all tracked files are in sync instead of writing changes.",
    )
    parser.add_argument(
        "--root",
        default=Path(__file__).resolve().parents[1],
        type=Path,
        help=argparse.SUPPRESS,
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    root = args.root.resolve()

    try:
        target_version = normalize_version(args.version) if args.version else None
        if args.check:
            return run_check(root, target_version)

        if target_version is None:
            target_version = read_cargo_version(get_version_files(root).cargo_toml)
        updated_paths = sync_versions(root, target_version)
    except ValueError as error:
        print(error, file=sys.stderr)
        return 1

    if updated_paths:
        print(f"Updated version metadata to {target_version}: {', '.join(updated_paths)}")
    else:
        print(f"Version metadata already set to {target_version}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())