use ::burn::tensor::{Int, Tensor, backend::Backend};
pub type Selection<B> = Tensor<B, 2>;
pub type Counts<B> = Tensor<B, 1>;
pub type Weights<B> = Tensor<B, 2>;
pub fn axis<B: Backend>(
extent: usize,
sample_count: usize,
radius: f32,
device: &B::Device,
) -> (Selection<B>, Counts<B>) {
let centers =
(Tensor::<B, 1, Int>::arange(0..extent as i64, device).float() + 0.5).unsqueeze_dim::<2>(1);
let fine = ((Tensor::<B, 1, Int>::arange(0..(extent * sample_count) as i64, device).float()
+ 0.5)
/ sample_count as f32)
.unsqueeze_dim::<2>(0);
let selection = (centers - fine).abs().lower_equal_elem(radius).float();
let counts = selection.clone().sum_dim(1).squeeze_dim::<1>(1);
(selection, counts)
}
pub fn weight_image<B: Backend>(
height: usize,
width: usize,
x_sample_count: usize,
y_sample_count: usize,
radius: f32,
device: &B::Device,
) -> Weights<B> {
let (_, column_counts) = axis::<B>(width, x_sample_count, radius, device);
let (_, row_counts) = axis::<B>(height, y_sample_count, radius, device);
row_counts.unsqueeze_dim::<2>(1) * column_counts.unsqueeze_dim::<2>(0)
}
#[cfg(test)]
mod tests {
use super::weight_image;
use crate::burn::tests::Backend;
#[test]
fn counts_subsamples_within_radius() {
let weight = weight_image::<Backend>(2, 2, 2, 2, 0.5, &Default::default());
assert_eq!(weight.into_data().to_vec::<f32>().unwrap(), [4.0; 4]);
}
#[test]
fn widens_with_radius_and_clips_at_borders() {
let weight = weight_image::<Backend>(1, 3, 1, 1, 1.0, &Default::default());
assert_eq!(weight.into_data().to_vec::<f32>().unwrap(), [2.0, 3.0, 2.0]);
}
}