megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
#! /usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import argparse
import logging
import re
from collections import namedtuple

import numpy as np
from tqdm import tqdm

import megengine as mge
from megengine.core.tensor.dtype import is_quantize
from megengine.logger import _imperative_rt_logger, get_logger, set_mgb_log_level
from megengine.utils.module_stats import (
    enable_receptive_field,
    get_activation_stats,
    get_op_stats,
    get_param_stats,
    print_activations_stats,
    print_op_stats,
    print_param_stats,
    print_summary,
    sizeof_fmt,
    sum_activations_stats,
    sum_op_stats,
    sum_param_stats,
)
from megengine.utils.network import Network

logger = get_logger(__name__)


def visualize(
    model_path: str,
    log_path: str,
    input: np.ndarray = None,
    inp_dict: dict = None,
    cal_params: bool = True,
    cal_flops: bool = True,
    cal_activations: bool = True,
    logging_to_stdout: bool = True,
    bar_length_max: int = 20,
):
    r"""Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    Args:
      model_path: dir path for megengine dumped model.
      log_path: dir path for tensorboard graph log.
      input: user defined input data for running model and calculating stats,
        alternative with inp_dict, used when the model has only one input.
      inp_dict: input dict for running model and calculating stats, alternative with
        input, used when the model has more than one input.
        When both input and inp_dict are None, a random input will be used.
      cal_params: whether calculate and record params size.
      cal_flops: whether calculate and record op flops.
      cal_activations: whether calculate and record op activations.
      logging_to_stdout: whether print all calculated statistic details.
      bar_length_max: size of bar indicating max flops or parameter size in net stats.
      model_path: str:
      log_path: str:
      input: np.ndarray:
      inp_dict: dict:
      cal_params: bool:
      cal_flops: bool:
      cal_activations: bool:
      logging_to_stdout: bool:
      bar_length_max: int:
    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return

    enable_receptive_field()

    graph = Network.load(model_path)
    graph.reset_batch_size(1)

    has_input = False
    if input is not None or inp_dict is not None:
        has_input = True
        repl_dict = {}
        inp_vars = graph.input_vars
        if inp_dict is not None:
            assert len(inp_dict) == len(
                inp_vars
            ), "Inputs are not sufficient for calculation."
            for v in inp_vars:
                new_input = graph.make_const(inp_dict[v.name], name=v.name)
                repl_dict[v] = new_input
        else:
            assert len(inp_vars) == 1, "The graph needs more than one input."
            inp_var = inp_vars[0]
            repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
        graph.replace_vars(repl_dict=repl_dict)

    graph._compile()

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    activations_list = []
    total_stats = namedtuple(
        "total_stats", ["param_size", "param_dims", "flops", "act_size", "act_dims"]
    )
    stats_details = namedtuple("module_stats", ["params", "flops", "activations"])

    for node in tqdm(graph.all_oprs):
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr.".format(
                        node
                    )
                )
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes": AttrValue(
                    list=AttrValue.ListValue(
                        shape=[
                            TensorShapeProto(
                                dim=[
                                    TensorShapeProto.Dim(size=d) for d in node_oup.shape
                                ]
                            )
                        ]
                    )
                ),
                "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }

        if cal_flops:
            flops_stats = get_op_stats(node, node.inputs, node.outputs)
            if flops_stats is not None:
                # add op flops attr
                if log_path and hasattr(flops_stats, "flops_num"):
                    attr["flops"] = AttrValue(
                        s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
                    )
                flops_stats["name"] = node.name
                flops_stats["class_name"] = node.type
                flops_list.append(flops_stats)

        if cal_activations:
            acts = get_activation_stats(node_oup, has_input=has_input)
            acts["name"] = node.name
            acts["class_name"] = node.type
            activations_list.append(acts)

        if cal_params:
            if node.type == "ImmutableTensor":
                param_stats = get_param_stats(node_oup)
                # add tensor size attr
                if log_path:
                    attr["size"] = AttrValue(
                        s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
                    )
                param_stats["name"] = node.name
                params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                )
            )
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    (
        total_flops,
        total_param_dims,
        total_param_size,
        total_act_dims,
        total_act_size,
    ) = (0, 0, 0, 0, 0)

    if cal_params:
        total_param_dims, total_param_size, params_list = sum_param_stats(
            params_list, bar_length_max
        )
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
        if logging_to_stdout:
            print_param_stats(params_list)

    if cal_flops:
        total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
        if logging_to_stdout:
            print_op_stats(flops_list)

    if cal_activations:
        total_act_dims, total_act_size, activations_list = sum_activations_stats(
            activations_list, bar_length_max
        )
        extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
        extra_info["total_act_size"] = sizeof_fmt(total_act_size)
        if logging_to_stdout:
            print_activations_stats(activations_list, has_input=has_input)

    if cal_flops and cal_params:
        extra_info["flops/param_size"] = "{:3.3f}".format(
            total_flops / total_param_size
        )

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(
            step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
        )
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    return (
        total_stats(
            param_size=total_param_size,
            param_dims=total_param_dims,
            flops=total_flops,
            act_size=total_act_size,
            act_dims=total_act_dims,
        ),
        stats_details(
            params=params_list, flops=flops_list, activations=activations_list
        ),
    )


def main():
    parser = argparse.ArgumentParser(
        description="load a megengine dumped model and export log file for tensorboard visualization.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("model_path", help="dumped model path.")
    parser.add_argument("--log_path", help="tensorboard log path.")
    parser.add_argument(
        "--load_input_data",
        help="load input data from pickle file; it should be a numpy array or a dict of numpy array",
    )
    parser.add_argument(
        "--bar_length_max",
        type=int,
        default=20,
        help="size of bar indicating max flops or parameter size in net stats.",
    )
    parser.add_argument(
        "--cal_params",
        action="store_true",
        help="whether calculate and record params size.",
    )
    parser.add_argument(
        "--cal_flops",
        action="store_true",
        help="whether calculate and record op flops.",
    )
    parser.add_argument(
        "--cal_activations",
        action="store_true",
        help="whether calculate and record op activations.",
    )
    parser.add_argument(
        "--logging_to_stdout",
        action="store_true",
        help="whether print all calculated statistic details.",
    )
    parser.add_argument(
        "--all",
        action="store_true",
        help="whether print all stats. Tensorboard logs will be placed in './log' if not specified.",
    )
    args = parser.parse_args()
    if args.load_input_data:
        logger.info("load data from {}".format(args.load_input_data))
        data = mge.load(args.load_input_data)
        if isinstance(data, dict):
            for v in data.values():
                assert isinstance(
                    v, np.ndarray
                ), "data should provide ndarray; got {} instead".format(v)
            args.inp_dict = data
        elif isinstance(data, np.ndarray):
            args.input = data
        else:
            logger.error("input data should be a numpy array or a dict of numpy array")
    if args.all:
        args.cal_params = True
        args.cal_flops = True
        args.cal_activations = True
        args.logging_to_stdout = True
        if not args.log_path:
            args.log_path = "./log"
    kwargs = vars(args)
    kwargs.pop("all")
    kwargs.pop("load_input_data")
    visualize(**kwargs)


if __name__ == "__main__":
    main()