from __future__ import annotations
import asyncio
import inspect
from typing import Any
import pytest
import oxillama_py
class MockEngine:
def __init__(self, response: str = "hello") -> None:
self.response = response
self.generate_calls: list[dict[str, Any]] = []
self.stream_calls: list[dict[str, Any]] = []
def generate(self, prompt: str, max_tokens: int = 128, **kwargs: Any) -> str:
self.generate_calls.append(
{"prompt": prompt, "max_tokens": max_tokens, **kwargs}
)
return self.response
def generate_streaming(
self,
prompt: str,
max_tokens: int = 128,
callback: Any = None,
**kwargs: Any,
) -> str:
self.stream_calls.append(
{"prompt": prompt, "max_tokens": max_tokens, **kwargs}
)
tokens = list(self.response) for tok in tokens:
if callback is not None:
callback(tok)
return self.response
class ErrorEngine:
def generate(self, prompt: str, max_tokens: int = 128, **kwargs: Any) -> str:
raise RuntimeError("generate failed")
def generate_streaming(
self,
prompt: str,
max_tokens: int = 128,
callback: Any = None,
**kwargs: Any,
) -> str:
raise RuntimeError("generate_streaming failed")
class MultiTokenEngine:
def __init__(self, tokens: list[str]) -> None:
self.tokens = tokens
def generate(self, prompt: str, max_tokens: int = 128, **kwargs: Any) -> str:
return "".join(self.tokens)
def generate_streaming(
self,
prompt: str,
max_tokens: int = 128,
callback: Any = None,
**kwargs: Any,
) -> str:
for tok in self.tokens:
if callback is not None:
callback(tok)
return "".join(self.tokens)
def test_async_engine_class_exists() -> None:
assert hasattr(oxillama_py, "AsyncEngine"), (
"AsyncEngine not found in oxillama_py"
)
def test_async_engine_class_is_type() -> None:
assert isinstance(oxillama_py.AsyncEngine, type), (
f"Expected a type, got {type(oxillama_py.AsyncEngine)}"
)
def test_async_engine_init_exists() -> None:
assert hasattr(oxillama_py.AsyncEngine, "__init__")
assert callable(oxillama_py.AsyncEngine.__init__)
def test_async_engine_has_generate() -> None:
assert hasattr(oxillama_py.AsyncEngine, "generate"), (
"AsyncEngine.generate is missing"
)
def test_async_engine_has_stream() -> None:
assert hasattr(oxillama_py.AsyncEngine, "stream"), (
"AsyncEngine.stream is missing"
)
def test_async_engine_generate_is_coroutine_function() -> None:
assert asyncio.iscoroutinefunction(oxillama_py.AsyncEngine.generate), (
"AsyncEngine.generate is not an async coroutine function"
)
def test_async_engine_stream_is_async_generator_function() -> None:
assert inspect.isasyncgenfunction(oxillama_py.AsyncEngine.stream), (
"AsyncEngine.stream is not an async generator function"
)
def test_async_engine_in_all() -> None:
assert "AsyncEngine" in oxillama_py.__all__, (
"AsyncEngine is missing from oxillama_py.__all__"
)
def test_async_engine_accepts_mock_engine() -> None:
ae = oxillama_py.AsyncEngine(MockEngine())
assert ae is not None
def test_async_engine_stores_engine_reference() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
assert ae._engine is mock
def test_async_engine_creates_thread_pool() -> None:
ae = oxillama_py.AsyncEngine(MockEngine())
assert hasattr(ae, "_pool"), "_pool attribute missing"
import concurrent.futures
assert isinstance(ae._pool, concurrent.futures.ThreadPoolExecutor), (
"_pool is not a ThreadPoolExecutor"
)
def test_async_engine_generate_returns_string() -> None:
ae = oxillama_py.AsyncEngine(MockEngine("hello world"))
result = asyncio.run(ae.generate("test prompt"))
assert result == "hello world"
def test_async_engine_generate_passes_prompt() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("my prompt"))
assert len(mock.generate_calls) == 1
assert mock.generate_calls[0]["prompt"] == "my prompt"
def test_async_engine_generate_passes_max_tokens() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("prompt", max_tokens=256))
assert mock.generate_calls[0]["max_tokens"] == 256
def test_async_engine_generate_default_max_tokens() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x"))
assert mock.generate_calls[0]["max_tokens"] == 512
def test_async_engine_generate_passes_temperature() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x", temperature=0.5))
assert abs(mock.generate_calls[0]["temperature"] - 0.5) < 1e-6
def test_async_engine_generate_omits_none_temperature() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x", temperature=None))
assert "temperature" not in mock.generate_calls[0]
def test_async_engine_generate_passes_top_p() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x", top_p=0.9))
assert abs(mock.generate_calls[0]["top_p"] - 0.9) < 1e-6
def test_async_engine_generate_passes_top_k() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x", top_k=40))
assert mock.generate_calls[0]["top_k"] == 40
def test_async_engine_generate_passes_seed() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x", seed=42))
assert mock.generate_calls[0]["seed"] == 42
def test_async_engine_generate_passes_kwargs() -> None:
mock = MockEngine()
ae = oxillama_py.AsyncEngine(mock)
asyncio.run(ae.generate("x", custom_flag=True))
assert mock.generate_calls[0]["custom_flag"] is True
def test_async_engine_generate_error_propagation() -> None:
ae = oxillama_py.AsyncEngine(ErrorEngine())
with pytest.raises(RuntimeError, match="generate failed"):
asyncio.run(ae.generate("x"))
def test_async_engine_generate_multiple_calls() -> None:
mock = MockEngine("token")
ae = oxillama_py.AsyncEngine(mock)
async def _run() -> list[str]:
r1 = await ae.generate("a")
r2 = await ae.generate("b")
r3 = await ae.generate("c")
return [r1, r2, r3]
results = asyncio.run(_run())
assert results == ["token", "token", "token"]
assert len(mock.generate_calls) == 3
def test_async_engine_stream_yields_tokens() -> None:
tokens = ["h", "e", "l", "l", "o"]
ae = oxillama_py.AsyncEngine(MultiTokenEngine(tokens))
async def _collect() -> list[str]:
return [tok async for tok in ae.stream("hi")]
result = asyncio.run(_collect())
assert result == tokens
def test_async_engine_stream_concatenated_equals_full_text() -> None:
text = "hello world"
ae = oxillama_py.AsyncEngine(MultiTokenEngine(list(text)))
async def _collect() -> str:
return "".join([tok async for tok in ae.stream("hi")])
result = asyncio.run(_collect())
assert result == text
def test_async_engine_stream_empty_response() -> None:
ae = oxillama_py.AsyncEngine(MultiTokenEngine([]))
async def _collect() -> list[str]:
return [tok async for tok in ae.stream("hi")]
result = asyncio.run(_collect())
assert result == []
def test_async_engine_stream_single_token() -> None:
ae = oxillama_py.AsyncEngine(MultiTokenEngine(["only"]))
async def _collect() -> list[str]:
return [tok async for tok in ae.stream("hi")]
result = asyncio.run(_collect())
assert result == ["only"]
def test_async_engine_stream_error_propagation() -> None:
ae = oxillama_py.AsyncEngine(ErrorEngine())
async def _drain() -> None:
async for _ in ae.stream("x"):
pass
with pytest.raises(RuntimeError, match="generate_streaming failed"):
asyncio.run(_drain())
def test_async_engine_stream_passes_max_tokens() -> None:
mock = MockEngine("ab")
ae = oxillama_py.AsyncEngine(mock)
async def _run() -> None:
async for _ in ae.stream("x", max_tokens=64):
pass
asyncio.run(_run())
assert mock.stream_calls[0]["max_tokens"] == 64
def test_async_engine_stream_default_max_tokens() -> None:
mock = MockEngine("x")
ae = oxillama_py.AsyncEngine(mock)
async def _run() -> None:
async for _ in ae.stream("y"):
pass
asyncio.run(_run())
assert mock.stream_calls[0]["max_tokens"] == 512
def test_async_engine_stream_passes_temperature() -> None:
mock = MockEngine("x")
ae = oxillama_py.AsyncEngine(mock)
async def _run() -> None:
async for _ in ae.stream("y", temperature=0.8):
pass
asyncio.run(_run())
assert abs(mock.stream_calls[0]["temperature"] - 0.8) < 1e-6
def test_async_engine_stream_omits_none_temperature() -> None:
mock = MockEngine("x")
ae = oxillama_py.AsyncEngine(mock)
async def _run() -> None:
async for _ in ae.stream("y", temperature=None):
pass
asyncio.run(_run())
assert "temperature" not in mock.stream_calls[0]
def _native_available() -> bool:
try:
import oxillama_py.oxillama_py return True
except ImportError:
return False
_REQUIRES_NATIVE = pytest.mark.skipif(
not _native_available(), reason="Native extension not built (run `maturin develop`)"
)
@_REQUIRES_NATIVE
def test_engine_has_async_engine_method() -> None:
assert hasattr(oxillama_py.Engine, "async_engine"), (
"Engine.async_engine method is missing"
)
@_REQUIRES_NATIVE
def test_engine_async_engine_method_callable() -> None:
assert callable(oxillama_py.Engine.async_engine)
@_REQUIRES_NATIVE
def test_engine_async_engine_returns_async_engine_instance() -> None:
cfg = oxillama_py.EngineConfig(model_path="dummy.gguf")
engine = oxillama_py.Engine(cfg)
ae = engine.async_engine()
assert isinstance(ae, oxillama_py.AsyncEngine), (
f"async_engine() returned {type(ae)}, expected AsyncEngine"
)
@_REQUIRES_NATIVE
def test_engine_async_engine_wraps_same_instance() -> None:
cfg = oxillama_py.EngineConfig(model_path="dummy.gguf")
engine = oxillama_py.Engine(cfg)
ae = engine.async_engine()
assert ae._engine is engine, (
"async_engine()._engine does not point back to the caller"
)