use ::burn::tensor::{Tensor, backend::Backend};
use crate::base::geometry::Indices;
use super::{Point, contour, point};
pub type Basis<B> = (Tensor<B, 1>, Tensor<B, 1>, Tensor<B, 1>);
pub struct Quadratic<B: Backend> {
start: Point<B>,
control: Point<B>,
end: Point<B>,
}
pub fn quadratic_basis<B: Backend>(t: Tensor<B, 1>) -> Basis<B> {
let u = t.ones_like() - t.clone();
(
u.clone().powi_scalar(2),
u * t.clone() * 2.0,
t.powi_scalar(2),
)
}
impl<B: Backend> Quadratic<B> {
pub const fn new(start: Point<B>, control: Point<B>, end: Point<B>) -> Self {
Self {
start,
control,
end,
}
}
pub fn approximate_length(&self) -> Tensor<B, 1> {
let p0 = self.start();
let p1 = self.evaluate(0.5);
let p2 = self.end();
point::distance(p0, p1.clone()) + point::distance(p1, p2)
}
pub fn control(&self) -> Point<B> {
self.control.clone()
}
pub fn end(&self) -> Point<B> {
self.end.clone()
}
pub fn evaluate(&self, t: f32) -> Point<B> {
let u = 1.0 - t;
self.start() * (u * u) + self.control() * (2.0 * u * t) + self.end() * (t * t)
}
pub fn from_arguments(arguments: contour::Arguments<B>, indices: Indices) -> Self {
Self::new(
select_point(arguments.clone(), indices.start()),
select_point(arguments.clone(), indices.control()),
select_point(arguments, indices.end()),
)
}
pub fn start(&self) -> Point<B> {
self.start.clone()
}
}
fn select_point<B: Backend>(arguments: contour::Arguments<B>, index: [usize; 2]) -> Point<B> {
arguments
.slice([index[0]..index[0] + 1, index[1]..index[1] + 1, 0..2])
.squeeze_dim::<2>(0)
.squeeze_dim::<1>(0)
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Tensor, TensorData};
use super::Quadratic;
use crate::{base::geometry::Indices, burn::tests::Backend};
type TestPoint = super::Point<Backend>;
#[test]
fn approximate_length_matches_straight_curve() {
let curve = line();
assert_scalar(curve.approximate_length(), 3.0);
}
#[test]
fn bounds_include_controls() {
let curve = Quadratic::new(point([1.0, 2.0]), point([-1.0, 4.0]), point([2.0, 1.0]));
let bounds = bounds(&curve);
assert_point(bounds.min, [-1.0, 1.0]);
assert_point(bounds.max, [2.0, 4.0]);
}
#[test]
fn evaluate_returns_endpoints() {
let curve = line();
assert_point(curve.evaluate(0.0), [0.0, 0.0]);
assert_point(curve.evaluate(1.0), [3.0, 0.0]);
}
#[test]
fn evaluate_returns_midpoint_on_straight_curve() {
let curve = line();
assert_point(curve.evaluate(0.5), [1.5, 0.0]);
}
#[test]
fn from_arguments_uses_indices() {
let arguments = Tensor::<Backend, 3>::from_floats(
[[[0.0, 0.0], [1.0, 0.0]], [[3.0, 0.0], [4.0, 0.0]]],
&Default::default(),
);
let indices = Indices::new(1, 2);
let curve = Quadratic::from_arguments(arguments, indices);
assert_point(curve.start(), [3.0, 0.0]);
assert_point(curve.control(), [4.0, 0.0]);
assert_point(curve.end(), [0.0, 0.0]);
}
#[test]
fn normal_is_left_normal() {
let curve = line();
assert_point(normal(&curve, 0.5), [0.0, 1.0]);
}
#[test]
fn tangent_matches_end_derivatives() {
let curve = line();
assert_point(tangent(&curve, 0.0), [3.0, 0.0]);
assert_point(tangent(&curve, 1.0), [3.0, 0.0]);
}
struct Bounds {
min: TestPoint,
max: TestPoint,
}
fn assert_close(actual: f32, expected: f32) {
assert!((actual - expected).abs() < 1e-6);
}
fn assert_point(point: TestPoint, expected: [f32; 2]) {
let actual = point.into_data().to_vec::<f32>().unwrap();
assert_close(actual[0], expected[0]);
assert_close(actual[1], expected[1]);
}
fn assert_scalar(tensor: Tensor<Backend, 1>, expected: f32) {
let actual = tensor.into_scalar();
assert_close(actual, expected);
}
fn bounds(curve: &Quadratic<Backend>) -> Bounds {
let points = Tensor::stack::<2>(vec![curve.start(), curve.control(), curve.end()], 0);
Bounds {
min: points.clone().min_dim(0).squeeze_dim::<1>(0),
max: points.max_dim(0).squeeze_dim::<1>(0),
}
}
fn coordinate(point: TestPoint, index: usize) -> Tensor<Backend, 1> {
point.slice_dim(0, index..index + 1)
}
fn left_normal(point: TestPoint) -> TestPoint {
let x = coordinate(point.clone(), 0);
let y = coordinate(point, 1);
normalize(Tensor::cat(vec![-y, x], 0))
}
fn line() -> Quadratic<Backend> {
Quadratic::new(point([0.0, 0.0]), point([1.5, 0.0]), point([3.0, 0.0]))
}
fn normal(curve: &Quadratic<Backend>, t: f32) -> TestPoint {
left_normal(tangent(curve, t))
}
fn point(value: [f32; 2]) -> TestPoint {
Tensor::<Backend, 1>::from_data(TensorData::from(value), &Default::default())
}
fn normalize(point: TestPoint) -> TestPoint {
point.clone() / point.powi_scalar(2).sum().sqrt()
}
fn tangent(curve: &Quadratic<Backend>, t: f32) -> TestPoint {
let u = 1.0 - t;
(curve.control() - curve.start()) * (2.0 * u) + (curve.end() - curve.control()) * (2.0 * t)
}
}