burn_nn/modules/interpolate/
interpolate2d.rs1use alloc::format;
2
3use burn::tensor::module::interpolate;
4
5use burn_core as burn;
6
7use burn::config::Config;
8use burn::module::{Content, DisplaySettings, Module, ModuleDisplay};
9use burn::tensor::Tensor;
10use burn::tensor::backend::Backend;
11use burn::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15#[derive(Config, Debug)]
20pub struct Interpolate2dConfig {
21 #[config(default = "None")]
24 pub output_size: Option<[usize; 2]>,
25
26 #[config(default = "None")]
29 pub scale_factor: Option<[f32; 2]>,
30
31 #[config(default = "InterpolateMode::Nearest")]
34 pub mode: InterpolateMode,
35
36 #[config(default = true)]
39 pub align_corners: bool,
40}
41
42#[derive(Module, Clone, Debug)]
56#[module(custom_display)]
57pub struct Interpolate2d {
58 pub output_size: Option<[usize; 2]>,
60
61 pub scale_factor: Option<[f32; 2]>,
63
64 pub mode: InterpolateMode,
66
67 pub align_corners: bool,
69}
70
71impl Interpolate2dConfig {
72 pub fn init(self) -> Interpolate2d {
74 Interpolate2d {
75 output_size: self.output_size,
76 scale_factor: self.scale_factor,
77 mode: self.mode,
78 align_corners: self.align_corners,
79 }
80 }
81}
82impl Interpolate2d {
83 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
105 let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
106 interpolate(
107 input,
108 output_size,
109 InterpolateOptions::new(self.mode.clone().into())
110 .with_align_corners(self.align_corners),
111 )
112 }
113}
114
115fn calculate_output_size(
132 input_dims: [usize; 4],
133 output_size: Option<[usize; 2]>,
134 scale_factor: Option<[f32; 2]>,
135) -> [usize; 2] {
136 match (output_size, scale_factor) {
137 (Some(output_size), None) => {
138 output_size
140 }
141 (None, Some(scale_factor)) => {
142 let [_, _, h, w] = input_dims;
144
145 let new_dim_h = (h as f64) * (scale_factor[0] as f64);
146
147 if new_dim_h > usize::MAX as f64 {
148 panic!("Scale factor for height is too large");
149 }
150
151 let new_dim_w = (w as f64) * (scale_factor[1] as f64);
152
153 if new_dim_w > usize::MAX as f64 {
154 panic!("Scale factor for width is too large");
155 }
156
157 [new_dim_h as usize, new_dim_w as usize]
158 }
159 _ => panic!("Either output_size or scale_factor must be provided"),
160 }
161}
162
163impl ModuleDisplay for Interpolate2d {
164 fn custom_settings(&self) -> Option<DisplaySettings> {
165 DisplaySettings::new()
166 .with_new_line_after_attribute(false)
167 .optional()
168 }
169
170 fn custom_content(&self, content: Content) -> Option<Content> {
171 content
172 .add_debug_attribute("mode", &self.mode)
173 .add("output_size", &format!("{:?}", self.output_size))
174 .add("scale_factor", &self.scale_factor)
175 .optional()
176 }
177}
178#[cfg(test)]
179mod tests {
180 use burn::tensor::Distribution;
181
182 use crate::TestBackend;
183
184 use super::*;
185
186 #[test]
187 fn test_calculate_output_size() {
188 let input_dims = [1, 1, 4, 4];
189
190 let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
191 assert_eq!(output_size, [2, 2]);
192
193 let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
194 assert_eq!(output_size, [8, 8]);
195
196 let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
197 assert_eq!(output_size, [2, 2]);
198
199 let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
200 assert_eq!(output_size, [8, 6]);
201 }
202
203 #[test]
204 #[should_panic(expected = "Either output_size or scale_factor must be provided")]
205 fn test_missing_params() {
206 calculate_output_size([1, 1, 4, 4], None, None);
207 }
208
209 #[test]
210 #[should_panic(expected = "Scale factor for height is too large")]
211 fn test_infinite_height() {
212 calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
213 }
214
215 #[test]
216 #[should_panic(expected = "Scale factor for width is too large")]
217 fn test_infinite_width() {
218 calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
219 }
220
221 #[test]
222 fn test_module() {
223 let input = Tensor::<TestBackend, 4>::random(
224 [2, 3, 4, 4],
225 Distribution::Uniform(0.0, 1.0),
226 &Default::default(),
227 );
228
229 let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
231 let interpolate = config.init();
232 let output = interpolate.forward(input.clone());
233 assert_eq!(output.dims(), [2, 3, 8, 8]);
234
235 let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
237 let interpolate = config.init();
238 let output = interpolate.forward(input.clone());
239 assert_eq!(output.dims(), [2, 3, 2, 2]);
240
241 let config = Interpolate2dConfig::new()
243 .with_output_size(Some([6, 6]))
244 .with_mode(InterpolateMode::Linear);
245 let interpolate = config.init();
246 let output = interpolate.forward(input);
247 assert_eq!(output.dims(), [2, 3, 6, 6]);
248 }
249
250 #[test]
251 fn display() {
252 let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
253 let layer = config.init();
254
255 assert_eq!(
256 alloc::format!("{layer}"),
257 "Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
258 scale_factor: None}"
259 );
260 }
261}