use ::burn::tensor::{Bool, Tensor, TensorData, backend::Backend};
use super::{EPSILON, linear, safe_divisor};
pub struct Roots<B: Backend> {
pub t: Tensor<B, 3>,
pub valid: Tensor<B, 3, Bool>,
pub valid_linear: Tensor<B, 2, Bool>,
}
pub fn solve<B: Backend>(a: Tensor<B, 2>, b: Tensor<B, 2>, c: Tensor<B, 2>) -> Roots<B> {
let device = a.device();
let linear_roots = linear::solve(b.clone(), c.clone());
let quadratic_mask = a.clone().abs().greater_equal_elem(EPSILON);
let valid_linear = linear_roots
.valid
.bool_and(quadratic_mask.clone().bool_not());
let discriminant = b.clone().powi_scalar(2) - a.clone() * 4.0 * c;
let quadratic_valid = Tensor::stack::<3>(
vec![
discriminant.clone().greater_equal_elem(0.0),
discriminant.clone().greater_elem(EPSILON),
],
0,
)
.bool_and(quadratic_mask.unsqueeze::<3>());
let valid = quadratic_valid.clone().bool_or(
valid_linear
.clone()
.unsqueeze::<3>()
.bool_and(first_root(&device)),
);
let root = discriminant.clamp_min(0.0).sqrt();
let reciprocal = safe_divisor(a * 2.0).recip();
let base = -b * reciprocal.clone();
let signed_reciprocals =
reciprocal.unsqueeze::<3>() * Tensor::from_floats([[[-1.0]], [[1.0]]], &device);
let quadratic_t = base.unsqueeze::<3>() + root.unsqueeze::<3>() * signed_reciprocals;
let t = linear_roots
.t
.unsqueeze::<3>()
.mask_where(quadratic_valid, quadratic_t);
Roots {
t,
valid,
valid_linear,
}
}
fn first_root<B: Backend>(device: &B::Device) -> Tensor<B, 3, Bool> {
Tensor::from_data(TensorData::from([[[true]], [[false]]]), device)
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Bool, Tensor};
use super::Roots;
use crate::burn::tests::Backend;
#[test]
fn preserves_linear_fallback() {
let roots = solve([0.0], [-2.0], [1.0]);
assert_values(root_values(&roots, 0), [0.5]);
assert_bool(root_validity(&roots, 0), [true]);
assert_bool(root_validity(&roots, 1), [false]);
assert_bool(linear_validity(&roots), [true]);
}
#[test]
fn rejects_without_real_roots() {
let roots = solve([1.0], [1.0], [1.0]);
assert_bool(root_validity(&roots, 0), [false]);
assert_bool(root_validity(&roots, 1), [false]);
assert_bool(linear_validity(&roots), [false]);
}
#[test]
fn solves_roots() {
let roots = solve([1.0], [0.0], [-1.0]);
assert_values(root_values(&roots, 0), [-1.0]);
assert_values(root_values(&roots, 1), [1.0]);
assert_bool(root_validity(&roots, 0), [true]);
assert_bool(root_validity(&roots, 1), [true]);
assert_bool(linear_validity(&roots), [false]);
}
#[test]
fn uses_single_repeated_root() {
let roots = solve([1.0], [-2.0], [1.0]);
assert_values(root_values(&roots, 0), [1.0]);
assert_bool(root_validity(&roots, 0), [true]);
assert_bool(root_validity(&roots, 1), [false]);
assert_bool(linear_validity(&roots), [false]);
}
fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
let actual = tensor.into_data().to_vec::<bool>().unwrap();
assert_eq!(actual, expected);
}
fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
let actual = tensor.into_data().to_vec::<f32>().unwrap();
for (actual, expected) in actual.iter().zip(expected) {
assert!((actual - expected).abs() < 1e-6);
}
}
fn column<const N: usize>(values: [f32; N]) -> Tensor<Backend, 2> {
Tensor::<Backend, 1>::from_floats(values, &Default::default()).reshape([N, 1])
}
fn linear_validity(roots: &Roots<Backend>) -> Tensor<Backend, 1, Bool> {
let [segments, _] = roots.valid_linear.dims();
roots.valid_linear.clone().reshape([segments])
}
fn root_validity(roots: &Roots<Backend>, index: usize) -> Tensor<Backend, 1, Bool> {
let [_, segments, _] = roots.valid.dims();
roots
.valid
.clone()
.slice_dim(0, index..index + 1)
.reshape([segments])
}
fn root_values(roots: &Roots<Backend>, index: usize) -> Tensor<Backend, 1> {
let [_, segments, _] = roots.t.dims();
roots
.t
.clone()
.slice_dim(0, index..index + 1)
.reshape([segments])
}
fn solve<const N: usize>(a: [f32; N], b: [f32; N], c: [f32; N]) -> Roots<Backend> {
super::solve(column(a), column(b), column(c))
}
}