p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import json
import shutil
import tarfile
import tempfile
import urllib.request
from pathlib import Path


DEFAULT_BASE_URL = (
    "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0.0"
)
REQUIRED_FILES = ("inference.json", "inference.pdiparams")


def load_manifest(manifest_path: Path) -> list[tuple[str, Path]]:
    payload = json.loads(manifest_path.read_text())
    model_items = payload.get("models")
    if not isinstance(model_items, list) or not model_items:
        raise ValueError(f"Manifest {manifest_path} must contain a non-empty 'models' list")

    models: list[tuple[str, Path]] = []
    for item in model_items:
        if not isinstance(item, dict):
            raise ValueError(f"Manifest {manifest_path} contains a non-object model entry")
        name = item.get("name")
        model_dir_raw = item.get("model_dir")
        if not isinstance(name, str) or not name:
            raise ValueError("Each model entry must define a non-empty string 'name'")
        if not isinstance(model_dir_raw, str) or not model_dir_raw:
            raise ValueError(f"Model '{name}' must define a non-empty string 'model_dir'")

        model_dir = Path(model_dir_raw)
        if not model_dir.is_absolute():
            model_dir = (manifest_path.parent / model_dir).resolve()
        models.append((name, model_dir))
    return models


def select_models(
    models: list[tuple[str, Path]],
    requested: list[str] | None,
) -> list[tuple[str, Path]]:
    if not requested:
        return models

    available = {name for name, _ in models}
    unknown = [name for name in requested if name not in available]
    if unknown:
        raise ValueError(
            f"Unknown models: {', '.join(unknown)}. Available: {', '.join(sorted(available))}"
        )
    requested_set = set(requested)
    return [(name, model_dir) for name, model_dir in models if name in requested_set]


def has_required_files(model_dir: Path) -> bool:
    return model_dir.is_dir() and all((model_dir / required).is_file() for required in REQUIRED_FILES)


def download_archive(url: str, archive_path: Path) -> None:
    with urllib.request.urlopen(url) as response, archive_path.open("wb") as handle:
        shutil.copyfileobj(response, handle)


def safe_extract_all(archive_path: Path, destination: Path) -> None:
    destination = destination.resolve()
    destination.mkdir(parents=True, exist_ok=True)

    with tarfile.open(archive_path) as archive:
        for member in archive.getmembers():
            if member.issym() or member.islnk() or member.isdev() or member.isfifo():
                raise ValueError(
                    f"Refusing to extract unsupported tar member type: {member.name}"
                )
            if not (member.isdir() or member.isfile()):
                raise ValueError(
                    f"Refusing to extract unsupported tar member type: {member.name}"
                )

            member_path = (destination / member.name).resolve()
            if member_path != destination and destination not in member_path.parents:
                raise ValueError(
                    f"Refusing to extract path outside destination: {member.name}"
                )

            if member.isdir():
                member_path.mkdir(parents=True, exist_ok=True)
                continue

            member_path.parent.mkdir(parents=True, exist_ok=True)
            extracted = archive.extractfile(member)
            if extracted is None:
                raise ValueError(f"Could not read file from archive: {member.name}")
            with extracted, member_path.open("wb") as handle:
                shutil.copyfileobj(extracted, handle)
            if member.mode:
                member_path.chmod(member.mode & 0o777)


def find_extracted_root(extract_dir: Path, expected_name: str) -> Path:
    expected = extract_dir / expected_name
    if expected.is_dir():
        return expected

    candidates = [path for path in extract_dir.rglob("*") if path.is_dir() and has_required_files(path)]
    if len(candidates) == 1:
        return candidates[0]
    raise FileNotFoundError(
        f"Could not locate extracted model directory '{expected_name}' under {extract_dir}"
    )


def remove_existing(path: Path) -> None:
    if path.is_symlink() or path.is_file():
        path.unlink()
    elif path.exists():
        shutil.rmtree(path)


def ensure_model(model_name: str, model_dir: Path, base_url: str, dry_run: bool) -> None:
    archive_name = f"{model_dir.name}.tar"
    archive_url = f"{base_url.rstrip('/')}/{archive_name}"

    if has_required_files(model_dir):
        print(f"[skip] {model_name}: {model_dir}")
        return

    if dry_run:
        print(f"[plan] {model_name}: {archive_url} -> {model_dir}")
        return

    print(f"[download] {model_name}: {archive_url}")
    model_dir.parent.mkdir(parents=True, exist_ok=True)

    with tempfile.TemporaryDirectory(
        prefix=f".download_{model_dir.name}_",
        dir=model_dir.parent,
    ) as temp_dir_raw:
        temp_dir = Path(temp_dir_raw)
        archive_path = temp_dir / archive_name
        extract_dir = temp_dir / "extract"
        extract_dir.mkdir()

        download_archive(archive_url, archive_path)
        safe_extract_all(archive_path, extract_dir)
        extracted_root = find_extracted_root(extract_dir, model_dir.name)

        if model_dir.exists():
            remove_existing(model_dir)
        shutil.move(str(extracted_root), str(model_dir))

    if not has_required_files(model_dir):
        raise FileNotFoundError(
            f"Downloaded model '{model_name}' is missing required files under {model_dir}"
        )
    print(f"[ready] {model_name}: {model_dir}")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Download and extract regression inference models referenced by a manifest."
    )
    parser.add_argument(
        "--manifest",
        type=Path,
        required=True,
        help="JSON manifest describing regression models.",
    )
    parser.add_argument(
        "--base-url",
        default=DEFAULT_BASE_URL,
        help="Base URL used to fetch <model_dir_name>.tar archives.",
    )
    parser.add_argument(
        "--models",
        nargs="*",
        help="Optional subset of model names from the manifest.",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print the resolved downloads without fetching archives.",
    )
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    manifest_path = args.manifest.resolve()
    models = load_manifest(manifest_path)
    selected = select_models(models, args.models)
    for model_name, model_dir in selected:
        ensure_model(model_name, model_dir, args.base_url, args.dry_run)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())