megenginelite-sys 1.8.2

A safe megenginelite wrapper in Rust
Documentation
from ..core._imperative_rt.core2 import Const
from ..jit.tracing import is_tracing

small_tensor_cache = {}


def _get_scalar_tensor_with_value(value, dtype=None, device=None):
    global small_tensor_cache
    if is_tracing():
        ret = Const(value, dtype, device, None)
    else:
        cache_key = (value, dtype, device)
        if cache_key not in small_tensor_cache:
            ret = Const(value, dtype, device, None)
            small_tensor_cache[cache_key] = ret
        else:
            ret = small_tensor_cache[cache_key]
    return ret


def get_scalar_zero(dtype=None, device=None):
    return _get_scalar_tensor_with_value(0, dtype, device)


def get_scalar_zero_point_five(dtype=None, device=None):
    return _get_scalar_tensor_with_value(0.5, dtype, device)


def get_scalar_one(dtype=None, device=None):
    return _get_scalar_tensor_with_value(1, dtype, device)


def get_scalar_two(dtype=None, device=None):
    return _get_scalar_tensor_with_value(2, dtype, device)