use ::burn::tensor::{Bool, Tensor, backend::Backend};
use super::{EPSILON, safe_divisor};
pub struct Roots<B: Backend, const D: usize> {
pub t: Tensor<B, D>,
pub valid: Tensor<B, D, Bool>,
}
pub fn solve<B: Backend, const D: usize>(b: Tensor<B, D>, c: Tensor<B, D>) -> Roots<B, D> {
let valid = b.clone().abs().greater_equal_elem(EPSILON);
let t = c * -safe_divisor(b).recip();
Roots { t, valid }
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Bool, Tensor};
use super::Roots;
use crate::burn::tests::Backend;
#[test]
fn rejects_degenerate() {
let roots = solve([0.0], [1.0]);
assert_bool(roots.valid, [false]);
}
#[test]
fn solves_root() {
let roots = solve([-2.0], [1.0]);
assert_values(roots.t, [0.5]);
assert_bool(roots.valid, [true]);
}
fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
let actual = tensor.into_data().to_vec::<bool>().unwrap();
assert_eq!(actual, expected);
}
fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
let actual = tensor.into_data().to_vec::<f32>().unwrap();
for (actual, expected) in actual.iter().zip(expected) {
assert!((actual - expected).abs() < 1e-6);
}
}
fn solve<const N: usize>(b: [f32; N], c: [f32; N]) -> Roots<Backend, 1> {
super::solve(
Tensor::<Backend, 1>::from_floats(b, &Default::default()),
Tensor::<Backend, 1>::from_floats(c, &Default::default()),
)
}
}