from __future__ import annotations
import asyncio
import concurrent.futures
import queue
from typing import TYPE_CHECKING, Any, AsyncIterator
from oxillama_py.callback import StreamingCallback, TokenCallback
from oxillama_py.progress import ProgressEvent, make_progress_adapter
from oxillama_py.utils import decode_from_logits
class AsyncEngine:
__slots__ = ("_engine", "_pool")
def __init__(self, engine: Any) -> None:
self._engine = engine
self._pool: concurrent.futures.ThreadPoolExecutor = (
concurrent.futures.ThreadPoolExecutor(max_workers=1)
)
async def generate(
self,
prompt: str,
max_tokens: int = 512,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
seed: int | None = None,
**kwargs: Any,
) -> str:
loop = asyncio.get_running_loop()
call_kwargs: dict[str, Any] = {}
if temperature is not None:
call_kwargs["temperature"] = temperature
if top_p is not None:
call_kwargs["top_p"] = top_p
if top_k is not None:
call_kwargs["top_k"] = top_k
if seed is not None:
call_kwargs["seed"] = seed
call_kwargs.update(kwargs)
engine = self._engine
return await loop.run_in_executor(
self._pool,
lambda: engine.generate(prompt, max_tokens, **call_kwargs),
)
async def stream(
self,
prompt: str,
max_tokens: int = 512,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
seed: int | None = None,
**kwargs: Any,
) -> AsyncIterator[str]:
loop = asyncio.get_running_loop()
token_queue: queue.Queue[Any] = queue.Queue()
_sentinel = object()
call_kwargs: dict[str, Any] = {}
if temperature is not None:
call_kwargs["temperature"] = temperature
if top_p is not None:
call_kwargs["top_p"] = top_p
if top_k is not None:
call_kwargs["top_k"] = top_k
if seed is not None:
call_kwargs["seed"] = seed
call_kwargs.update(kwargs)
engine = self._engine
def _run_generation() -> None:
try:
engine.generate_streaming(
prompt,
max_tokens,
lambda tok: token_queue.put(tok),
**call_kwargs,
)
except Exception as exc: token_queue.put(exc)
finally:
token_queue.put(_sentinel)
self._pool.submit(_run_generation)
while True:
item = await loop.run_in_executor(None, token_queue.get)
if item is _sentinel:
break
if isinstance(item, BaseException):
raise item
yield item
try:
from oxillama_py.oxillama_py import ( Engine,
EngineConfig,
GenerateError,
GrammarError,
LoadError,
Lora,
OxiLlamaError,
QuantError,
SamplerConfig,
SnapshotInfo,
SpeculativeConfig,
SpeculativeEngine,
Tokenizer,
TokenizerError,
)
except ImportError:
Engine = None EngineConfig = None GenerateError = None GrammarError = None LoadError = None Lora = None OxiLlamaError = None QuantError = None SamplerConfig = None SnapshotInfo = None SpeculativeConfig = None SpeculativeEngine = None Tokenizer = None TokenizerError = None
from oxillama_py import snapshot
from oxillama_py.snapshot import SnapshotError
from oxillama_py import torch_helper as _torch_helper
__version__ = "0.1.0"
import sys as _sys
_torch_helper.try_patch(_sys.modules[__name__])
_DEPRECATED_NAMES = ("TqdmProgress", "CollectTokens")
def __getattr__(name: str) -> Any:
if name in _DEPRECATED_NAMES:
import warnings
warnings.warn(
f"oxillama_py.{name} is deprecated; pass progress= to generate*() "
"instead. See oxillama_py.progress.make_progress_adapter for the "
"new API.",
DeprecationWarning,
stacklevel=2,
)
from oxillama_py.tqdm_helper import CollectTokens, TqdmProgress
return {"TqdmProgress": TqdmProgress, "CollectTokens": CollectTokens}[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
if TYPE_CHECKING:
from oxillama_py.tqdm_helper import CollectTokens, TqdmProgress
__all__ = [
"EngineConfig",
"Engine",
"SamplerConfig",
"SpeculativeConfig",
"SpeculativeEngine",
"Lora",
"Tokenizer",
"AsyncEngine",
"SnapshotInfo",
"snapshot",
"OxiLlamaError",
"LoadError",
"GenerateError",
"TokenizerError",
"GrammarError",
"QuantError",
"SnapshotError",
"StreamingCallback",
"TokenCallback",
"ProgressEvent",
"make_progress_adapter",
"TqdmProgress",
"CollectTokens",
]