use ::burn::tensor::{IndexingUpdateOp, Tensor, backend::Backend};
use crate::burn::geometry::{Contour, Coordinates, contains};
use super::{geometry::Geometry, signal};
const EPSILON: f32 = 1e-4;
const MIN_PDF: f32 = 1e-6;
pub type Jumps<B> = Tensor<B, 1>;
pub type SegmentGradients<B> = Tensor<B, 3>;
pub fn jumps<B, I>(contours: I, points: Coordinates<B>, normals: Coordinates<B>) -> Jumps<B>
where
B: Backend,
I: IntoIterator<Item = Contour<B>>,
{
let contours = contours.into_iter().collect::<Vec<_>>();
let (point_x, point_y) = points;
let (normal_x, normal_y) = normals;
let sample_shape = point_x.dims();
assert!(
!contours.is_empty()
&& sample_shape[0] > 0
&& point_y.dims() == sample_shape
&& normal_x.dims() == sample_shape
&& normal_y.dims() == sample_shape,
"boundary point and normal columns must have matching non-empty [samples] shapes"
);
let offset_x = normal_x * EPSILON;
let offset_y = normal_y * EPSILON;
let sample_count = sample_shape[0];
let x = Tensor::cat(
vec![point_x.clone() - offset_x.clone(), point_x + offset_x],
0,
);
let y = Tensor::cat(
vec![point_y.clone() - offset_y.clone(), point_y + offset_y],
0,
);
let coverage = contains(contours, x, y).float();
let minus_coverage = coverage.clone().slice_dim(0, 0..sample_count);
let plus_coverage = coverage.slice_dim(0, sample_count..2 * sample_count);
minus_coverage - plus_coverage
}
pub fn segment_gradients<B: Backend>(
segment_count: usize,
geometry: &Geometry<B>,
signal: signal::Values<B>,
jumps: Jumps<B>,
) -> SegmentGradients<B> {
let sample_count = signal.dims()[0];
let (b0, b1, b2) = geometry.basis();
let (normal_x, normal_y) = geometry.normals();
let pdf = geometry.pdf();
let (start_indices, control_indices, end_indices) = geometry.indices();
assert!(
segment_count > 0
&& sample_count > 0
&& jumps.dims()[0] == sample_count
&& b0.dims()[0] == sample_count
&& b1.dims()[0] == sample_count
&& b2.dims()[0] == sample_count
&& normal_x.dims()[0] == sample_count
&& normal_y.dims()[0] == sample_count
&& pdf.dims()[0] == sample_count
&& start_indices.dims() == [sample_count]
&& control_indices.dims() == [sample_count]
&& end_indices.dims() == [sample_count],
"segment gradients require positive segment and sample counts with matching boundary sample shapes"
);
let contribution = signal * jumps / pdf.clamp_min(MIN_PDF);
let w0 = b0 * contribution.clone();
let w1 = b1 * contribution.clone();
let w2 = b2 * contribution;
let update_count = sample_count * 3;
let indices = Tensor::cat(vec![start_indices, control_indices, end_indices], 0) * 2;
let x_indices = indices.clone().reshape([update_count, 1]);
let y_indices = (indices + 1).reshape([update_count, 1]);
let x_updates = Tensor::cat(
vec![
w0.clone() * normal_x.clone(),
w1.clone() * normal_x.clone(),
w2.clone() * normal_x,
],
0,
);
let y_updates = Tensor::cat(
vec![w0 * normal_y.clone(), w1 * normal_y.clone(), w2 * normal_y],
0,
);
Tensor::<B, 1>::zeros([segment_count * 4], &x_updates.device())
.scatter_nd::<2, 1>(x_indices, x_updates, IndexingUpdateOp::Add)
.scatter_nd::<2, 1>(y_indices, y_updates, IndexingUpdateOp::Add)
.reshape([segment_count, 2, 2])
}
#[cfg(test)]
mod tests {
use super::super::{
distribution::distribution,
geometry::evaluate,
tests::{assert_floats, assert_matrix, samples, square, triangle},
};
use super::{jumps, segment_gradients};
use crate::burn::tests::Backend;
#[test]
fn accumulates_segment_gradients() {
let segments = triangle();
let segment_count = segments.dims()[0];
let distribution = distribution::<Backend>(segments.clone());
let records = distribution.sample(samples([0.125, 5.0 / 12.0, 19.0 / 24.0]));
let geometry = evaluate(segments, &records);
let gradients = segment_gradients(
segment_count,
&geometry,
samples([1.0, 2.0, 3.0]),
samples([1.0, 1.0, -1.0]),
);
assert_matrix(
gradients.reshape([6, 2]),
[
[-7.2, 8.4],
[0.0, 6.0],
[-6.0, 3.0],
[-12.0, 0.0],
[-13.2, 5.4],
[-14.4, 10.8],
],
);
}
#[test]
fn computes_grayscale_boundary_jumps() {
let values = jumps(
[square([0.0, 0.0], [1.0, 1.0])],
(samples([0.5, 1.0, 0.5, 0.0]), samples([0.0, 0.5, 1.0, 0.5])),
(
samples([0.0, -1.0, 0.0, 1.0]),
samples([1.0, 0.0, -1.0, 0.0]),
),
);
assert_floats(values, [-1.0, -1.0, -1.0, -1.0]);
}
#[test]
#[should_panic]
fn rejects_mismatched_segment_gradient_jumps() {
let segments = triangle();
let distribution = distribution::<Backend>(segments.clone());
let records = distribution.sample(samples([0.125, 5.0 / 12.0]));
let geometry = evaluate(segments, &records);
let _ = segment_gradients(3, &geometry, samples([1.0, 2.0]), samples([1.0]));
}
#[test]
#[should_panic]
fn rejects_mismatched_segment_gradient_signals() {
let segments = triangle();
let distribution = distribution::<Backend>(segments.clone());
let records = distribution.sample(samples([0.125, 5.0 / 12.0]));
let geometry = evaluate(segments, &records);
let _ = segment_gradients(3, &geometry, samples([1.0]), samples([1.0, 1.0]));
}
#[test]
#[should_panic]
fn rejects_zero_segment_gradient_count() {
let segments = triangle();
let distribution = distribution::<Backend>(segments.clone());
let records = distribution.sample(samples([0.125]));
let geometry = evaluate(segments, &records);
let _ = segment_gradients(0, &geometry, samples([1.0]), samples([1.0]));
}
}