use ::burn::tensor::{Int, Tensor, backend::Backend};
use crate::{
base::geometry::{Command, Indices},
burn::geometry::{Contour, Quadratic, contour},
};
const EPSILON: f32 = 1e-6;
pub type Cdf<B> = Tensor<B, 1>;
pub type Lengths<B> = Tensor<B, 1>;
pub type Pmf<B> = Tensor<B, 1>;
pub type SegmentIndices<B> = Tensor<B, 1, Int>;
pub type UnitSamples<B> = Tensor<B, 1>;
pub struct Distribution<B: Backend> {
lengths: Lengths<B>,
pmf: Pmf<B>,
cdf: Cdf<B>,
}
pub struct Records<B: Backend> {
segment_count: usize,
segment_indices: SegmentIndices<B>,
t: UnitSamples<B>,
segment_pmf: Pmf<B>,
}
impl<B: Backend> Distribution<B> {
pub fn sample(&self, samples: UnitSamples<B>) -> Records<B> {
let segment_count = self.segment_count();
let samples = samples.clamp(0.0, 1.0);
let cdf = self.cdf.clone().unsqueeze_dim::<2>(0);
let sample_grid = samples.clone().unsqueeze_dim::<2>(1);
let segment_indices = cdf
.greater_equal(sample_grid)
.int()
.argmax(1)
.squeeze_dim::<1>(1);
let previous_cdf = previous_cdf(self.cdf.clone());
let lower = previous_cdf.gather(0, segment_indices.clone());
let pmf = self.pmf.clone().gather(0, segment_indices.clone());
let t = (samples - lower) / pmf.clone().clamp_min(EPSILON);
Records {
segment_count,
segment_indices,
t,
segment_pmf: pmf,
}
}
pub fn segment_count(&self) -> usize {
self.lengths.dims()[0]
}
}
impl<B: Backend> Records<B> {
pub const fn segment_count(&self) -> usize {
self.segment_count
}
pub fn segment_indices(&self) -> SegmentIndices<B> {
self.segment_indices.clone()
}
pub fn segment_pmf(&self) -> Pmf<B> {
self.segment_pmf.clone()
}
pub fn t(&self) -> UnitSamples<B> {
self.t.clone()
}
}
pub fn distribution<B: Backend>(contour: Contour<B>) -> Distribution<B> {
let [segment_count, point_count, coordinate_count] = contour.dims();
assert!(
segment_count > 0
&& point_count == 2
&& coordinate_count == 2
&& contour.commands().len() == segment_count,
"contours must have shape [segments, 2, 2]"
);
let mut lengths = Vec::with_capacity(segment_count);
let arguments = contour.arguments();
for segment in 0..segment_count {
let indices = Indices::new(segment, segment_count);
lengths.push(match contour.command(segment) {
Command::Linear => linear_length(arguments.clone(), indices),
Command::Quadratic => {
Quadratic::from_arguments(arguments.clone(), indices).approximate_length()
}
});
}
let lengths = Tensor::cat(lengths, 0);
let total = lengths.clone().sum().clamp_min(EPSILON);
let pmf = lengths.clone() / total;
let cdf = pmf.clone().cumsum(0);
Distribution { lengths, pmf, cdf }
}
pub fn unit_samples<B: Backend>(count: usize, device: &B::Device) -> UnitSamples<B> {
(Tensor::<B, 1, Int>::arange(0..count as i64, device).float() + 0.5) / count as f32
}
fn linear_length<B: Backend>(arguments: contour::Arguments<B>, indices: Indices) -> Tensor<B, 1> {
let start = point(arguments.clone(), indices.start());
let end = point(arguments, indices.end());
(end - start).powi_scalar(2).sum().sqrt()
}
fn point<B: Backend>(arguments: contour::Arguments<B>, index: [usize; 2]) -> Tensor<B, 1> {
arguments
.slice([index[0]..index[0] + 1, index[1]..index[1] + 1, 0..2])
.squeeze_dim::<2>(0)
.squeeze_dim::<1>(0)
}
fn previous_cdf<B: Backend>(cdf: Cdf<B>) -> Cdf<B> {
let segment_count = cdf.dims()[0];
let zeros = Tensor::<B, 1>::zeros([1], &cdf.device());
if segment_count == 1 {
zeros
} else {
Tensor::cat(vec![zeros, cdf.slice_dim(0, 0..segment_count - 1)], 0)
}
}
#[cfg(test)]
mod tests {
use ::burn::tensor::{Tensor, TensorData, backend};
use super::super::tests::{
assert_floats, assert_ints, samples, segments, small_triangle, triangle,
};
use super::{Cdf, Distribution, EPSILON, Lengths, Pmf, distribution};
use crate::{
base::geometry::Command,
burn::{geometry::Contour, tests::Backend},
};
#[test]
fn builds_contour_length_distribution() {
let contours = contours::<Backend, _>([triangle(), small_triangle()]);
assert_eq!(contours.distributions.len(), 2);
assert!(!contours.distributions.is_empty());
assert_floats(contours.lengths.clone(), [12.0, 6.0]);
assert_floats(contours.pmf.clone(), [2.0 / 3.0, 1.0 / 3.0]);
assert_floats(contours.cdf.clone(), [2.0 / 3.0, 1.0]);
assert_floats(contours.distributions[1].lengths.clone(), [1.5, 2.0, 2.5]);
}
#[test]
fn builds_quadratic_length_distribution() {
let distribution = distribution::<Backend>(triangle());
assert_floats(distribution.lengths.clone(), [3.0, 4.0, 5.0]);
assert_floats(distribution.pmf.clone(), [0.25, 1.0 / 3.0, 5.0 / 12.0]);
assert_floats(distribution.cdf.clone(), [0.25, 7.0 / 12.0, 1.0]);
}
#[test]
#[should_panic]
fn rejects_empty_contour_distribution() {
let _ = distribution::<Backend>(segments([]));
}
#[test]
#[should_panic]
fn rejects_invalid_segment_shape_distribution() {
let invalid = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.0; 6], [1, 3, 2]),
&Default::default(),
);
let _ = distribution::<Backend>(Contour::new(vec![Command::Quadratic], invalid));
}
#[test]
fn samples_boundary_records() {
let distribution = distribution::<Backend>(triangle());
let records = distribution.sample(samples([0.125, 5.0 / 12.0, 19.0 / 24.0]));
assert_eq!(records.segment_count(), 3);
assert_ints(records.segment_indices(), [0, 1, 2]);
assert_floats(records.t(), [0.5, 0.5, 0.5]);
assert_floats(records.segment_pmf(), [0.25, 1.0 / 3.0, 5.0 / 12.0]);
}
struct Contours<B: backend::Backend> {
distributions: Vec<Distribution<B>>,
lengths: Lengths<B>,
pmf: Pmf<B>,
cdf: Cdf<B>,
}
fn contours<B, I>(contours: I) -> Contours<B>
where
B: backend::Backend,
I: IntoIterator<Item = Contour<B>>,
{
let distributions = contours.into_iter().map(distribution).collect::<Vec<_>>();
assert!(!distributions.is_empty(), "contours must not be empty");
let lengths = Tensor::cat(
distributions
.iter()
.map(|distribution| distribution.lengths.clone().sum())
.collect(),
0,
);
let total = lengths.clone().sum().clamp_min(EPSILON);
let pmf = lengths.clone() / total;
let cdf = pmf.clone().cumsum(0);
Contours {
distributions,
lengths,
pmf,
cdf,
}
}
}