burn_core/nn/interpolate/
interpolate2d.rs1use alloc::format;
2
3use burn_tensor::module::interpolate;
4
5use crate as burn;
6
7use crate::config::Config;
8use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay};
9use crate::tensor::Tensor;
10use crate::tensor::backend::Backend;
11use crate::tensor::ops::InterpolateOptions;
12
13use super::InterpolateMode;
14
15#[derive(Config, Debug)]
20pub struct Interpolate2dConfig {
21 #[config(default = "None")]
24 pub output_size: Option<[usize; 2]>,
25
26 #[config(default = "None")]
29 pub scale_factor: Option<[f32; 2]>,
30
31 #[config(default = "InterpolateMode::Nearest")]
34 pub mode: InterpolateMode,
35}
36
37#[derive(Module, Clone, Debug)]
51#[module(custom_display)]
52pub struct Interpolate2d {
53 pub output_size: Option<[usize; 2]>,
55
56 pub scale_factor: Option<[f32; 2]>,
58
59 pub mode: Ignored<InterpolateMode>,
61}
62
63impl Interpolate2dConfig {
64 pub fn init(self) -> Interpolate2d {
66 Interpolate2d {
67 output_size: self.output_size,
68 scale_factor: self.scale_factor,
69 mode: Ignored(self.mode),
70 }
71 }
72}
73impl Interpolate2d {
74 pub fn forward<B: Backend>(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
96 let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor);
97 interpolate(
98 input,
99 output_size,
100 InterpolateOptions::new(self.mode.0.clone().into()),
101 )
102 }
103}
104
105fn calculate_output_size(
122 input_dims: [usize; 4],
123 output_size: Option<[usize; 2]>,
124 scale_factor: Option<[f32; 2]>,
125) -> [usize; 2] {
126 match (output_size, scale_factor) {
127 (Some(output_size), None) => {
128 output_size
130 }
131 (None, Some(scale_factor)) => {
132 let [_, _, h, w] = input_dims;
134
135 let new_dim_h = (h as f64) * (scale_factor[0] as f64);
136
137 if new_dim_h > usize::MAX as f64 {
138 panic!("Scale factor for height is too large");
139 }
140
141 let new_dim_w = (w as f64) * (scale_factor[1] as f64);
142
143 if new_dim_w > usize::MAX as f64 {
144 panic!("Scale factor for width is too large");
145 }
146
147 [new_dim_h as usize, new_dim_w as usize]
148 }
149 _ => panic!("Either output_size or scale_factor must be provided"),
150 }
151}
152
153impl ModuleDisplay for Interpolate2d {
154 fn custom_settings(&self) -> Option<DisplaySettings> {
155 DisplaySettings::new()
156 .with_new_line_after_attribute(false)
157 .optional()
158 }
159
160 fn custom_content(&self, content: Content) -> Option<Content> {
161 content
162 .add("mode", &self.mode)
163 .add("output_size", &format!("{:?}", self.output_size))
164 .add("scale_factor", &self.scale_factor)
165 .optional()
166 }
167}
168#[cfg(test)]
169mod tests {
170 use burn_tensor::Distribution;
171
172 use crate::TestBackend;
173
174 use super::*;
175
176 #[test]
177 fn test_calculate_output_size() {
178 let input_dims = [1, 1, 4, 4];
179
180 let output_size = calculate_output_size(input_dims, Some([2, 2]), None);
181 assert_eq!(output_size, [2, 2]);
182
183 let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0]));
184 assert_eq!(output_size, [8, 8]);
185
186 let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5]));
187 assert_eq!(output_size, [2, 2]);
188
189 let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5]));
190 assert_eq!(output_size, [8, 6]);
191 }
192
193 #[test]
194 #[should_panic(expected = "Either output_size or scale_factor must be provided")]
195 fn test_missing_params() {
196 calculate_output_size([1, 1, 4, 4], None, None);
197 }
198
199 #[test]
200 #[should_panic(expected = "Scale factor for height is too large")]
201 fn test_infinite_height() {
202 calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0]));
203 }
204
205 #[test]
206 #[should_panic(expected = "Scale factor for width is too large")]
207 fn test_infinite_width() {
208 calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0]));
209 }
210
211 #[test]
212 fn test_module() {
213 let input = Tensor::<TestBackend, 4>::random(
214 [2, 3, 4, 4],
215 Distribution::Uniform(0.0, 1.0),
216 &Default::default(),
217 );
218
219 let config = Interpolate2dConfig::new().with_output_size(Some([8, 8]));
221 let interpolate = config.init();
222 let output = interpolate.forward(input.clone());
223 assert_eq!(output.dims(), [2, 3, 8, 8]);
224
225 let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5]));
227 let interpolate = config.init();
228 let output = interpolate.forward(input.clone());
229 assert_eq!(output.dims(), [2, 3, 2, 2]);
230
231 let config = Interpolate2dConfig::new()
233 .with_output_size(Some([6, 6]))
234 .with_mode(InterpolateMode::Linear);
235 let interpolate = config.init();
236 let output = interpolate.forward(input);
237 assert_eq!(output.dims(), [2, 3, 6, 6]);
238 }
239
240 #[test]
241 fn display() {
242 let config = Interpolate2dConfig::new().with_output_size(Some([20, 20]));
243 let layer = config.init();
244
245 assert_eq!(
246 alloc::format!("{layer}"),
247 "Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \
248 scale_factor: None}"
249 );
250 }
251}