use alloc::format;
use burn::tensor::module::interpolate;
use burn_core as burn;
use burn::config::Config;
use burn::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
use burn::tensor::ops::InterpolateOptions;
use super::InterpolateMode;
#[derive(Config, Debug)]
pub struct Interpolate1dConfig {
#[config(default = "None")]
pub output_size: Option<usize>,
#[config(default = "None")]
pub scale_factor: Option<f32>,
#[config(default = "InterpolateMode::Nearest")]
pub mode: InterpolateMode,
}
#[derive(Module, Clone, Debug)]
#[module(custom_display)]
pub struct Interpolate1d {
pub output_size: Option<usize>,
pub scale_factor: Option<f32>,
pub mode: Ignored<InterpolateMode>,
}
impl Interpolate1dConfig {
pub fn init(self) -> Interpolate1d {
Interpolate1d {
output_size: self.output_size,
scale_factor: self.scale_factor,
mode: Ignored(self.mode),
}
}
}
impl Interpolate1d {
pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
let input = input.unsqueeze_dim(2);
let result = interpolate(
input,
[1, output_size],
InterpolateOptions::new(self.mode.0.clone().into()),
);
result.squeeze_dims(&[2])
}
}
fn calculate_output_size(
input_dims: [usize; 3],
output_size: Option<usize>,
scale_factor: Option<f32>,
) -> usize {
match (output_size, scale_factor) {
(Some(output_size), None) => {
output_size
}
(None, Some(scale_factor)) => {
let [_, _, l] = input_dims;
let new_dim = (l as f64) * (scale_factor as f64);
if new_dim > usize::MAX as f64 {
panic!("Scale factor is too large");
}
new_dim as usize
}
_ => panic!("Either output_size or scale_factor must be provided"),
}
}
impl ModuleDisplay for Interpolate1d {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("mode", &self.mode)
.add("output_size", &format!("{:?}", self.output_size))
.add("scale_factor", &self.scale_factor)
.optional()
}
}
#[cfg(test)]
mod tests {
use burn::tensor::Distribution;
use super::*;
use crate::TestBackend;
#[test]
fn test_calculate_output_size() {
let input_dims = [1, 1, 4];
let output_size = calculate_output_size(input_dims, Some(2), None);
assert_eq!(output_size, 2);
let output_size = calculate_output_size(input_dims, None, Some(2.0));
assert_eq!(output_size, 8);
let output_size = calculate_output_size(input_dims, None, Some(0.5));
assert_eq!(output_size, 2);
let output_size = calculate_output_size(input_dims, None, Some(1.5));
assert_eq!(output_size, 6);
}
#[test]
#[should_panic(expected = "Either output_size or scale_factor must be provided")]
fn test_panic() {
let input_dims = [1, 1, 4];
calculate_output_size(input_dims, None, None);
}
#[test]
#[should_panic(expected = "Scale factor is too large")]
fn test_large_scale_factor() {
let input_dims = [1, 1, usize::MAX - 1];
calculate_output_size(input_dims, None, Some(2.0));
}
#[test]
fn test_module() {
let input = Tensor::<TestBackend, 3>::random(
[2, 3, 4],
Distribution::Uniform(0.0, 1.0),
&Default::default(),
);
let config = Interpolate1dConfig::new().with_output_size(Some(8));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 8]);
let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
let interpolate = config.init();
let output = interpolate.forward(input.clone());
assert_eq!(output.dims(), [2, 3, 2]);
let config = Interpolate1dConfig::new()
.with_output_size(Some(6))
.with_mode(InterpolateMode::Linear);
let interpolate = config.init();
let output = interpolate.forward(input);
assert_eq!(output.dims(), [2, 3, 6]);
}
#[test]
fn display() {
let config = Interpolate1dConfig::new().with_output_size(Some(20));
let layer = config.init();
assert_eq!(
alloc::format!("{layer}"),
"Interpolate1d {mode: Nearest, output_size: Some(20), \
scale_factor: None}"
);
}
}