#[cfg(test)]
use super::{assert_almost_equals, new_backward_input, new_input, new_tensor};
use super::{
expect_tensor, expect_tensor_mut, Backward, Data, Forward, Gradient, Overwrite, Reduction,
Tensor,
};
use ndarray::{arr0, Ix0, Zip};
use std::{
cell::{Cell, Ref, RefCell, RefMut},
fmt::{Debug, Display},
rc::Rc,
};
#[allow(clippy::upper_case_acronyms)]
pub struct MAELoss<T, U>
where
T: Data,
U: Data<Dim = T::Dim>,
{
input: Rc<T>,
target: Rc<U>,
data: RefCell<Tensor<Ix0>>,
reduction: Reduction,
computed: Cell<bool>,
}
impl<T, U> MAELoss<T, U>
where
T: Data,
U: Data<Dim = T::Dim>,
{
pub(crate) fn new(input: Rc<T>, target: Rc<U>, reduction: Reduction) -> Self {
Self {
input,
target,
data: RefCell::new(arr0(0.)),
reduction,
computed: Cell::new(false),
}
}
}
impl<T, U> Data for MAELoss<T, U>
where
T: Data,
U: Data<Dim = T::Dim>,
{
type Dim = Ix0;
fn data(&self) -> Ref<Tensor<Self::Dim>> {
self.data.borrow()
}
fn data_mut(&self) -> RefMut<Tensor<Self::Dim>> {
self.data.borrow_mut()
}
}
impl<T, U> Forward for MAELoss<T, U>
where
T: Data,
U: Data<Dim = T::Dim>,
{
fn forward(&self) {
if self.was_computed() {
return;
}
self.computed.set(true);
let (mut loss_data, input_data, target_data) = {
(
self.data.borrow_mut(),
self.input.data(),
self.target.data(),
)
};
*loss_data = {
let total_loss = Zip::from(&*input_data)
.and(&*target_data)
.fold(0.0, |loss, input, target| loss + (input - target).abs());
match self.reduction {
Reduction::Mean => arr0(total_loss / input_data.len() as f32),
Reduction::Sum => arr0(total_loss),
}
};
}
fn was_computed(&self) -> bool {
self.computed.get()
}
fn reset_computation(&self) {
self.computed.set(false);
}
}
impl<T, U> Debug for MAELoss<T, U>
where
T: Data,
U: Data<Dim = T::Dim>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MAELoss")
.field("data", &self.data.borrow())
.field("reduction", &self.reduction)
.field("computed", &self.computed.get())
.finish()
}
}
impl<T, U> Display for MAELoss<T, U>
where
T: Data,
U: Data<Dim = T::Dim>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(f, "{}", &self.data.borrow())
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
gradient: RefCell<Option<Tensor<Ix0>>>,
overwrite: Cell<bool>,
diff_input: Rc<T>,
input: Rc<U>,
target: Rc<V>,
reduction: Reduction,
}
impl<T, U, V> MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
pub(crate) fn new(
diff_input: Rc<T>,
input: Rc<U>,
target: Rc<V>,
reduction: Reduction,
) -> Self {
Self {
diff_input,
input,
target,
gradient: RefCell::new(Some(arr0(0.))),
reduction,
overwrite: Cell::new(true),
}
}
}
impl<T, U, V> Gradient for MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
type Dim = Ix0;
fn gradient(&self) -> Ref<Tensor<Self::Dim>> {
expect_tensor(&self.gradient)
}
fn gradient_mut(&self) -> RefMut<Tensor<Self::Dim>> {
expect_tensor_mut(&self.gradient)
}
}
impl<T, U, V> Overwrite for MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
fn can_overwrite(&self) -> bool {
self.overwrite.get()
}
fn set_overwrite(&self, state: bool) {
self.overwrite.set(state);
}
}
impl<T, U, V> Backward for MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
fn backward(&self) {
let (mut operand_gradient, gradient, input_data, target_data) = {
(
self.diff_input.gradient_mut(),
self.gradient(),
self.input.data(),
self.target.data(),
)
};
let zip = Zip::from(&mut *operand_gradient)
.and_broadcast(&*gradient)
.and(&*input_data)
.and(&*target_data);
match self.reduction {
Reduction::Mean => {
let n = input_data.len() as f32;
if self.diff_input.can_overwrite() {
zip.for_each(|op_grad, grad, input, target| {
let diff = input - target;
*op_grad = if diff != 0. {
diff.signum() * grad / n
} else {
0.
}
});
self.diff_input.set_overwrite(false);
} else {
zip.for_each(|op_grad, grad, input, target| {
let diff = input - target;
*op_grad += if diff != 0. {
diff.signum() * grad / n
} else {
0.
}
});
}
}
Reduction::Sum => {
if self.diff_input.can_overwrite() {
zip.for_each(|op_grad, grad, input, target| {
let diff = input - target;
*op_grad = if diff != 0. { diff.signum() * grad } else { 0. }
});
self.diff_input.set_overwrite(false);
} else {
zip.for_each(|op_grad, grad, input, target| {
let diff = input - target;
*op_grad += if diff != 0. { diff.signum() * grad } else { 0. }
});
}
}
}
}
fn no_grad(&self) {
*self.gradient.borrow_mut() = None;
}
fn with_grad(&self) {
*self.gradient.borrow_mut() = Some(arr0(0.));
}
}
impl<T, U, V> Debug for MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MAELossBackward")
.field("gradient", &self.gradient.borrow())
.field("reduction", &self.reduction)
.field("overwrite", &self.overwrite.get())
.finish()
}
}
impl<T, U, V> Display for MAELossBackward<T, U, V>
where
T: Gradient + Overwrite,
U: Data<Dim = T::Dim>,
V: Data<Dim = T::Dim>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match &*self.gradient.borrow() {
Some(gradient) => write!(f, "{}", &gradient),
None => write!(f, "None"),
}
}
}
#[cfg(test)]
mod test;