megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import collections
import heapq
from collections import OrderedDict
from typing import Dict, List, Tuple, Union

import numpy as np

from ..core import _imperative_rt
from ..core._imperative_rt import GraphProfiler
from ..core._imperative_rt import OperatorNode as _OpNode
from ..core._imperative_rt import VarNode as _VarNode
from ..core.tensor import megbrain_graph as G
from ..core.tensor.megbrain_graph import set_priority_to_id
from ..tensor import Tensor

__all__ = [
    "get_dep_vars",
    "get_owner_opr_inputs",
    "get_owner_opr_type",
    "get_opr_type",
    "graph_traversal",
    "get_oprs_seq",
    "replace_vars",
    "replace_oprs",
    "set_priority_to_id",
    "GraphInference",
]


def get_dep_vars(
    var: Union[_VarNode, List[_VarNode]], var_type: Union[str, List[str]] = None
) -> List[_VarNode]:
    r"""Returns :class:`.tensor.core.megbrain_graph.VarNode` of type ``var_type`` that input ``var``
    depands on. If ``var_type`` is None, returns all types.
    """
    outputs = []
    memo = set()

    if isinstance(var, _VarNode):
        var = [var]

    if isinstance(var_type, str):
        var_type = [var_type]

    q = list(var)
    while q:
        v = q.pop(0)
        if v in memo:
            continue
        memo.add(v)
        q.extend(get_owner_opr_inputs(v))
        if var_type is not None:
            if get_owner_opr_type(v) in var_type:
                outputs.append(v)
        else:
            outputs.append(v)

    return outputs


def get_owner_opr_inputs(var: _VarNode) -> List[_VarNode]:
    r"""Gets the inputs of owner opr of a variable. """
    return var.owner.inputs


def get_owner_opr_type(var: _VarNode) -> str:
    r"""Gets the type of owner opr of a variable."""
    return var.owner.type


def get_opr_type(opr: _OpNode) -> str:
    r"""Gets the type of an opr."""
    assert isinstance(opr, _OpNode)
    return opr.type


class _OprStableOrderHeapq:
    r"""heap implementation for operator comparison in stable order"""

    _list = None
    _extra_priority = None
    _used_id_name_pairs = None

    def __init__(self, extra_priority):
        assert isinstance(extra_priority, collections.Callable)
        self._list = []
        self._extra_priority = extra_priority
        self._used_id_name_pairs = {}

    def pop_min(self):
        return heapq.heappop(self._list)[-1]

    def add(self, opr):
        # named as add to mimic set() interface

        id_ = opr.id
        name = opr.name

        other = self._used_id_name_pairs.setdefault((id_, name), opr)
        if other is not opr:
            raise RuntimeError(
                "duplicated (id, name) pair: opr0={} opr1={}".format(other, opr)
            )

        item = self._extra_priority(opr) + (id_, name, opr)
        heapq.heappush(self._list, item)

    def __bool__(self):
        return bool(self._list)


def graph_traversal(outputs: _VarNode):
    r"""Helper function to traverse the computing graph and return enough useful information.

    Args:
        outputs: model outputs.

    Returns:
        tuple (map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree)

        WHERE

        * map_oprs is dict from opr_id to actual opr
        * map_vars is dict from var_id to actual var
        * var2oprs is dict from var to dest oprs along with index
        * opr2receivers is dict from current opr to next opr
        * indegree2opr is dict from in_degree to opr in computing graph
        * opr2indegree is dict from opr in computing graph to in_degree

        (indegree2opr, opr2indegree) are only used in topological sort in get_oprs_seq function
    """
    # meta information for comp graph
    map_oprs = collections.defaultdict(set)
    map_vars = collections.defaultdict(set)

    var2oprs = collections.defaultdict(list)
    opr2receivers = collections.defaultdict(list)
    queue = []
    [queue.append(o) for o in [x.owner for x in outputs] if o not in queue]
    visited = set(map(lambda x: x.id, queue))

    # iterate through whole comp_graph, fill in meta information
    indegree2opr = collections.defaultdict(set)
    indegree2opr[0] = _OprStableOrderHeapq(lambda op: (op.priority,))
    opr2indegree = {}

    idx = 0
    while idx < len(queue):
        cur_opr = queue[idx]
        map_oprs[cur_opr.id] = cur_opr

        idx += 1

        indegree = 0
        for var_idx, var in enumerate(cur_opr.inputs):
            map_vars[var.id] = var
            var2oprs[var.id].append((cur_opr.id, var_idx))

            pre_opr = var.owner

            if pre_opr.id not in visited:
                visited.add(pre_opr.id)
                queue.append(pre_opr)

            indegree += 1
            opr2receivers[pre_opr.id].append(cur_opr.id)
        opr = cur_opr if indegree == 0 else cur_opr.id
        indegree2opr[indegree].add(opr)
        opr2indegree[cur_opr.id] = indegree

    return map_oprs, map_vars, var2oprs, opr2receivers, indegree2opr, opr2indegree


