Skip to main content

burn_nn/modules/interpolate/
interpolate2d.rs

1use alloc::format;
2
3use burn::tensor::module::interpolate;
4
5use burn_core as burn;
6
7use burn::config::Config;
8use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11use burn::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15/// Configuration for the 2D interpolation module.
16///
17/// This struct defines the configuration options for the 2D interpolation operation.
18/// It allows specifying the output size, scale factor, and interpolation mode.
19#[derive(Config, Debug)]
20pub struct Interpolate2dConfig {
21    /// Output size of the interpolated tensor.
22    /// If specified, this takes precedence over `scale_factor`.
23    #[config(default = "None")]
24    pub output_size: Option<[usize; 2]>,
25
26    /// Scale factor for resizing the input tensor.
27    /// This is used when `output_size` is not specified.
28    #[config(default = "None")]
29    pub scale_factor: Option<[f32; 2]>,
30
31    /// Interpolation mode to use for resizing.
32    /// Determines how the output values are calculated.
33    #[config(default = "InterpolateMode::Nearest")]
34    pub mode: InterpolateMode,
35
36    /// If `true`, the input and output tensors are aligned by their corner pixels.
37    /// If `false`, half-pixel coordinate mapping is used instead.
38    #[config(default = true)]
39    pub align_corners: bool,
40}
41
42/// Interpolate module for resizing tensors with shape [N, C, H, W].
43///
44/// This struct represents an interpolation module that can resize tensors
45/// using various interpolation methods. It provides flexibility in specifying
46/// either an output size or a scale factor for resizing, along with options
47/// for the interpolation mode.
48///
49/// The module can be used to upsample or downsample tensors, preserving the
50/// number of channels and batch size while adjusting the height and width
51/// dimensions.
52///
53/// The module can be created using the [Interpolate2dConfig] struct and the
54/// `init` method, which returns an instance of the [Interpolate2d] struct.
55#[derive(Module, Clone, Debug)]
56#[module(custom_display)]
57pub struct Interpolate2d {
58    /// Output size of the interpolated tensor
59    pub output_size: Option<[usize; 2]>,
60
61    /// Scale factor for resizing the input tensor
62    pub scale_factor: Option<[f32; 2]>,
63
64    /// Interpolation mode used for resizing
65    pub mode: InterpolateMode,
66
67    /// Whether to align corner pixels
68    pub align_corners: bool,
69}
70
71impl Interpolate2dConfig {
72    /// Initialize the interpolation module
73    pub fn init(self) -> Interpolate2d {
74        Interpolate2d {
75            output_size: self.output_size,
76            scale_factor: self.scale_factor,
77            mode: self.mode,
78            align_corners: self.align_corners,
79        }
80    }
81}
82impl Interpolate2d {
83    /// Performs the forward pass of the interpolation module
84    ///
85    /// # Arguments
86    ///
87    /// * `input` - Input tensor with shape [N, C, H, W]
88    ///
89    /// # Returns
90    ///
91    /// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by
92    /// the output_size or scale_factor specified in the module configuration
93    ///
94    /// # Example
95    ///
96    /// ```ignore
97    /// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);
98    /// let interpolate = Interpolate2dConfig::new()
99    ///     .with_output_size(Some([128, 128]))
100    ///     .init();
101    /// let output = interpolate.forward(input);
102    /// assert_eq!(output.dims(), [1, 3, 128, 128]);
103    /// ```
104    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
105        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
106        interpolate(
107            input,
108            output_size,
109            InterpolateOptions::new(self.mode.clone().into())
110                .with_align_corners(self.align_corners),
111        )
112    }
113}
114
115/// Calculates the output size for tensor interpolation.
116///
117/// # Arguments
118///
119/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].
120/// * `output_size` - Optional desired output size [H', W'].
121/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].
122///
123/// # Returns
124///
125/// A tuple [H', W'] representing the calculated output size.
126///
127/// # Panics
128///
129/// Panics if neither `output_size` nor `scale_factor` is provided,
130/// or if the scale factor results in dimensions exceeding usize::MAX.
131fn calculate_output_size(
132    input_dims: [usize; 4],
133    output_size: Option<[usize; 2]>,
134    scale_factor: Option<[f32; 2]>,
135) -> [usize; 2] {
136    match (output_size, scale_factor) {
137        (Some(output_size), None) => {
138            // Use provided
139            output_size
140        }
141        (None, Some(scale_factor)) => {
142            // Calculate output size based on scale factor
143            let [_, _, h, w] = input_dims;
144
145            let new_dim_h = (h as f64) * (scale_factor[0] as f64);
146
147            if new_dim_h > usize::MAX as f64 {
148                panic!("Scale factor for height is too large");
149            }
150
151            let new_dim_w = (w as f64) * (scale_factor[1] as f64);
152
153            if new_dim_w > usize::MAX as f64 {
154                panic!("Scale factor for width is too large");
155            }
156
157            [new_dim_h as usize, new_dim_w as usize]
158        }
159        _ => panic!("Either output_size or scale_factor must be provided"),
160    }
161}
162
163impl ModuleDisplay for Interpolate2d {
164    fn custom_settings(&self) -> Option<DisplaySettings> {
165        DisplaySettings::new()
166            .with_new_line_after_attribute(false)
167            .optional()
168    }
169
170    fn custom_content(&self, content: Content) -> Option<Content> {
171        content
172            .add_debug_attribute("mode", &self.mode)
173            .add("output_size", &format!("{:?}", self.output_size))
174            .add("scale_factor", &self.scale_factor)
175            .optional()
176    }
177}
178#[cfg(test)]
179mod tests {
180    use burn::tensor::Distribution;
181
182    use crate::TestBackend;
183
184    use super::*;
185
186    #[test]
187    fn test_calculate_output_size() {
188        let input_dims = [1, 1, 4, 4];
189
190        let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
191        assert_eq!(output_size, [2, 2]);
192
193        let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
194        assert_eq!(output_size, [8, 8]);
195
196        let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
197        assert_eq!(output_size, [2, 2]);
198
199        let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
200        assert_eq!(output_size, [8, 6]);
201    }
202
203    #[test]
204    #[should_panic(expected = "Either output_size or scale_factor must be provided")]
205    fn test_missing_params() {
206        calculate_output_size([1, 1, 4, 4], None, None);
207    }
208
209    #[test]
210    #[should_panic(expected = "Scale factor for height is too large")]
211    fn test_infinite_height() {
212        calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
213    }
214
215    #[test]
216    #[should_panic(expected = "Scale factor for width is too large")]
217    fn test_infinite_width() {
218        calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
219    }
220
221    #[test]
222    fn test_module() {
223        let input = Tensor::<TestBackend, 4>::random(
224            [2, 3, 4, 4],
225            Distribution::Uniform(0.0, 1.0),
226            &Default::default(),
227        );
228
229        // Test with output_size
230        let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
231        let interpolate = config.init();
232        let output = interpolate.forward(input.clone());
233        assert_eq!(output.dims(), [2, 3, 8, 8]);
234
235        // Test with scale_factor
236        let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
237        let interpolate = config.init();
238        let output = interpolate.forward(input.clone());
239        assert_eq!(output.dims(), [2, 3, 2, 2]);
240
241        // Test with different interpolation mode
242        let config = Interpolate2dConfig::new()
243            .with_output_size(Some([6, 6]))
244            .with_mode(InterpolateMode::Linear);
245        let interpolate = config.init();
246        let output = interpolate.forward(input);
247        assert_eq!(output.dims(), [2, 3, 6, 6]);
248    }
249
250    #[test]
251    fn display() {
252        let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
253        let layer = config.init();
254
255        assert_eq!(
256            alloc::format!("{layer}"),
257            "Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
258            scale_factor: None}"
259        );
260    }
261}