use ::burn::tensor::{Tensor, backend::Backend};
use crate::base::{Config, geometry::Layout};
use super::{
boundary,
geometry::{Arguments, Contour, contour_tensors},
raster,
tensor::{Targets, Value},
};
pub struct Loss {
config: Config,
}
impl Loss {
pub const fn new(config: Config) -> Self {
Self { config }
}
pub fn evaluate<B: Backend>(
&self,
layouts: &[Layout],
arguments: Arguments<B>,
targets: Targets<B>,
) -> Value<B> {
let [layout_count, height, width] = targets.dims();
assert!(
layout_count == layouts.len() && layout_count > 0 && height > 0 && width > 0,
"targets must have one non-empty image per layout"
);
let argument_shape = arguments.dims();
let device = arguments.device();
let prepared_layouts = prepare_layouts(layouts, arguments.clone());
let prediction = render_layouts(&prepared_layouts, height, width, self.config, &device);
let difference = prediction - targets;
let loss = difference.clone().powi_scalar(2).mean();
let signal = difference.detach() * (2.0 / (layout_count * height * width) as f32);
let gradients = differentiate_layouts(
&prepared_layouts,
signal,
argument_shape,
&device,
self.config,
);
attach_gradient(loss, arguments, gradients)
}
}
struct PreparedLayout<'a, B: Backend> {
layout: &'a Layout,
contours: Vec<Contour<B>>,
}
fn attach_gradient<B: Backend>(
loss: Value<B>,
arguments: Arguments<B>,
gradients: Arguments<B>,
) -> Value<B> {
let gradients = gradients.detach();
let tracked = (arguments.clone() * gradients.clone()).sum();
let constant = (arguments.detach() * gradients).sum();
loss + tracked - constant
}
fn differentiate_layouts<B: Backend>(
layouts: &[PreparedLayout<'_, B>],
signal: Targets<B>,
argument_shape: [usize; 4],
device: &B::Device,
config: Config,
) -> Arguments<B> {
let [_, height, width] = signal.dims();
let [batch_size, sequence_length, _, _] = argument_shape;
let mut gradients = Tensor::<B, 4>::zeros([batch_size, sequence_length, 2, 2], device);
for (layout_index, layout) in layouts.iter().enumerate() {
if layout.contours.is_empty() {
continue;
}
let signal = signal
.clone()
.slice([layout_index..layout_index + 1, 0..height, 0..width])
.squeeze_dim::<2>(0);
let contour_gradients = boundary::differentiate(&layout.contours, signal, config);
for (local_index, contour_gradients) in contour_gradients.into_iter().enumerate() {
let range = &layout.layout.contours()[local_index];
let current = gradients
.clone()
.slice([
layout.layout.batch()..layout.layout.batch() + 1,
range.start()..range.end(),
0..2,
0..2,
])
.squeeze_dim::<3>(0);
let updated = current + contour_gradients;
gradients = gradients.slice_assign(
[
layout.layout.batch()..layout.layout.batch() + 1,
range.start()..range.end(),
0..2,
0..2,
],
updated.unsqueeze_dim::<4>(0),
);
}
}
gradients
}
fn prepare_layouts<'a, B: Backend>(
layouts: &'a [Layout],
arguments: Arguments<B>,
) -> Vec<PreparedLayout<'a, B>> {
let [batch_size, sequence_length, point_count, coordinate_count] = arguments.dims();
assert!(
!layouts.is_empty()
&& batch_size > 0
&& sequence_length > 0
&& point_count == 2
&& coordinate_count == 2,
"layouts must not be empty and arguments must have shape [batch, sequence, 2, 2]"
);
layouts
.iter()
.map(|layout| PreparedLayout {
layout,
contours: contour_tensors(layout, arguments.clone()),
})
.collect()
}
fn render_layouts<B: Backend>(
layouts: &[PreparedLayout<'_, B>],
height: usize,
width: usize,
config: Config,
device: &B::Device,
) -> Targets<B> {
let mut rasters = Vec::with_capacity(layouts.len());
for layout in layouts {
let raster = if layout.contours.is_empty() {
Tensor::<B, 2>::zeros([height, width], device)
} else {
raster::render(&layout.contours, height, width, config)
};
rasters.push(raster);
}
Tensor::<B, 2>::stack::<3>(rasters, 0)
}
#[cfg(test)]
mod tests {
use ::burn::{
backend::Autodiff,
tensor::{Tensor, TensorData},
};
use super::Loss;
use crate::{
base::{
Config,
geometry::{Command, Layout, Range},
},
burn::tests::Backend,
};
type AutodiffBackend = Autodiff<Backend>;
#[test]
fn attaches_boundary_gradients_to_arguments() {
let arguments = sequence_auto(square_auto([0.0, 0.0], [1.0, 1.0])).require_grad();
let loss = Loss::new(Config::default()).evaluate(
&square_layouts(),
arguments.clone(),
image_auto([[[0.0]]]),
);
let gradients = loss.backward();
let argument_gradient = arguments.grad(&gradients).unwrap();
assert_nonzero(argument_gradient);
}
#[test]
fn gradient_step_moves_square_toward_target() {
let arguments = sequence(square([0.0, 0.0], [1.0, 1.0]));
let gradient = argument_gradient(arguments.clone(), image_auto([[[0.0, 1.0]]]));
let updated = arguments.clone() - gradient * 0.25;
assert!(center(updated, 0) > center(arguments, 0));
}
#[test]
fn gradient_step_reduces_shifted_square_loss() {
let target = image([[[0.0, 1.0]]]);
let arguments = sequence(square([0.0, 0.0], [1.0, 1.0]));
let before = loss_value(arguments.clone(), target.clone());
let gradient = argument_gradient(arguments.clone(), image_auto([[[0.0, 1.0]]]));
let updated = arguments - gradient * 4.0;
let after = loss_value(updated, target);
assert!(after < before, "expected {after} to be less than {before}");
}
#[test]
#[should_panic]
fn rejects_empty_range() {
let _ = Loss::new(Config::default()).evaluate(
&[Layout::new(0, commands(4), vec![Range::new(4, 4)])],
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[1.0]]]),
);
}
#[test]
#[should_panic]
fn rejects_invalid_config() {
let _ = Loss::new(Config::new(0, 2, 0.5));
}
#[test]
#[should_panic]
fn rejects_invalid_layout_batch_index() {
let _ = Loss::new(Config::default()).evaluate(
&[Layout::new(1, commands(4), vec![Range::new(0, 4)])],
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[1.0]]]),
);
}
#[test]
#[should_panic]
fn rejects_invalid_range_end() {
let _ = Loss::new(Config::default()).evaluate(
&[Layout::new(0, commands(5), vec![Range::new(0, 5)])],
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[1.0]]]),
);
}
#[test]
#[should_panic]
fn rejects_invalid_arguments() {
let _ = Loss::new(Config::default()).evaluate(
&square_layouts(),
Tensor::<Backend, 4>::from_data(
TensorData::new(Vec::<f32>::new(), [1, 0, 2, 2]),
&Default::default(),
),
image([[[1.0]]]),
);
}
#[test]
#[should_panic]
fn rejects_mismatched_target_count() {
let _ = Loss::new(Config::default()).evaluate(
&square_layouts(),
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[1.0]], [[1.0]]]),
);
}
#[test]
fn returns_mean_squared_error() {
let loss = Loss::new(Config::default()).evaluate(
&square_layouts(),
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[0.0]]]),
);
assert_close(loss.into_scalar(), 1.0);
}
#[test]
fn returns_zero_for_empty_layout() {
let loss = Loss::new(Config::default()).evaluate(
&[
Layout::new(0, commands(0), vec![]),
Layout::new(0, commands(4), vec![Range::new(0, 4)]),
],
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[0.0]], [[1.0]]]),
);
assert_close(loss.into_scalar(), 0.0);
}
#[test]
fn returns_zero_for_matching_layouts() {
let arguments = Tensor::<Backend, 4>::from_data(
TensorData::from([
square_segments([0.0, 0.0], [1.0, 0.0], [1.0, 1.0], [0.0, 1.0]),
square_segments([0.0, 0.0], [0.5, 0.0], [0.5, 1.0], [0.0, 1.0]),
]),
&Default::default(),
);
let loss = Loss::new(Config::default()).evaluate(
&[
Layout::new(0, commands(4), vec![Range::new(0, 4)]),
Layout::new(1, commands(4), vec![Range::new(0, 4)]),
],
arguments,
image([[[1.0]], [[0.5]]]),
);
assert_close(loss.into_scalar(), 0.0);
}
#[test]
fn returns_zero_for_matching_ring() {
let loss = Loss::new(Config::default()).evaluate(
&[Layout::new(
0,
commands(8),
vec![Range::new(0, 4), Range::new(4, 8)],
)],
sequence(ring()),
image([[[1.0, 1.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]]]),
);
assert_close(loss.into_scalar(), 0.0);
}
#[test]
fn returns_zero_for_matching_target() {
let loss = Loss::new(Config::default()).evaluate(
&square_layouts(),
sequence(square([0.0, 0.0], [1.0, 1.0])),
image([[[1.0]]]),
);
assert_close(loss.into_scalar(), 0.0);
}
fn argument_gradient(
arguments: Tensor<Backend, 4>,
target: Tensor<AutodiffBackend, 3>,
) -> Tensor<Backend, 4> {
let arguments = Tensor::<AutodiffBackend, 4>::from_inner(arguments).require_grad();
let loss =
Loss::new(Config::default()).evaluate(&square_layouts(), arguments.clone(), target);
let gradients = loss.backward();
arguments.grad(&gradients).unwrap()
}
fn assert_close(actual: f32, expected: f32) {
assert!((actual - expected).abs() < 1e-6);
}
fn assert_nonzero<const D: usize>(tensor: Tensor<Backend, D>) {
let actual = tensor.into_data().to_vec::<f32>().unwrap();
assert!(actual.iter().any(|value| value.abs() > 1e-6));
}
fn center<const D: usize>(tensor: Tensor<Backend, D>, coordinate: usize) -> f32 {
let values = tensor.into_data().to_vec::<f32>().unwrap();
let point_count = values.len() / 2;
values
.chunks_exact(2)
.map(|point| point[coordinate])
.sum::<f32>()
/ point_count as f32
}
fn commands(count: usize) -> Vec<Command> {
vec![Command::Linear; count]
}
fn image<const BATCH: usize, const HEIGHT: usize, const WIDTH: usize>(
values: [[[f32; WIDTH]; HEIGHT]; BATCH],
) -> Tensor<Backend, 3> {
Tensor::<Backend, 3>::from_data(TensorData::from(values), &Default::default())
}
fn image_auto<const BATCH: usize, const HEIGHT: usize, const WIDTH: usize>(
values: [[[f32; WIDTH]; HEIGHT]; BATCH],
) -> Tensor<AutodiffBackend, 3> {
Tensor::<AutodiffBackend, 3>::from_data(TensorData::from(values), &Default::default())
}
fn interpolate(start: [f32; 2], end: [f32; 2], t: f32) -> [f32; 2] {
[
start[0] + (end[0] - start[0]) * t,
start[1] + (end[1] - start[1]) * t,
]
}
fn loss_value(arguments: Tensor<Backend, 4>, target: Tensor<Backend, 3>) -> f32 {
Loss::new(Config::default())
.evaluate(&square_layouts(), arguments, target)
.into_scalar()
}
fn ring() -> Tensor<Backend, 3> {
Tensor::<Backend, 4>::from_data(
TensorData::from([
square_segments([0.0, 0.0], [3.0, 0.0], [3.0, 3.0], [0.0, 3.0]),
square_segments([1.0, 1.0], [1.0, 2.0], [2.0, 2.0], [2.0, 1.0]),
]),
&Default::default(),
)
.reshape([8, 2, 2])
}
fn sequence(arguments: Tensor<Backend, 3>) -> Tensor<Backend, 4> {
arguments.unsqueeze_dim::<4>(0)
}
fn sequence_auto(arguments: Tensor<AutodiffBackend, 3>) -> Tensor<AutodiffBackend, 4> {
arguments.unsqueeze_dim::<4>(0)
}
fn square(min: [f32; 2], max: [f32; 2]) -> Tensor<Backend, 3> {
Tensor::<Backend, 3>::from_data(square_data(min, max), &Default::default())
}
fn square_auto(min: [f32; 2], max: [f32; 2]) -> Tensor<AutodiffBackend, 3> {
Tensor::<AutodiffBackend, 3>::from_data(square_data(min, max), &Default::default())
}
fn square_data(min: [f32; 2], max: [f32; 2]) -> TensorData {
TensorData::from(square_segments(
[min[0], min[1]],
[max[0], min[1]],
[max[0], max[1]],
[min[0], max[1]],
))
}
fn square_layouts() -> Vec<Layout> {
vec![Layout::new(0, commands(4), vec![Range::new(0, 4)])]
}
fn square_segments(a: [f32; 2], b: [f32; 2], c: [f32; 2], d: [f32; 2]) -> [[[f32; 2]; 2]; 4] {
[
[a, interpolate(a, b, 0.5)],
[b, interpolate(b, c, 0.5)],
[c, interpolate(c, d, 0.5)],
[d, interpolate(d, a, 0.5)],
]
}
}