leibniz 0.2.0

The package provides a differentiable vector graphics rasterization loss.
Documentation
//! Quadratic roots.

use ::burn::tensor::{Bool, Tensor, TensorData, backend::Backend};

use super::{EPSILON, linear, safe_divisor};

pub struct Roots<B: Backend> {
    /// Stacked roots with shape `[2, segments, samples]`.
    pub t: Tensor<B, 3>,
    /// Root validity with shape `[2, segments, samples]`, with the linear
    /// fallback folded into the first root.
    pub valid: Tensor<B, 3, Bool>,
    /// Linear fallback rows with shape `[segments, 1]`.
    pub valid_linear: Tensor<B, 2, Bool>,
}

pub fn solve<B: Backend>(a: Tensor<B, 2>, b: Tensor<B, 2>, c: Tensor<B, 2>) -> Roots<B> {
    let device = a.device();
    // Build validity masks directly; earlier zero/full helper tensors were a
    // measurable cost in root solving.
    let linear_roots = linear::solve(b.clone(), c.clone());
    let quadratic_mask = a.clone().abs().greater_equal_elem(EPSILON);
    // Keep the linear fallback separate so winding can specialize quadratic
    // root signs without evaluating the derivative at each root.
    let valid_linear = linear_roots
        .valid
        .bool_and(quadratic_mask.clone().bool_not());
    // Scale the per-segment column rather than the sample-sized product; the
    // factor is a power of two, so the fold is exact.
    let discriminant = b.clone().powi_scalar(2) - a.clone() * 4.0 * c;
    // Evaluate both roots in one stacked pass: the first root admits a
    // repeated root while the second requires a strictly positive
    // discriminant.
    let quadratic_valid = Tensor::stack::<3>(
        vec![
            discriminant.clone().greater_equal_elem(0.0),
            discriminant.clone().greater_elem(EPSILON),
        ],
        0,
    )
    .bool_and(quadratic_mask.unsqueeze::<3>());
    let valid = quadratic_valid.clone().bool_or(
        valid_linear
            .clone()
            .unsqueeze::<3>()
            .bool_and(first_root(&device)),
    );
    let root = discriminant.clamp_min(0.0).sqrt();
    // Share the reciprocal denominator across both quadratic roots, and carry
    // the root signs in the per-segment reciprocal columns so the stacked
    // roots cost one multiplication and one addition. The product can deviate
    // from the quotient in the last ulp, an accepted departure from the
    // reference arithmetic.
    let reciprocal = safe_divisor(a * 2.0).recip();
    let base = -b * reciprocal.clone();
    let signed_reciprocals =
        reciprocal.unsqueeze::<3>() * Tensor::from_floats([[[-1.0]], [[1.0]]], &device);
    let quadratic_t = base.unsqueeze::<3>() + root.unsqueeze::<3>() * signed_reciprocals;
    let t = linear_roots
        .t
        .unsqueeze::<3>()
        .mask_where(quadratic_valid, quadratic_t);

    Roots {
        t,
        valid,
        valid_linear,
    }
}

fn first_root<B: Backend>(device: &B::Device) -> Tensor<B, 3, Bool> {
    Tensor::from_data(TensorData::from([[[true]], [[false]]]), device)
}

#[cfg(test)]
mod tests {
    use ::burn::tensor::{Bool, Tensor};

    use super::Roots;
    use crate::burn::tests::Backend;

    #[test]
    fn preserves_linear_fallback() {
        let roots = solve([0.0], [-2.0], [1.0]);

        assert_values(root_values(&roots, 0), [0.5]);
        assert_bool(root_validity(&roots, 0), [true]);
        assert_bool(root_validity(&roots, 1), [false]);
        assert_bool(linear_validity(&roots), [true]);
    }

    #[test]
    fn rejects_without_real_roots() {
        let roots = solve([1.0], [1.0], [1.0]);

        assert_bool(root_validity(&roots, 0), [false]);
        assert_bool(root_validity(&roots, 1), [false]);
        assert_bool(linear_validity(&roots), [false]);
    }

    #[test]
    fn solves_roots() {
        let roots = solve([1.0], [0.0], [-1.0]);

        assert_values(root_values(&roots, 0), [-1.0]);
        assert_values(root_values(&roots, 1), [1.0]);
        assert_bool(root_validity(&roots, 0), [true]);
        assert_bool(root_validity(&roots, 1), [true]);
        assert_bool(linear_validity(&roots), [false]);
    }

    #[test]
    fn uses_single_repeated_root() {
        let roots = solve([1.0], [-2.0], [1.0]);

        assert_values(root_values(&roots, 0), [1.0]);
        assert_bool(root_validity(&roots, 0), [true]);
        assert_bool(root_validity(&roots, 1), [false]);
        assert_bool(linear_validity(&roots), [false]);
    }

    fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
        let actual = tensor.into_data().to_vec::<bool>().unwrap();

        assert_eq!(actual, expected);
    }

    fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
        let actual = tensor.into_data().to_vec::<f32>().unwrap();

        for (actual, expected) in actual.iter().zip(expected) {
            assert!((actual - expected).abs() < 1e-6);
        }
    }

    fn column<const N: usize>(values: [f32; N]) -> Tensor<Backend, 2> {
        Tensor::<Backend, 1>::from_floats(values, &Default::default()).reshape([N, 1])
    }

    fn linear_validity(roots: &Roots<Backend>) -> Tensor<Backend, 1, Bool> {
        let [segments, _] = roots.valid_linear.dims();

        roots.valid_linear.clone().reshape([segments])
    }

    fn root_validity(roots: &Roots<Backend>, index: usize) -> Tensor<Backend, 1, Bool> {
        let [_, segments, _] = roots.valid.dims();

        roots
            .valid
            .clone()
            .slice_dim(0, index..index + 1)
            .reshape([segments])
    }

    fn root_values(roots: &Roots<Backend>, index: usize) -> Tensor<Backend, 1> {
        let [_, segments, _] = roots.t.dims();

        roots
            .t
            .clone()
            .slice_dim(0, index..index + 1)
            .reshape([segments])
    }

    fn solve<const N: usize>(a: [f32; N], b: [f32; N], c: [f32; N]) -> Roots<Backend> {
        super::solve(column(a), column(b), column(c))
    }
}