use ::burn::tensor::{Bool, Int, Tensor, TensorData, backend::Backend};
use crate::{
base::geometry::Command,
burn::{
geometry::{
Basis, Contour, Coordinates, contour, left_normals, quadratic_basis, row_lengths,
},
tensor::{column, expand},
},
};
use super::distribution::{Records, UnitSamples};
const EPSILON: f32 = 1e-6;
pub type Pdf<B> = Tensor<B, 1>;
pub type Indices<B> = (Tensor<B, 2, Int>, Tensor<B, 2, Int>, Tensor<B, 2, Int>);
type Points<B> = (Tensor<B, 2>, Tensor<B, 2>, Tensor<B, 2>);
pub struct Geometry<B: Backend> {
points: Coordinates<B>,
normals: Coordinates<B>,
pdf: Pdf<B>,
basis: Basis<B>,
indices: Indices<B>,
}
impl<B: Backend> Geometry<B> {
pub fn basis(&self) -> Basis<B> {
self.basis.clone()
}
pub fn indices(&self) -> Indices<B> {
self.indices.clone()
}
pub fn normals(&self) -> Coordinates<B> {
self.normals.clone()
}
pub fn pdf(&self) -> Pdf<B> {
self.pdf.clone()
}
pub fn points(&self) -> Coordinates<B> {
self.points.clone()
}
}
pub fn evaluate<B: Backend>(contour: Contour<B>, records: &Records<B>) -> Geometry<B> {
let [segment_count, point_count, coordinate_count] = contour.dims();
assert!(
segment_count > 0
&& point_count == 2
&& coordinate_count == 2
&& contour.commands().len() == segment_count
&& segment_count == records.segment_count(),
"commands and arguments must match sampled records for a non-empty contour"
);
let t = records.t();
let segment_indices = records.segment_indices();
let linear_mask = linear_mask(contour.commands(), segment_indices.clone(), &t.device());
let indices = indices(segment_indices, segment_count);
let control_points = control_points(contour.arguments(), indices.clone());
let basis = basis(t.clone(), linear_mask.clone());
let points = evaluate_points(control_points.clone(), basis.clone());
let (normals, tangent_lengths) = evaluate_normals(control_points, t, linear_mask);
let pdf = records.segment_pmf() / tangent_lengths.clone().clamp_min(EPSILON);
Geometry {
points,
normals,
pdf,
basis,
indices,
}
}
fn basis<B: Backend>(t: UnitSamples<B>, linear_mask: Tensor<B, 1, Bool>) -> Basis<B> {
let (quadratic0, quadratic1, quadratic2) = quadratic_basis(t.clone());
let linear0 = t.ones_like() - t.clone();
let linear1 = t.zeros_like();
let linear2 = t;
(
quadratic0.mask_where(linear_mask.clone(), linear0),
quadratic1.mask_where(linear_mask.clone(), linear1),
quadratic2.mask_where(linear_mask, linear2),
)
}
fn control_points<B: Backend>(arguments: contour::Arguments<B>, indices: Indices<B>) -> Points<B> {
let (start, control, end) = indices;
let p0 = arguments.clone().gather_nd::<2, 2>(start);
let p1 = arguments.clone().gather_nd::<2, 2>(control);
let p2 = arguments.gather_nd::<2, 2>(end);
(p0, p1, p2)
}
fn evaluate_normals<B: Backend>(
control_points: Points<B>,
t: UnitSamples<B>,
linear_mask: Tensor<B, 1, Bool>,
) -> (Coordinates<B>, Tensor<B, 1>) {
let (p0, p1, p2) = control_points;
let u = t.ones_like() - t.clone();
let t0 = expand(u * 2.0);
let t1 = expand(t * 2.0);
let quadratic_tangents = (p1.clone() - p0.clone()) * t0 + (p2.clone() - p1) * t1;
let linear_tangents = p2 - p0;
let linear_mask = linear_mask.unsqueeze_dim::<2>(1).repeat_dim(1, 2);
let tangents = quadratic_tangents.mask_where(linear_mask, linear_tangents);
let tangent_lengths = row_lengths(tangents.clone());
let normals = left_normals(tangents, tangent_lengths.clone(), EPSILON);
(normals, tangent_lengths)
}
fn evaluate_points<B: Backend>(control_points: Points<B>, basis: Basis<B>) -> Coordinates<B> {
let (p0, p1, p2) = control_points;
let (b0, b1, b2) = basis;
let b0 = expand(b0);
let b1 = expand(b1);
let b2 = expand(b2);
let boundary_points = p0 * b0 + p1 * b1 + p2 * b2;
(
column(boundary_points.clone(), 0),
column(boundary_points, 1),
)
}
fn indices<B: Backend>(segment_indices: Tensor<B, 1, Int>, segment_count: usize) -> Indices<B> {
let zeros = segment_indices.zeros_like();
let ones = zeros.clone() + 1;
let end_segment_indices = (segment_indices.clone() + 1).remainder_scalar(segment_count as i64);
let start = Tensor::<B, 1, Int>::stack::<2>(vec![segment_indices.clone(), zeros.clone()], 1);
let control = Tensor::<B, 1, Int>::stack::<2>(vec![segment_indices, ones], 1);
let end = Tensor::<B, 1, Int>::stack::<2>(vec![end_segment_indices, zeros], 1);
(start, control, end)
}
fn linear_mask<B: Backend>(
commands: &[Command],
segment_indices: Tensor<B, 1, Int>,
device: &B::Device,
) -> Tensor<B, 1, Bool> {
let values = commands
.iter()
.map(|command| matches!(command, Command::Linear))
.collect::<Vec<_>>();
Tensor::<B, 1, Bool>::from_data(TensorData::new(values, [commands.len()]), device)
.select(0, segment_indices)
}
#[cfg(test)]
mod tests {
use super::super::{
distribution::distribution,
tests::{assert_floats, assert_int_matrix, one_segment, samples, triangle},
};
use super::evaluate;
use crate::burn::tests::Backend;
#[test]
fn evaluates_quadratic_boundary_geometry() {
let segments = triangle();
let distribution = distribution::<Backend>(segments.clone());
let records = distribution.sample(samples([0.125, 5.0 / 12.0, 19.0 / 24.0]));
let geometry = evaluate(segments, &records);
let (point_x, point_y) = geometry.points();
let (normal_x, normal_y) = geometry.normals();
assert_floats(point_x, [1.5, 3.0, 1.5]);
assert_floats(point_y, [0.0, 2.0, 2.0]);
assert_floats(normal_x, [0.0, -1.0, 0.8]);
assert_floats(normal_y, [1.0, 0.0, -0.6]);
assert_floats(geometry.pdf(), [1.0 / 12.0, 1.0 / 12.0, 1.0 / 12.0]);
let (b0, b1, b2) = geometry.basis();
assert_floats(b0, [0.25, 0.25, 0.25]);
assert_floats(b1, [0.5, 0.5, 0.5]);
assert_floats(b2, [0.25, 0.25, 0.25]);
let (start, control, end) = geometry.indices();
assert_int_matrix(start, [[0, 0], [1, 0], [2, 0]]);
assert_int_matrix(control, [[0, 1], [1, 1], [2, 1]]);
assert_int_matrix(end, [[1, 0], [2, 0], [0, 0]]);
}
#[test]
#[should_panic]
fn rejects_mismatched_boundary_records() {
let distribution = distribution::<Backend>(triangle());
let records = distribution.sample(samples([0.125]));
let _ = evaluate(one_segment(), &records);
}
}