leibniz 0.1.0

The package provides a differentiable vector graphics rasterization loss.
Documentation
//! Quadratic roots.

use ::burn::tensor::{Bool, Tensor, backend::Backend};

use super::{EPSILON, Values, linear, safe_divisor};

pub struct Roots<B: Backend> {
    pub t0: Values<B>,
    pub t1: Values<B>,
    pub valid_linear: Tensor<B, 1, Bool>,
    pub valid_quadratic0: Tensor<B, 1, Bool>,
    pub valid_quadratic1: Tensor<B, 1, Bool>,
}

pub fn solve<B: Backend>(a: Values<B>, b: Values<B>, c: Values<B>) -> Roots<B> {
    // Build validity masks directly; earlier zero/full helper tensors were a
    // measurable cost in root solving.
    let linear_roots = linear::solve(b.clone(), c.clone());
    let quadratic_mask = a.clone().abs().greater_equal_elem(EPSILON);
    // Keep the linear fallback separate so winding can specialize quadratic
    // root signs without evaluating the derivative at each root.
    let valid_linear = linear_roots
        .valid
        .bool_and(quadratic_mask.clone().bool_not());
    let discriminant = b.clone().powi_scalar(2) - a.clone() * c.clone() * 4.0;
    let valid_quadratic0 = quadratic_mask
        .clone()
        .bool_and(discriminant.clone().greater_equal_elem(0.0));
    let valid_quadratic1 = quadratic_mask.bool_and(discriminant.clone().greater_elem(EPSILON));
    let root = discriminant.clamp_min(0.0).sqrt();
    let denominator = safe_divisor(a.clone() * 2.0);
    // Share the denominator across both quadratic roots to avoid one
    // sample-sized division pass.
    let quadratic_base = -b / denominator.clone();
    let scaled_root = root / denominator;
    let quadratic_t0 = quadratic_base.clone() - scaled_root.clone();
    let quadratic_t1 = quadratic_base + scaled_root;
    let t0 = linear_roots
        .t
        .mask_where(valid_quadratic0.clone(), quadratic_t0);

    Roots {
        t0,
        t1: quadratic_t1,
        valid_linear,
        valid_quadratic0,
        valid_quadratic1,
    }
}

#[cfg(test)]
mod tests {
    use ::burn::tensor::{Bool, Tensor};

    use super::Roots;
    use crate::burn::tests::Backend;

    #[test]
    fn preserves_linear_fallback() {
        let roots = solve([0.0], [-2.0], [1.0]);

        assert_values(roots.t0, [0.5]);
        assert_bool(roots.valid_linear, [true]);
        assert_bool(roots.valid_quadratic0, [false]);
        assert_bool(roots.valid_quadratic1, [false]);
    }

    #[test]
    fn rejects_without_real_roots() {
        let roots = solve([1.0], [1.0], [1.0]);

        assert_bool(roots.valid_linear, [false]);
        assert_bool(roots.valid_quadratic0, [false]);
        assert_bool(roots.valid_quadratic1, [false]);
    }

    #[test]
    fn solves_roots() {
        let roots = solve([1.0], [0.0], [-1.0]);

        assert_values(roots.t0, [-1.0]);
        assert_values(roots.t1, [1.0]);
        assert_bool(roots.valid_linear, [false]);
        assert_bool(roots.valid_quadratic0, [true]);
        assert_bool(roots.valid_quadratic1, [true]);
    }

    #[test]
    fn uses_single_repeated_root() {
        let roots = solve([1.0], [-2.0], [1.0]);

        assert_values(roots.t0, [1.0]);
        assert_bool(roots.valid_linear, [false]);
        assert_bool(roots.valid_quadratic0, [true]);
        assert_bool(roots.valid_quadratic1, [false]);
    }

    fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
        let actual = tensor.into_data().to_vec::<bool>().unwrap();

        assert_eq!(actual, expected);
    }

    fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
        let actual = tensor.into_data().to_vec::<f32>().unwrap();

        for (actual, expected) in actual.iter().zip(expected) {
            assert!((actual - expected).abs() < 1e-6);
        }
    }

    fn solve<const N: usize>(a: [f32; N], b: [f32; N], c: [f32; N]) -> Roots<Backend> {
        super::solve(
            Tensor::<Backend, 1>::from_floats(a, &Default::default()),
            Tensor::<Backend, 1>::from_floats(b, &Default::default()),
            Tensor::<Backend, 1>::from_floats(c, &Default::default()),
        )
    }
}