burn_core/nn/interpolate/
interpolate1d.rs1use alloc::format;
2
3use burn_tensor::module::interpolate;
4
5use crate as burn;
6
7use crate::config::Config;
8use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11use crate::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15#[derive(Config, Debug)]
20pub struct Interpolate1dConfig {
21    #[config(default = "None")]
24    pub output_size: Option<usize>,
25
26    #[config(default = "None")]
29    pub scale_factor: Option<f32>,
30
31    #[config(default = "InterpolateMode::Nearest")]
34    pub mode: InterpolateMode,
35}
36
37#[derive(Module, Clone, Debug)]
50#[module(custom_display)]
51pub struct Interpolate1d {
52    pub output_size: Option<usize>,
54
55    pub scale_factor: Option<f32>,
57
58    pub mode: Ignored<InterpolateMode>,
60}
61
62impl Interpolate1dConfig {
63    pub fn init(self) -> Interpolate1d {
65        Interpolate1d {
66            output_size: self.output_size,
67            scale_factor: self.scale_factor,
68            mode: Ignored(self.mode),
69        }
70    }
71}
72
73impl Interpolate1d {
74    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
96        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
97
98        let input = input.unsqueeze_dim(2);
101
102        let result = interpolate(
103            input,
104            [1, output_size],
105            InterpolateOptions::new(self.mode.0.clone().into()),
106        );
107
108        result.squeeze_dims(&[2])
109    }
110}
111
112fn calculate_output_size(
129    input_dims: [usize; 3],
130    output_size: Option<usize>,
131    scale_factor: Option<f32>,
132) -> usize {
133    match (output_size, scale_factor) {
134        (Some(output_size), None) => {
135            output_size
137        }
138        (None, Some(scale_factor)) => {
139            let [_, _, l] = input_dims;
141
142            let new_dim = (l as f64) * (scale_factor as f64);
143
144            if new_dim > usize::MAX as f64 {
145                panic!("Scale factor is too large");
146            }
147
148            new_dim as usize
149        }
150        _ => panic!("Either output_size or scale_factor must be provided"),
151    }
152}
153
154impl ModuleDisplay for Interpolate1d {
155    fn custom_settings(&self) -> Option<DisplaySettings> {
156        DisplaySettings::new()
157            .with_new_line_after_attribute(false)
158            .optional()
159    }
160
161    fn custom_content(&self, content: Content) -> Option<Content> {
162        content
163            .add("mode", &self.mode)
164            .add("output_size", &format!("{:?}", self.output_size))
165            .add("scale_factor", &self.scale_factor)
166            .optional()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172
173    use burn_tensor::Distribution;
174
175    use super::*;
176    use crate::TestBackend;
177    #[test]
178    fn test_calculate_output_size() {
179        let input_dims = [1, 1, 4];
180
181        let output_size = calculate_output_size(input_dims, Some(2), None);
182        assert_eq!(output_size, 2);
183
184        let output_size = calculate_output_size(input_dims, None, Some(2.0));
185        assert_eq!(output_size, 8);
186
187        let output_size = calculate_output_size(input_dims, None, Some(0.5));
188        assert_eq!(output_size, 2);
189
190        let output_size = calculate_output_size(input_dims, None, Some(1.5));
191        assert_eq!(output_size, 6);
192    }
193
194    #[test]
195    #[should_panic(expected = "Either output_size or scale_factor must be provided")]
196    fn test_panic() {
197        let input_dims = [1, 1, 4];
198        calculate_output_size(input_dims, None, None);
199    }
200
201    #[test]
202    #[should_panic(expected = "Scale factor is too large")]
203    fn test_large_scale_factor() {
204        let input_dims = [1, 1, usize::MAX - 1];
205        calculate_output_size(input_dims, None, Some(2.0));
206    }
207
208    #[test]
209    fn test_module() {
210        let input = Tensor::<TestBackend, 3>::random(
211            [2, 3, 4],
212            Distribution::Uniform(0.0, 1.0),
213            &Default::default(),
214        );
215
216        let config = Interpolate1dConfig::new().with_output_size(Some(8));
218        let interpolate = config.init();
219        let output = interpolate.forward(input.clone());
220        assert_eq!(output.dims(), [2, 3, 8]);
221
222        let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
224        let interpolate = config.init();
225        let output = interpolate.forward(input.clone());
226        assert_eq!(output.dims(), [2, 3, 2]);
227
228        let config = Interpolate1dConfig::new()
230            .with_output_size(Some(6))
231            .with_mode(InterpolateMode::Linear);
232        let interpolate = config.init();
233        let output = interpolate.forward(input);
234        assert_eq!(output.dims(), [2, 3, 6]);
235    }
236
237    #[test]
238    fn display() {
239        let config = Interpolate1dConfig::new().with_output_size(Some(20));
240        let layer = config.init();
241
242        assert_eq!(
243            alloc::format!("{layer}"),
244            "Interpolate1d {mode: Nearest, output_size: Some(20), \
245            scale_factor: None}"
246        );
247    }
248}