import os
from contextlib import contextmanager
from ._imperative_rt.core2 import get_option, set_option
__compute_mode = "default"
__conv_format = "default"
_benchmark_kernel = False
_deterministic_kernel = False
__all__ = [
"benchmark_kernel",
"deterministic_kernel",
"async_level",
"_compute_mode",
"_conv_format",
"_override",
]
@property
def benchmark_kernel(mod):
return _benchmark_kernel
@benchmark_kernel.setter
def benchmark_kernel(mod, option: bool):
global _benchmark_kernel
_benchmark_kernel = option
@property
def deterministic_kernel(mod):
return _deterministic_kernel
@deterministic_kernel.setter
def deterministic_kernel(mod, option: bool):
global _deterministic_kernel
_deterministic_kernel = option
@property
def async_level(mod) -> int:
return get_option("async_level")
@async_level.setter
def async_level(mod, level: int):
assert level >= 0 and level <= 2, "async_level should be 0, 1 or 2"
set_option("async_level", level)
@property
def _compute_mode(mod):
return __compute_mode
@_compute_mode.setter
def _compute_mode(mod, _compute_mode: str):
global __compute_mode
__compute_mode = _compute_mode
@property
def _conv_format(mod):
return __conv_format
@_conv_format.setter
def _conv_format(mod, format: str):
global __conv_format
__conv_format = format
def _reset_execution_config(
benchmark_kernel=None,
deterministic_kernel=None,
async_level=None,
compute_mode=None,
conv_format=None,
):
global _benchmark_kernel, _deterministic_kernel, _async_level, __compute_mode, __conv_format
orig_flags = (
_benchmark_kernel,
_deterministic_kernel,
get_option("async_level"),
__compute_mode,
__conv_format,
)
if benchmark_kernel is not None:
_benchmark_kernel = benchmark_kernel
if deterministic_kernel is not None:
_deterministic_kernel = deterministic_kernel
if async_level is not None:
set_option("async_level", async_level)
if compute_mode is not None:
__compute_mode = compute_mode
if conv_format is not None:
__conv_format = conv_format
return orig_flags
@contextmanager
def _override(
benchmark_kernel=None,
deterministic_kernel=None,
async_level=None,
compute_mode=None,
conv_format=None,
):
orig_flags = _reset_execution_config(
benchmark_kernel, deterministic_kernel, async_level, compute_mode, conv_format,
)
try:
yield
finally:
_reset_execution_config(*orig_flags)
def _get_actual_op_param(function_param, config_param):
return function_param if config_param == "default" else config_param