burn_core/nn/interpolate/
interpolate1d.rs

1use alloc::format;
2
3use burn_tensor::module::interpolate;
4
5use crate as burn;
6
7use crate::config::Config;
8use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11use crate::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15/// Configuration for the 1D interpolation module.
16///
17/// This struct defines the configuration options for the 1D interpolation operation.
18/// It allows specifying the output size, scale factor, and interpolation mode.
19#[derive(Config, Debug)]
20pub struct Interpolate1dConfig {
21    /// Output size of the interpolated tensor.
22    /// If specified, this takes precedence over `scale_factor`.
23    #[config(default = "None")]
24    pub output_size: Option<usize>,
25
26    /// Scale factor for resizing the input tensor.
27    /// This is used when `output_size` is not specified.
28    #[config(default = "None")]
29    pub scale_factor: Option<f32>,
30
31    /// Interpolation mode to use for resizing.
32    /// Determines how the output values are calculated.
33    #[config(default = "InterpolateMode::Nearest")]
34    pub mode: InterpolateMode,
35}
36
37/// Interpolate module for resizing 1D tensors with shape [N, C, L].
38///
39/// This struct represents a 1D interpolation module that can resize tensors
40/// using various interpolation methods. It provides flexibility in specifying
41/// either an output size or a scale factor for resizing, along with options
42/// for the interpolation mode.
43///
44/// The module can be used to upsample or downsample 1D tensors, preserving the
45/// number of channels and batch size while adjusting the length dimension.
46///
47/// The module can be created using the [Interpolate1dConfig] struct and the
48/// `init` method, which returns an instance of the [Interpolate1d] struct.
49#[derive(Module, Clone, Debug)]
50#[module(custom_display)]
51pub struct Interpolate1d {
52    /// Output size of the interpolated tensor
53    pub output_size: Option<usize>,
54
55    /// Scale factor for resizing the input tensor
56    pub scale_factor: Option<f32>,
57
58    /// Interpolation mode used for resizing
59    pub mode: Ignored<InterpolateMode>,
60}
61
62impl Interpolate1dConfig {
63    /// Initialize the interpolation module
64    pub fn init(self) -> Interpolate1d {
65        Interpolate1d {
66            output_size: self.output_size,
67            scale_factor: self.scale_factor,
68            mode: Ignored(self.mode),
69        }
70    }
71}
72
73impl Interpolate1d {
74    /// Performs the forward pass of the 1D interpolation module
75    ///
76    /// # Arguments
77    ///
78    /// * `input` - Input tensor with shape [N, C, L]
79    ///
80    /// # Returns
81    ///
82    /// Resized tensor with shape [N, C, L'], where L' is determined by
83    /// the output_size or scale_factor specified in the module configuration
84    ///
85    /// # Example
86    ///
87    /// ```ignore
88    /// let input = Tensor::<Backend, 3>::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device);
89    /// let interpolate = Interpolate1dConfig::new()
90    ///     .with_output_size(Some(128))
91    ///     .init();
92    /// let output = interpolate.forward(input);
93    /// assert_eq!(output.dims(), [1, 3, 128]);
94    /// ```
95    pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
96        let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
97
98        // Use the interpolate operation to resize the temporal input tensor
99        // by adding a new dimension for the interpolation axis
100        let input = input.unsqueeze_dim(2);
101
102        let result = interpolate(
103            input,
104            [1, output_size],
105            InterpolateOptions::new(self.mode.0.clone().into()),
106        );
107
108        result.squeeze_dims(&[2])
109    }
110}
111
112/// Calculate output size based on input dimensions, output size, and scale factor
113///
114/// # Arguments
115///
116/// * `input_dims` - Input dimensions of the tensor
117/// * `output_size` - Output size for the interpolated tensor
118/// * `scale_factor` - Scale factor for resizing the tensor
119///
120/// # Returns
121///
122/// Output size for the interpolated tensor
123///
124/// # Panics
125///
126/// Panics if neither output_size nor scale_factor is provided
127/// or if the scale factor is too large
128fn calculate_output_size(
129    input_dims: [usize; 3],
130    output_size: Option<usize>,
131    scale_factor: Option<f32>,
132) -> usize {
133    match (output_size, scale_factor) {
134        (Some(output_size), None) => {
135            // Use provided
136            output_size
137        }
138        (None, Some(scale_factor)) => {
139            // Calculate output size based on scale factor
140            let [_, _, l] = input_dims;
141
142            let new_dim = (l as f64) * (scale_factor as f64);
143
144            if new_dim > usize::MAX as f64 {
145                panic!("Scale factor is too large");
146            }
147
148            new_dim as usize
149        }
150        _ => panic!("Either output_size or scale_factor must be provided"),
151    }
152}
153
154impl ModuleDisplay for Interpolate1d {
155    fn custom_settings(&self) -> Option<DisplaySettings> {
156        DisplaySettings::new()
157            .with_new_line_after_attribute(false)
158            .optional()
159    }
160
161    fn custom_content(&self, content: Content) -> Option<Content> {
162        content
163            .add("mode", &self.mode)
164            .add("output_size", &format!("{:?}", self.output_size))
165            .add("scale_factor", &self.scale_factor)
166            .optional()
167    }
168}
169
170#[cfg(test)]
171mod tests {
172
173    use burn_tensor::Distribution;
174
175    use super::*;
176    use crate::TestBackend;
177    #[test]
178    fn test_calculate_output_size() {
179        let input_dims = [1, 1, 4];
180
181        let output_size = calculate_output_size(input_dims, Some(2), None);
182        assert_eq!(output_size, 2);
183
184        let output_size = calculate_output_size(input_dims, None, Some(2.0));
185        assert_eq!(output_size, 8);
186
187        let output_size = calculate_output_size(input_dims, None, Some(0.5));
188        assert_eq!(output_size, 2);
189
190        let output_size = calculate_output_size(input_dims, None, Some(1.5));
191        assert_eq!(output_size, 6);
192    }
193
194    #[test]
195    #[should_panic(expected = "Either output_size or scale_factor must be provided")]
196    fn test_panic() {
197        let input_dims = [1, 1, 4];
198        calculate_output_size(input_dims, None, None);
199    }
200
201    #[test]
202    #[should_panic(expected = "Scale factor is too large")]
203    fn test_large_scale_factor() {
204        let input_dims = [1, 1, usize::MAX - 1];
205        calculate_output_size(input_dims, None, Some(2.0));
206    }
207
208    #[test]
209    fn test_module() {
210        let input = Tensor::<TestBackend, 3>::random(
211            [2, 3, 4],
212            Distribution::Uniform(0.0, 1.0),
213            &Default::default(),
214        );
215
216        // Test with output_size
217        let config = Interpolate1dConfig::new().with_output_size(Some(8));
218        let interpolate = config.init();
219        let output = interpolate.forward(input.clone());
220        assert_eq!(output.dims(), [2, 3, 8]);
221
222        // Test with scale_factor
223        let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
224        let interpolate = config.init();
225        let output = interpolate.forward(input.clone());
226        assert_eq!(output.dims(), [2, 3, 2]);
227
228        // Test with different interpolation mode
229        let config = Interpolate1dConfig::new()
230            .with_output_size(Some(6))
231            .with_mode(InterpolateMode::Linear);
232        let interpolate = config.init();
233        let output = interpolate.forward(input);
234        assert_eq!(output.dims(), [2, 3, 6]);
235    }
236
237    #[test]
238    fn display() {
239        let config = Interpolate1dConfig::new().with_output_size(Some(20));
240        let layer = config.init();
241
242        assert_eq!(
243            alloc::format!("{layer}"),
244            "Interpolate1d {mode: Nearest, output_size: Some(20), \
245            scale_factor: None}"
246        );
247    }
248}