import gc
import os
import resource
import tempfile
import numpy as np
import pytest
from ztensor import Reader, Writer
from ztensor._ztensor import _Reader as RustReader
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
def _get_rss_kb():
return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
def _walk_base_chain(arr):
obj = arr
while hasattr(obj, "base") and obj.base is not None:
obj = obj.base
yield obj
@pytest.fixture
def zt_file():
fd, path = tempfile.mkstemp(suffix=".zt")
os.close(fd)
tensors = {
"small": np.arange(100, dtype=np.float32),
"matrix": np.random.randn(64, 128).astype(np.float32),
"vector": np.linspace(-1.0, 1.0, 256, dtype=np.float32),
}
with Writer(path) as w:
for name, arr in tensors.items():
w.add(name, arr)
yield path, tensors
if os.path.exists(path):
os.remove(path)
@pytest.fixture
def large_zt_file():
fd, path = tempfile.mkstemp(suffix=".zt")
os.close(fd)
tensors = {
"big": np.random.randn(512, 512).astype(np.float32), "bigger": np.random.randn(256, 1024).astype(np.float32), "biggest": np.random.randn(1024, 512).astype(np.float32), }
with Writer(path) as w:
for name, arr in tensors.items():
w.add(name, arr)
yield path, tensors
if os.path.exists(path):
os.remove(path)
class TestBaseChain:
def test_zerocopy_base_chain_contains_reader(self, zt_file):
path, _ = zt_file
reader = Reader(path)
arr = reader["small"]
found_reader = any(isinstance(b, RustReader) for b in _walk_base_chain(arr))
assert found_reader, (
f"Reader not found in base chain. "
f"Chain types: {[type(b).__name__ for b in _walk_base_chain(arr)]}"
)
def test_copy_true_base_chain_excludes_reader(self, zt_file):
path, _ = zt_file
reader = Reader(path)
tensors = reader.read_numpy(reader.keys(), copy=True)
for name, arr in tensors.items():
found_reader = any(isinstance(b, RustReader) for b in _walk_base_chain(arr))
assert (
not found_reader
), f"copy=True array '{name}' has Reader in base chain"
def test_load_numpy_zerocopy_base_chain(self, zt_file):
path, _ = zt_file
reader = Reader(path)
tensors = reader.read_numpy(reader.keys(), copy=False)
for name, arr in tensors.items():
found_reader = any(isinstance(b, RustReader) for b in _walk_base_chain(arr))
assert (
found_reader
), f"copy=False array '{name}' missing _Reader in base chain"
def test_view_preserves_base_chain(self, zt_file):
path, expected = zt_file
reader = Reader(path)
arr = reader["matrix"]
view = arr[10:20, ::2] view2 = view[:, :32]
del reader
del arr
gc.collect()
np.testing.assert_array_equal(view2, expected["matrix"][10:20, ::2][:, :32])
def test_transpose_preserves_base_chain(self, zt_file):
path, expected = zt_file
reader = Reader(path)
arr = reader["matrix"]
transposed = arr.T
del reader
del arr
gc.collect()
np.testing.assert_array_equal(transposed, expected["matrix"].T)
class TestLifetime:
def test_single_array_survives(self, zt_file):
path, expected = zt_file
reader = Reader(path)
arr = reader["small"]
del reader
gc.collect()
np.testing.assert_array_equal(arr, expected["small"])
def test_bulk_load_survives(self, zt_file):
path, expected = zt_file
reader = Reader(path)
tensors = reader.read_numpy(reader.keys(), copy=False)
del reader
gc.collect()
for name, exp in expected.items():
np.testing.assert_array_equal(tensors[name], exp)
def test_partial_deletion_order(self, zt_file):
path, expected = zt_file
reader = Reader(path)
arr_small = reader["small"]
arr_matrix = reader["matrix"]
arr_vector = reader["vector"]
del arr_matrix
del reader
gc.collect()
np.testing.assert_array_equal(arr_small, expected["small"])
np.testing.assert_array_equal(arr_vector, expected["vector"])
def test_mixed_single_and_bulk_reads(self, zt_file):
path, expected = zt_file
reader = Reader(path)
single = reader["small"]
bulk = reader.read_numpy(reader.keys(), copy=False)
del reader
gc.collect()
np.testing.assert_array_equal(single, expected["small"])
for name, exp in expected.items():
np.testing.assert_array_equal(bulk[name], exp)
class TestCOW:
def test_write_triggers_cow_not_corruption(self, zt_file):
path, expected = zt_file
with open(path, "rb") as f:
original_bytes = f.read()
reader = Reader(path)
arr = reader["small"]
assert arr.flags[
"WRITEABLE"
], "Zero-copy array from .zt file should be writable (MAP_PRIVATE)"
arr[:] = 999.0
np.testing.assert_array_equal(arr, np.full(100, 999.0, dtype=np.float32))
del reader
del arr
gc.collect()
with open(path, "rb") as f:
after_bytes = f.read()
assert (
original_bytes == after_bytes
), "File corrupted by writing to zero-copy array"
reader2 = Reader(path)
np.testing.assert_array_equal(reader2["small"], expected["small"])
def test_two_readers_cow_isolation(self, zt_file):
path, expected = zt_file
reader1 = Reader(path)
reader2 = Reader(path)
arr1 = reader1["small"]
arr2 = reader2["small"]
if arr1.flags["WRITEABLE"]:
arr1[:] = -1.0
np.testing.assert_array_equal(arr2, expected["small"])
@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed")
class TestTorchZeroCopy:
def test_torch_zerocopy_survives_reader_deletion(self, zt_file):
path, expected = zt_file
reader = Reader(path)
tensors = reader.read_torch(reader.keys(), copy=False)
del reader
gc.collect()
for name, exp in expected.items():
np.testing.assert_array_equal(tensors[name].numpy(), exp)
def test_torch_copy_true_independent(self, zt_file):
path, expected = zt_file
reader = Reader(path)
tensors = reader.read_torch(reader.keys(), copy=True)
del reader
gc.collect()
for name, exp in expected.items():
np.testing.assert_array_equal(tensors[name].numpy(), exp)
def test_torch_zerocopy_numpy_roundtrip(self, zt_file):
path, expected = zt_file
reader = Reader(path)
tensors = reader.read_torch(reader.keys(), copy=False)
del reader
gc.collect()
for name, exp in expected.items():
np_arr = tensors[name].numpy()
np.testing.assert_array_equal(np_arr, exp)
class TestFileDeletion:
def test_file_deleted_while_arrays_alive(self):
fd, path = tempfile.mkstemp(suffix=".zt")
os.close(fd)
data = np.arange(1000, dtype=np.float32)
with Writer(path) as w:
w.add("tensor", data)
reader = Reader(path)
arr = reader["tensor"]
os.remove(path)
assert not os.path.exists(path)
np.testing.assert_array_equal(arr, data)
del reader
gc.collect()
np.testing.assert_array_equal(arr, data)
class TestMemoryLeaks:
def test_rss_stable_over_open_close_cycles(self, large_zt_file):
path, _ = large_zt_file
per_file_mb = os.path.getsize(path) / (1024 * 1024)
for _ in range(3):
r = Reader(path)
t = r.read_numpy(r.keys(), copy=False)
del t, r
gc.collect()
rss_before = _get_rss_kb()
for _ in range(20):
reader = Reader(path)
tensors = reader.read_numpy(reader.keys(), copy=False)
del tensors
del reader
gc.collect()
rss_after = _get_rss_kb()
growth_mb = (rss_after - rss_before) / 1024
assert growth_mb < per_file_mb * 2, (
f"RSS grew by {growth_mb:.1f} MB over 20 iterations "
f"(file is {per_file_mb:.1f} MB). Possible mmap leak."
)
def test_rss_stable_copy_true(self, large_zt_file):
path, _ = large_zt_file
per_file_mb = os.path.getsize(path) / (1024 * 1024)
for _ in range(3):
r = Reader(path)
t = r.read_numpy(r.keys(), copy=True)
del t, r
gc.collect()
rss_before = _get_rss_kb()
for _ in range(20):
reader = Reader(path)
tensors = reader.read_numpy(reader.keys(), copy=True)
del tensors
del reader
gc.collect()
rss_after = _get_rss_kb()
growth_mb = (rss_after - rss_before) / 1024
assert growth_mb < per_file_mb * 2, (
f"RSS grew by {growth_mb:.1f} MB over 20 copy=True iterations "
f"(file is {per_file_mb:.1f} MB). Possible leak in copy path."
)
class TestEdgeCases:
def test_empty_tensor_zerocopy(self):
fd, path = tempfile.mkstemp(suffix=".zt")
os.close(fd)
try:
empty = np.array([], dtype=np.float32)
with Writer(path) as w:
w.add("empty", empty)
reader = Reader(path)
arr = reader["empty"]
del reader
gc.collect()
assert arr.size == 0
finally:
if os.path.exists(path):
os.remove(path)
def test_scalar_tensor_zerocopy(self):
fd, path = tempfile.mkstemp(suffix=".zt")
os.close(fd)
try:
scalar = np.array(42.0, dtype=np.float32)
with Writer(path) as w:
w.add("scalar", scalar)
reader = Reader(path)
arr = reader["scalar"]
del reader
gc.collect()
np.testing.assert_equal(float(arr.flat[0]), 42.0)
finally:
if os.path.exists(path):
os.remove(path)
def test_concurrent_readers_same_file(self, zt_file):
path, expected = zt_file
reader1 = Reader(path)
reader2 = Reader(path)
arr1 = reader1["small"]
arr2 = reader2["small"]
del reader1
gc.collect()
np.testing.assert_array_equal(arr2, expected["small"])
np.testing.assert_array_equal(arr1, expected["small"])
del reader2
gc.collect()
np.testing.assert_array_equal(arr1, expected["small"])
np.testing.assert_array_equal(arr2, expected["small"])
def test_multiple_dtypes_zerocopy(self):
fd, path = tempfile.mkstemp(suffix=".zt")
os.close(fd)
try:
tensors = {
"f32": np.array([1.0, 2.0, 3.0], dtype=np.float32),
"f64": np.array([4.0, 5.0], dtype=np.float64),
"i32": np.array([10, 20, 30, 40], dtype=np.int32),
"u8": np.array([0, 128, 255], dtype=np.uint8),
}
with Writer(path) as w:
for name, arr in tensors.items():
w.add(name, arr)
reader = Reader(path)
loaded = reader.read_numpy(reader.keys(), copy=False)
del reader
gc.collect()
for name, exp in tensors.items():
np.testing.assert_array_equal(loaded[name], exp)
finally:
if os.path.exists(path):
os.remove(path)