def get_oprs_seq(
    outputs: List[_VarNode], prune_reshape=False, prune_immtensor=True
) -> List[_OpNode]:
    r"""Gets oprs in some topological order for a dumped model.

    Args:
        outputs: model outputs.
        prune_reshape: whether to prune the useless operators used by Reshape opr during inference.
        prune_immtensor: whether to prune the ImmutableTensor opr.

    Returns:
        opr list with some correct execution order.
    """

    def topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree):
        # generate an execution order with topological sort algorithm
        oprs_seq = []
        nr_remain = len(map_oprs)
        while indegree2opr[0]:
            opr = indegree2opr[0].pop_min()
            opr_id = opr.id
            nr_remain -= 1
            if opr.type != "ImmutableTensor" or not prune_immtensor:
                oprs_seq.append(opr)

            for post_id in opr2receivers[opr_id]:
                indegree = opr2indegree[post_id]
                indegree2opr[indegree].remove(post_id)

                indegree -= 1
                if indegree == 0:
                    indegree2opr[indegree].add(map_oprs[post_id])
                else:
                    indegree2opr[indegree].add(post_id)
                opr2indegree[post_id] = indegree

        assert nr_remain == 0, "there are {} remaining nodes; cyclic graph?".format(
            nr_remain
        )
        return oprs_seq

    # reshape op definition: reshape(input_tensor, dest_shape) -> output_tensor
    # when inferencing, shape of output_tensor is already known, so one can prune some operators related to dest_shape in the loaded graph
    def prune_reshape_oprs(outputs, oprs_seq, var2oprs):
        def iterative_pruning(cur_opr, post_opr, marked_opr_ids, visited):
            useless = True
            for oup in cur_opr.outputs:
                if "workspace" not in oup.name:
                    var_idx = post_opr.inputs.index(oup)
                    var2oprs[oup.id].remove((post_opr.id, var_idx))
                    useless = useless and (len(var2oprs[oup.id]) == 0)

            if useless:
                marked_opr_ids.append(cur_opr.id)

                for opr in set([var.owner for var in cur_opr.inputs]):
                    if (opr.id, cur_opr.id) not in visited:
                        visited.add((opr.id, cur_opr.id))
                        iterative_pruning(opr, cur_opr, marked_opr_ids, visited)

        reshape_vars = get_dep_vars(outputs, "Reshape")
        reshape_oprs = [var.owner for var in reshape_vars]

        marked_opr_ids = []
        visited = set()
        for reshape_opr in reshape_oprs:
            iterative_pruning(
                reshape_opr.inputs[1].owner, reshape_opr, marked_opr_ids, visited
            )

        # filter out all marked oprs
        return list(filter(lambda x: x.id not in marked_opr_ids, oprs_seq))

    # adjust the order of oprs, let param/data privoder oprs close to the oprs which use them as inputs.
    def reorder_oprs_seq(oprs):
        rst = []
        param_or_data_provider_oprs = []
        other_oprs = []

        for o in oprs:
            if o.type in ["ImmutableTensor", "Host2DeviceCopy"]:
                param_or_data_provider_oprs.append(o)
            else:
                other_oprs.append(o)

        for o in other_oprs:
            for inp in o.inputs:
                if inp.owner.type in ["ImmutableTensor", "Host2DeviceCopy"]:
                    if inp.owner in param_or_data_provider_oprs:
                        rst.append(inp.owner)
                        param_or_data_provider_oprs.remove(inp.owner)
            rst.append(o)
        rst = rst + param_or_data_provider_oprs
        assert len(rst) == len(oprs)
        return rst

    map_oprs, _, var2oprs, opr2receivers, indegree2opr, opr2indegree = graph_traversal(
        outputs
    )
    oprs_seq = topological_sort(map_oprs, opr2receivers, indegree2opr, opr2indegree)
    oprs_seq = reorder_oprs_seq(oprs_seq)
    if prune_reshape is True:
        oprs_seq = prune_reshape_oprs(outputs, oprs_seq, var2oprs.copy())
    return oprs_seq


