import os
from ..core import _config
from ..core.ops import builtin
from ..logger import get_logger
from ..utils.deprecation import deprecated
Strategy = builtin.ops.Convolution.Strategy
if os.getenv("MEGENGINE_CONV_EXECUTION_STRATEGY") != None:
get_logger().warning(
"Environment variable `MEGENGINE_CONV_EXECUTION_STRATEGY` is deprecated, please use `MEGENGINE_EXECUTION_STRATEGY`"
)
_valid_string_option = {
"REPRODUCIBLE": Strategy.REPRODUCIBLE,
"HEURISTIC": Strategy.HEURISTIC,
"PROFILE": Strategy.PROFILE,
}
def get_execution_strategy() -> Strategy:
strategy = Strategy(0)
if _config._benchmark_kernel:
strategy |= Strategy.PROFILE
else:
strategy |= Strategy.HEURISTIC
if _config._deterministic_kernel:
strategy |= Strategy.REPRODUCIBLE
return strategy
def set_execution_strategy(option):
if isinstance(option, Strategy):
_config._benchmark_kernel = (
True if option & _valid_string_option["PROFILE"] != Strategy(0) else False
)
_config._deterministic_kernel = (
True
if option & _valid_string_option["REPRODUCIBLE"] != Strategy(0)
else False
)
return
assert isinstance(option, str)
_config._benchmark_kernel = False
_config._deterministic_kernel = False
for opt in option.split("_"):
if not opt in _valid_string_option:
raise ValueError(
"Valid option can only be one of {}, or combine them with '_'.".format(
_valid_string_option.keys()
)
)
_config._benchmark_kernel |= _valid_string_option[opt] == Strategy.PROFILE
_config._deterministic_kernel |= (
_valid_string_option[opt] == Strategy.REPRODUCIBLE
)
@deprecated(version="1.3", reason="use get_execution_strategy() instead")
def get_conv_execution_strategy() -> str:
return get_execution_strategy()
@deprecated(version="1.3", reason="use set_execution_strategy() instead")
def set_conv_execution_strategy(option: str):
return set_execution_strategy(option)
set_execution_strategy(os.getenv("MEGENGINE_EXECUTION_STRATEGY", "HEURISTIC"))