megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
from contextlib import contextmanager

from ._imperative_rt.core2 import get_option, set_option

__compute_mode = "default"
__conv_format = "default"
_benchmark_kernel = False
_deterministic_kernel = False

__all__ = [
    "benchmark_kernel",
    "deterministic_kernel",
    "async_level",
    "_compute_mode",
    "_conv_format",
    "_override",
]


@property
def benchmark_kernel(mod):
    r"""Whether or not run possible algorithms on real device to find the best one. The default option is false,
    which means use heuristic to choose the fastest algorithm.
    
    Examples:    
        .. code-block::

           import megengine as mge
           mge.config.benchmark_kernel = True
    """
    return _benchmark_kernel


@benchmark_kernel.setter
def benchmark_kernel(mod, option: bool):
    global _benchmark_kernel
    _benchmark_kernel = option


@property
def deterministic_kernel(mod):
    r"""Whether or not the fastest algorithm choosed is reproducible. The default option is false,
    which means the algorithm is not reproducible.
    
    Examples:    
        .. code-block::

           import megengine as mge
           mge.config.deterministic_kernel = True
    """
    return _deterministic_kernel


@deterministic_kernel.setter
def deterministic_kernel(mod, option: bool):
    global _deterministic_kernel
    _deterministic_kernel = option


@property
def async_level(mod) -> int:
    r"""Get or set config whether raise error exactly when invoking op. The default level is 2,
    which means both device and user side errors are async.
    
    Examples:    
        .. code-block::

           import megengine as mge
           mge.config.async_level = 2
    """
    return get_option("async_level")


@async_level.setter
def async_level(mod, level: int):
    assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
    set_option("async_level", level)


@property
def _compute_mode(mod):
    r"""Get or set the precision of intermediate results. The default option is "default",
    which means that no special requirements will be placed on.  When set to 'float32', it
    would be used for accumulator and intermediate result, but only effective when input and 
    output are of float16 dtype.
    
    Examples:    
        .. code-block::

           import megengine as mge
           mge.config._compute_mode = "default"
    """
    return __compute_mode


@_compute_mode.setter
def _compute_mode(mod, _compute_mode: str):
    global __compute_mode
    __compute_mode = _compute_mode


@property
def _conv_format(mod):
    r"""Get or set convolution data/filter/output layout format. The default option is "default",
    which means that no special format will be placed on. There are all layout definitions

    ``NCHW`` layout: ``{N, C, H, W}``
    ``NHWC`` layout: ``{N, H, W, C}``
    ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}``
    ``NHWCD4I`` layout: with ``align_axis = 2``
    ``NCHW4`` layout: ``{N, C/4, H, W, 4}``
    ``NCHW88`` layout: ``{N, C/8, H, W, 8}``
    ``CHWN4`` layout: ``{C/4, H, W, N, 4}``
    ``NCHW64`` layout: ``{N, C/64, H, W, 64}``
    
    Examples:    
        .. code-block::

           import megengine as mge
           mge.config._conv_format = "NHWC"
    """
    return __conv_format


@_conv_format.setter
def _conv_format(mod, format: str):
    global __conv_format
    __conv_format = format


def _reset_execution_config(
    benchmark_kernel=None,
    deterministic_kernel=None,
    async_level=None,
    compute_mode=None,
    conv_format=None,
):
    global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format
    orig_flags = (
        _benchmark_kernel,
        _deterministic_kernel,
        get_option("async_level"),
        __compute_mode,
        __conv_format,
    )
    if benchmark_kernel is not None:
        _benchmark_kernel = benchmark_kernel
    if deterministic_kernel is not None:
        _deterministic_kernel = deterministic_kernel
    if async_level is not None:
        set_option("async_level", async_level)
    if compute_mode is not None:
        __compute_mode = compute_mode
    if conv_format is not None:
        __conv_format = conv_format

    return orig_flags


@contextmanager
def _override(
    benchmark_kernel=None,
    deterministic_kernel=None,
    async_level=None,
    compute_mode=None,
    conv_format=None,
):
    r"""A context manager that users can opt in by attaching the decorator to set 
    the config of the global variable.
    
    Examples:    
        .. code-block::

           import megengine as mge
           
           @mge.config._override(
                benchmark_kernel = True,
                deterministic_kernel = Fasle,
                async_level=2,
                compute_mode="float32",
                conv_format="NHWC",
            )
           def train():
    """
    orig_flags = _reset_execution_config(
        benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format,
    )
    try:
        yield
    finally:
        # recover the previous values
        _reset_execution_config(*orig_flags)


def _get_actual_op_param(function_param, config_param):
    return function_param if config_param == "default" else config_param