leibniz 0.1.0

The package provides a differentiable vector graphics rasterization loss.
Documentation
//! Linear roots.

use ::burn::tensor::{Bool, Tensor, backend::Backend};

use super::{EPSILON, Values, safe_divisor};

pub struct Roots<B: Backend> {
    pub t: Values<B>,
    pub valid: Tensor<B, 1, Bool>,
}

pub fn solve<B: Backend>(b: Values<B>, c: Values<B>) -> Roots<B> {
    let valid = b.clone().abs().greater_equal_elem(EPSILON);
    let t = -c / safe_divisor(b);

    Roots { t, valid }
}

#[cfg(test)]
mod tests {
    use ::burn::tensor::{Bool, Tensor};

    use super::Roots;
    use crate::burn::tests::Backend;

    #[test]
    fn rejects_degenerate() {
        let roots = solve([0.0], [1.0]);

        assert_bool(roots.valid, [false]);
    }

    #[test]
    fn solves_root() {
        let roots = solve([-2.0], [1.0]);

        assert_values(roots.t, [0.5]);
        assert_bool(roots.valid, [true]);
    }

    fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
        let actual = tensor.into_data().to_vec::<bool>().unwrap();

        assert_eq!(actual, expected);
    }

    fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
        let actual = tensor.into_data().to_vec::<f32>().unwrap();

        for (actual, expected) in actual.iter().zip(expected) {
            assert!((actual - expected).abs() < 1e-6);
        }
    }

    fn solve<const N: usize>(b: [f32; N], c: [f32; N]) -> Roots<Backend> {
        super::solve(
            Tensor::<Backend, 1>::from_floats(b, &Default::default()),
            Tensor::<Backend, 1>::from_floats(c, &Default::default()),
        )
    }
}