#include "megbrain/opr/training/loss.h"
#include "megbrain/exception.h"
#include "megbrain/opr/indexing.h"
namespace mgb {
namespace loss {
CrossEntropyLoss::CrossEntropyLoss(
bool with_logits, float label_smooth, ReduceMode reduce_mode, int axis)
: m_with_logits(with_logits),
m_label_smooth(label_smooth),
m_reduce_mode(reduce_mode),
m_axis(axis) {}
SymbolVar CrossEntropyLoss::operator()(
mgb::SymbolVar symbol_pred, mgb::SymbolVar symbol_label) {
mgb_assert(
symbol_pred.shape().ndim >= symbol_label.shape().ndim,
"The label must have less dimensions than the pred.");
for (size_t i = 0; i < symbol_label.shape().ndim; i++) {
mgb_assert(
symbol_pred.shape()[i] == symbol_label.shape()[i] || (int)i == m_axis,
"Unmatched shape for pred and label.");
}
mgb_assert(m_label_smooth >= .0f, "The label_smmoth must be positive value");
SymbolVar symbol_loss;
SymbolVar symbol_middle;
SymbolVar symbol_max = opr::reduce_ax_max(symbol_pred, m_axis);
SymbolVar symbol_primary_item =
opr::IndexingOneHot::make(symbol_pred, symbol_label, {m_axis});
if (m_with_logits) {
symbol_middle = opr::reduce_ax_sum(symbol_pred, m_axis) /
opr::GetVarShape::make(symbol_pred, {m_axis});
SymbolVar symbol_logits =
symbol_max + opr::log(opr::reduce_ax_sum(
opr::exp(symbol_pred - symbol_max), m_axis));
symbol_loss = symbol_logits;
} else {
symbol_middle = opr::reduce_ax_sum(opr::log(symbol_pred), m_axis) /
opr::GetVarShape::make(symbol_pred, {m_axis});
symbol_primary_item = opr::log(symbol_primary_item);
}
if (m_label_smooth > .0f) {
symbol_loss = symbol_loss - m_label_smooth * symbol_middle -
(1 - m_label_smooth) * symbol_primary_item;
} else {
symbol_loss = symbol_loss - symbol_primary_item;
}
if (m_reduce_mode == ReduceMode::MEAN) {
symbol_loss =
opr::reduce_sum(symbol_loss.flatten(), symbol_loss.make_scalar(1)) /
(float)(symbol_loss.shape().total_nr_elems());
} else if (m_reduce_mode == ReduceMode::SUM) {
symbol_loss =
opr::reduce_sum(symbol_loss.flatten(), symbol_loss.make_scalar(1));
}
return symbol_loss;
}
MSELoss::MSELoss(ReduceMode reduce_mode) : m_reduce_mode(reduce_mode){};
mgb::SymbolVar MSELoss::operator()(
mgb::SymbolVar symbol_pred, mgb::SymbolVar symol_label) {
return opr::pow(symbol_pred - symol_label, symbol_pred.make_scalar(2));
}
}
}