megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
#!/usr/bin/env python
# -*-coding=utf-8-*-

from megengine.logger import get_logger

logger = get_logger(__name__)

try:
    from tensorboardX import SummaryWriter
    from tensorboardX.proto.attr_value_pb2 import AttrValue
    from tensorboardX.proto.graph_pb2 import GraphDef
    from tensorboardX.proto.node_def_pb2 import NodeDef
    from tensorboardX.proto.plugin_text_pb2 import TextPluginData
    from tensorboardX.proto.step_stats_pb2 import (
        DeviceStepStats,
        RunMetadata,
        StepStats,
    )
    from tensorboardX.proto.summary_pb2 import Summary, SummaryMetadata
    from tensorboardX.proto.tensor_pb2 import TensorProto
    from tensorboardX.proto.tensor_shape_pb2 import TensorShapeProto
    from tensorboardX.proto.versions_pb2 import VersionDef
except ImportError:
    logger.error(
        "TensorBoard and TensorboardX are required for visualize.", exc_info=True,
    )


def tensor_shape_proto(shape):
    """Creates an object matching
    https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
    """
    return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in shape])


def attr_value_proto(shape, dtype, attr):
    """Creates a dict of objects matching
    https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
    specifically designed for a NodeDef. The values have been
    reverse engineered from standard TensorBoard logged data.
    """
    attr_proto = {}
    if shape is not None:
        shapeproto = tensor_shape_proto(shape)
        attr_proto["_output_shapes"] = AttrValue(
            list=AttrValue.ListValue(shape=[shapeproto])
        )
    if dtype is not None:
        attr_proto["dtype"] = AttrValue(s=dtype.encode(encoding="utf-8"))
    if attr is not None:
        for key in attr.keys():
            attr_proto[key] = AttrValue(s=attr[key].encode(encoding="utf-8"))

    return attr_proto


def node_proto(
    name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
):
    """Creates an object matching
    https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
    """
    if input is None:
        input = []
    if not isinstance(input, list):
        input = [input]
    return NodeDef(
        name=name.encode(encoding="utf_8"),
        op=op,
        input=input,
        attr=attr_value_proto(outputshape, dtype, attributes),
    )


def node(
    name, op="UnSpecified", input=None, outputshape=None, dtype=None, attributes={}
):
    return node_proto(name, op, input, outputshape, dtype, attributes)


def graph(node_list):
    graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))
    stepstats = RunMetadata(
        step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])
    )
    return graph_def, stepstats


def text(tag, text):
    plugin_data = SummaryMetadata.PluginData(
        plugin_name="text", content=TextPluginData(version=0).SerializeToString()
    )
    smd = SummaryMetadata(plugin_data=plugin_data)
    string_val = []
    for item in text:
        string_val.append(item.encode(encoding="utf_8"))
    tensor = TensorProto(
        dtype="DT_STRING",
        string_val=string_val,
        tensor_shape=TensorShapeProto(dim=[TensorShapeProto.Dim(size=len(text))]),
    )

    return Summary(value=[Summary.Value(tag=tag, metadata=smd, tensor=tensor)])


class NodeRaw:
    def __init__(self, name, op, input, outputshape, dtype, attributes):
        self.name = name
        self.op = op
        self.input = input
        self.outputshape = outputshape
        self.dtype = dtype
        self.attributes = attributes


class SummaryWriterExtend(SummaryWriter):
    def __init__(
        self,
        logdir=None,
        comment="",
        purge_step=None,
        max_queue=10,
        flush_secs=120,
        filename_suffix="",
        write_to_disk=True,
        log_dir=None,
        **kwargs
    ):
        self.node_raw_dict = {}
        super().__init__(
            logdir,
            comment,
            purge_step,
            max_queue,
            flush_secs,
            filename_suffix,
            write_to_disk,
            log_dir,
            **kwargs,
        )

    def add_text(self, tag, text_string_list, global_step=None, walltime=None):
        """Add text data to summary.

        Args:
            tag (string): Data identifier
            text_string_list (string list): String to save
            global_step (int): Global step value to record
            walltime (float): Optional override default walltime (time.time())
                seconds after epoch of event

        Examples:

            .. code-block:: python

               # text can be divided into three levels by tag and global_step
               from writer import SummaryWriterExtend
               writer = SummaryWriterExtend()

               writer.add_text('level1.0/level2.0', ['text0'], 0)
               writer.add_text('level1.0/level2.0', ['text1'], 1)
               writer.add_text('level1.0/level2.1', ['text2'])
               writer.add_text('level1.1', ['text3'])
        """

        self._get_file_writer().add_summary(
            text(tag, text_string_list), global_step, walltime
        )

    def add_node_raw(
        self,
        name,
        op="UnSpecified",
        input=[],
        outputshape=None,
        dtype=None,
        attributes={},
    ):
        """Add node raw datas that can help build graph.

        After add all nodes, call ``add_graph_by_node_raw_list()`` to build
        graph and add graph data to summary.

        Args:
            name (string): opr name.
            op (string): opr class name.
            input (string list): input opr name.
            outputshape (list): output shape.
            dtype (string): output data dtype.
            attributes (dict): attributes info.

        Examples:

            .. code-block:: python

               from writer import SummaryWriterExtend
               writer = SummaryWriterExtend()

               writer.add_node_raw('node1', 'opr1', outputshape=[6, 2, 3], dtype="float32", attributes={
                       "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
               writer.add_node_raw('node2', 'opr2', outputshape=[6, 2, 3], dtype="float32", input="node1",  attributes={
                                   "peak_size": "12MB", "mmory_alloc": "2MB, percent: 16.7%"})
               writer.add_graph_by_node_raw_list()

        """
        # self.node_raw_list.append(
        #     node(name, op, input, outputshape, dtype, attributes))
        self.node_raw_dict[name] = NodeRaw(
            name, op, input, outputshape, dtype, dict(attributes)
        )

    def add_node_raw_name_suffix(self, name, suffix):
        """Give node name suffix in order to finding this node by 'search nodes'
        Args:
            name (string): opr name.
            suffix (string): nam suffix.
        """
        old_name = self.node_raw_dict[name].name
        new_name = old_name + suffix
        # self.node_raw_dict[new_name] = self.node_raw_dict.pop(name)
        self.node_raw_dict[name].name = new_name
        for node_name, node in self.node_raw_dict.items():
            node.input = [new_name if x == old_name else x for x in node.input]

    def add_node_raw_attributes(self, name, attributes):
        """
        Args:
            name (string): opr name.
            attributes (dict): attributes info that need to be added.
        """
        for key, value in attributes.items():
            self.node_raw_dict[name].attributes[key] = value

    def add_graph_by_node_raw_list(self):
        """Build graph and add graph data to summary."""
        node_raw_list = []
        for key, value in self.node_raw_dict.items():
            node_raw_list.append(
                node(
                    value.name,
                    value.op,
                    value.input,
                    value.outputshape,
                    value.dtype,
                    value.attributes,
                )
            )
        self._get_file_writer().add_graph(graph(node_raw_list))