import os
from typing import Any, AsyncIterator, Callable, Optional, Sequence, Union
try:
from typing import Protocol, TypedDict, runtime_checkable
except ImportError:
from typing_extensions import Protocol, TypedDict, runtime_checkable
try:
import numpy as np
import numpy.typing as npt
_HAS_NUMPY = True
except ImportError:
_HAS_NUMPY = False
try:
import torch
_HAS_TORCH = True
except ImportError:
_HAS_TORCH = False
__version__: str
__all__: list[str]
class ProgressEvent:
tokens_generated: int
tokens_total: Optional[int]
elapsed_secs: float
tokens_per_sec: float
eta_secs: Optional[float]
is_final: bool
text_so_far: str
def __init__(
self,
tokens_generated: int,
tokens_total: Optional[int],
elapsed_secs: float,
tokens_per_sec: float,
eta_secs: Optional[float],
is_final: bool,
text_so_far: str,
) -> None: ...
ProgressCallback = Callable[[ProgressEvent], None]
ProgressLike = Union[Any, ProgressCallback, None]
def make_progress_adapter(
obj: ProgressLike, max_tokens: int
) -> tuple[
Optional[Callable[[ProgressEvent], None]],
Optional[Callable[[Optional[BaseException]], None]],
]: ...
@runtime_checkable
class StreamingCallback(Protocol):
def __call__(self, token: str, token_id: int, is_final: bool) -> None: ...
TokenCallback = Callable[[str, int, bool], None]
class HubOrigin(TypedDict):
repo_id: str
filename: str
sha256: str
class SnapshotInfo:
arch_id: str
model_path: str
tokenizer_path: Optional[str]
max_context_length: int
num_threads: int
version: int
magic: bytes
tokens_count: int
def __repr__(self) -> str: ...
class OxiLlamaError(Exception):
...
class LoadError(OxiLlamaError):
...
class GenerateError(OxiLlamaError):
...
class TokenizerError(OxiLlamaError):
...
class GrammarError(OxiLlamaError):
...
class QuantError(OxiLlamaError):
...
class KvCacheFullError(OxiLlamaError):
...
class SamplerConfig:
temperature: float
top_k: int
top_p: float
min_p: float
repetition_penalty: float
repetition_penalty_window: int
seed: Optional[int]
mirostat: int
mirostat_tau: float
mirostat_eta: float
def __init__(
self,
*,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.9,
min_p: float = 0.0,
repetition_penalty: float = 1.1,
repetition_penalty_window: int = 64,
seed: Optional[int] = None,
mirostat: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> None: ...
@staticmethod
def greedy() -> SamplerConfig: ...
@staticmethod
def mirostat_v2(tau: float = 5.0, eta: float = 0.1) -> SamplerConfig: ...
def __repr__(self) -> str: ...
class EngineConfig:
model_path: str
tokenizer_path: Optional[str]
context_size: Optional[int]
num_threads: int
sampler: SamplerConfig
def __init__(
self,
model_path: str,
*,
context_size: Optional[int] = None,
num_threads: int = 4,
tokenizer_path: Optional[str] = None,
sampler: Optional[SamplerConfig] = None,
) -> None: ...
def __repr__(self) -> str: ...
class Engine:
def __init__(self, config: EngineConfig) -> None: ...
def load_model(self) -> None: ...
def is_loaded(self) -> bool: ...
def reset(self) -> None: ...
def tokenize(self, text: str) -> list[int]: ...
def decode_token(self, token: int) -> str: ...
def is_eos(self, token: int) -> bool: ...
def hidden_size(self) -> Optional[int]: ...
def generate(
self,
prompt: str,
max_tokens: int = 128,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
seed: Optional[int] = None,
cancel_token: Optional["CancellationToken"] = None,
progress: ProgressLike = None,
progress_throttle_ms: Optional[int] = None,
progress_throttle_tokens: Optional[int] = None,
progress_capture_text: bool = False,
strict_progress: bool = False,
) -> str: ...
def generate_streaming(
self,
prompt: str,
max_tokens: int = 128,
callback: Optional[Callable[[str], None]] = None,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
seed: Optional[int] = None,
cancel_token: Optional["CancellationToken"] = None,
strict_callback: bool = False,
progress: ProgressLike = None,
progress_throttle_ms: Optional[int] = None,
progress_throttle_tokens: Optional[int] = None,
progress_capture_text: bool = False,
strict_progress: bool = False,
) -> str: ...
def embed(self, text: str) -> list[float]: ...
def embed_numpy(self, text: str) -> "np.ndarray[tuple[int], np.dtype[np.float32]]": ...
def embed_batch_numpy(self, texts: Sequence[str]) -> "np.ndarray[tuple[int, int], np.dtype[np.float32]]": ...
def apply_lora(self, lora_path: str) -> None: ...
@classmethod
def from_hub(
cls,
repo_id: str,
*,
filename: Optional[str] = None,
revision: Optional[str] = None,
token: Optional[str] = None,
config: Optional[EngineConfig] = None,
) -> Engine: ...
def snapshot(
self,
path: Union[str, os.PathLike[str]],
*,
hub_origin: Optional[HubOrigin] = None,
) -> None: ...
def snapshot_bytes(self) -> bytes: ...
@classmethod
def snapshot_info(cls, path: Union[str, os.PathLike[str]]) -> SnapshotInfo: ...
@classmethod
def restore(
cls,
path: Union[str, os.PathLike[str]],
*,
model_path: Optional[Union[str, os.PathLike[str]]] = None,
) -> Engine: ...
@classmethod
def from_snapshot_with_hub(
cls,
snapshot_path: Union[str, os.PathLike[str]],
) -> Engine: ...
def logits_dlpack(self, text: str) -> object: ...
def embeddings_dlpack(self, text: str) -> object: ...
def logits_torch(self, text: str, **kwargs: Any) -> "torch.Tensor": ...
def embeddings_torch(self, text: str, **kwargs: Any) -> "torch.Tensor": ...
def async_engine(self) -> "AsyncEngine": ...
def __reduce__(self) -> None: ...
def __reduce_ex__(self, protocol: int) -> None: ...
class SpeculativeConfig:
target: EngineConfig
draft: EngineConfig
num_speculative: int
seed: Optional[int]
def __init__(
self,
target: EngineConfig,
draft: EngineConfig,
*,
num_speculative: int = 4,
seed: Optional[int] = None,
) -> None: ...
def __repr__(self) -> str: ...
class SpeculativeEngine:
def __init__(self, config: SpeculativeConfig) -> None: ...
def generate(
self,
prompt: str,
max_tokens: int = 128,
*,
progress: ProgressLike = None,
progress_throttle_ms: Optional[int] = None,
progress_throttle_tokens: Optional[int] = None,
progress_capture_text: bool = False,
strict_progress: bool = False,
) -> str: ...
def generate_streaming(
self,
prompt: str,
max_tokens: int = 128,
callback: Optional[Callable[[str], None]] = None,
*,
progress: ProgressLike = None,
progress_throttle_ms: Optional[int] = None,
progress_throttle_tokens: Optional[int] = None,
progress_capture_text: bool = False,
strict_progress: bool = False,
) -> str: ...
def snapshot(self, path: str) -> None: ...
def snapshot_bytes(self) -> bytes: ...
@classmethod
def restore(cls, path: str, target_model: str, draft_model: str) -> "SpeculativeEngine": ...
def __reduce__(self) -> tuple: ...
def __reduce_ex__(self, protocol: int) -> tuple: ...
class Tokenizer:
@staticmethod
def from_file(path: str) -> Tokenizer: ...
@staticmethod
def from_json(json: str) -> Tokenizer: ...
def encode(self, text: str) -> list[int]: ...
def encode_batch(self, texts: list[str]) -> list[list[int]]: ...
def decode(self, ids: list[int]) -> str: ...
@property
def vocab_size(self) -> int: ...
def id_to_token(self, id: int) -> Optional[str]: ...
def apply_chat_template(
self,
messages: list[dict],
template: Optional[str] = None,
add_generation_prompt: bool = True,
) -> str: ...
def __repr__(self) -> str: ...
class Lora:
@staticmethod
def load(path: str) -> Lora: ...
@property
def rank(self) -> int: ...
@property
def alpha(self) -> float: ...
def num_adapters(self) -> int: ...
def __repr__(self) -> str: ...
class CancellationToken:
def __init__(self) -> None: ...
def cancel(self) -> None: ...
def is_cancelled(self) -> bool: ...
def reset(self) -> None: ...
def __repr__(self) -> str: ...
class AsyncEngine:
def __init__(self, engine: Any) -> None:
...
async def generate(
self,
prompt: str,
max_tokens: int = 512,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
seed: Optional[int] = None,
**kwargs: Any,
) -> str:
...
def stream(
self,
prompt: str,
max_tokens: int = 512,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
seed: Optional[int] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
...