def _is_seq(x):
return isinstance(x, (list, tuple)) or (
hasattr(x, "tolist") and not isinstance(x, (str, bytes))
)
def _flatten(x):
if hasattr(x, "tolist") and not isinstance(x, (str, bytes)):
x = x.tolist()
if isinstance(x, (list, tuple)):
for v in x:
for leaf in _flatten(v):
yield leaf
else:
yield x
def _shape(x):
if hasattr(x, "shape"):
return tuple(x.shape)
if isinstance(x, (list, tuple)):
if not x:
return (0,)
first = _shape(x[0])
return (len(x),) + first
return ()
def _fail(msg, err_msg=None):
if err_msg:
msg = f"{err_msg}\n{msg}"
raise AssertionError(msg)
def assert_equal(actual, desired, err_msg="", verbose=True):
a = list(_flatten(actual))
d = list(_flatten(desired))
if _shape(actual) != _shape(desired):
_fail(
f"shape mismatch: {_shape(actual)} vs {_shape(desired)}",
err_msg,
)
if len(a) != len(d):
_fail(
f"length mismatch: {len(a)} vs {len(d)}",
err_msg,
)
for i, (x, y) in enumerate(zip(a, d)):
if x != y:
extra = f" (item {i})" if verbose else ""
_fail(f"{x!r} != {y!r}{extra}", err_msg)
def assert_array_equal(actual, desired, err_msg="", verbose=True):
assert_equal(actual, desired, err_msg=err_msg, verbose=verbose)
def assert_allclose(
actual, desired, rtol=1e-7, atol=0.0, equal_nan=True,
err_msg="", verbose=True,
):
a = list(_flatten(actual))
d = list(_flatten(desired))
if len(a) != len(d):
_fail(f"length mismatch: {len(a)} vs {len(d)}", err_msg)
for i, (x, y) in enumerate(zip(a, d)):
x_nan = (x != x) if isinstance(x, float) else False
y_nan = (y != y) if isinstance(y, float) else False
if x_nan and y_nan:
if equal_nan:
continue
_fail(f"NaN at item {i} (equal_nan=False)", err_msg)
if x_nan or y_nan:
_fail(f"NaN mismatch at item {i}: {x!r} vs {y!r}", err_msg)
diff = abs(x - y)
tol = atol + rtol * abs(y)
if diff > tol:
extra = f" (item {i}, diff={diff}, tol={tol})" if verbose else ""
_fail(f"{x!r} not close to {y!r}{extra}", err_msg)
def assert_array_almost_equal(actual, desired, decimal=6, err_msg="", verbose=True):
tol = 1.5 * (10 ** -decimal)
assert_allclose(
actual, desired, rtol=0.0, atol=tol,
err_msg=err_msg, verbose=verbose,
)
def assert_almost_equal(actual, desired, decimal=7, err_msg="", verbose=True):
assert_array_almost_equal(
actual, desired, decimal=decimal,
err_msg=err_msg, verbose=verbose,
)
def assert_approx_equal(actual, desired, significant=7, err_msg=""):
if desired == 0.0:
tol = 10 ** -significant
else:
tol = abs(desired) * 10 ** (-significant + 1)
if abs(actual - desired) > tol:
_fail(
f"{actual!r} not approx equal to {desired!r} "
f"to {significant} significant digits",
err_msg,
)
def assert_array_less(x, y, err_msg=""):
a = list(_flatten(x))
b = list(_flatten(y))
for i, (p, q) in enumerate(zip(a, b)):
if not (p < q):
_fail(f"item {i}: {p!r} not < {q!r}", err_msg)
def assert_raises(exc_type, callable_, *args, **kwargs):
try:
callable_(*args, **kwargs)
except exc_type:
return
raise AssertionError(
f"expected {exc_type.__name__}, no exception raised"
)
def assert_warns(*args, **kwargs):
if args and callable(args[0]):
return args[0]()
return None
try:
import sys as _sys
IS_64BIT = _sys.maxsize > 2 ** 32
except ImportError:
IS_64BIT = True
try:
import platform as _platform
IS_PYPY = _platform.python_implementation() == "PyPy"
IS_WASM = _platform.machine().lower().startswith("wasm")
IS_MUSL = (
"musl" in _platform.libc_ver()[0].lower()
if hasattr(_platform, "libc_ver") else False
)
except ImportError:
IS_PYPY = False
IS_WASM = False
IS_MUSL = False
IS_PYSTON = False
IS_EDITABLE = False
IS_INSTALLED = True
NOGIL_BUILD = False
HAS_REFCOUNT = True
HAS_LAPACK64 = False
BLAS_SUPPORTS_FPE = True
NUMPY_ROOT = ""
verbose = 0
class IgnoreException(Exception):
class KnownFailureException(Exception):
class SkipTest(Exception):
class TestCase:
def assertEqual(self, a, b, msg=None):
if a != b:
raise AssertionError(msg or f"{a!r} != {b!r}")
def assertNotEqual(self, a, b, msg=None):
if a == b:
raise AssertionError(msg or f"{a!r} == {b!r}")
def assertTrue(self, x, msg=None):
if not x:
raise AssertionError(msg or f"{x!r} is not truthy")
def assertFalse(self, x, msg=None):
if x:
raise AssertionError(msg or f"{x!r} is not falsy")
def assertRaises(self, exc, fn, *args, **kw):
assert_raises(exc, fn, *args, **kw)
def assert_(condition, msg=""):
if not condition:
raise AssertionError(msg or "assertion failed")
def assert_array_compare(comparison, x, y, err_msg="", verbose=True,
header="", strict=False, equal_nan=True):
_ = (verbose, header, strict, equal_nan)
a = list(_flatten(x))
b = list(_flatten(y))
if len(a) != len(b):
_fail(f"length mismatch: {len(a)} vs {len(b)}", err_msg)
for i, (xi, yi) in enumerate(zip(a, b)):
if not comparison(xi, yi):
_fail(f"item {i}: comparison(xi={xi!r}, yi={yi!r}) is false", err_msg)
def assert_array_almost_equal_nulp(actual, desired, nulp=1):
import math
a = list(_flatten(actual))
d = list(_flatten(desired))
for i, (x, y) in enumerate(zip(a, d)):
if x == y:
continue
ulp = math.ulp(max(abs(float(x)), abs(float(y))))
if abs(float(x) - float(y)) > nulp * ulp:
raise AssertionError(f"item {i}: {x!r} vs {y!r} differs by > {nulp} ULP")
def assert_array_max_ulp(a, b, maxulp=1, dtype=None):
_ = dtype
assert_array_almost_equal_nulp(a, b, nulp=maxulp)
return [maxulp]
def assert_no_gc_cycles(*args, **kwargs):
if args and callable(args[0]):
return args[0](*args[1:], **kwargs)
def assert_no_warnings(*args, **kwargs):
if args and callable(args[0]):
return args[0](*args[1:], **kwargs)
def assert_raises_regex(exc_type, regex, callable_, *args, **kwargs):
import re
try:
callable_(*args, **kwargs)
except exc_type as e:
if not re.search(regex, str(e)):
raise AssertionError(
f"exception message {str(e)!r} did not match regex {regex!r}"
)
return
raise AssertionError(f"expected {exc_type.__name__}, none raised")
def assert_string_equal(actual, desired):
if actual != desired:
raise AssertionError(f"strings differ:\n actual: {actual!r}\n desired: {desired!r}")
def break_cycles():
pass
def build_err_msg(arrays, err_msg, header="Arrays are not equal",
verbose=True, names=("ACTUAL", "DESIRED"), precision=8):
_ = (verbose, precision)
parts = [header]
if err_msg:
parts.append(err_msg)
for n, a in zip(names, arrays):
parts.append(f"{n}: {a!r}")
return "\n".join(parts)
def check_support_sve():
return False
class clear_and_catch_warnings:
def __init__(self, record=False, modules=()):
_ = (record, modules)
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def decorate_methods(cls, decorator, testmatch=None):
_ = testmatch
for name in dir(cls):
if name.startswith("test_"):
setattr(cls, name, decorator(getattr(cls, name)))
def extbuild(*args, **kwargs):
raise NotImplementedError("extbuild requires a C toolchain")
def jiffies(_proc_pid_stat="", _load_time=()):
return 0
def measure(code_str, times=1, label=""):
_ = label
import time
start = time.time()
for _ in range(times):
exec(code_str)
return time.time() - start
def memusage(_proc_pid_stat=""):
return 0
class overrides:
ARRAY_FUNCTIONS = set()
def print_assert_equal(test_string, actual, desired):
print(test_string)
assert_equal(actual, desired)
def run_threaded(func, n_threads=2, args=(), kwargs=None):
import threading
kwargs = kwargs or {}
threads = [threading.Thread(target=func, args=args, kwargs=kwargs)
for _ in range(n_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
def rundocs(filename=None, raise_on_error=True):
_ = (filename, raise_on_error)
return True
def runstring(astr, dict_):
exec(astr, dict_)
class suppress_warnings:
def __init__(self, forwarding_rule="always"):
_ = forwarding_rule
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def filter(self, *args, **kwargs):
pass
def record(self, *args, **kwargs):
return []
def tempdir(*args, **kwargs):
import tempfile
return tempfile.TemporaryDirectory(*args, **kwargs)
def temppath(*args, **kwargs):
import tempfile
return tempfile.NamedTemporaryFile(*args, **kwargs)
def test(*args, **kwargs):
_ = (args, kwargs)
return True
__all__ = [
"assert_equal",
"assert_array_equal",
"assert_allclose",
"assert_array_almost_equal",
"assert_almost_equal",
"assert_approx_equal",
"assert_array_less",
"assert_raises",
"assert_warns",
"assert_",
"assert_array_compare",
"assert_array_almost_equal_nulp",
"assert_array_max_ulp",
"assert_no_gc_cycles",
"assert_no_warnings",
"assert_raises_regex",
"assert_string_equal",
"break_cycles",
"build_err_msg",
"check_support_sve",
"clear_and_catch_warnings",
"decorate_methods",
"extbuild",
"jiffies",
"measure",
"memusage",
"overrides",
"print_assert_equal",
"run_threaded",
"rundocs",
"runstring",
"suppress_warnings",
"tempdir",
"temppath",
"test",
"IgnoreException",
"KnownFailureException",
"SkipTest",
"TestCase",
"IS_64BIT",
"IS_PYPY",
"IS_PYSTON",
"IS_WASM",
"IS_MUSL",
"IS_EDITABLE",
"IS_INSTALLED",
"NOGIL_BUILD",
"HAS_REFCOUNT",
"HAS_LAPACK64",
"BLAS_SUPPORTS_FPE",
"NUMPY_ROOT",
"verbose",
]