burn_nn/modules/interpolate/
interpolate1d.rs1use alloc::format;
2
3use burn::tensor::module::interpolate;
4
5use burn_core as burn;
6
7use burn::config::Config;
8use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11use burn::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15#[derive(Config, Debug)]
20pub struct Interpolate1dConfig {
21 #[config(default = "None")]
24 pub output_size: Option<usize>,
25
26 #[config(default = "None")]
29 pub scale_factor: Option<f32>,
30
31 #[config(default = "InterpolateMode::Nearest")]
34 pub mode: InterpolateMode,
35
36 #[config(default = true)]
39 pub align_corners: bool,
40}
41
42#[derive(Module, Clone, Debug)]
55#[module(custom_display)]
56pub struct Interpolate1d {
57 pub output_size: Option<usize>,
59
60 pub scale_factor: Option<f32>,
62
63 pub mode: InterpolateMode,
65
66 pub align_corners: bool,
68}
69
70impl Interpolate1dConfig {
71 pub fn init(self) -> Interpolate1d {
73 Interpolate1d {
74 output_size: self.output_size,
75 scale_factor: self.scale_factor,
76 mode: self.mode,
77 align_corners: self.align_corners,
78 }
79 }
80}
81
82impl Interpolate1d {
83 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
105 let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
106
107 let input = input.unsqueeze_dim(2);
110
111 let result = interpolate(
112 input,
113 [1, output_size],
114 InterpolateOptions::new(self.mode.clone().into())
115 .with_align_corners(self.align_corners),
116 );
117
118 result.squeeze_dims(&[2])
119 }
120}
121
122fn calculate_output_size(
139 input_dims: [usize; 3],
140 output_size: Option<usize>,
141 scale_factor: Option<f32>,
142) -> usize {
143 match (output_size, scale_factor) {
144 (Some(output_size), None) => {
145 output_size
147 }
148 (None, Some(scale_factor)) => {
149 let [_, _, l] = input_dims;
151
152 let new_dim = (l as f64) * (scale_factor as f64);
153
154 if new_dim > usize::MAX as f64 {
155 panic!("Scale factor is too large");
156 }
157
158 new_dim as usize
159 }
160 _ => panic!("Either output_size or scale_factor must be provided"),
161 }
162}
163
164impl ModuleDisplay for Interpolate1d {
165 fn custom_settings(&self) -> Option<DisplaySettings> {
166 DisplaySettings::new()
167 .with_new_line_after_attribute(false)
168 .optional()
169 }
170
171 fn custom_content(&self, content: Content) -> Option<Content> {
172 content
173 .add_debug_attribute("mode", &self.mode)
174 .add("output_size", &format!("{:?}", self.output_size))
175 .add("scale_factor", &self.scale_factor)
176 .optional()
177 }
178}
179
180#[cfg(test)]
181mod tests {
182
183 use burn::tensor::Distribution;
184
185 use super::*;
186 use crate::TestBackend;
187 #[test]
188 fn test_calculate_output_size() {
189 let input_dims = [1, 1, 4];
190
191 let output_size = calculate_output_size(input_dims, Some(2), None);
192 assert_eq!(output_size, 2);
193
194 let output_size = calculate_output_size(input_dims, None, Some(2.0));
195 assert_eq!(output_size, 8);
196
197 let output_size = calculate_output_size(input_dims, None, Some(0.5));
198 assert_eq!(output_size, 2);
199
200 let output_size = calculate_output_size(input_dims, None, Some(1.5));
201 assert_eq!(output_size, 6);
202 }
203
204 #[test]
205 #[should_panic(expected = "Either output_size or scale_factor must be provided")]
206 fn test_panic() {
207 let input_dims = [1, 1, 4];
208 calculate_output_size(input_dims, None, None);
209 }
210
211 #[test]
212 #[should_panic(expected = "Scale factor is too large")]
213 fn test_large_scale_factor() {
214 let input_dims = [1, 1, usize::MAX - 1];
215 calculate_output_size(input_dims, None, Some(2.0));
216 }
217
218 #[test]
219 fn test_module() {
220 let input = Tensor::<TestBackend, 3>::random(
221 [2, 3, 4],
222 Distribution::Uniform(0.0, 1.0),
223 &Default::default(),
224 );
225
226 let config = Interpolate1dConfig::new().with_output_size(Some(8));
228 let interpolate = config.init();
229 let output = interpolate.forward(input.clone());
230 assert_eq!(output.dims(), [2, 3, 8]);
231
232 let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
234 let interpolate = config.init();
235 let output = interpolate.forward(input.clone());
236 assert_eq!(output.dims(), [2, 3, 2]);
237
238 let config = Interpolate1dConfig::new()
240 .with_output_size(Some(6))
241 .with_mode(InterpolateMode::Linear);
242 let interpolate = config.init();
243 let output = interpolate.forward(input);
244 assert_eq!(output.dims(), [2, 3, 6]);
245 }
246
247 #[test]
248 fn display() {
249 let config = Interpolate1dConfig::new().with_output_size(Some(20));
250 let layer = config.init();
251
252 assert_eq!(
253 alloc::format!("{layer}"),
254 "Interpolate1d {mode: Nearest, output_size: Some(20), \
255 scale_factor: None}"
256 );
257 }
258}