1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
use crate::tensor_ops::cpu_kernels::BinaryDerivative;
use num_traits::Float;

impl<F: Float> BinaryDerivative<F> for super::BCEKernelOp {
    #[inline(always)]
    fn f(&self, &logit: &F, &prob: &F) -> F {
        logit.max(F::zero()) - logit * prob + (F::one() + (-logit.abs()).exp()).ln()
    }
    #[inline(always)]
    fn dfdx(&self, &logit: &F, &prob: &F) -> F {
        F::one() - prob - (F::one() + logit.exp()).recip()
    }
    #[inline(always)]
    fn dfdy(&self, &logit: &F, _: &F) -> F {
        -logit
    }
}