use ::burn::tensor::{Bool, Int, Tensor, TensorData, backend::Backend};
use crate::{
base::geometry::Command,
burn::geometry::{Basis, Contour, Coordinates, contour, left_normals, quadratic_basis},
};
use super::distribution::{Records, SegmentIndices, UnitSamples};
const EPSILON: f32 = 1e-6;
pub type Pdf<B> = Tensor<B, 1>;
pub type Indices<B> = (Tensor<B, 1, Int>, Tensor<B, 1, Int>, Tensor<B, 1, Int>);
type Points<B> = (Coordinates<B>, Coordinates<B>, Coordinates<B>);
pub struct Geometry<B: Backend> {
points: Coordinates<B>,
normals: Coordinates<B>,
pdf: Pdf<B>,
basis: Basis<B>,
indices: Indices<B>,
}
impl<B: Backend> Geometry<B> {
pub fn basis(&self) -> Basis<B> {
self.basis.clone()
}
pub fn indices(&self) -> Indices<B> {
self.indices.clone()
}
pub fn normals(&self) -> Coordinates<B> {
self.normals.clone()
}
pub fn pdf(&self) -> Pdf<B> {
self.pdf.clone()
}
pub fn points(&self) -> Coordinates<B> {
self.points.clone()
}
}
pub fn evaluate<B: Backend>(contour: Contour<B>, records: &Records<B>) -> Geometry<B> {
let [segment_count, point_count, coordinate_count] = contour.dims();
assert!(
segment_count > 0
&& point_count == 2
&& coordinate_count == 2
&& contour.commands().len() == segment_count
&& segment_count == records.segment_count(),
"commands and arguments must match sampled records for a non-empty contour"
);
let t = records.t();
let segment_indices = records.segment_indices();
let linear_mask = linear_mask(contour.commands(), segment_indices.clone(), &t.device());
let end_segment_indices =
(segment_indices.clone() + 1).remainder_scalar(segment_count as i64);
let indices = indices(segment_indices.clone(), end_segment_indices.clone());
let control_points = control_points(contour.arguments(), segment_indices, end_segment_indices);
let basis = basis(t.clone(), linear_mask.clone());
let points = evaluate_points(control_points.clone(), basis.clone());
let (normals, tangent_lengths) = evaluate_normals(control_points, t, linear_mask);
let pdf = records.segment_pmf() / tangent_lengths.clone().clamp_min(EPSILON);
Geometry {
points,
normals,
pdf,
basis,
indices,
}
}
fn basis<B: Backend>(t: UnitSamples<B>, linear_mask: Tensor<B, 1, Bool>) -> Basis<B> {
let (quadratic0, quadratic1, quadratic2) = quadratic_basis(t.clone());
let linear0 = t.ones_like() - t.clone();
let linear1 = t.zeros_like();
let linear2 = t;
(
quadratic0.mask_where(linear_mask.clone(), linear0),
quadratic1.mask_where(linear_mask.clone(), linear1),
quadratic2.mask_where(linear_mask, linear2),
)
}
fn control_points<B: Backend>(
arguments: contour::Arguments<B>,
segment_indices: SegmentIndices<B>,
end_segment_indices: SegmentIndices<B>,
) -> Points<B> {
let segment_count = arguments.dims()[0];
let flat = arguments.reshape([segment_count * 4]);
let start = segment_indices * 4;
let end = end_segment_indices * 4;
let p0 = (
flat.clone().gather(0, start.clone()),
flat.clone().gather(0, start.clone() + 1),
);
let p1 = (
flat.clone().gather(0, start.clone() + 2),
flat.clone().gather(0, start + 3),
);
let p2 = (flat.clone().gather(0, end.clone()), flat.gather(0, end + 1));
(p0, p1, p2)
}
fn evaluate_normals<B: Backend>(
control_points: Points<B>,
t: UnitSamples<B>,
linear_mask: Tensor<B, 1, Bool>,
) -> (Coordinates<B>, Tensor<B, 1>) {
let ((p0_x, p0_y), (p1_x, p1_y), (p2_x, p2_y)) = control_points;
let u = t.ones_like() - t.clone();
let t0 = u * 2.0;
let t1 = t * 2.0;
let tangent_x = ((p1_x.clone() - p0_x.clone()) * t0.clone()
+ (p2_x.clone() - p1_x) * t1.clone())
.mask_where(linear_mask.clone(), p2_x - p0_x);
let tangent_y = ((p1_y.clone() - p0_y.clone()) * t0 + (p2_y.clone() - p1_y) * t1)
.mask_where(linear_mask, p2_y - p0_y);
let tangent_lengths =
(tangent_x.clone().powi_scalar(2) + tangent_y.clone().powi_scalar(2)).sqrt();
let normals = left_normals((tangent_x, tangent_y), tangent_lengths.clone(), EPSILON);
(normals, tangent_lengths)
}
fn evaluate_points<B: Backend>(control_points: Points<B>, basis: Basis<B>) -> Coordinates<B> {
let ((p0_x, p0_y), (p1_x, p1_y), (p2_x, p2_y)) = control_points;
let (b0, b1, b2) = basis;
(
p0_x * b0.clone() + p1_x * b1.clone() + p2_x * b2.clone(),
p0_y * b0 + p1_y * b1 + p2_y * b2,
)
}
fn indices<B: Backend>(
segment_indices: SegmentIndices<B>,
end_segment_indices: SegmentIndices<B>,
) -> Indices<B> {
let start = segment_indices * 2;
let control = start.clone() + 1;
let end = end_segment_indices * 2;
(start, control, end)
}
fn linear_mask<B: Backend>(
commands: &[Command],
segment_indices: Tensor<B, 1, Int>,
device: &B::Device,
) -> Tensor<B, 1, Bool> {
let values = commands
.iter()
.map(|command| matches!(command, Command::Linear))
.collect::<Vec<_>>();
Tensor::<B, 1, Bool>::from_data(TensorData::new(values, [commands.len()]), device)
.gather(0, segment_indices)
}
#[cfg(test)]
mod tests {
use super::super::{
distribution::distribution,
tests::{assert_floats, assert_ints, one_segment, samples, triangle},
};
use super::evaluate;
use crate::burn::tests::Backend;
#[test]
fn evaluates_quadratic_boundary_geometry() {
let segments = triangle();
let distribution = distribution::<Backend>(segments.clone());
let records = distribution.sample(samples([0.125, 5.0 / 12.0, 19.0 / 24.0]));
let geometry = evaluate(segments, &records);
let (point_x, point_y) = geometry.points();
let (normal_x, normal_y) = geometry.normals();
assert_floats(point_x, [1.5, 3.0, 1.5]);
assert_floats(point_y, [0.0, 2.0, 2.0]);
assert_floats(normal_x, [0.0, -1.0, 0.8]);
assert_floats(normal_y, [1.0, 0.0, -0.6]);
assert_floats(geometry.pdf(), [1.0 / 12.0, 1.0 / 12.0, 1.0 / 12.0]);
let (b0, b1, b2) = geometry.basis();
assert_floats(b0, [0.25, 0.25, 0.25]);
assert_floats(b1, [0.5, 0.5, 0.5]);
assert_floats(b2, [0.25, 0.25, 0.25]);
let (start, control, end) = geometry.indices();
assert_ints(start, [0, 2, 4]);
assert_ints(control, [1, 3, 5]);
assert_ints(end, [2, 4, 0]);
}
#[test]
#[should_panic]
fn rejects_mismatched_boundary_records() {
let distribution = distribution::<Backend>(triangle());
let records = distribution.sample(samples([0.125]));
let _ = evaluate(one_segment(), &records);
}
}