p2o 0.1.1

A PaddlePaddle New IR (PIR) to ONNX model converter.
Documentation
from __future__ import annotations

import json
from pathlib import Path
from typing import Any

from .manifest import ModelSpec


def load_model_ir(model_json: Path) -> dict[str, Any]:
    return json.loads(model_json.read_text())


def get_model_ops(data: dict[str, Any]) -> list[dict[str, Any]]:
    collected: list[dict[str, Any]] = []

    def walk_region(region: dict[str, Any]) -> None:
        for block in region.get("blocks", []):
            for op in block.get("ops", []):
                if not isinstance(op, dict) or "#" not in op:
                    continue
                collected.append(op)
                for sub_region in op.get("regions", []):
                    if isinstance(sub_region, dict):
                        walk_region(sub_region)

    walk_region(data["program"]["regions"][0])
    return collected


def extract_struct_name(op: dict[str, Any]) -> str | None:
    attrs = op.get("A")
    if not isinstance(attrs, list):
        return None
    for attr in attrs:
        if isinstance(attr, dict) and attr.get("N") == "struct_name":
            at = attr.get("AT")
            if isinstance(at, dict):
                value = at.get("D")
                if isinstance(value, str) and value:
                    return value
    return None


def collect_model_inventory(model_spec: ModelSpec, subgraph_targets: list[dict[str, str]]) -> dict[str, Any]:
    data = load_model_ir(model_spec.model_dir / "inference.json")
    ops = get_model_ops(data)
    op_types = sorted({op["#"] for op in ops})
    struct_names = sorted({name for op in ops if (name := extract_struct_name(op))})
    subgraphs: list[dict[str, Any]] = []
    for target in subgraph_targets:
        matched_ops = [op for op in ops if target["match"] in (extract_struct_name(op) or "")]
        if not matched_ops:
            continue
        subgraphs.append(
            {
                "id": target["id"],
                "label": target["label"],
                "match": target["match"],
                "op_types": sorted({op["#"] for op in matched_ops}),
                "op_count": len(matched_ops),
            }
        )
    return {
        "model": model_spec.name,
        "model_dir": str(model_spec.model_dir),
        "operator_targets": op_types,
        "struct_names": struct_names,
        "subgraph_targets": subgraphs,
    }


def build_inventory(
    model_specs: dict[str, ModelSpec],
    model_names: list[str],
    operator_targets: dict[str, dict[str, str]],
    subgraph_targets: list[dict[str, str]],
) -> dict[str, Any]:
    per_model = [collect_model_inventory(model_specs[name], subgraph_targets) for name in model_names]
    covered_ops = sorted({op for model in per_model for op in model["operator_targets"]})

    operator_target_items = []
    for op_name in sorted(operator_targets):
        meta = operator_targets[op_name]
        used_by = [model["model"] for model in per_model if op_name in model["operator_targets"]]
        operator_target_items.append(
            {
                "op_name": op_name,
                "kind": meta["kind"],
                "implementation": meta["impl"],
                "note": meta["note"],
                "covered_by_models": used_by,
                "currently_covered": bool(used_by),
            }
        )

    subgraph_target_items = []
    for target in subgraph_targets:
        matched_models = []
        for model in per_model:
            for subgraph in model["subgraph_targets"]:
                if subgraph["id"] == target["id"]:
                    matched_models.append(
                        {
                            "model": model["model"],
                            "op_count": subgraph["op_count"],
                            "op_types": subgraph["op_types"],
                        }
                    )
        subgraph_target_items.append(
            {
                "id": target["id"],
                "label": target["label"],
                "match": target["match"],
                "covered_by_models": matched_models,
                "currently_covered": bool(matched_models),
            }
        )

    return {
        "models": per_model,
        "operator_targets": operator_target_items,
        "subgraph_targets": subgraph_target_items,
        "summary": {
            "selected_models": model_names,
            "implemented_operator_target_count": len(operator_targets),
            "covered_operator_target_count": len(covered_ops),
            "implemented_subgraph_target_count": len(subgraph_targets),
            "covered_subgraph_target_count": sum(1 for item in subgraph_target_items if item["currently_covered"]),
        },
    }