use ::burn::tensor::{Bool, Tensor, backend::Backend};
use super::{EPSILON, Values, safe_divisor};
pub struct Roots<B: Backend> {
pub t: Values<B>,
pub valid: Tensor<B, 1, Bool>,
}
pub fn solve<B: Backend>(b: Values<B>, c: Values<B>) -> Roots<B> {
let valid = b.clone().abs().greater_equal_elem(EPSILON);
let t = -c / safe_divisor(b);
Roots { t, valid }
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Bool, Tensor};
use super::Roots;
use crate::burn::tests::Backend;
#[test]
fn rejects_degenerate() {
let roots = solve([0.0], [1.0]);
assert_bool(roots.valid, [false]);
}
#[test]
fn solves_root() {
let roots = solve([-2.0], [1.0]);
assert_values(roots.t, [0.5]);
assert_bool(roots.valid, [true]);
}
fn assert_bool<const N: usize>(tensor: Tensor<Backend, 1, Bool>, expected: [bool; N]) {
let actual = tensor.into_data().to_vec::<bool>().unwrap();
assert_eq!(actual, expected);
}
fn assert_values<const N: usize>(tensor: Tensor<Backend, 1>, expected: [f32; N]) {
let actual = tensor.into_data().to_vec::<f32>().unwrap();
for (actual, expected) in actual.iter().zip(expected) {
assert!((actual - expected).abs() < 1e-6);
}
}
fn solve<const N: usize>(b: [f32; N], c: [f32; N]) -> Roots<Backend> {
super::solve(
Tensor::<Backend, 1>::from_floats(b, &Default::default()),
Tensor::<Backend, 1>::from_floats(c, &Default::default()),
)
}
}