p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any


TARGETS_CONFIG = Path(__file__).resolve().parents[1] / "regression_targets.json"


@dataclass(frozen=True)
class ModelSpec:
    name: str
    model_dir: Path
    default_opsets: tuple[int, ...]
    input_kind: str | None = None
    sample_inputs: dict[str, dict[str, Any]] | None = None
    diff_rtol: float | None = None
    diff_atol: float | None = None
    compare_config: dict[str, Any] | None = None
    semantic_baseline: Path | None = None


def load_target_catalog() -> tuple[dict[str, dict[str, str]], list[dict[str, str]]]:
    payload = json.loads(TARGETS_CONFIG.read_text())

    operator_targets = payload.get("operator_targets")
    if not isinstance(operator_targets, dict) or not operator_targets:
        raise ValueError(f"Target catalog {TARGETS_CONFIG} must contain a non-empty 'operator_targets' object")

    subgraph_targets = payload.get("subgraph_targets")
    if not isinstance(subgraph_targets, list) or not subgraph_targets:
        raise ValueError(f"Target catalog {TARGETS_CONFIG} must contain a non-empty 'subgraph_targets' list")

    return operator_targets, subgraph_targets


def load_manifest(manifest_path: Path) -> dict[str, ModelSpec]:
    payload = json.loads(manifest_path.read_text())
    model_items = payload.get("models")
    if not isinstance(model_items, list) or not model_items:
        raise ValueError(f"Manifest {manifest_path} must contain a non-empty 'models' list")

    specs: dict[str, ModelSpec] = {}
    for item in model_items:
        if not isinstance(item, dict):
            raise ValueError(f"Manifest {manifest_path} contains a non-object model entry")
        name = item.get("name")
        model_dir_raw = item.get("model_dir")
        if not isinstance(name, str) or not name:
            raise ValueError("Each model entry must define a non-empty string 'name'")
        if not isinstance(model_dir_raw, str) or not model_dir_raw:
            raise ValueError(f"Model '{name}' must define a non-empty string 'model_dir'")
        model_dir = Path(model_dir_raw)
        if not model_dir.is_absolute():
            model_dir = (manifest_path.parent / model_dir).resolve()

        default_opsets_raw = item.get("default_opsets", [17])
        if not isinstance(default_opsets_raw, list) or not default_opsets_raw or not all(isinstance(v, int) for v in default_opsets_raw):
            raise ValueError(f"Model '{name}' must define 'default_opsets' as a non-empty int list")

        input_kind = item.get("input_kind")
        if input_kind is not None and not isinstance(input_kind, str):
            raise ValueError(f"Model '{name}' has invalid 'input_kind'")

        sample_inputs = item.get("sample_inputs")
        if sample_inputs is not None and not isinstance(sample_inputs, dict):
            raise ValueError(f"Model '{name}' has invalid 'sample_inputs'")

        compare_config = item.get("compare_config")
        if compare_config is not None and not isinstance(compare_config, dict):
            raise ValueError(f"Model '{name}' has invalid 'compare_config'")

        semantic_baseline_raw = item.get("semantic_baseline")
        semantic_baseline: Path | None = None
        if semantic_baseline_raw is not None:
            if not isinstance(semantic_baseline_raw, str) or not semantic_baseline_raw:
                raise ValueError(f"Model '{name}' has invalid 'semantic_baseline'")
            semantic_baseline = Path(semantic_baseline_raw)
            if not semantic_baseline.is_absolute():
                semantic_baseline = (manifest_path.parent / semantic_baseline).resolve()

        diff_tolerance = item.get("diff_tolerance")
        diff_rtol: float | None = None
        diff_atol: float | None = None
        if diff_tolerance is not None:
            if not isinstance(diff_tolerance, dict):
                raise ValueError(f"Model '{name}' has invalid 'diff_tolerance'")
            diff_rtol_raw = diff_tolerance.get("rtol")
            diff_atol_raw = diff_tolerance.get("atol")
            if diff_rtol_raw is not None:
                if not isinstance(diff_rtol_raw, (int, float)):
                    raise ValueError(f"Model '{name}' has invalid diff_tolerance.rtol")
                diff_rtol = float(diff_rtol_raw)
            if diff_atol_raw is not None:
                if not isinstance(diff_atol_raw, (int, float)):
                    raise ValueError(f"Model '{name}' has invalid diff_tolerance.atol")
                diff_atol = float(diff_atol_raw)

        specs[name] = ModelSpec(
            name=name,
            model_dir=model_dir,
            default_opsets=tuple(default_opsets_raw),
            input_kind=input_kind,
            sample_inputs=sample_inputs,
            diff_rtol=diff_rtol,
            diff_atol=diff_atol,
            compare_config=compare_config,
            semantic_baseline=semantic_baseline,
        )
    return specs


def select_model_names(model_specs: dict[str, ModelSpec], requested: list[str] | None) -> list[str]:
    if not requested:
        return list(model_specs)
    unknown = [name for name in requested if name not in model_specs]
    if unknown:
        available = ", ".join(sorted(model_specs))
        raise ValueError(f"Unknown models: {', '.join(unknown)}. Available: {available}")
    return requested