burn_core/nn/interpolate/
interpolate2d.rs

1use alloc::format;
2
3use burn_tensor::module::interpolate;
4
5use crate as burn;
6
7use crate::config::Config;
8use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11use crate::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15/// Configuration for the 2D interpolation module.
16///
17/// This struct defines the configuration options for the 2D interpolation operation.
18/// It allows specifying the output size, scale factor, and interpolation mode.
19#[derive(Config, Debug)]
20pub struct Interpolate2dConfig {
21    /// Output size of the interpolated tensor.
22    /// If specified, this takes precedence over `scale_factor`.
23    #[config(default = "None")]
24    pub output_size: Option<[usize; 2]>,
25
26    /// Scale factor for resizing the input tensor.
27    /// This is used when `output_size` is not specified.
28    #[config(default = "None")]
29    pub scale_factor: Option<[f32; 2]>,
30
31    /// Interpolation mode to use for resizing.
32    /// Determines how the output values are calculated.
33    #[config(default = "InterpolateMode::Nearest")]
34    pub mode: InterpolateMode,
35}
36
37/// Interpolate module for resizing tensors with shape [N, C, H, W].
38///
39/// This struct represents an interpolation module that can resize tensors
40/// using various interpolation methods. It provides flexibility in specifying
41/// either an output size or a scale factor for resizing, along with options
42/// for the interpolation mode.
43///
44/// The module can be used to upsample or downsample tensors, preserving the
45/// number of channels and batch size while adjusting the height and width
46/// dimensions.
47///
48/// The module can be created using the [Interpolate2dConfig] struct and the
49/// `init` method, which returns an instance of the [Interpolate2d] struct.
50#[derive(Module, Clone, Debug)]
51#[module(custom_display)]
52pub struct Interpolate2d {
53    /// Output size of the interpolated tensor
54    pub output_size: Option<[usize; 2]>,
55
56    /// Scale factor for resizing the input tensor
57    pub scale_factor: Option<[f32; 2]>,
58
59    /// Interpolation mode used for resizing
60    pub mode: Ignored<InterpolateMode>,
61}
62
63impl Interpolate2dConfig {
64    /// Initialize the interpolation module
65    pub fn init(self) -> Interpolate2d {
66        Interpolate2d {
67            output_size: self.output_size,
68            scale_factor: self.scale_factor,
69            mode: Ignored(self.mode),
70        }
71    }
72}
73impl Interpolate2d {
74    /// Performs the forward pass of the interpolation module
75    ///
76    /// # Arguments
77    ///
78    /// * `input` - Input tensor with shape [N, C, H, W]
79    ///
80    /// # Returns
81    ///
82    /// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by
83    /// the output_size or scale_factor specified in the module configuration
84    ///
85    /// # Example
86    ///
87    /// ```ignore
88    /// let input = Tensor::<Backend, 2>::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device);
89    /// let interpolate = Interpolate2dConfig::new()
90    ///     .with_output_size(Some([128, 128]))
91    ///     .init();
92    /// let output = interpolate.forward(input);
93    /// assert_eq!(output.dims(), [1, 3, 128, 128]);
94    /// ```
95    pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
96        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
97        interpolate(
98            input,
99            output_size,
100            InterpolateOptions::new(self.mode.0.clone().into()),
101        )
102    }
103}
104
105/// Calculates the output size for tensor interpolation.
106///
107/// # Arguments
108///
109/// * `input_dims` - The dimensions of the input tensor [N, C, H, W].
110/// * `output_size` - Optional desired output size [H', W'].
111/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w].
112///
113/// # Returns
114///
115/// A tuple [H', W'] representing the calculated output size.
116///
117/// # Panics
118///
119/// Panics if neither `output_size` nor `scale_factor` is provided,
120/// or if the scale factor results in dimensions exceeding usize::MAX.
121fn calculate_output_size(
122    input_dims: [usize; 4],
123    output_size: Option<[usize; 2]>,
124    scale_factor: Option<[f32; 2]>,
125) -> [usize; 2] {
126    match (output_size, scale_factor) {
127        (Some(output_size), None) => {
128            // Use provided
129            output_size
130        }
131        (None, Some(scale_factor)) => {
132            // Calculate output size based on scale factor
133            let [_, _, h, w] = input_dims;
134
135            let new_dim_h = (h as f64) * (scale_factor[0] as f64);
136
137            if new_dim_h > usize::MAX as f64 {
138                panic!("Scale factor for height is too large");
139            }
140
141            let new_dim_w = (w as f64) * (scale_factor[1] as f64);
142
143            if new_dim_w > usize::MAX as f64 {
144                panic!("Scale factor for width is too large");
145            }
146
147            [new_dim_h as usize, new_dim_w as usize]
148        }
149        _ => panic!("Either output_size or scale_factor must be provided"),
150    }
151}
152
153impl ModuleDisplay for Interpolate2d {
154    fn custom_settings(&self) -> Option<DisplaySettings> {
155        DisplaySettings::new()
156            .with_new_line_after_attribute(false)
157            .optional()
158    }
159
160    fn custom_content(&self, content: Content) -> Option<Content> {
161        content
162            .add("mode", &self.mode)
163            .add("output_size", &format!("{:?}", self.output_size))
164            .add("scale_factor", &self.scale_factor)
165            .optional()
166    }
167}
168#[cfg(test)]
169mod tests {
170    use burn_tensor::Distribution;
171
172    use crate::TestBackend;
173
174    use super::*;
175
176    #[test]
177    fn test_calculate_output_size() {
178        let input_dims = [1, 1, 4, 4];
179
180        let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
181        assert_eq!(output_size, [2, 2]);
182
183        let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
184        assert_eq!(output_size, [8, 8]);
185
186        let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
187        assert_eq!(output_size, [2, 2]);
188
189        let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
190        assert_eq!(output_size, [8, 6]);
191    }
192
193    #[test]
194    #[should_panic(expected = "Either output_size or scale_factor must be provided")]
195    fn test_missing_params() {
196        calculate_output_size([1, 1, 4, 4], None, None);
197    }
198
199    #[test]
200    #[should_panic(expected = "Scale factor for height is too large")]
201    fn test_infinite_height() {
202        calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
203    }
204
205    #[test]
206    #[should_panic(expected = "Scale factor for width is too large")]
207    fn test_infinite_width() {
208        calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
209    }
210
211    #[test]
212    fn test_module() {
213        let input = Tensor::<TestBackend, 4>::random(
214            [2, 3, 4, 4],
215            Distribution::Uniform(0.0, 1.0),
216            &Default::default(),
217        );
218
219        // Test with output_size
220        let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
221        let interpolate = config.init();
222        let output = interpolate.forward(input.clone());
223        assert_eq!(output.dims(), [2, 3, 8, 8]);
224
225        // Test with scale_factor
226        let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
227        let interpolate = config.init();
228        let output = interpolate.forward(input.clone());
229        assert_eq!(output.dims(), [2, 3, 2, 2]);
230
231        // Test with different interpolation mode
232        let config = Interpolate2dConfig::new()
233            .with_output_size(Some([6, 6]))
234            .with_mode(InterpolateMode::Linear);
235        let interpolate = config.init();
236        let output = interpolate.forward(input);
237        assert_eq!(output.dims(), [2, 3, 6, 6]);
238    }
239
240    #[test]
241    fn display() {
242        let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
243        let layer = config.init();
244
245        assert_eq!(
246            alloc::format!("{layer}"),
247            "Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
248            scale_factor: None}"
249        );
250    }
251}