mod linear;
mod quadratic;
use ::burn::tensor::{Bool, Int, Tensor, TensorData, backend::Backend};
use crate::base::geometry::Command;
use super::{Contour, contour};
pub type Mask<B> = Tensor<B, 1, Bool>;
type Values<B> = Tensor<B, 1>;
type Column<B> = Tensor<B, 1>;
type Samples<B> = Tensor<B, 2>;
pub fn contains<B, I>(contours: I, x: Tensor<B, 1>, y: Tensor<B, 1>) -> Mask<B>
where
B: Backend,
I: IntoIterator<Item = Contour<B>>,
{
evaluate(contours, x, y).not_equal_elem(0.0)
}
fn evaluate<B, I>(contours: I, x: Values<B>, y: Values<B>) -> Values<B>
where
B: Backend,
I: IntoIterator<Item = Contour<B>>,
{
assert!(
x.dims() == y.dims(),
"sample coordinate columns must have matching shapes"
);
let sample_count = x.dims()[0];
let mut values = x.zeros_like();
let x = x.reshape([1, sample_count]);
let y = y.reshape([1, sample_count]);
for contour in contours {
values = values + evaluate_contour(contour, x.clone(), y.clone());
}
values
}
fn evaluate_contour<B: Backend>(contour: Contour<B>, x: Samples<B>, y: Samples<B>) -> Values<B> {
let [segment_count, point_count, coordinate_count] = contour.dims();
assert!(
segment_count > 0
&& point_count == 2
&& coordinate_count == 2
&& contour.commands().len() == segment_count,
"contours must have shape [segments, 2, 2]"
);
let sample_count = x.dims()[1];
let arguments = contour.arguments();
let device = arguments.device();
let start_x = coordinate_column(arguments.clone(), 0, 0);
let start_y = coordinate_column(arguments.clone(), 0, 1);
let control_x = coordinate_column(arguments.clone(), 1, 0);
let control_y = coordinate_column(arguments, 1, 1);
let end_x = roll(start_x.clone());
let end_y = roll(start_y.clone());
let mut linear = Vec::new();
let mut quadratic = Vec::new();
for (segment, command) in contour.commands().iter().enumerate() {
match command {
Command::Linear => linear.push(segment as i64),
Command::Quadratic => quadratic.push(segment as i64),
}
}
let mut values: Option<Values<B>> = None;
if !linear.is_empty() {
let index = indices::<B>(&linear, &device);
let contribution = linear_winding(
start_x.clone().select(0, index.clone()),
start_y.clone().select(0, index.clone()),
end_x.clone().select(0, index.clone()),
end_y.clone().select(0, index),
x.clone(),
y.clone(),
sample_count,
);
values = Some(accumulate(values, contribution));
}
if !quadratic.is_empty() {
let index = indices::<B>(&quadratic, &device);
let contribution = quadratic_winding(
start_x.select(0, index.clone()),
start_y.select(0, index.clone()),
control_x.select(0, index.clone()),
control_y.select(0, index.clone()),
end_x.select(0, index.clone()),
end_y.select(0, index),
x,
y,
sample_count,
);
values = Some(accumulate(values, contribution));
}
values.unwrap_or_else(|| Tensor::zeros([sample_count], &device))
}
fn coordinate_column<B: Backend>(
arguments: contour::Arguments<B>,
point: usize,
coordinate: usize,
) -> Column<B> {
let segments = arguments.dims()[0];
arguments
.slice([0..segments, point..point + 1, coordinate..coordinate + 1])
.reshape([segments])
}
fn roll<B: Backend>(column: Column<B>) -> Column<B> {
let segments = column.dims()[0];
Tensor::cat(
vec![
column.clone().slice_dim(0, 1..segments),
column.slice_dim(0, 0..1),
],
0,
)
}
fn indices<B: Backend>(positions: &[i64], device: &B::Device) -> Tensor<B, 1, Int> {
Tensor::from_data(TensorData::new(positions.to_vec(), [positions.len()]), device)
}
fn segment_column<B: Backend>(column: Column<B>) -> Tensor<B, 2> {
let segments = column.dims()[0];
column.reshape([segments, 1])
}
fn accumulate<B: Backend>(values: Option<Values<B>>, contribution: Values<B>) -> Values<B> {
match values {
Some(values) => values + contribution,
None => contribution,
}
}
fn linear_winding<B: Backend>(
start_x: Column<B>,
start_y: Column<B>,
end_x: Column<B>,
end_y: Column<B>,
x: Samples<B>,
y: Samples<B>,
sample_count: usize,
) -> Values<B> {
let x_coefficients = linear::point_coefficients(segment_column(start_x), segment_column(end_x));
let y_coefficients = linear::point_coefficients(segment_column(start_y), segment_column(end_y));
let contribution = linear::evaluate(x_coefficients, y_coefficients, x, y);
contribution.sum_dim(0).reshape([sample_count])
}
#[allow(clippy::too_many_arguments)]
fn quadratic_winding<B: Backend>(
start_x: Column<B>,
start_y: Column<B>,
control_x: Column<B>,
control_y: Column<B>,
end_x: Column<B>,
end_y: Column<B>,
x: Samples<B>,
y: Samples<B>,
sample_count: usize,
) -> Values<B> {
let x_coefficients = quadratic::point_coefficients(
segment_column(start_x),
segment_column(control_x),
segment_column(end_x),
);
let y_coefficients = quadratic::point_coefficients(
segment_column(start_y),
segment_column(control_y),
segment_column(end_y),
);
let contribution = quadratic::evaluate(x_coefficients, y_coefficients, x, y);
contribution.sum_dim(0).reshape([sample_count])
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Tensor, TensorData};
use crate::{base::geometry::Command, burn::tests::Backend};
#[test]
fn classifies_samples_inside_curved_quadratic_contour() {
let inside = contour_contains(
[curved_square()],
sample_coordinates([[1.25, 0.5], [1.8, 0.5]]),
);
assert_bool(inside, [true, false]);
}
#[test]
fn classifies_samples_inside_multiple_contours() {
let inside = contour_contains(
[
square([0.0, 0.0], [4.0, 4.0], Direction::CounterClockwise),
square([1.0, 1.0], [3.0, 3.0], Direction::Clockwise),
],
sample_coordinates([[0.5, 0.5], [2.0, 2.0], [5.0, 2.0]]),
);
assert_bool(inside, [true, false, false]);
}
#[test]
fn classifies_samples_inside_one_contour() {
let inside = contour_contains(
[square([0.0, 0.0], [1.0, 1.0], Direction::CounterClockwise)],
sample_coordinates([[0.5, 0.5], [1.5, 0.5], [-0.5, 0.5]]),
);
assert_bool(inside, [true, false, false]);
}
#[test]
#[should_panic]
fn rejects_invalid_contours() {
let arguments = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.0; 6], [1, 3, 2]),
&Default::default(),
);
let _ = evaluate_contour_values(
super::Contour::new(vec![Command::Quadratic], arguments),
sample_coordinates([[0.5, 0.5]]),
);
}
#[test]
fn winding_sign_depends_on_direction() {
let ccw = evaluate_contour_values(
square([0.0, 0.0], [1.0, 1.0], Direction::CounterClockwise),
sample_coordinates([[0.5, 0.5]]),
);
let cw = evaluate_contour_values(
square([0.0, 0.0], [1.0, 1.0], Direction::Clockwise),
sample_coordinates([[0.5, 0.5]]),
);
assert_close(scalar(ccw), 1.0);
assert_close(scalar(cw), -1.0);
}
enum Direction {
Clockwise,
CounterClockwise,
}
fn assert_bool<const N: usize>(
tensor: Tensor<Backend, 1, ::burn::tensor::Bool>,
expected: [bool; N],
) {
let actual = tensor.into_data().to_vec::<bool>().unwrap();
assert_eq!(actual, expected);
}
fn assert_close(actual: f32, expected: f32) {
assert!((actual - expected).abs() < 1e-6);
}
fn contour_contains(
contours: impl IntoIterator<Item = super::Contour<Backend>>,
(x, y): (Tensor<Backend, 1>, Tensor<Backend, 1>),
) -> Tensor<Backend, 1, ::burn::tensor::Bool> {
super::contains(contours, x, y)
}
fn curved_square() -> super::Contour<Backend> {
super::Contour::new(
vec![Command::Quadratic; 4],
Tensor::<Backend, 3>::from_data(
TensorData::from([
[[0.0, 0.0], [0.5, 0.0]],
[[1.0, 0.0], [2.0, 0.5]],
[[1.0, 1.0], [0.5, 1.0]],
[[0.0, 1.0], [0.0, 0.5]],
]),
&Default::default(),
),
)
}
fn evaluate_contour_values(
contour: super::Contour<Backend>,
(x, y): (Tensor<Backend, 1>, Tensor<Backend, 1>),
) -> Tensor<Backend, 1> {
super::evaluate([contour], x, y)
}
fn interpolate(start: [f32; 2], end: [f32; 2], t: f32) -> [f32; 2] {
[
start[0] + (end[0] - start[0]) * t,
start[1] + (end[1] - start[1]) * t,
]
}
fn sample_coordinates<const N: usize>(
values: [[f32; 2]; N],
) -> (Tensor<Backend, 1>, Tensor<Backend, 1>) {
let x = values.iter().map(|value| value[0]).collect::<Vec<_>>();
let y = values.iter().map(|value| value[1]).collect::<Vec<_>>();
(
Tensor::<Backend, 1>::from_data(TensorData::new(x, [N]), &Default::default()),
Tensor::<Backend, 1>::from_data(TensorData::new(y, [N]), &Default::default()),
)
}
fn scalar(tensor: Tensor<Backend, 1>) -> f32 {
tensor.into_scalar()
}
fn square(min: [f32; 2], max: [f32; 2], direction: Direction) -> super::Contour<Backend> {
let segments = match direction {
Direction::Clockwise => square_segments(
[min[0], min[1]],
[min[0], max[1]],
[max[0], max[1]],
[max[0], min[1]],
),
Direction::CounterClockwise => square_segments(
[min[0], min[1]],
[max[0], min[1]],
[max[0], max[1]],
[min[0], max[1]],
),
};
super::Contour::new(
vec![Command::Linear; 4],
Tensor::<Backend, 3>::from_data(TensorData::from(segments), &Default::default()),
)
}
fn square_segments(a: [f32; 2], b: [f32; 2], c: [f32; 2], d: [f32; 2]) -> [[[f32; 2]; 2]; 4] {
[
[a, interpolate(a, b, 0.5)],
[b, interpolate(b, c, 0.5)],
[c, interpolate(c, d, 0.5)],
[d, interpolate(d, a, 0.5)],
]
}
}