use ::burn::tensor::{Bool, Tensor, backend::Backend};
use crate::base::geometry::Indices;
use super::super::{contour, roots};
use super::{Values, segment_coordinate};
type Coefficients<B> = (Values<B>, Values<B>);
pub fn coefficients<B: Backend>(
arguments: contour::Arguments<B>,
indices: Indices,
) -> (Coefficients<B>, Coefficients<B>) {
let start = indices.start();
let end = indices.end();
let x = point_coefficients(
segment_coordinate(arguments.clone(), start, 0),
segment_coordinate(arguments.clone(), end, 0),
);
let y = point_coefficients(
segment_coordinate(arguments.clone(), start, 1),
segment_coordinate(arguments, end, 1),
);
(x, y)
}
pub fn evaluate<B: Backend>(
x_coefficients: Coefficients<B>,
y_coefficients: Coefficients<B>,
x: Values<B>,
y: Values<B>,
) -> Values<B> {
let (x1, x0) = x_coefficients;
let (y1, y0) = y_coefficients;
let roots = roots::linear::solve(y1.clone(), y0 - y);
let x_offset = x - x0;
let sign = y1.greater_elem(0.0).float() * 2.0 - 1.0;
root_contribution(roots.t, roots.valid, x_offset, x1, sign)
}
fn point_coefficients<B: Backend>(p0: Values<B>, p1: Values<B>) -> Coefficients<B> {
(p1 - p0.clone(), p0)
}
fn root_contribution<B: Backend>(
t: Values<B>,
valid: Tensor<B, 1, Bool>,
x_offset: Values<B>,
x_coefficient: Values<B>,
sign: Values<B>,
) -> Values<B> {
let curve_x = x_coefficient * t.clone();
let valid = valid
.bool_and(t.clone().greater_equal_elem(0.0))
.bool_and(t.lower_equal_elem(1.0))
.bool_and(curve_x.greater(x_offset));
valid.float() * sign
}