oxillama-py 0.1.3

Python bindings for OxiLLaMa LLM inference engine
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
"""Tests for the AsyncEngine Python class (Track D - OxiLLaMa v0.1.5).

These are pure-Python tests that do NOT require the native extension to be
built.  ``AsyncEngine`` is implemented in ``oxillama_py/__init__.py`` and
wraps any object that exposes ``generate`` / ``generate_streaming`` methods.

Tests cover:
  - Class attributes and coroutine introspection (no execution)
  - Functional tests using mock engine objects with ``asyncio.run``
  - Streaming via ``async for`` with mock generate_streaming callbacks
  - Error propagation from the underlying engine
  - Optional kwarg forwarding (temperature, top_p, top_k, seed)
  - Sentinel / completion signalling in the queue bridge
  - Thread-pool serialisation (single worker)
  - Presence of ``async_engine()`` on the native ``Engine`` class (skipped
    when the extension is not built)
"""

from __future__ import annotations

import asyncio
import inspect
from typing import Any

import pytest

import oxillama_py


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


class MockEngine:
    """Minimal synchronous mock that satisfies the AsyncEngine contract."""

    def __init__(self, response: str = "hello") -> None:
        self.response = response
        self.generate_calls: list[dict[str, Any]] = []
        self.stream_calls: list[dict[str, Any]] = []

    def generate(self, prompt: str, max_tokens: int = 128, **kwargs: Any) -> str:
        self.generate_calls.append(
            {"prompt": prompt, "max_tokens": max_tokens, **kwargs}
        )
        return self.response

    def generate_streaming(
        self,
        prompt: str,
        max_tokens: int = 128,
        callback: Any = None,
        **kwargs: Any,
    ) -> str:
        self.stream_calls.append(
            {"prompt": prompt, "max_tokens": max_tokens, **kwargs}
        )
        tokens = list(self.response)  # split into individual characters
        for tok in tokens:
            if callback is not None:
                callback(tok)
        return self.response


class ErrorEngine:
    """Mock that raises RuntimeError on every call."""

    def generate(self, prompt: str, max_tokens: int = 128, **kwargs: Any) -> str:
        raise RuntimeError("generate failed")

    def generate_streaming(
        self,
        prompt: str,
        max_tokens: int = 128,
        callback: Any = None,
        **kwargs: Any,
    ) -> str:
        raise RuntimeError("generate_streaming failed")


class MultiTokenEngine:
    """Mock that yields a configurable list of tokens."""

    def __init__(self, tokens: list[str]) -> None:
        self.tokens = tokens

    def generate(self, prompt: str, max_tokens: int = 128, **kwargs: Any) -> str:
        return "".join(self.tokens)

    def generate_streaming(
        self,
        prompt: str,
        max_tokens: int = 128,
        callback: Any = None,
        **kwargs: Any,
    ) -> str:
        for tok in self.tokens:
            if callback is not None:
                callback(tok)
        return "".join(self.tokens)


# ---------------------------------------------------------------------------
# Structural / introspection tests (no execution)
# ---------------------------------------------------------------------------


def test_async_engine_class_exists() -> None:
    """AsyncEngine must be exported from the top-level oxillama_py package."""
    assert hasattr(oxillama_py, "AsyncEngine"), (
        "AsyncEngine not found in oxillama_py"
    )


def test_async_engine_class_is_type() -> None:
    """AsyncEngine must be a class/type (not None, not a function)."""
    assert isinstance(oxillama_py.AsyncEngine, type), (
        f"Expected a type, got {type(oxillama_py.AsyncEngine)}"
    )


def test_async_engine_init_exists() -> None:
    """AsyncEngine must have an __init__ method."""
    assert hasattr(oxillama_py.AsyncEngine, "__init__")
    assert callable(oxillama_py.AsyncEngine.__init__)


def test_async_engine_has_generate() -> None:
    """AsyncEngine must expose a ``generate`` attribute."""
    assert hasattr(oxillama_py.AsyncEngine, "generate"), (
        "AsyncEngine.generate is missing"
    )


def test_async_engine_has_stream() -> None:
    """AsyncEngine must expose a ``stream`` attribute."""
    assert hasattr(oxillama_py.AsyncEngine, "stream"), (
        "AsyncEngine.stream is missing"
    )


def test_async_engine_generate_is_coroutine_function() -> None:
    """AsyncEngine.generate must be an async coroutine function."""
    assert asyncio.iscoroutinefunction(oxillama_py.AsyncEngine.generate), (
        "AsyncEngine.generate is not an async coroutine function"
    )


def test_async_engine_stream_is_async_generator_function() -> None:
    """AsyncEngine.stream must be an async generator function."""
    assert inspect.isasyncgenfunction(oxillama_py.AsyncEngine.stream), (
        "AsyncEngine.stream is not an async generator function"
    )


