import os
import pytest
def _native_available() -> bool:
try:
import oxillama_py.oxillama_py return True
except ImportError:
return False
_REQUIRES_NATIVE = pytest.mark.skipif(
not _native_available(), reason="Native extension not built (run `maturin develop`)"
)
@pytest.fixture(scope="session")
def model_path():
p = os.environ.get("OXILLAMA_TEST_MODEL")
if not p:
pytest.skip("OXILLAMA_TEST_MODEL environment variable not set")
if not os.path.isfile(p):
pytest.skip(f"Model file not found: {p}")
return p
@pytest.fixture(scope="session")
def engine(model_path):
import oxillama_py
config = oxillama_py.EngineConfig(model_path=model_path)
eng = oxillama_py.Engine(config)
eng.load_model()
return eng
@_REQUIRES_NATIVE
def test_engine_config_defaults():
import oxillama_py
cfg = oxillama_py.EngineConfig(model_path="model.gguf")
assert cfg.model_path == "model.gguf"
assert cfg.num_threads == 4
assert cfg.context_size is None
@_REQUIRES_NATIVE
def test_engine_config_context_size_override():
import oxillama_py
cfg = oxillama_py.EngineConfig(model_path="x.gguf", context_size=8192)
assert cfg.context_size == 8192
@_REQUIRES_NATIVE
def test_sampler_config_defaults():
import oxillama_py
sc = oxillama_py.SamplerConfig()
assert abs(sc.temperature - 0.7) < 1e-5
assert sc.top_k == 40
assert abs(sc.top_p - 0.9) < 1e-5
assert sc.seed is None
assert sc.mirostat == 0
@_REQUIRES_NATIVE
def test_sampler_config_greedy():
import oxillama_py
sc = oxillama_py.SamplerConfig.greedy()
assert sc.temperature == 0.0
assert sc.top_k == 1
@_REQUIRES_NATIVE
def test_engine_not_loaded_initially():
import oxillama_py
cfg = oxillama_py.EngineConfig(model_path="nonexistent.gguf")
eng = oxillama_py.Engine(cfg)
assert not eng.is_loaded()
@_REQUIRES_NATIVE
def test_engine_hidden_size_none_before_load():
import oxillama_py
cfg = oxillama_py.EngineConfig(model_path="nonexistent.gguf")
eng = oxillama_py.Engine(cfg)
assert eng.hidden_size() is None
@_REQUIRES_NATIVE
def test_engine_tokenize_raises_without_model():
import oxillama_py
cfg = oxillama_py.EngineConfig(model_path="nonexistent.gguf")
eng = oxillama_py.Engine(cfg)
with pytest.raises(Exception):
eng.tokenize("hello")
@_REQUIRES_NATIVE
def test_lora_load_raises_for_missing_file():
import oxillama_py
with pytest.raises(Exception):
oxillama_py.Lora.load("/tmp/oxillama_py_nonexistent_lora_xyz.gguf")
@_REQUIRES_NATIVE
def test_speculative_config_defaults():
import oxillama_py
t_cfg = oxillama_py.EngineConfig(model_path="target.gguf")
d_cfg = oxillama_py.EngineConfig(model_path="draft.gguf")
sc = oxillama_py.SpeculativeConfig(t_cfg, d_cfg)
assert sc.num_speculative == 4
assert sc.seed is None
def test_is_loaded(engine):
assert engine.is_loaded()
def test_tokenize(engine):
tokens = engine.tokenize("Hello world")
assert isinstance(tokens, list)
assert len(tokens) > 0
assert all(isinstance(t, int) for t in tokens)
def test_decode_token_roundtrip(engine):
tokens = engine.tokenize("Hello")
decoded = engine.decode_token(tokens[0])
assert isinstance(decoded, str)
assert len(decoded) > 0
def test_hidden_size(engine):
hs = engine.hidden_size()
assert hs is not None
assert hs > 0
def test_embed_returns_float_list(engine):
emb = engine.embed("Hello world")
assert isinstance(emb, list)
assert len(emb) > 0
assert all(isinstance(x, float) for x in emb)
def test_embed_l2_normalised(engine):
import math
emb = engine.embed("test sentence")
norm = math.sqrt(sum(x * x for x in emb))
assert abs(norm - 1.0) < 1e-3, f"expected unit norm, got {norm}"
def test_embed_different_inputs(engine):
emb1 = engine.embed("cat")
emb2 = engine.embed("philosophy")
assert emb1 != emb2, "different inputs should yield different embeddings"
def test_generate_returns_string(engine):
text = engine.generate("Hello", max_tokens=32)
assert isinstance(text, str)
def test_generate_streaming_collects_tokens(engine):
tokens_received = []
def callback(tok):
tokens_received.append(tok)
result = engine.generate_streaming("Hello", max_tokens=32, callback=callback)
assert "".join(tokens_received) == result
def test_is_eos(engine):
result = engine.is_eos(0)
assert isinstance(result, bool)