leibniz 0.2.0

The package provides a differentiable vector graphics rasterization loss.
Documentation
//! Linear roots.

use ::burn::tensor::{Bool, Tensor, backend::Backend};

use super::{EPSILON, safe_divisor};

pub struct Roots<B: Backend, const D: usize> {
    pub t: Tensor<B, D>,
    pub valid: Tensor<B, D, Bool>,
}

pub fn solve<B: Backend, const D: usize>(b: Tensor<B, D>, c: Tensor<B, D>) -> Roots<B, D> {
    let valid = b.clone().abs().greater_equal_elem(EPSILON);
    // Multiply by the negated reciprocal rather than dividing: in the winding
    // path, the constant is sample-sized while the divisor is a per-segment
    // column, and the sample-sized multiplication is cheaper than a division.
    // The product can deviate from the quotient in the last ulp, an accepted
    // departure from the reference arithmetic.
    let t = c * -safe_divisor(b).recip();

    Roots { t, valid }
}

#[cfg(test)]
mod tests {
    use ::burn::tensor::{Bool, Tensor};

    use super::Roots;
    use crate::burn::tests::Backend;

    #[test]
    fn rejects_degenerate() {
        let roots = solve([0.0], [1.0]);

        assert_bool(roots.valid, [false]);
    }

    #[test]
    fn solves_root() {
        let roots = solve([-2.0], [1.0]);

        assert_values(roots.t, [0.5]);
        assert_bool(roots.valid, [true]);
    }

    fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
        let actual = tensor.into_data().to_vec::<bool>().unwrap();

        assert_eq!(actual, expected);
    }

    fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
        let actual = tensor.into_data().to_vec::<f32>().unwrap();

        for (actual, expected) in actual.iter().zip(expected) {
            assert!((actual - expected).abs() < 1e-6);
        }
    }

    fn solve<const N: usize>(b: [f32; N], c: [f32; N]) -> Roots<Backend, 1> {
        super::solve(
            Tensor::<Backend, 1>::from_floats(b, &Default::default()),
            Tensor::<Backend, 1>::from_floats(c, &Default::default()),
        )
    }
}