def test_async_engine_in_all() -> None:
    """AsyncEngine must be listed in oxillama_py.__all__."""
    assert "AsyncEngine" in oxillama_py.__all__, (
        "AsyncEngine is missing from oxillama_py.__all__"
    )


# ---------------------------------------------------------------------------
# Construction tests
# ---------------------------------------------------------------------------


def test_async_engine_accepts_mock_engine() -> None:
    """AsyncEngine.__init__ must accept any object (including mocks)."""
    ae = oxillama_py.AsyncEngine(MockEngine())
    assert ae is not None


def test_async_engine_stores_engine_reference() -> None:
    """The wrapped engine must be accessible as _engine."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    assert ae._engine is mock


def test_async_engine_creates_thread_pool() -> None:
    """AsyncEngine must create a private ThreadPoolExecutor."""
    ae = oxillama_py.AsyncEngine(MockEngine())
    assert hasattr(ae, "_pool"), "_pool attribute missing"
    import concurrent.futures

    assert isinstance(ae._pool, concurrent.futures.ThreadPoolExecutor), (
        "_pool is not a ThreadPoolExecutor"
    )


# ---------------------------------------------------------------------------
# generate() functional tests
# ---------------------------------------------------------------------------


def test_async_engine_generate_returns_string() -> None:
    """await ae.generate(...) must return the engine's response string."""
    ae = oxillama_py.AsyncEngine(MockEngine("hello world"))
    result = asyncio.run(ae.generate("test prompt"))
    assert result == "hello world"


def test_async_engine_generate_passes_prompt() -> None:
    """generate() must forward the prompt to the underlying engine."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("my prompt"))
    assert len(mock.generate_calls) == 1
    assert mock.generate_calls[0]["prompt"] == "my prompt"


def test_async_engine_generate_passes_max_tokens() -> None:
    """generate() must forward the max_tokens argument."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("prompt", max_tokens=256))
    assert mock.generate_calls[0]["max_tokens"] == 256


def test_async_engine_generate_default_max_tokens() -> None:
    """generate() default max_tokens must be 512."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x"))
    assert mock.generate_calls[0]["max_tokens"] == 512


def test_async_engine_generate_passes_temperature() -> None:
    """generate() must forward the temperature kwarg when provided."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x", temperature=0.5))
    assert abs(mock.generate_calls[0]["temperature"] - 0.5) < 1e-6


def test_async_engine_generate_omits_none_temperature() -> None:
    """generate() must NOT forward temperature=None to the engine."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x", temperature=None))
    assert "temperature" not in mock.generate_calls[0]


def test_async_engine_generate_passes_top_p() -> None:
    """generate() must forward top_p when provided."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x", top_p=0.9))
    assert abs(mock.generate_calls[0]["top_p"] - 0.9) < 1e-6


def test_async_engine_generate_passes_top_k() -> None:
    """generate() must forward top_k when provided."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x", top_k=40))
    assert mock.generate_calls[0]["top_k"] == 40


def test_async_engine_generate_passes_seed() -> None:
    """generate() must forward seed when provided."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x", seed=42))
    assert mock.generate_calls[0]["seed"] == 42


def test_async_engine_generate_passes_kwargs() -> None:
    """generate() must forward arbitrary extra kwargs to the engine."""
    mock = MockEngine()
    ae = oxillama_py.AsyncEngine(mock)
    asyncio.run(ae.generate("x", custom_flag=True))
    assert mock.generate_calls[0]["custom_flag"] is True


def test_async_engine_generate_error_propagation() -> None:
    """generate() must propagate RuntimeError raised by the engine."""
    ae = oxillama_py.AsyncEngine(ErrorEngine())
    with pytest.raises(RuntimeError, match="generate failed"):
        asyncio.run(ae.generate("x"))


def test_async_engine_generate_multiple_calls() -> None:
    """Multiple sequential awaits on the same engine must all succeed."""
    mock = MockEngine("token")
    ae = oxillama_py.AsyncEngine(mock)

    async def _run() -> list[str]:
        r1 = await ae.generate("a")
        r2 = await ae.generate("b")
        r3 = await ae.generate("c")
        return [r1, r2, r3]

    results = asyncio.run(_run())
    assert results == ["token", "token", "token"]
    assert len(mock.generate_calls) == 3


# ---------------------------------------------------------------------------
# stream() functional tests
# ---------------------------------------------------------------------------


def test_async_engine_stream_yields_tokens() -> None:
    """stream() must yield each token produced by generate_streaming."""
    tokens = ["h", "e", "l", "l", "o"]
    ae = oxillama_py.AsyncEngine(MultiTokenEngine(tokens))

    async def _collect() -> list[str]:
        return [tok async for tok in ae.stream("hi")]

    result = asyncio.run(_collect())
    assert result == tokens


