use ::burn::tensor::{Bool, Tensor, TensorData, backend::Backend};
use super::super::roots;
pub type Coefficients<B, const D: usize> = (Tensor<B, D>, Tensor<B, D>, Tensor<B, D>);
pub fn evaluate<B: Backend>(
x_coefficients: Coefficients<B, 2>,
y_coefficients: Coefficients<B, 2>,
x: Tensor<B, 2>,
y: Tensor<B, 2>,
) -> Tensor<B, 2> {
let (x2, x1, x0) = x_coefficients;
let (y2, y1, y0) = y_coefficients;
let roots = roots::quadratic::solve(y2, y1.clone(), y0 - y);
let x_offset = (x - x0).unsqueeze::<3>();
let t = roots.t;
let curve_x =
evaluate_polynomial_without_constant((x2.unsqueeze::<3>(), x1.unsqueeze::<3>()), t.clone());
let valid = roots
.valid
.bool_and(t.clone().greater_equal_elem(0.0))
.bool_and(t.lower_equal_elem(1.0))
.bool_and(curve_x.greater(x_offset));
let contribution = valid.float() * signs(y1, roots.valid_linear);
let [root_count, segment_count, sample_count] = contribution.dims();
contribution.reshape([root_count * segment_count, sample_count])
}
fn evaluate_polynomial_without_constant<B: Backend, const D: usize>(
(a, b): (Tensor<B, D>, Tensor<B, D>),
t: Tensor<B, D>,
) -> Tensor<B, D> {
(a * t.clone() + b) * t
}
pub fn point_coefficients<B: Backend, const D: usize>(
p0: Tensor<B, D>,
p1: Tensor<B, D>,
p2: Tensor<B, D>,
) -> Coefficients<B, D> {
(
p0.clone() - p1.clone() * 2.0 + p2,
-p0.clone() * 2.0 + p1 * 2.0,
p0,
)
}
fn signs<B: Backend>(y1: Tensor<B, 2>, valid_linear: Tensor<B, 2, Bool>) -> Tensor<B, 3> {
let device = y1.device();
let root_signs = Tensor::<B, 3>::from_floats([[[-1.0]], [[1.0]]], &device);
let first_root = Tensor::from_data(TensorData::from([[[true]], [[false]]]), &device);
let linear_sign = y1.greater_elem(0.0).float() * 2.0 - 1.0;
root_signs.mask_where(
valid_linear.unsqueeze::<3>().bool_and(first_root),
linear_sign.unsqueeze::<3>(),
)
}