Skip to main content

burn_cubecl/kernel/interpolate/
base.rs

1use crate::{
2    CubeRuntime,
3    kernel::into_contiguous,
4    ops::{numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
5    tensor::CubeTensor,
6};
7use burn_backend::{
8    Shape, TensorMetadata,
9    ops::{InterpolateMode, InterpolateOptions},
10};
11
12use super::{
13    bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch,
14    lanczos3::interpolate_lanczos3_launch, nearest::interpolate_nearest_launch,
15    nearest_backward::interpolate_nearest_backward_launch,
16};
17
18/// Interpolate operation
19///
20/// Supports nearest, bilinear, bicubic and lanczos3 modes
21pub fn interpolate<R: CubeRuntime>(
22    input: CubeTensor<R>,
23    output_size: [usize; 2],
24    options: InterpolateOptions,
25) -> CubeTensor<R> {
26    let [batch_size, channels, _, _] = input.meta.shape().dims();
27    let [out_height, out_width] = output_size;
28
29    let input = into_contiguous(permute_nchw_to_nhwc(input));
30
31    let shape_out = Shape::new([batch_size, out_height, out_width, channels]);
32    let output = empty_device_dtype(
33        input.client.clone(),
34        input.device.clone(),
35        shape_out,
36        input.dtype,
37    );
38
39    let align_corners = options.align_corners;
40    let output = match options.mode {
41        InterpolateMode::Nearest => interpolate_nearest_launch(input, output),
42        InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output, align_corners),
43        InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output, align_corners),
44        InterpolateMode::Lanczos3 => interpolate_lanczos3_launch(input, output, align_corners),
45    };
46
47    permute_nhwc_to_nchw(output)
48}
49
50/// Backward interpolate operation
51///
52/// Note: only nearest mode is supported
53pub fn interpolate_backward<R: CubeRuntime>(
54    input: CubeTensor<R>,
55    out_grad: CubeTensor<R>,
56    _output_size: [usize; 2],
57    options: InterpolateOptions,
58) -> CubeTensor<R> {
59    let input = permute_nchw_to_nhwc(input);
60    let out_grad = permute_nchw_to_nhwc(out_grad);
61
62    let output_shape = input.shape();
63    let output = empty_device_dtype(
64        input.client.clone(),
65        input.device.clone(),
66        output_shape,
67        input.dtype,
68    );
69
70    let output = match options.mode {
71        InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output),
72        InterpolateMode::Bilinear => {
73            panic!("bilinear interpolation backward is not supported by JIT backend")
74        }
75        InterpolateMode::Bicubic => {
76            panic!("bicubic interpolation backward is not supported by JIT backend")
77        }
78        InterpolateMode::Lanczos3 => {
79            panic!("lanczos3 interpolation backward is not supported by JIT backend")
80        }
81    };
82
83    permute_nhwc_to_nchw(output)
84}