from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple
__all__ = [
"ProgressEvent",
"make_progress_adapter",
]
@dataclass(frozen=True, slots=True)
class ProgressEvent:
tokens_generated: int
tokens_total: Optional[int]
elapsed_secs: float
tokens_per_sec: float
eta_secs: Optional[float]
is_final: bool
text_so_far: str
class _TqdmAdapter:
__slots__ = ("_pbar", "_max_tokens", "_first", "_last_tokens")
def __init__(self, pbar: Any, max_tokens: int) -> None:
self._pbar = pbar
self._max_tokens = max_tokens
self._first = True
self._last_tokens = 0
def __call__(self, event: ProgressEvent) -> None:
if self._first:
try:
self._pbar.total = event.tokens_total or self._max_tokens
except Exception:
pass
self._first = False
delta = event.tokens_generated - self._last_tokens
if delta > 0:
try:
self._pbar.update(delta)
except Exception:
pass
self._last_tokens = event.tokens_generated
try:
self._pbar.set_postfix_str(
f"{event.tokens_per_sec:.1f} tok/s", refresh=False
)
except Exception:
pass
def finalise(self, error: Optional[BaseException]) -> None:
if error is not None:
label = type(error).__name__
try:
self._pbar.set_postfix_str(f"error: {label}", refresh=False)
except Exception:
pass
try:
self._pbar.close()
except Exception:
pass
class _IPyWidgetAdapter:
__slots__ = ("_w", "_max_tokens", "_first")
def __init__(self, widget: Any, max_tokens: int) -> None:
self._w = widget
self._max_tokens = max_tokens
self._first = True
def __call__(self, event: ProgressEvent) -> None:
if self._first:
try:
self._w.max = event.tokens_total or self._max_tokens
except Exception:
pass
self._first = False
try:
self._w.value = event.tokens_generated
except Exception:
pass
try:
self._w.description = f"{event.tokens_per_sec:.1f} tok/s"
except Exception:
pass
def finalise(self, error: Optional[BaseException]) -> None:
if error is None:
style = "success"
elif "Cancel" in type(error).__name__:
style = "warning"
else:
style = "danger"
try:
self._w.bar_style = style
except Exception:
pass
try:
self._w.value = self._w.max
except Exception:
pass
class _CallableAdapter:
__slots__ = ("_fn",)
def __init__(self, fn: Callable[[ProgressEvent], None]) -> None:
self._fn = fn
def __call__(self, event: ProgressEvent) -> None:
self._fn(event)
def finalise(self, error: Optional[BaseException]) -> None:
return None
def _is_tqdm(obj: Any) -> bool:
return all(hasattr(obj, attr) for attr in ("update", "set_postfix_str", "close"))
def _is_ipywidget(obj: Any) -> bool:
if not (hasattr(obj, "value") and hasattr(obj, "max")):
return False
return "Progress" in type(obj).__name__
def make_progress_adapter(
obj: Any, max_tokens: int
) -> Tuple[Optional[Callable[[ProgressEvent], None]], Optional[Callable[[Optional[BaseException]], None]]]:
if obj is None:
return (None, None)
if _is_tqdm(obj):
adapter: Any = _TqdmAdapter(obj, max_tokens)
elif _is_ipywidget(obj):
adapter = _IPyWidgetAdapter(obj, max_tokens)
elif callable(obj):
adapter = _CallableAdapter(obj)
else:
raise TypeError(
"progress must be a tqdm pbar, ipywidgets.IntProgress, callable, "
f"or None; got {type(obj).__name__}"
)
return (adapter.__call__, adapter.finalise)
def _build_bridge(
obj: Any, max_tokens: int
) -> Tuple[Callable[[Tuple[int, float, bool, str]], None], Callable[[Optional[BaseException]], None]]:
cb, fin = make_progress_adapter(obj, max_tokens)
if cb is None or fin is None:
def _noop_callback(_payload: Tuple[int, float, bool, str]) -> None:
return None
def _noop_finaliser(_error: Optional[BaseException]) -> None:
return None
return (_noop_callback, _noop_finaliser)
cb_resolved: Callable[[ProgressEvent], None] = cb
fin_resolved: Callable[[Optional[BaseException]], None] = fin
def _wrapped_callback(payload: Tuple[int, float, bool, str]) -> None:
tokens, elapsed_secs, is_final, text_so_far = payload
if tokens >= 2 and elapsed_secs > 0:
tps = tokens / elapsed_secs
remaining = max(max_tokens - tokens, 0)
eta: Optional[float] = remaining / tps if tps > 0 else None
else:
eta = None
tps = (tokens / elapsed_secs) if elapsed_secs > 0 else 0.0
event = ProgressEvent(
tokens_generated=tokens,
tokens_total=max_tokens if max_tokens > 0 else None,
elapsed_secs=elapsed_secs,
tokens_per_sec=tps,
eta_secs=eta,
is_final=is_final,
text_so_far=text_so_far,
)
cb_resolved(event)
return (_wrapped_callback, fin_resolved)