import numpy as np
import pytest
import megengine.functional as F
from megengine import tensor
def test_cross_entropy_with_logits():
data = tensor([[0, 50], [0, -150]]).astype(np.float32)
label = tensor([1, 0]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0)
label = tensor([0, 1]).astype(np.int32)
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 100)
label = np.array([1, 0])
loss = F.nn.cross_entropy(data, label)
np.testing.assert_allclose(loss.numpy(), 0.0)
def test_cross_entropy():
def softmax(x):
x = np.exp(x)
x /= x.sum(1, keepdims=True)
return x
def ref(x, y):
return np.mean([-np.log(x[i, y[i]]) for i in range(len(y))])
x = (np.random.rand(5, 10) - 0.5) * 4
y = np.random.randint(10, size=(5,))
for i in range(len(x)):
x[i, y[i]] += np.random.rand() * 2
x = softmax(x)
l_ref = ref(x, y)
l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False)
np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6)
def test_cross_entropy_reduction():
logits = np.random.randn(16, 10)
label = np.random.randint(10, size=[16])
logits = tensor(logits, dtype="float32")
label = tensor(label, dtype="int32")
perm = np.random.permutation(16)
logits_perm = tensor(logits[perm], dtype="float32")
label_perm = tensor(label[perm], dtype="int32")
loss = F.nn.cross_entropy(logits, label, reduction="none")
loss_perm = F.nn.cross_entropy(logits_perm, label_perm, reduction="none")
np.testing.assert_allclose(loss.numpy()[perm], loss_perm.numpy())
loss_sum = F.nn.cross_entropy(logits, label, reduction="sum")
np.testing.assert_allclose(loss.numpy().sum(), loss_sum.numpy(), rtol=2e-7)
loss_mean = F.nn.cross_entropy(logits, label, reduction="mean")
np.testing.assert_allclose(loss_mean.numpy(), loss_sum.numpy() / 16)
loss_ls = F.nn.cross_entropy(logits, label, reduction="mean", label_smooth=0.1)
loss_ls_none_reduce = F.nn.cross_entropy(
logits, label, reduction="none", label_smooth=0.1
)
np.testing.assert_allclose(
loss_ls.numpy(), loss_ls_none_reduce.numpy().mean(), rtol=2e-7
)
with pytest.raises(ValueError):
F.nn.cross_entropy(logits, label, reduction="MEAN")
with pytest.raises(ValueError):
F.nn.cross_entropy(logits, label, reduction="max")
def ctc_nll_naive_npy(
pred,
pred_lengths,
label,
label_lengths,
blank=0,
reduction="mean",
time_major=False,
):
pred = np.asarray(pred, dtype=np.float32)
pred_lengths = np.asarray(pred_lengths, dtype=np.int8)
label = np.asarray(label, dtype=np.int32)
label_lengths = np.asarray(label_lengths, dtype=np.int32)
if time_major:
pred = np.transpose(pred, (1, 0, 2))
batch_size, time_len, nr_class = pred.shape
assert pred_lengths.shape == (batch_size,) and pred_lengths.max() <= pred.shape[1]
assert label_lengths.shape == (batch_size,)
assert label.shape == (label_lengths.sum(),) and label.max() < nr_class
ret = np.empty((batch_size,), dtype=np.float32)
label_start = 0
for i in range(batch_size):
label_end = label_start + label_lengths[i]
ret[i] = _ctc_npy_single_seq(
pred[i][: pred_lengths[i]], label[label_start:label_end], blank
)
label_start = label_end
if reduction == "mean":
return (ret / label_lengths).mean()
elif reduction == "sum":
return ret.sum()
elif reduction == "none":
return ret
else:
raise ValueError("{} is not a valid value for reduction".format(reduction))
def _ctc_npy_single_seq(pred, label, blank):
def safelog(x):
eps = np.finfo(x.dtype).tiny
return np.log(np.maximum(x, eps))
def log_sum_exp(x, y):
x, y = np.maximum(x, y), np.minimum(x, y)
return x + np.log1p(np.exp(y - x))
assert np.abs(pred.sum(axis=1) - 1).max() <= 1e-3
len_pred, alphabet_size = pred.shape
(len_label,) = label.shape
len_ex_label = len_label * 2 + 1
ex_label = (np.zeros(len_ex_label)).astype(np.int32) + blank
ex_label[1::2] = label
prob = np.zeros(len_ex_label, dtype=np.float32)
prob[0] = pred[0][ex_label[0]]
prob[1] = pred[0][ex_label[1]]
prob = safelog(prob)
ex_label_pmask = ex_label[2:] != ex_label[:-2]
for t in range(1, len_pred):
new_prob = prob.copy()
new_prob[1:] = log_sum_exp(new_prob[1:], prob[:-1])
new_prob[2:] = (
new_prob[2:] * (1 - ex_label_pmask)
+ log_sum_exp(new_prob[2:], prob[:-2]) * ex_label_pmask
)
new_prob += safelog(pred[t, ex_label])
prob = new_prob
return -log_sum_exp(prob[-1], prob[-2])
def test_ctc_loss():
def test_func(T, C, N):
input = np.random.randn(T, N, C)
input = F.softmax(tensor(input), axis=-1).numpy()
input_lengths = np.ones(N, dtype=np.int32) * T
target_lengths = np.random.randint(low=1, high=T + 1, size=(N,), dtype=np.int32)
target = np.random.randint(
low=1, high=C, size=(sum(target_lengths)), dtype=np.int32
)
input_mge = tensor(input)
input_lengths_mge = tensor(input_lengths)
target_mge = tensor(target)
target_lengths_mge = tensor(target_lengths)
blank = np.random.randint(C)
for method in ["mean", "sum", "none"]:
np_out = ctc_nll_naive_npy(
input,
input_lengths,
target,
target_lengths,
blank=blank,
reduction=method,
time_major=True,
)
mge_out = F.nn.ctc_loss(
input_mge,
input_lengths_mge,
target_mge,
target_lengths_mge,
blank=blank,
reduction=method,
)
np.testing.assert_allclose(mge_out.numpy(), np_out, rtol=2e-6)
cases = [[1, 2, 1], [100, 50, 200], [100, 5, 1]]
for case in cases:
test_func(*case)