Function dfdx::tensor_ops::huber_error
source · pub fn huber_error<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D> + Merge<R>, R: Tape<E, D>>(
lhs: Tensor<S, E, D, T>,
rhs: Tensor<S, E, D, R>,
delta: impl Into<f64>
) -> Tensor<S, E, D, T>
Expand description
Huber Loss
uses absolute error when the error is higher than beta
, and squared error when the
error is lower than beta
.
It computes:
- if
|x - y| < delta
:0.5 * (x - y)^2
- otherwise:
delta * (|x - y| - 0.5 * delta)
let a = dev.tensor([1.0, 1.0, 1.0]);
let b = dev.tensor([1.5, 1.75, 2.5]);
let r = a.huber_error(b, 1.0);
assert_eq!(r.array(), [0.125, 0.28125, 1.0]);