def replace_vars(
    dst: List[_VarNode], varmap: Dict[_VarNode, _VarNode]
) -> List[_VarNode]:
    r"""Replaces vars in the graph.

    Args:
        dst: target vars representing the graph.
        varmap: the map that specifies how to replace the vars.

    Returns:
        new vars that correspond to ``dst`` with all the dependencies replaced.
    """
    dst_vec = []
    repl_src_vec = []
    repl_dst_vec = []
    for i in dst:
        assert isinstance(i, _VarNode)
        dst_vec.append(i)

    for i, j in getattr(varmap, "items", lambda: varmap)():
        assert isinstance(i, _VarNode)
        assert isinstance(j, _VarNode)
        repl_src_vec.append(i)
        repl_dst_vec.append(j)

    return _imperative_rt.graph._replace_vars(repl_src_vec, repl_dst_vec, dst_vec)


def replace_oprs(dst: List[_VarNode], oprmap: Dict[_OpNode, _OpNode]) -> List[_VarNode]:
    """Replaces operators in the graph.

    Args:
        dst: target vars representing the graph.
        oprmap: the map that specifies how to replace the operators.

    Returns:
        new vars that correspond to ``dst`` with all the dependencies replaced.
    """
    dst_vec = []
    repl_src_vec = []
    repl_dst_vec = []
    for i in dst:
        assert isinstance(i, _VarNode)
        dst_vec.append(i)

    for i, j in getattr(oprmap, "items", lambda: oprmap)():
        assert isinstance(i, _OpNode)
        assert isinstance(j, _OpNode)
        repl_src_vec.append(i)
        repl_dst_vec.append(j)

    return _imperative_rt.graph._replace_oprs(repl_src_vec, repl_dst_vec, dst_vec)


def find_vars_by_name(dst: List[_VarNode], names: List[str]) -> List[_VarNode]:
    r"""Gets VarNode list by names in the graph.

    Args:
        dst: target vars representing the graph.
        names: name list for target VarNode.

    Returns:
        results found by names.
    """
    output_names = names.copy()
    all_vars = get_dep_vars(dst) + dst
    # use dict to keep outputs order the same as names.
    output_dict = {}
    for i in all_vars:
        if i.name in output_names:
            output_dict[i.name] = i
            output_names.remove(i.name)
    assert len(output_names) == 0, "Can not find varnode {} in this model".format(
        output_names
    )
    return [output_dict[i] for i in names]


