from __future__ import annotations
from typing import List
__all__ = ["decode_from_logits"]
def decode_from_logits(
logits: "numpy.ndarray", temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
) -> int:
import numpy as np
arr = np.asarray(logits, dtype=np.float32)
if arr.ndim != 1:
raise ValueError(f"Expected 1-D logits array, got shape {arr.shape}")
if temperature != 1.0 and temperature > 0.0:
arr = arr / temperature
if top_k > 0:
kth = np.partition(arr, -top_k)[-top_k]
arr = np.where(arr >= kth, arr, -np.inf)
arr = arr - arr.max()
exp_arr = np.exp(arr)
probs = exp_arr / exp_arr.sum()
if top_p < 1.0:
sorted_idx = np.argsort(probs)[::-1]
cumprobs = np.cumsum(probs[sorted_idx])
cutoff = int(np.searchsorted(cumprobs, top_p))
allowed = sorted_idx[: cutoff + 1]
mask = np.zeros_like(probs)
mask[allowed] = 1.0
probs = probs * mask
total = float(probs.sum())
if total > 0:
probs = probs / total
else:
probs = mask / mask.sum()
return int(np.argmax(probs))