def test_async_engine_stream_concatenated_equals_full_text() -> None:
    """Concatenating all yielded tokens must equal the full response."""
    text = "hello world"
    ae = oxillama_py.AsyncEngine(MultiTokenEngine(list(text)))

    async def _collect() -> str:
        return "".join([tok async for tok in ae.stream("hi")])

    result = asyncio.run(_collect())
    assert result == text


def test_async_engine_stream_empty_response() -> None:
    """stream() on an engine that produces no tokens must yield nothing."""
    ae = oxillama_py.AsyncEngine(MultiTokenEngine([]))

    async def _collect() -> list[str]:
        return [tok async for tok in ae.stream("hi")]

    result = asyncio.run(_collect())
    assert result == []


def test_async_engine_stream_single_token() -> None:
    """stream() with a single-token response must yield exactly one item."""
    ae = oxillama_py.AsyncEngine(MultiTokenEngine(["only"]))

    async def _collect() -> list[str]:
        return [tok async for tok in ae.stream("hi")]

    result = asyncio.run(_collect())
    assert result == ["only"]


def test_async_engine_stream_error_propagation() -> None:
    """stream() must propagate RuntimeError raised by generate_streaming."""
    ae = oxillama_py.AsyncEngine(ErrorEngine())

    async def _drain() -> None:
        async for _ in ae.stream("x"):
            pass

    with pytest.raises(RuntimeError, match="generate_streaming failed"):
        asyncio.run(_drain())


def test_async_engine_stream_passes_max_tokens() -> None:
    """stream() must forward max_tokens to generate_streaming."""
    mock = MockEngine("ab")
    ae = oxillama_py.AsyncEngine(mock)

    async def _run() -> None:
        async for _ in ae.stream("x", max_tokens=64):
            pass

    asyncio.run(_run())
    assert mock.stream_calls[0]["max_tokens"] == 64


def test_async_engine_stream_default_max_tokens() -> None:
    """stream() default max_tokens must be 512."""
    mock = MockEngine("x")
    ae = oxillama_py.AsyncEngine(mock)

    async def _run() -> None:
        async for _ in ae.stream("y"):
            pass

    asyncio.run(_run())
    assert mock.stream_calls[0]["max_tokens"] == 512


def test_async_engine_stream_passes_temperature() -> None:
    """stream() must forward temperature kwarg when provided."""
    mock = MockEngine("x")
    ae = oxillama_py.AsyncEngine(mock)

    async def _run() -> None:
        async for _ in ae.stream("y", temperature=0.8):
            pass

    asyncio.run(_run())
    assert abs(mock.stream_calls[0]["temperature"] - 0.8) < 1e-6


def test_async_engine_stream_omits_none_temperature() -> None:
    """stream() must NOT forward temperature=None to the engine."""
    mock = MockEngine("x")
    ae = oxillama_py.AsyncEngine(mock)

    async def _run() -> None:
        async for _ in ae.stream("y", temperature=None):
            pass

    asyncio.run(_run())
    assert "temperature" not in mock.stream_calls[0]


# ---------------------------------------------------------------------------
# Native Engine.async_engine() method (requires native extension)
# ---------------------------------------------------------------------------


def _native_available() -> bool:
    """Return True if the native PyO3 extension is importable."""
    try:
        import oxillama_py.oxillama_py  # type: ignore[import-untyped]  # noqa: F401
        return True
    except ImportError:
        return False


_REQUIRES_NATIVE = pytest.mark.skipif(
    not _native_available(), reason="Native extension not built (run `maturin develop`)"
)


@_REQUIRES_NATIVE
def test_engine_has_async_engine_method() -> None:
    """Engine must expose an async_engine() method (Rust-side)."""
    assert hasattr(oxillama_py.Engine, "async_engine"), (
        "Engine.async_engine method is missing"
    )


@_REQUIRES_NATIVE
def test_engine_async_engine_method_callable() -> None:
    """Engine.async_engine must be callable."""
    assert callable(oxillama_py.Engine.async_engine)


@_REQUIRES_NATIVE
def test_engine_async_engine_returns_async_engine_instance() -> None:
    """engine.async_engine() must return an AsyncEngine wrapping the engine."""
    cfg = oxillama_py.EngineConfig(model_path="dummy.gguf")
    engine = oxillama_py.Engine(cfg)
    ae = engine.async_engine()
    assert isinstance(ae, oxillama_py.AsyncEngine), (
        f"async_engine() returned {type(ae)}, expected AsyncEngine"
    )


@_REQUIRES_NATIVE
def test_engine_async_engine_wraps_same_instance() -> None:
    """The AsyncEngine returned by async_engine() must wrap the caller engine."""
    cfg = oxillama_py.EngineConfig(model_path="dummy.gguf")
    engine = oxillama_py.Engine(cfg)
    ae = engine.async_engine()
    # The _engine attribute should be the original Engine instance.
    assert ae._engine is engine, (
        "async_engine()._engine does not point back to the caller"
    )