use super::{
variable::{
BCELoss, BCELossBackward, BCEWithLogitsLoss, BCEWithLogitsLossBackward, KLDivLoss,
KLDivLossBackward, MAELoss, MAELossBackward, MSELoss, MSELossBackward, NLLLoss,
NLLLossBackward, Overwrite,
},
Data, Gradient, Var, VarDiff,
};
use ndarray::Dimension;
use std::fmt::Debug;
#[derive(Clone, Debug)]
pub enum Reduction {
Sum,
Mean,
}
pub fn mse_loss<T, U, V>(
mut input: VarDiff<T, U>,
target: Var<V>,
reduction: Reduction,
) -> VarDiff<MSELoss<T, V>, MSELossBackward<U, T, V>>
where
T: Data,
U: Gradient<Dim = T::Dim> + Overwrite,
V: Data<Dim = T::Dim>,
{
input.var.past.merge(target.past);
let forward_node = MSELoss::new(
input.var.node.clone(),
target.node.clone(),
reduction.clone(),
);
let var = Var::from(forward_node, input.var.past);
let backward_node = MSELossBackward::new(input.node, input.var.node, target.node, reduction);
VarDiff::from(backward_node, input.past, var)
}
pub fn mae_loss<T, U, V>(
mut input: VarDiff<T, U>,
target: Var<V>,
reduction: Reduction,
) -> VarDiff<MAELoss<T, V>, MAELossBackward<U, T, V>>
where
T: Data,
U: Gradient<Dim = T::Dim> + Overwrite,
V: Data<Dim = T::Dim>,
{
input.var.past.merge(target.past);
let forward_node = MAELoss::new(
input.var.node.clone(),
target.node.clone(),
reduction.clone(),
);
let var = Var::from(forward_node, input.var.past);
let backward_node = MAELossBackward::new(input.node, input.var.node, target.node, reduction);
VarDiff::from(backward_node, input.past, var)
}
pub fn bce_loss<T, U, V>(
mut input: VarDiff<T, U>,
target: Var<V>,
reduction: Reduction,
) -> VarDiff<BCELoss<T, V>, BCELossBackward<U, T, V>>
where
T: Data,
U: Gradient<Dim = T::Dim> + Overwrite,
V: Data<Dim = T::Dim>,
{
input.var.past.merge(target.past);
let forward_node = BCELoss::new(
input.var.node.clone(),
target.node.clone(),
reduction.clone(),
);
let var = Var::from(forward_node, input.var.past);
let backward_node = BCELossBackward::new(input.node, input.var.node, target.node, reduction);
VarDiff::from(backward_node, input.past, var)
}
pub fn bce_with_logits_loss<T, U, V>(
mut input: VarDiff<T, U>,
target: Var<V>,
reduction: Reduction,
) -> VarDiff<BCEWithLogitsLoss<T, V>, BCEWithLogitsLossBackward<U, T, V>>
where
T: Data,
U: Gradient<Dim = T::Dim> + Overwrite,
V: Data<Dim = T::Dim>,
{
input.var.past.merge(target.past);
let forward_node = BCEWithLogitsLoss::new(
input.var.node.clone(),
target.node.clone(),
reduction.clone(),
);
let var = Var::from(forward_node, input.var.past);
let backward_node =
BCEWithLogitsLossBackward::new(input.node, input.var.node, target.node, reduction);
VarDiff::from(backward_node, input.past, var)
}
pub fn nll_loss<T, U, V>(
mut input: VarDiff<T, U>,
target: Var<V>,
reduction: Reduction,
) -> VarDiff<NLLLoss<T, V>, NLLLossBackward<U, V>>
where
T: Data<Dim = <V::Dim as Dimension>::Larger>,
U: Gradient<Dim = T::Dim> + Overwrite,
V: Data,
T::Dim: Copy,
{
input.var.past.merge(target.past);
let forward_node = NLLLoss::new(
input.var.node.clone(),
target.node.clone(),
reduction.clone(),
);
let var = Var::from(forward_node, input.var.past);
let backward_node = NLLLossBackward::new(input.node, target.node, reduction);
VarDiff::from(backward_node, input.past, var)
}
pub fn kldiv_loss<T, U, V>(
mut input: VarDiff<T, U>,
target: Var<V>,
reduction: Reduction,
) -> VarDiff<KLDivLoss<T, V>, KLDivLossBackward<U, V>>
where
T: Data,
U: Gradient<Dim = T::Dim> + Overwrite,
V: Data<Dim = T::Dim>,
{
input.var.past.merge(target.past);
let forward_node = KLDivLoss::new(
input.var.node.clone(),
target.node.clone(),
reduction.clone(),
);
let var = Var::from(forward_node, input.var.past);
let backward_node = KLDivLossBackward::new(input.node, target.node, reduction);
VarDiff::from(backward_node, input.past, var)
}