from typing import Iterable, Union
import numpy as np
from ..tensor import Tensor
from .elemwise import abs, maximum, minimum
from .math import topk as _topk
from .tensor import broadcast_to, transpose
__all__ = [
"topk_accuracy",
]
def topk_accuracy(
logits: Tensor, target: Tensor, topk: Union[int, Iterable[int]] = 1
) -> Union[Tensor, Iterable[Tensor]]:
if isinstance(topk, int):
topk = (topk,)
_, pred = _topk(logits, k=max(topk), descending=True)
accs = []
for k in topk:
correct = pred[:, :k].detach() == broadcast_to(
transpose(target, (0, "x")), (target.shape[0], k)
)
accs.append(correct.astype(np.float32).sum() / target.shape[0])
if len(topk) == 1: accs = accs[0]
return accs