use ::burn::tensor::{Tensor, backend::Backend};
use crate::burn::{filter::Weights, geometry::Coordinates};
const EPSILON: f32 = 1e-6;
pub type Image<B> = Tensor<B, 2>;
pub type Values<B> = Tensor<B, 1>;
pub fn sample<B: Backend>(
image: Image<B>,
points: Coordinates<B>,
radius: f32,
weights: Weights<B>,
) -> Values<B> {
let [height, width] = image.dims();
let (x, y) = points;
let [sample_count] = x.dims();
assert!(
height > 0
&& width > 0
&& sample_count > 0
&& y.dims() == [sample_count]
&& weights.dims() == [height, width],
"image must be non-empty, match the weights, and point columns must have matching non-empty shapes"
);
let point_in_bounds = x
.clone()
.greater_equal_elem(0.0)
.bool_and(x.clone().lower_elem(width as f32))
.bool_and(y.clone().greater_equal_elem(0.0))
.bool_and(y.clone().lower_elem(height as f32));
let device = image.device();
let weighted = (image / weights.clamp_min(EPSILON)).reshape([height * width]);
let x_floor = x.clone().floor();
let y_floor = y.clone().floor();
let mut values = Tensor::<B, 1>::zeros([sample_count], &device);
let (minimum_offset, maximum_offset) = if radius == 0.5 {
(-1, 0)
} else {
let offset = (radius.ceil() as i32).min(width.max(height) as i32);
(-offset, offset)
};
for dy in minimum_offset..=maximum_offset {
for dx in minimum_offset..=maximum_offset {
let candidate_x = x_floor.clone() + dx as f32;
let candidate_y = y_floor.clone() + dy as f32;
let center_x = candidate_x.clone() + 0.5;
let center_y = candidate_y.clone() + 0.5;
let inside_filter = (center_x - x.clone())
.abs()
.lower_equal_elem(radius)
.bool_and((center_y - y.clone()).abs().lower_equal_elem(radius))
.bool_and(candidate_x.clone().greater_equal_elem(0.0))
.bool_and(candidate_x.clone().lower_elem(width as f32))
.bool_and(candidate_y.clone().greater_equal_elem(0.0))
.bool_and(candidate_y.clone().lower_elem(height as f32))
.bool_and(point_in_bounds.clone());
let ix = candidate_x.int().clamp(0, width as i64 - 1);
let iy = candidate_y.int().clamp(0, height as i64 - 1);
let index = iy * width as i64 + ix;
let sample = weighted.clone().select(0, index);
values = values
+ Tensor::<B, 1>::zeros([sample_count], &device).mask_where(inside_filter, sample);
}
}
values
}
#[cfg(test)]
mod tests {
use super::super::tests::{assert_floats, image, samples};
use super::sample;
#[test]
fn divides_by_weights() {
let values = sample(
image([[2.0, 4.0], [6.0, 8.0]]),
(samples([0.5, 1.5, 0.5, 1.5]), samples([0.5, 0.5, 1.5, 1.5])),
0.5,
image([[2.0, 2.0], [2.0, 2.0]]),
);
assert_floats(values, [1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn gathers_within_wider_radius() {
let values = sample(
image([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
(samples([1.5]), samples([0.5])),
1.0,
image([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
);
assert_floats(values, [21.0]);
}
#[test]
fn samples_grayscale_loss_signal() {
let values = sample(
image([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]),
(
samples([0.5, 1.5, 0.5, 1.0, 3.0]),
samples([0.5, 0.5, 1.5, 1.0, 1.0]),
),
0.5,
image([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
);
assert_floats(values, [1.0, 2.0, 4.0, 12.0, 0.0]);
}
}