import asyncio
import collections
import itertools
import queue
import threading
from typing import AsyncIterable, Callable, Iterable, List, Optional, Union
from ctranslate2._ext import (
GenerationResult,
GenerationStepResult,
Generator,
ScoringResult,
TranslationResult,
Translator,
)
def register_extensions():
setattr(Translator, "translate_iterable", translator_translate_iterable)
setattr(Translator, "score_iterable", translator_score_iterable)
setattr(Translator, "generate_tokens", translator_generate_tokens)
setattr(Generator, "generate_iterable", generator_generate_iterable)
setattr(Generator, "score_iterable", generator_score_iterable)
setattr(Generator, "generate_tokens", generator_generate_tokens)
setattr(Generator, "async_generate_tokens", generator_async_generate_tokens)
def translator_translate_iterable(
translator: Translator,
source: Iterable[List[str]],
target_prefix: Optional[Iterable[List[str]]] = None,
max_batch_size: int = 32,
batch_type: str = "examples",
**kwargs,
) -> Iterable[TranslationResult]:
iterables = [source]
if target_prefix is not None:
iterables.append(target_prefix)
yield from _process_iterable(
translator.translate_batch,
iterables,
max_batch_size,
batch_type,
**kwargs,
)
def translator_score_iterable(
translator: Translator,
source: Iterable[List[str]],
target: Iterable[List[str]],
max_batch_size: int = 64,
batch_type: str = "examples",
**kwargs,
) -> Iterable[ScoringResult]:
yield from _process_iterable(
translator.score_batch,
[source, target],
max_batch_size,
batch_type,
**kwargs,
)
def generator_generate_iterable(
generator: Generator,
start_tokens: Iterable[List[str]],
max_batch_size: int = 32,
batch_type: str = "examples",
**kwargs,
) -> Iterable[GenerationResult]:
yield from _process_iterable(
generator.generate_batch,
[start_tokens],
max_batch_size,
batch_type,
**kwargs,
)
def generator_score_iterable(
generator: Generator,
tokens: Iterable[List[str]],
max_batch_size: int = 64,
batch_type: str = "examples",
**kwargs,
) -> Iterable[ScoringResult]:
yield from _process_iterable(
generator.score_batch,
[tokens],
max_batch_size,
batch_type,
**kwargs,
)
def translator_generate_tokens(
translator: Translator,
source: List[str],
target_prefix: Optional[List[str]] = None,
*,
max_decoding_length: int = 256,
min_decoding_length: int = 1,
sampling_topk: int = 1,
sampling_topp: float = 1,
sampling_temperature: float = 1,
return_log_prob: bool = False,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
disable_unk: bool = False,
suppress_sequences: Optional[List[List[str]]] = None,
end_token: Optional[Union[str, List[str], List[int]]] = None,
max_input_length: int = 1024,
use_vmap: bool = False,
) -> Iterable[GenerationStepResult]:
yield from _generate_tokens(
translator.translate_batch,
[source],
[target_prefix] if target_prefix is not None else None,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
disable_unk=disable_unk,
suppress_sequences=suppress_sequences,
end_token=end_token,
max_decoding_length=max_decoding_length,
min_decoding_length=min_decoding_length,
sampling_topk=sampling_topk,
sampling_topp=sampling_topp,
sampling_temperature=sampling_temperature,
return_scores=return_log_prob,
max_input_length=max_input_length,
use_vmap=use_vmap,
)
def generator_generate_tokens(
generator: Generator,
prompt: Union[List[str], List[List[str]]],
max_batch_size: int = 0,
batch_type: str = "examples",
*,
max_length: int = 512,
min_length: int = 0,
sampling_topk: int = 1,
sampling_topp: float = 1,
sampling_temperature: float = 1,
return_log_prob: bool = False,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
disable_unk: bool = False,
suppress_sequences: Optional[List[List[str]]] = None,
end_token: Optional[Union[str, List[str], List[int]]] = None,
static_prompt: Optional[List[str]] = None,
cache_static_prompt: bool = True,
callback: Callable[[GenerationStepResult], bool] = None,
) -> Iterable[GenerationStepResult]:
if len(prompt) > 0 and isinstance(prompt[0], str):
prompt = [prompt]
yield from _generate_tokens(
generator.generate_batch,
prompt,
max_batch_size=max_batch_size,
batch_type=batch_type,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
disable_unk=disable_unk,
suppress_sequences=suppress_sequences,
end_token=end_token,
max_length=max_length,
min_length=min_length,
sampling_topk=sampling_topk,
sampling_topp=sampling_topp,
sampling_temperature=sampling_temperature,
return_scores=return_log_prob,
static_prompt=static_prompt,
cache_static_prompt=cache_static_prompt,
include_prompt_in_result=False,
callback=callback,
)
async def generator_async_generate_tokens(
generator: Generator,
prompt: Union[List[str], List[List[str]]],
max_batch_size: int = 0,
batch_type: str = "examples",
*,
max_length: int = 512,
min_length: int = 0,
sampling_topk: int = 1,
sampling_topp: float = 1,
sampling_temperature: float = 1,
return_log_prob: bool = False,
repetition_penalty: float = 1,
no_repeat_ngram_size: int = 0,
disable_unk: bool = False,
suppress_sequences: Optional[List[List[str]]] = None,
end_token: Optional[Union[str, List[str], List[int]]] = None,
static_prompt: Optional[List[str]] = None,
cache_static_prompt: bool = True,
callback: Callable[[GenerationStepResult], bool] = None,
) -> AsyncIterable[GenerationStepResult]:
if len(prompt) > 0 and isinstance(prompt[0], str):
prompt = [prompt]
async for step_result in AsyncGenerator(
generator.generate_batch,
prompt,
max_batch_size=max_batch_size,
batch_type=batch_type,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
disable_unk=disable_unk,
suppress_sequences=suppress_sequences,
end_token=end_token,
max_length=max_length,
min_length=min_length,
sampling_topk=sampling_topk,
sampling_topp=sampling_topp,
sampling_temperature=sampling_temperature,
return_scores=return_log_prob,
static_prompt=static_prompt,
cache_static_prompt=cache_static_prompt,
include_prompt_in_result=False,
callback=callback,
):
yield step_result
class AsyncGenerator:
def __init__(self, process_func, *args, **kwargs):
self.queue = asyncio.Queue()
self.shutdown_event = threading.Event()
self.iterator_task = None
self.process_func = process_func
self.args = args
self.kwargs = kwargs
async def producer(self):
for step_result in _generate_tokens(
self.process_func, *self.args, **self.kwargs
):
await self.queue.put(step_result)
await asyncio.sleep(0.0001)
if self.shutdown_event.is_set():
break
await self.queue.put(None)
def __aiter__(self):
self.iterator_task = asyncio.create_task(self.producer())
return self
async def __anext__(self):
if self.shutdown_event.is_set():
raise StopAsyncIteration
try:
item = await self.queue.get()
if item is None:
self.shutdown_event.set()
raise StopAsyncIteration
return item
except asyncio.CancelledError:
self.shutdown_event.set()
raise StopAsyncIteration
def _generate_tokens(process_func, *args, **kwargs):
step_results = queue.Queue()
generator_closed = threading.Event()
user_callback = kwargs.get("callback", None)
if user_callback is None:
user_callback = lambda step_result: False
def _callback(step_result):
user_callback_result = user_callback(step_result)
step_results.put(step_result)
return generator_closed.is_set() or user_callback_result
kwargs.update(
{
"asynchronous": True,
"beam_size": 1,
"callback": _callback,
}
)
async_results = process_func(*args, **kwargs)
def _catch_exception():
try:
for result in async_results:
result.result()
except Exception as e:
step_results.put(e)
step_results.put(None)
thread = threading.Thread(target=_catch_exception, daemon=True)
thread.start()
while True:
step_result = step_results.get()
if step_result is None:
break
if isinstance(step_result, Exception):
raise step_result
try:
yield step_result
except GeneratorExit:
generator_closed.set()
break
thread.join()
def _process_iterable(process_func, iterables, max_batch_size, batch_type, **kwargs):
if max_batch_size < 1:
raise ValueError("max_batch_size must be >= 1")
if len(iterables) == 1:
iterable = iterables[0]
else:
iterable = itertools.zip_longest(*iterables)
kwargs.update(
{
"max_batch_size": max_batch_size,
"batch_type": batch_type,
"asynchronous": True,
}
)
read_batch_size = max_batch_size * 16 if max_batch_size > 1 else max_batch_size
queue = collections.deque()
for streams in _batch_iterator(iterable, read_batch_size, batch_type):
queue.extend(process_func(*streams, **kwargs))
while queue and queue[0].done():
yield queue.popleft().result()
while queue:
yield queue.popleft().result()
def _batch_iterator(iterable, batch_size, batch_type):
streams = None
max_length = 0
for example in iterable:
if not isinstance(example, tuple):
example = (example,)
if batch_type == "examples":
if streams and len(streams[0]) == batch_size:
yield streams
streams = None
elif batch_type == "tokens":
max_length = max(max_length, len(example[0]))
if streams and (len(streams[0]) + 1) * max_length > batch_size:
yield streams
streams = None
max_length = len(example[0])
else:
raise ValueError("Invalid batch type %s" % batch_type)
if streams is None:
streams = tuple([] for _ in example)
for batch, element in zip(streams, example):
if element is None and len(streams) > 1:
raise ValueError("Input iterables do not have the same length")
batch.append(element)
if streams is not None:
yield streams