1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
use crate::tensor_ops::cpu_kernels::BinaryDerivative;
use num_traits::Float;
impl<F: Float> BinaryDerivative<F> for super::BCEKernelOp {
#[inline(always)]
fn f(&self, &logit: &F, &prob: &F) -> F {
logit.max(F::zero()) - logit * prob + (F::one() + (-logit.abs()).exp()).ln()
}
#[inline(always)]
fn dfdx(&self, &logit: &F, &prob: &F) -> F {
F::one() - prob - (F::one() + logit.exp()).recip()
}
#[inline(always)]
fn dfdy(&self, &logit: &F, _: &F) -> F {
-logit
}
}