Function _gradient_checker

Source
pub fn _gradient_checker(
    op: &mut dyn OpTrait,
    one_input: &[Tensor],
    input_mask: Option<&[bool]>,
    step: Option<Tensor>,
    tolerance: Option<Tensor>,
) -> bool
Expand description

Verify the gradient implementation is right.

op: the tested operator. one_input: test data points. input_mask: may skip some data point if the element is false. step: delta that is used for numeric difference. tolerance: numeric tolerance for equality.

one_input and input_mask should have the same size. step and tolerance are both scalar.