burn_core/nn/interpolate/
interpolate1d.rs1use alloc::format;
2
3use burn_tensor::module::interpolate;
4
5use crate as burn;
6
7use crate::config::Config;
8use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11use crate::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15#[derive(Config, Debug)]
20pub struct Interpolate1dConfig {
21 #[config(default = "None")]
24 pub output_size: Option<usize>,
25
26 #[config(default = "None")]
29 pub scale_factor: Option<f32>,
30
31 #[config(default = "InterpolateMode::Nearest")]
34 pub mode: InterpolateMode,
35}
36
37#[derive(Module, Clone, Debug)]
50#[module(custom_display)]
51pub struct Interpolate1d {
52 pub output_size: Option<usize>,
54
55 pub scale_factor: Option<f32>,
57
58 pub mode: Ignored<InterpolateMode>,
60}
61
62impl Interpolate1dConfig {
63 pub fn init(self) -> Interpolate1d {
65 Interpolate1d {
66 output_size: self.output_size,
67 scale_factor: self.scale_factor,
68 mode: Ignored(self.mode),
69 }
70 }
71}
72
73impl Interpolate1d {
74 pub fn forward<B: Backend>(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
96 let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
97
98 let input = input.unsqueeze_dim(2);
101
102 let result = interpolate(
103 input,
104 [1, output_size],
105 InterpolateOptions::new(self.mode.0.clone().into()),
106 );
107
108 result.squeeze_dims(&[2])
109 }
110}
111
112fn calculate_output_size(
129 input_dims: [usize; 3],
130 output_size: Option<usize>,
131 scale_factor: Option<f32>,
132) -> usize {
133 match (output_size, scale_factor) {
134 (Some(output_size), None) => {
135 output_size
137 }
138 (None, Some(scale_factor)) => {
139 let [_, _, l] = input_dims;
141
142 let new_dim = (l as f64) * (scale_factor as f64);
143
144 if new_dim > usize::MAX as f64 {
145 panic!("Scale factor is too large");
146 }
147
148 new_dim as usize
149 }
150 _ => panic!("Either output_size or scale_factor must be provided"),
151 }
152}
153
154impl ModuleDisplay for Interpolate1d {
155 fn custom_settings(&self) -> Option<DisplaySettings> {
156 DisplaySettings::new()
157 .with_new_line_after_attribute(false)
158 .optional()
159 }
160
161 fn custom_content(&self, content: Content) -> Option<Content> {
162 content
163 .add("mode", &self.mode)
164 .add("output_size", &format!("{:?}", self.output_size))
165 .add("scale_factor", &self.scale_factor)
166 .optional()
167 }
168}
169
170#[cfg(test)]
171mod tests {
172
173 use burn_tensor::Distribution;
174
175 use super::*;
176 use crate::TestBackend;
177 #[test]
178 fn test_calculate_output_size() {
179 let input_dims = [1, 1, 4];
180
181 let output_size = calculate_output_size(input_dims, Some(2), None);
182 assert_eq!(output_size, 2);
183
184 let output_size = calculate_output_size(input_dims, None, Some(2.0));
185 assert_eq!(output_size, 8);
186
187 let output_size = calculate_output_size(input_dims, None, Some(0.5));
188 assert_eq!(output_size, 2);
189
190 let output_size = calculate_output_size(input_dims, None, Some(1.5));
191 assert_eq!(output_size, 6);
192 }
193
194 #[test]
195 #[should_panic(expected = "Either output_size or scale_factor must be provided")]
196 fn test_panic() {
197 let input_dims = [1, 1, 4];
198 calculate_output_size(input_dims, None, None);
199 }
200
201 #[test]
202 #[should_panic(expected = "Scale factor is too large")]
203 fn test_large_scale_factor() {
204 let input_dims = [1, 1, usize::MAX - 1];
205 calculate_output_size(input_dims, None, Some(2.0));
206 }
207
208 #[test]
209 fn test_module() {
210 let input = Tensor::<TestBackend, 3>::random(
211 [2, 3, 4],
212 Distribution::Uniform(0.0, 1.0),
213 &Default::default(),
214 );
215
216 let config = Interpolate1dConfig::new().with_output_size(Some(8));
218 let interpolate = config.init();
219 let output = interpolate.forward(input.clone());
220 assert_eq!(output.dims(), [2, 3, 8]);
221
222 let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5));
224 let interpolate = config.init();
225 let output = interpolate.forward(input.clone());
226 assert_eq!(output.dims(), [2, 3, 2]);
227
228 let config = Interpolate1dConfig::new()
230 .with_output_size(Some(6))
231 .with_mode(InterpolateMode::Linear);
232 let interpolate = config.init();
233 let output = interpolate.forward(input);
234 assert_eq!(output.dims(), [2, 3, 6]);
235 }
236
237 #[test]
238 fn display() {
239 let config = Interpolate1dConfig::new().with_output_size(Some(20));
240 let layer = config.init();
241
242 assert_eq!(
243 alloc::format!("{layer}"),
244 "Interpolate1d {mode: Nearest, output_size: Some(20), \
245 scale_factor: None}"
246 );
247 }
248}