from __future__ import annotations
import builtins
import sys
from typing import Any
import pytest
torch = pytest.importorskip("torch")
from oxillama_py import torch_helper
def _make_mock_engine(logit_data: "list[float]", embed_data: "list[float]") -> type:
import numpy as np
class _MockEngine:
def logits_dlpack(self, text: str, **kwargs: Any): return np.array(logit_data, dtype=np.float32)
def embeddings_dlpack(self, text: str, **kwargs: Any): return np.array(embed_data, dtype=np.float32)
return _MockEngine
def test_logits_torch_method_exists():
try:
from oxillama_py import Engine except ImportError:
pytest.skip("oxillama_py extension not built")
if Engine is None:
pytest.skip("oxillama_py extension not built")
assert hasattr(Engine, "logits_torch"), "Engine.logits_torch not patched"
def test_embeddings_torch_method_exists():
try:
from oxillama_py import Engine except ImportError:
pytest.skip("oxillama_py extension not built")
if Engine is None:
pytest.skip("oxillama_py extension not built")
assert hasattr(Engine, "embeddings_torch"), "Engine.embeddings_torch not patched"
def test_patch_engine_class_adds_methods():
MockCls = _make_mock_engine([1.0, 2.0, 3.0], [0.5, 0.5])
torch_helper.patch_engine_class(MockCls)
assert hasattr(MockCls, "logits_torch"), "logits_torch not added by patch_engine_class"
assert hasattr(MockCls, "embeddings_torch"), "embeddings_torch not added by patch_engine_class"
def test_patch_engine_class_is_idempotent():
MockCls = _make_mock_engine([1.0], [0.5])
torch_helper.patch_engine_class(MockCls)
torch_helper.patch_engine_class(MockCls)
assert hasattr(MockCls, "logits_torch")
assert hasattr(MockCls, "embeddings_torch")
def test_logits_torch_returns_tensor():
MockCls = _make_mock_engine([1.0, 2.0, 3.0], [0.5])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.logits_torch("test")
assert isinstance(result, torch.Tensor), (
f"logits_torch must return torch.Tensor, got {type(result)}"
)
def test_logits_torch_shape_matches():
data = [1.0, 2.0, 3.0, 4.0]
MockCls = _make_mock_engine(data, [0.0])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.logits_torch("test")
assert result.shape == (len(data),), (
f"Expected shape ({len(data)},), got {result.shape}"
)
def test_logits_torch_dtype_is_float32():
MockCls = _make_mock_engine([1.0, 2.0], [0.0])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.logits_torch("test")
assert result.dtype == torch.float32, (
f"Expected float32, got {result.dtype}"
)
def test_logits_torch_values_match_source():
data = [3.14, -1.0, 0.0, 999.9]
MockCls = _make_mock_engine(data, [0.0])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.logits_torch("test")
result_list = result.tolist()
for i, (got, expected) in enumerate(zip(result_list, data)):
assert abs(got - expected) < 1e-4, (
f"Value mismatch at index {i}: got {got}, expected {expected}"
)
def test_embeddings_torch_returns_tensor():
MockCls = _make_mock_engine([1.0], [0.1, 0.2, 0.3])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.embeddings_torch("test")
assert isinstance(result, torch.Tensor), (
f"embeddings_torch must return torch.Tensor, got {type(result)}"
)
def test_embeddings_torch_shape_matches():
embed_data = [0.1, 0.2, 0.3, 0.4]
MockCls = _make_mock_engine([1.0], embed_data)
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.embeddings_torch("test")
assert result.shape == (len(embed_data),), (
f"Expected shape ({len(embed_data)},), got {result.shape}"
)
def test_embeddings_torch_dtype_is_float32():
MockCls = _make_mock_engine([1.0], [0.5, 0.5])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
result = engine.embeddings_torch("test")
assert result.dtype == torch.float32, (
f"Expected float32, got {result.dtype}"
)
def test_no_torch_raises_helpful_error_for_logits(monkeypatch):
real_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "torch":
raise ImportError("No module named 'torch'")
return real_import(name, *args, **kwargs)
MockCls = _make_mock_engine([1.0, 2.0], [0.0])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
monkeypatch.setattr(builtins, "__import__", mock_import)
with pytest.raises(ImportError, match="torch"):
engine.logits_torch("test")
def test_no_torch_raises_helpful_error_for_embeddings(monkeypatch):
real_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "torch":
raise ImportError("No module named 'torch'")
return real_import(name, *args, **kwargs)
MockCls = _make_mock_engine([1.0], [0.5])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
monkeypatch.setattr(builtins, "__import__", mock_import)
with pytest.raises(ImportError, match="torch"):
engine.embeddings_torch("test")
def test_import_error_message_mentions_pip_install(monkeypatch):
real_import = builtins.__import__
def mock_import(name: str, *args: Any, **kwargs: Any) -> Any:
if name == "torch":
raise ImportError("No module named 'torch'")
return real_import(name, *args, **kwargs)
MockCls = _make_mock_engine([1.0], [0.0])
torch_helper.patch_engine_class(MockCls)
engine = MockCls()
monkeypatch.setattr(builtins, "__import__", mock_import)
with pytest.raises(ImportError) as exc_info:
engine.logits_torch("test")
assert "pip install torch" in str(exc_info.value), (
"ImportError must include 'pip install torch' suggestion"
)
def test_torch_helper_has_expected_functions():
assert hasattr(torch_helper, "patch_engine_class"), (
"torch_helper must expose patch_engine_class"
)
assert hasattr(torch_helper, "try_patch"), (
"torch_helper must expose try_patch"
)
def test_try_patch_silently_handles_no_engine():
import types
fake_module = types.ModuleType("fake_oxillama_py")
torch_helper.try_patch(fake_module)
def test_try_patch_silently_handles_engine_is_none():
import types
fake_module = types.ModuleType("fake_oxillama_py")
fake_module.Engine = None
torch_helper.try_patch(fake_module)
def test_try_patch_patches_engine_when_present():
import types
fake_module = types.ModuleType("fake_oxillama_py")
class FakeEngine:
pass
fake_module.Engine = FakeEngine
torch_helper.try_patch(fake_module)
assert hasattr(FakeEngine, "logits_torch"), (
"try_patch must add logits_torch to Engine"
)
assert hasattr(FakeEngine, "embeddings_torch"), (
"try_patch must add embeddings_torch to Engine"
)
def test_torch_helper_lazy_import():
assert hasattr(torch_helper, "patch_engine_class")
assert hasattr(torch_helper, "try_patch")
assert callable(torch_helper.patch_engine_class)
assert callable(torch_helper.try_patch)