use ::burn::tensor::{Bool, Tensor, backend::Backend};
use super::{EPSILON, Values, linear, safe_divisor};
pub struct Roots<B: Backend> {
pub t0: Values<B>,
pub t1: Values<B>,
pub valid_linear: Tensor<B, 1, Bool>,
pub valid_quadratic0: Tensor<B, 1, Bool>,
pub valid_quadratic1: Tensor<B, 1, Bool>,
}
pub fn solve<B: Backend>(a: Values<B>, b: Values<B>, c: Values<B>) -> Roots<B> {
let linear_roots = linear::solve(b.clone(), c.clone());
let quadratic_mask = a.clone().abs().greater_equal_elem(EPSILON);
let valid_linear = linear_roots
.valid
.bool_and(quadratic_mask.clone().bool_not());
let discriminant = b.clone().powi_scalar(2) - a.clone() * c.clone() * 4.0;
let valid_quadratic0 = quadratic_mask
.clone()
.bool_and(discriminant.clone().greater_equal_elem(0.0));
let valid_quadratic1 = quadratic_mask.bool_and(discriminant.clone().greater_elem(EPSILON));
let root = discriminant.clamp_min(0.0).sqrt();
let denominator = safe_divisor(a.clone() * 2.0);
let quadratic_base = -b / denominator.clone();
let scaled_root = root / denominator;
let quadratic_t0 = quadratic_base.clone() - scaled_root.clone();
let quadratic_t1 = quadratic_base + scaled_root;
let t0 = linear_roots
.t
.mask_where(valid_quadratic0.clone(), quadratic_t0);
Roots {
t0,
t1: quadratic_t1,
valid_linear,
valid_quadratic0,
valid_quadratic1,
}
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Bool, Tensor};
use super::Roots;
use crate::burn::tests::Backend;
#[test]
fn preserves_linear_fallback() {
let roots = solve([0.0], [-2.0], [1.0]);
assert_values(roots.t0, [0.5]);
assert_bool(roots.valid_linear, [true]);
assert_bool(roots.valid_quadratic0, [false]);
assert_bool(roots.valid_quadratic1, [false]);
}
#[test]
fn rejects_without_real_roots() {
let roots = solve([1.0], [1.0], [1.0]);
assert_bool(roots.valid_linear, [false]);
assert_bool(roots.valid_quadratic0, [false]);
assert_bool(roots.valid_quadratic1, [false]);
}
#[test]
fn solves_roots() {
let roots = solve([1.0], [0.0], [-1.0]);
assert_values(roots.t0, [-1.0]);
assert_values(roots.t1, [1.0]);
assert_bool(roots.valid_linear, [false]);
assert_bool(roots.valid_quadratic0, [true]);
assert_bool(roots.valid_quadratic1, [true]);
}
#[test]
fn uses_single_repeated_root() {
let roots = solve([1.0], [-2.0], [1.0]);
assert_values(roots.t0, [1.0]);
assert_bool(roots.valid_linear, [false]);
assert_bool(roots.valid_quadratic0, [true]);
assert_bool(roots.valid_quadratic1, [false]);
}
fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
let actual = tensor.into_data().to_vec::<bool>().unwrap();
assert_eq!(actual, expected);
}
fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
let actual = tensor.into_data().to_vec::<f32>().unwrap();
for (actual, expected) in actual.iter().zip(expected) {
assert!((actual - expected).abs() < 1e-6);
}
}
fn solve<const N: usize>(a: [f32; N], b: [f32; N], c: [f32; N]) -> Roots<Backend> {
super::solve(
Tensor::<Backend, 1>::from_floats(a, &Default::default()),
Tensor::<Backend, 1>::from_floats(b, &Default::default()),
Tensor::<Backend, 1>::from_floats(c, &Default::default()),
)
}
}