import pytest
import sys
import os
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root / "python"))
class MockEvlib:
def __init__(self):
self.formats = MockFormats()
self.representations = MockRepresentations()
self.visualization = MockVisualization()
self.augmentation = MockAugmentation()
self.processing = MockProcessing()
class MockFormats:
def load_events(self, file_path, **kwargs):
import numpy as np
n_events = 10
xs = np.random.randint(0, 640, n_events)
ys = np.random.randint(0, 480, n_events)
ts = np.linspace(0, 1, n_events)
ps = np.random.choice([-1, 1], n_events)
return xs, ys, ts, ps
def load_events_filtered(self, file_path, **kwargs):
return self.load_events(file_path, **kwargs)
def save_events_to_hdf5(self, xs, ys, ts, ps, file_path):
pass
class MockRepresentations:
def events_to_voxel_grid(self, xs, ys, ts, ps, n_bins, shape):
import numpy as np
h, w = shape
voxel_data = np.random.rand(n_bins, h, w).astype(np.float32)
voxel_shape_data = (n_bins, h, w)
voxel_shape_shape = (n_bins, h, w)
return voxel_data, voxel_shape_data, voxel_shape_shape
def events_to_smooth_voxel_grid(self, xs, ys, ts, ps, n_bins, shape):
return self.events_to_voxel_grid(xs, ys, ts, ps, n_bins, shape)
class MockVisualization:
def draw_events_to_image(self, xs, ys, ps, width, height):
import numpy as np
return np.random.rand(height, width)
class MockAugmentation:
def flip_events_x(self, xs, ys, ts, ps, shape):
xs_flipped = shape[0] - xs
return xs_flipped, ys, ts, ps
def add_random_events(self, xs, ys, ts, ps, n_events, shape):
import numpy as np
new_xs = np.concatenate([xs, np.random.randint(0, shape[0], n_events)])
new_ys = np.concatenate([ys, np.random.randint(0, shape[1], n_events)])
new_ts = np.concatenate([ts, np.random.uniform(ts.min(), ts.max(), n_events)])
new_ps = np.concatenate([ps, np.random.choice([-1, 1], n_events)])
return new_xs, new_ys, new_ts, new_ps
class MockProcessing:
def download_model(self, model_name):
return f"/mock/path/to/{model_name}"
def events_to_video(self, xs, ys, ts, ps, model_path, width, height):
import numpy as np
return np.random.rand(height, width, 3)
_global_namespace = {}
@pytest.fixture(autouse=True, scope="session")
def setup_global_namespace():
global _global_namespace
import numpy as np
import time
try:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def mock_show(*args, **kwargs):
pass
class MockFigure:
def __init__(self, *args, **kwargs):
pass
def show(self):
pass
def mock_figure(*args, **kwargs):
return MockFigure(*args, **kwargs)
plt.show = mock_show
plt.figure = mock_figure
_global_namespace["plt"] = plt
_global_namespace["matplotlib"] = matplotlib
except ImportError:
pass
try:
import evlib
_global_namespace["evlib"] = evlib
except ImportError:
_global_namespace["evlib"] = MockEvlib()
_global_namespace["np"] = np
_global_namespace["numpy"] = np
_global_namespace["time"] = time
project_root = Path(__file__).parent.parent
original_cwd = os.getcwd()
os.chdir(project_root)
_global_namespace["_original_cwd"] = original_cwd
yield _global_namespace
os.chdir(original_cwd)
@pytest.fixture(autouse=True)
def inject_global_namespace(setup_global_namespace):
import builtins
original_globals = getattr(builtins, "__dict__", {}).copy()
for name, value in setup_global_namespace.items():
if not name.startswith("_"):
setattr(builtins, name, value)
yield
for name in list(builtins.__dict__.keys()):
if name in setup_global_namespace and not name.startswith("_"):
if name in original_globals:
setattr(builtins, name, original_globals[name])
else:
delattr(builtins, name)
def pytest_configure(config):
config.addinivalue_line("markers", "docs: marks tests as documentation tests")
config.addinivalue_line("markers", "slow: marks tests as slow")
config.addinivalue_line(
"markers", "requires_data: marks tests requiring data files"
)
config.addinivalue_line("markers", "requires_evlib: marks tests requiring evlib")
config.addinivalue_line(
"markers", "requires_matplotlib: marks tests requiring matplotlib"
)
def pytest_runtest_setup(item):
if "docs" in str(item.fspath):
item.add_marker(pytest.mark.docs)
if hasattr(item, "obj") and hasattr(item.obj, "__globals__"):
item.obj.__globals__.update(_global_namespace)