logo
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
//! # Loss functions.
//!
//! The purpose of a loss function is to compute the quantity that a model should seek to minimize
//! during training.
//!
//! All losses are provided via function handles.
//!
//! ## Regression losses
//!
//! * [`mse_loss`] - Measures the mean squared error between each element in the input and the
//! target.
//!
//! * [`mae_loss`] - Measures the mean absolute error between each element in the input and the
//! target.
//!
//! ## Probabilistic losses
//!
//! * [`bce_loss`] - Measures the binary cross entropy between the target and the input.
//!
//! * [`bce_with_logits_loss`] - Measures the binary cross entropy with logits between the target
//! and the input.
//!
//! * [`nll_loss`] -  Measures the negative log likelihood between the target and the input.
//!
//! * [`kldiv_loss`] -  Measures the Kullback-Leibler divergence between the target and the input.
use super::{
    variable::{
        BCELoss, BCELossBackward, BCEWithLogitsLoss, BCEWithLogitsLossBackward, KLDivLoss,
        KLDivLossBackward, MAELoss, MAELossBackward, MSELoss, MSELossBackward, NLLLoss,
        NLLLossBackward, Overwrite,
    },
    Data, Gradient, Var, VarDiff,
};
use ndarray::Dimension;
use std::fmt::Debug;

/// Specifies the reduction to apply to the *loss* output.
#[derive(Clone, Debug)]
pub enum Reduction {
    /// The output will be summed.
    Sum,
    /// The sum of the output will be divided by the batch size for the [`kldiv_loss`] and the
    /// [`nll_loss`]. For all other losses the output will be divided by the number of elements.
    Mean,
}

/// Computes the **mean squared error** *(squared L2 norm)* between each element in the input x
/// and target y.
///
/// ```text
///        1   n
/// Lᴏss = ―   ∑ (xᵢ- ʏᵢ)²
///        n  i=1
/// ```
pub fn mse_loss<T, U, V>(
    mut input: VarDiff<T, U>,
    target: Var<V>,
    reduction: Reduction,
) -> VarDiff<MSELoss<T, V>, MSELossBackward<U, T, V>>
where
    T: Data,
    U: Gradient<Dim = T::Dim> + Overwrite,
    V: Data<Dim = T::Dim>,
{
    input.var.past.merge(target.past);
    let forward_node = MSELoss::new(
        input.var.node.clone(),
        target.node.clone(),
        reduction.clone(),
    );
    let var = Var::from(forward_node, input.var.past);

    let backward_node = MSELossBackward::new(input.node, input.var.node, target.node, reduction);
    VarDiff::from(backward_node, input.past, var)
}

/// Computes the **mean absolute error** *(MAE)* between each element in the input x and target y.
///
/// ```text
///        1   n
/// Lᴏss = ―   ∑ |xᵢ- ʏᵢ|
///        n  i=1
/// ```
pub fn mae_loss<T, U, V>(
    mut input: VarDiff<T, U>,
    target: Var<V>,
    reduction: Reduction,
) -> VarDiff<MAELoss<T, V>, MAELossBackward<U, T, V>>
where
    T: Data,
    U: Gradient<Dim = T::Dim> + Overwrite,
    V: Data<Dim = T::Dim>,
{
    input.var.past.merge(target.past);
    let forward_node = MAELoss::new(
        input.var.node.clone(),
        target.node.clone(),
        reduction.clone(),
    );
    let var = Var::from(forward_node, input.var.past);

    let backward_node = MAELossBackward::new(input.node, input.var.node, target.node, reduction);
    VarDiff::from(backward_node, input.past, var)
}

/// Computes the **binary cross entropy** between the target y and input x.
///
/// ```text
///        1   n
/// Lᴏss = ―   ∑ - [ʏᵢ * ln(xᵢ) + (1 - ʏᵢ) * ln(1 - xᵢ)]
///        n  i=1
/// ```
///
/// Note that the target y should be numbers between 0 and 1.
/// Notice that if a component of the input x is either 0 or 1,
/// one of the log terms would be mathematically undefined in the above loss equation.
/// Rust sets *ln(0) = -inf*, however, an infinite term in the loss equation is not desirable.
/// Our solution is that BCELoss clamps its log function outputs to be greater than or equal
/// to -100. This way, we can always have a finite loss value.
pub fn bce_loss<T, U, V>(
    mut input: VarDiff<T, U>,
    target: Var<V>,
    reduction: Reduction,
) -> VarDiff<BCELoss<T, V>, BCELossBackward<U, T, V>>
where
    T: Data,
    U: Gradient<Dim = T::Dim> + Overwrite,
    V: Data<Dim = T::Dim>,
{
    input.var.past.merge(target.past);
    let forward_node = BCELoss::new(
        input.var.node.clone(),
        target.node.clone(),
        reduction.clone(),
    );
    let var = Var::from(forward_node, input.var.past);

    let backward_node = BCELossBackward::new(input.node, input.var.node, target.node, reduction);
    VarDiff::from(backward_node, input.past, var)
}