def convert_inputs(
    dst: List[_VarNode], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
    r"""Replaces ``Host2DeviceCopy`` with :class:`~.InputNode` in the graph
    to :meth:`~.InputNode.set_value` and run.

    Args:
        dst: target vars representing the graph.
        inputs: indicates which inputs to be replaced. All
            inputs(``Host2DeiceCopy``) will be replaced if not specified.

    Returns:
        new vars that correspond to ``dst`` with all inputs replaced, and new inputs dict.
    """
    if inputs is None:
        inputs = get_dep_vars(dst, "Host2DeviceCopy")
    input_dict = OrderedDict()
    replace_dict = {}
    for inp in inputs:
        inp_node = G.InputNode(
            device=inp.comp_node, dtype=inp.dtype, shape=inp.shape, graph=inp.graph,
        )
        inp_node.name = inp.name
        input_dict[inp.name] = inp_node
        replace_dict[inp] = inp_node.outputs[0]
    new_output_nodes = replace_vars(dst, replace_dict)
    for old, new in zip(dst, new_output_nodes):
        new.name = old.name
    return new_output_nodes, input_dict


def convert_outputs(dst: List[_VarNode]) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
    r"""Wraps ``dst`` with :class:`~.OutputNode` in the graph to get outputs
    with :meth:`~.OutputNode.get_value`.

    Args:
        dst: target vars representing the graph.

    Returns:
        new vars that correspond to ``dst`` with all inputs replaced, and outputs dict.
    """
    output_dict = OrderedDict([(i.name, G.OutputNode(i)) for i in dst])
    new_output_nodes = [i.outputs[0] for i in output_dict.values()]
    return new_output_nodes, output_dict


def embed_inputs(
    dst: List[_VarNode], data: List[np.ndarray], inputs: List[_VarNode] = None
) -> Tuple[List[_VarNode], Dict[str, _VarNode]]:
    r"""Embeds ``data`` to the graph's inputs of ``dst``.

    Args:
        dst: target vars representing the graph.
        data: data to be embeded.
        inputs: indicates which inputs to be replaced. All
            inputs(``Host2DeiceCopy``) will be replaced if not specified.

    Returns:
      new vars that correspond to ``dst`` with all inputs replaced, and new inputs dict.
    """
    if inputs is None:
        inputs = get_dep_vars(dst, "Host2DeviceCopy")
    assert len(data) == len(inputs)
    input_dict = OrderedDict()
    replace_dict = {}
    for inp, d in zip(inputs, data):
        new_inp = _imperative_rt.make_shared(inp.graph, Tensor(d)._dev_tensor())
        new_inp.name = inp.name
        input_dict[inp.name] = new_inp
        replace_dict[inp] = new_inp
    new_output_nodes = replace_vars(dst, replace_dict)
    for old, new in zip(dst, new_output_nodes):
        new.name = old.name
    return new_output_nodes, input_dict


class GraphInference:
    r"""Loads a serialized computing graph as a GraphInference object which can be used
    to execute the computing graph.

    Args:
        file: could be file object or filename.
        outputs: only compile the subgraph with outputs as its endpoints.
    """

    def __init__(
        self,
        file,
        outputs: List[str] = None,
        profiling: bool = False,
        optimize_for_inference: bool = False,
        **kwargs
    ):
        ret = G.load_graph(file)
        self._graph, output_nodes = ret.graph, ret.output_vars_list
        if outputs is not None:
            output_nodes = find_vars_by_name(output_nodes, outputs)
        self._origin_outputs = output_nodes

        # replace inputs with `InputNode`
        output_nodes, self._inp_dict = convert_inputs(output_nodes)

        # replace outputs with `OutputNode`
        output_nodes, self._oup_dict = convert_outputs(output_nodes)

        self._func = self._graph.compile(output_nodes)

    def run(
        self, *inp_args: np.ndarray, inp_dict: Dict[str, np.ndarray] = None
    ) -> Dict[str, np.ndarray]:
        r"""

        Args:
            inp_args: list of input datas.
            inp_dict: dict of named input datas.

        Returns:
            a dict {output_name: output_value}.

        Note:
            Note that the order of the Graph's input nodes may be different from the order of the origin traced function's arguments.
            It is recommended to use ``inp_dict`` to provide input data by name.
        """
        assert len(inp_args) <= len(
            self._inp_dict
        ), "This model expects {} inputs".format(len(self._inp_dict))
        inputs = {}
        inp_keys = list(self._inp_dict.keys())
        for ind, data in enumerate(inp_args):
            inputs[inp_keys[ind]] = data
        if inp_dict is not None:
            inputs.update(inp_dict)
        assert (
            inputs.keys() == self._inp_dict.keys()
        ), "This model expects inputs {}, but gets inputs {}".format(
            list(self._inp_dict.keys()), list(inputs.keys())
        )
        for key in self._inp_dict:
            self._inp_dict[key].set_value(
                Tensor(inputs[key], device=self._inp_dict[key].device)._dev_tensor()
            )
        self._func.execute()
        self._func.wait()

        result = OrderedDict()
        for key in self._oup_dict:
            result[key] = self._oup_dict[key].get_value().numpy()
        return result