Skip to main content

burn_nn/modules/interpolate/
interpolate1d.rs

1use alloc::format;
2
3use burn::tensor::module::interpolate;
4
5use burn_core as burn;
6
7use burn::config::Config;
8use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11use burn::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15/// Configuration for the 1D interpolation module.
16///
17/// This struct defines the configuration options for the 1D interpolation operation.
18/// It allows specifying the output size, scale factor, and interpolation mode.
19#[derive(Config, Debug)]
20pub struct Interpolate1dConfig {
21    /// Output size of the interpolated tensor.
22    /// If specified, this takes precedence over `scale_factor`.
23    #[config(default = "None")]
24    pub output_size: Option<usize>,
25
26    /// Scale factor for resizing the input tensor.
27    /// This is used when `output_size` is not specified.
28    #[config(default = "None")]
29    pub scale_factor: Option<f32>,
30
31    /// Interpolation mode to use for resizing.
32    /// Determines how the output values are calculated.
33    #[config(default = "InterpolateMode::Nearest")]
34    pub mode: InterpolateMode,
35
36    /// If `true`, the input and output tensors are aligned by their corner pixels.
37    /// If `false`, half-pixel coordinate mapping is used instead.
38    #[config(default = true)]
39    pub align_corners: bool,
40}
41
42/// Interpolate module for resizing 1D tensors with shape [N, C, L].
43///
44/// This struct represents a 1D interpolation module that can resize tensors
45/// using various interpolation methods. It provides flexibility in specifying
46/// either an output size or a scale factor for resizing, along with options
47/// for the interpolation mode.
48///
49/// The module can be used to upsample or downsample 1D tensors, preserving the
50/// number of channels and batch size while adjusting the length dimension.
51///
52/// The module can be created using the [Interpolate1dConfig] struct and the
53/// `init` method, which returns an instance of the [Interpolate1d] struct.
54#[derive(Module, Clone, Debug)]
55#[module(custom_display)]
56pub struct Interpolate1d {
57    /// Output size of the interpolated tensor
58    pub output_size: Option<usize>,
59
60    /// Scale factor for resizing the input tensor
61    pub scale_factor: Option<f32>,
62
63    /// Interpolation mode used for resizing
64    pub mode: InterpolateMode,
65
66    /// Whether to align corner pixels
67    pub align_corners: bool,
68}
69
70impl Interpolate1dConfig {
71    /// Initialize the interpolation module
72    pub fn init(self) -> Interpolate1d {
73        Interpolate1d {
74            output_size: self.output_size,
75            scale_factor: self.scale_factor,
76            mode: self.mode,
77            align_corners: self.align_corners,
78        }
79    }
80}
81
82impl Interpolate1d {
83    /// Performs the forward pass of the 1D interpolation module
84    ///
85    /// # Arguments
86    ///
87    /// * `input` - Input tensor with shape [N, C, L]
88    ///
89    /// # Returns
90    ///
91    /// Resized tensor with shape [N, C, L'], where L' is determined by
92    /// the output_size or scale_factor specified in the module configuration
93    ///
94    /// # Example
95    ///
96    /// ```ignore
97    /// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);
98    /// let interpolate = Interpolate1dConfig::new()
99    ///     .with_output_size(Some(128))
100    ///     .init();
101    /// let output = interpolate.forward(input);
102    /// assert_eq!(output.dims(), [1, 3, 128]);
103    /// ```
104    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
105        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
106
107        // Use the interpolate operation to resize the temporal input tensor
108        // by adding a new dimension for the interpolation axis
109        let input = input.unsqueeze_dim(2);
110
111        let result = interpolate(
112            input,
113            [1, output_size],
114            InterpolateOptions::new(self.mode.clone().into())
115                .with_align_corners(self.align_corners),
116        );
117
118        result.squeeze_dims(&[2])
119    }
120}
121
122/// Calculate output size based on input dimensions, output size, and scale factor
123///
124/// # Arguments
125///
126/// * `input_dims` - Input dimensions of the tensor
127/// * `output_size` - Output size for the interpolated tensor
128/// * `scale_factor` - Scale factor for resizing the tensor
129///
130/// # Returns
131///
132/// Output size for the interpolated tensor
133///
134/// # Panics
135///
136/// Panics if neither output_size nor scale_factor is provided
137/// or if the scale factor is too large
138fn calculate_output_size(
139    input_dims: [usize; 3],
140    output_size: Option<usize>,
141    scale_factor: Option<f32>,
142) -> usize {
143    match (output_size, scale_factor) {
144        (Some(output_size), None) => {
145            // Use provided
146            output_size
147        }
148        (None, Some(scale_factor)) => {
149            // Calculate output size based on scale factor
150            let [_, _, l] = input_dims;
151
152            let new_dim = (l as f64) * (scale_factor as f64);
153
154            if new_dim > usize::MAX as f64 {
155                panic!("Scale factor is too large");
156            }
157
158            new_dim as usize
159        }
160        _ => panic!("Either output_size or scale_factor must be provided"),
161    }
162}
163
164impl ModuleDisplay for Interpolate1d {
165    fn custom_settings(&self) -> Option<DisplaySettings> {
166        DisplaySettings::new()
167            .with_new_line_after_attribute(false)
168            .optional()
169    }
170
171    fn custom_content(&self, content: Content) -> Option<Content> {
172        content
173            .add_debug_attribute("mode", &self.mode)
174            .add("output_size", &format!("{:?}", self.output_size))
175            .add("scale_factor", &self.scale_factor)
176            .optional()
177    }
178}
179
180#[cfg(test)]
181mod tests {
182
183    use burn::tensor::Distribution;
184
185    use super::*;
186    use crate::TestBackend;
187    #[test]
188    fn test_calculate_output_size() {
189        let input_dims = [1, 1, 4];
190
191        let output_size = calculate_output_size(input_dims, Some(2), None);
192        assert_eq!(output_size, 2);
193
194        let output_size = calculate_output_size(input_dims, None, Some(2.0));
195        assert_eq!(output_size, 8);
196
197        let output_size = calculate_output_size(input_dims, None, Some(0.5));
198        assert_eq!(output_size, 2);
199
200        let output_size = calculate_output_size(input_dims, None, Some(1.5));
201        assert_eq!(output_size, 6);
202    }
203
204    #[test]
205    #[should_panic(expected = "Either output_size or scale_factor must be provided")]
206    fn test_panic() {
207        let input_dims = [1, 1, 4];
208        calculate_output_size(input_dims, None, None);
209    }
210
211    #[test]
212    #[should_panic(expected = "Scale factor is too large")]
213    fn test_large_scale_factor() {
214        let input_dims = [1, 1, usize::MAX - 1];
215        calculate_output_size(input_dims, None, Some(2.0));
216    }
217
218    #[test]
219    fn test_module() {
220        let input = Tensor::<TestBackend, 3>::random(
221            [2, 3, 4],
222            Distribution::Uniform(0.0, 1.0),
223            &Default::default(),
224        );
225
226        // Test with output_size
227        let config = Interpolate1dConfig::new().with_output_size(Some(8));
228        let interpolate = config.init();
229        let output = interpolate.forward(input.clone());
230        assert_eq!(output.dims(), [2, 3, 8]);
231
232        // Test with scale_factor
233        let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
234        let interpolate = config.init();
235        let output = interpolate.forward(input.clone());
236        assert_eq!(output.dims(), [2, 3, 2]);
237
238        // Test with different interpolation mode
239        let config = Interpolate1dConfig::new()
240            .with_output_size(Some(6))
241            .with_mode(InterpolateMode::Linear);
242        let interpolate = config.init();
243        let output = interpolate.forward(input);
244        assert_eq!(output.dims(), [2, 3, 6]);
245    }
246
247    #[test]
248    fn display() {
249        let config = Interpolate1dConfig::new().with_output_size(Some(20));
250        let layer = config.init();
251
252        assert_eq!(
253            alloc::format!("{layer}"),
254            "Interpolate1d {mode: Nearest, output_size: Some(20), \
255            scale_factor: None}"
256        );
257    }
258}