/// Computes the **binary cross entropy with logits** between the target y and input x.
///
/// ```text
///        1   n
/// Lᴏss = ―   ∑  - [ʏᵢ * ln(σ(xᵢ)) + (1 - ʏᵢ) * ln(1 - σ(xᵢ))]
///        n  i=1
/// ```
/// This loss combines a sigmoid and a binary cross entropy.
/// This version is more numerically stable than using a plain sigmoid followed by a
/// binary cross entropy as, by combining the operations into one layer, we take
/// advantage of the log-sum-exp trick for numerical stability.
/// Note that the target y should be numbers between 0 and 1 and the
/// input x should be raw unnormalized scores.
pub fn bce_with_logits_loss<T, U, V>(
    mut input: VarDiff<T, U>,
    target: Var<V>,
    reduction: Reduction,
) -> VarDiff<BCEWithLogitsLoss<T, V>, BCEWithLogitsLossBackward<U, T, V>>
where
    T: Data,
    U: Gradient<Dim = T::Dim> + Overwrite,
    V: Data<Dim = T::Dim>,
{
    input.var.past.merge(target.past);
    let forward_node = BCEWithLogitsLoss::new(
        input.var.node.clone(),
        target.node.clone(),
        reduction.clone(),
    );
    let var = Var::from(forward_node, input.var.past);

    let backward_node =
        BCEWithLogitsLossBackward::new(input.node, input.var.node, target.node, reduction);
    VarDiff::from(backward_node, input.past, var)
}

/// Computes the **negative log likelihood** between the target y and input x.
///
/// ```text
///         1   n
/// Lᴏss =  ―   ∑  - xₙ,ᵧₙ
///         n  i=1
/// ```
///
/// The input x given is expected to contain log-probabilities for each class,
/// this is typically achieved by using [`.log_softmax()`]. input has to be a of size either
/// (minibatch, C) or (minibatch, C, d1, d2, ..., dk) with k >= 1 for the K-dimensional
/// case. The target that this loss expects should be a class index in the range [0, C) where
/// C = number of classes. When the given reduction is equal to [`Reduction::Mean`] the total
/// loss is divided by the batch size.
///
/// As mentioned before, this loss can also be used for higher dimensional inputs, such as 2D
/// images, by providing an input of size (minibatch, C, d1, d2, ..., dk) with k >= 1 where
/// k is the number of dimensions. In the case of images, it computes NLL loss *per-pixel*.
///
/// In the K-dimensional case this loss expects a target of shape
/// (minibatch, d1, d2, ..., dk).
///
/// [`.log_softmax()`]: VarDiff::log_softmax()
pub fn nll_loss<T, U, V>(
    mut input: VarDiff<T, U>,
    target: Var<V>,
    reduction: Reduction,
) -> VarDiff<NLLLoss<T, V>, NLLLossBackward<U, V>>
where
    T: Data<Dim = <V::Dim as Dimension>::Larger>,
    U: Gradient<Dim = T::Dim> + Overwrite,
    V: Data,
    T::Dim: Copy,
{
    input.var.past.merge(target.past);
    let forward_node = NLLLoss::new(
        input.var.node.clone(),
        target.node.clone(),
        reduction.clone(),
    );
    let var = Var::from(forward_node, input.var.past);

    let backward_node = NLLLossBackward::new(input.node, target.node, reduction);
    VarDiff::from(backward_node, input.past, var)
}

/// Computes the **Kullback-Leibler** divergence between the target and the input.
///
/// ```text
///         n
/// Lᴏss =  ∑  ʏₙ * (ln(ʏₙ) - xₙ)
///        i=1
/// ```
///
/// The [Kullback-Leibler divergence](https://en.wikipedia.org/wiki/Kullback–Leibler_divergence) is
/// a useful distance measure for continuous distributions and is often useful when performing
/// direct regression over the space of (discretely sampled) continuous output distributions.
///
/// The input given is expected to contain log-probabilities and is not restricted to a 2D Tensor,
/// while the targets are interpreted as probabilities. When the given reduction is equal
/// to [`Reduction::Mean`] the total loss is divided by the batch size.
///
/// This criterion expects a target variable of the same size as the input variable.
pub fn kldiv_loss<T, U, V>(
    mut input: VarDiff<T, U>,
    target: Var<V>,
    reduction: Reduction,
) -> VarDiff<KLDivLoss<T, V>, KLDivLossBackward<U, V>>
where
    T: Data,
    U: Gradient<Dim = T::Dim> + Overwrite,
    V: Data<Dim = T::Dim>,
{
    input.var.past.merge(target.past);
    let forward_node = KLDivLoss::new(
        input.var.node.clone(),
        target.node.clone(),
        reduction.clone(),
    );
    let var = Var::from(forward_node, input.var.past);

    let backward_node = KLDivLossBackward::new(input.node, target.node, reduction);
    VarDiff::from(backward_node, input.past, var)
}