use ::burn::tensor::{Bool, Int, Tensor, backend::Backend};
use crate::burn::{filter::Weights, geometry::Coordinates};
const EPSILON: f32 = 1e-6;
pub type Image<B> = Tensor<B, 2>;
pub type Values<B> = Tensor<B, 1>;
pub fn sample<B: Backend>(
image: Image<B>,
points: Coordinates<B>,
radius: f32,
weights: Weights<B>,
) -> Values<B> {
let [height, width] = image.dims();
let (x, y) = points;
let [sample_count] = x.dims();
assert!(
height > 0
&& width > 0
&& sample_count > 0
&& y.dims() == [sample_count]
&& weights.dims() == [height, width],
"image must be non-empty, match the weights, and point columns must have matching non-empty shapes"
);
let device = image.device();
let weighted = (image / weights.clamp_min(EPSILON)).reshape([height * width]);
let x_floor = x.clone().floor();
let y_floor = y.clone().floor();
let zeros = Tensor::<B, 1>::zeros([sample_count], &device);
let mut values = zeros.clone();
let (minimum_offset, maximum_offset) = if radius == 0.5 {
(-1, 0)
} else {
let offset = (radius.ceil() as i32).min(width.max(height) as i32);
(-offset, offset)
};
let columns = (minimum_offset..=maximum_offset)
.map(|offset| axis(&x, &x_floor, offset, width, radius))
.collect::<Vec<_>>();
let rows = (minimum_offset..=maximum_offset)
.map(|offset| {
let (inside, index) = axis(&y, &y_floor, offset, height, radius);
(inside, index * width as i64)
})
.collect::<Vec<_>>();
for (row_inside, row_index) in &rows {
for (column_inside, column_index) in &columns {
let inside_filter = row_inside.clone().bool_and(column_inside.clone());
let index = row_index.clone() + column_index.clone();
let sample = weighted.clone().gather(0, index);
values = values + zeros.clone().mask_where(inside_filter, sample);
}
}
values
}
fn axis<B: Backend>(
coordinate: &Tensor<B, 1>,
floor: &Tensor<B, 1>,
offset: i32,
extent: usize,
radius: f32,
) -> (Tensor<B, 1, Bool>, Tensor<B, 1, Int>) {
let candidate = floor.clone() + offset as f32;
let center = candidate.clone() + 0.5;
let inside = (center - coordinate.clone())
.abs()
.lower_equal_elem(radius)
.bool_and(candidate.clone().greater_equal_elem(0.0))
.bool_and(candidate.clone().lower_elem(extent as f32))
.bool_and(coordinate.clone().greater_equal_elem(0.0))
.bool_and(coordinate.clone().lower_elem(extent as f32));
let index = candidate.int().clamp(0, extent as i64 - 1);
(inside, index)
}
#[cfg(test)]
mod tests {
use super::super::tests::{assert_floats, image, samples};
use super::sample;
#[test]
fn divides_by_weights() {
let values = sample(
image([[2.0, 4.0], [6.0, 8.0]]),
(samples([0.5, 1.5, 0.5, 1.5]), samples([0.5, 0.5, 1.5, 1.5])),
0.5,
image([[2.0, 2.0], [2.0, 2.0]]),
);
assert_floats(values, [1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn gathers_within_wider_radius() {
let values = sample(
image([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
(samples([1.5]), samples([0.5])),
1.0,
image([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
);
assert_floats(values, [21.0]);
}
#[test]
fn samples_grayscale_loss_signal() {
let values = sample(
image([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
(
samples([0.5, 1.5, 0.5, 1.0, 3.0]),
samples([0.5, 0.5, 1.5, 1.0, 1.0]),
),
0.5,
image([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
);
assert_floats(values, [1.0, 2.0, 4.0, 12.0, 0.0]);
}
}