use alloc::format;
use burn::tensor::module::interpolate;
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::InterpolateOptions;
use super::InterpolateMode;
#[derive(Config, Debug)]
pub struct Interpolate2dConfig {
#[config(default = "None")]
pub output_size: Option<[usize; 2]>,
#[config(default = "None")]
pub scale_factor: Option<[f32; 2]>,
#[config(default = "InterpolateMode::Nearest")]
pub mode: InterpolateMode,
}
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Interpolate2d {
pub output_size: Option<[usize; 2]>,
pub scale_factor: Option<[f32; 2]>,
pub mode: Ignored<InterpolateMode>,
}
impl Interpolate2dConfig {
pub fn init(self) -> Interpolate2d {
Interpolate2d {
output_size: self.output_size,
scale_factor: self.scale_factor,
mode: Ignored(self.mode),
}
}
}
impl Interpolate2d {
pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
interpolate(
input,
output_size,
InterpolateOptions::new(self.mode.0.clone().into()),
)
}
}
fn calculate_output_size(
input_dims: [usize; 4],
output_size: Option<[usize; 2]>,
scale_factor: Option<[f32; 2]>,
) -> [usize; 2] {
match (output_size, scale_factor) {
(Some(output_size), None) => {
output_size
}
(None, Some(scale_factor)) => {
let [_, _, h, w] = input_dims;
let new_dim_h = (h as f64) * (scale_factor[0] as f64);
if new_dim_h > usize::MAX as f64 {
panic!("Scale factor for height is too large");
}
let new_dim_w = (w as f64) * (scale_factor[1] as f64);
if new_dim_w > usize::MAX as f64 {
panic!("Scale factor for width is too large");
}
[new_dim_h as usize, new_dim_w as usize]
}
_ => panic!("Either output_size or scale_factor must be provided"),
}
}
impl ModuleDisplay for Interpolate2d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("mode", &self.mode)
.add("output_size", &format!("{:?}", self.output_size))
.add("scale_factor", &self.scale_factor)
.optional()
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Distribution;
use crate::TestBackend;
use super::*;
#[test]
fn test_calculate_output_size() {
let input_dims = [1, 1, 4, 4];
let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
assert_eq!(output_size, [2, 2]);
let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
assert_eq!(output_size, [8, 8]);
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
assert_eq!(output_size, [2, 2]);
let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
assert_eq!(output_size, [8, 6]);
}
#[test]
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
fn test_missing_params() {
calculate_output_size([1, 1, 4, 4], None, None);
}
#[test]
#[should_panic(expected = "Scale factor for height is too large")]
fn test_infinite_height() {
calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
}
#[test]
#[should_panic(expected = "Scale factor for width is too large")]
fn test_infinite_width() {
calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
}
#[test]
fn test_module() {
let input = Tensor::<TestBackend, 4>::random(
[2, 3, 4, 4],
Distribution::Uniform(0.0, 1.0),
&Default::default(),
);
let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 8, 8]);
let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 2, 2]);
let config = Interpolate2dConfig::new()
.with_output_size(Some([6, 6]))
.with_mode(InterpolateMode::Linear);
let interpolate = config.init();
let output = interpolate.forward(input);
assert_eq!(output.dims(), [2, 3, 6, 6]);
}
#[test]
fn display() {
let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
scale_factor: None}"
);
}
}