burn_cubecl/kernel/interpolate/
base.rs1use crate::{
2 CubeRuntime,
3 kernel::into_contiguous,
4 ops::{numeric::empty_device_dtype, permute_nchw_to_nhwc, permute_nhwc_to_nchw},
5 tensor::CubeTensor,
6};
7use burn_backend::{
8 Shape, TensorMetadata,
9 ops::{InterpolateMode, InterpolateOptions},
10};
11
12use super::{
13 bicubic::interpolate_bicubic_launch, bilinear::interpolate_bilinear_launch,
14 lanczos3::interpolate_lanczos3_launch, nearest::interpolate_nearest_launch,
15 nearest_backward::interpolate_nearest_backward_launch,
16};
17
18pub fn interpolate<R: CubeRuntime>(
22 input: CubeTensor<R>,
23 output_size: [usize; 2],
24 options: InterpolateOptions,
25) -> CubeTensor<R> {
26 let [batch_size, channels, _, _] = input.meta.shape().dims();
27 let [out_height, out_width] = output_size;
28
29 let input = into_contiguous(permute_nchw_to_nhwc(input));
30
31 let shape_out = Shape::new([batch_size, out_height, out_width, channels]);
32 let output = empty_device_dtype(
33 input.client.clone(),
34 input.device.clone(),
35 shape_out,
36 input.dtype,
37 );
38
39 let align_corners = options.align_corners;
40 let output = match options.mode {
41 InterpolateMode::Nearest => interpolate_nearest_launch(input, output),
42 InterpolateMode::Bilinear => interpolate_bilinear_launch(input, output, align_corners),
43 InterpolateMode::Bicubic => interpolate_bicubic_launch(input, output, align_corners),
44 InterpolateMode::Lanczos3 => interpolate_lanczos3_launch(input, output, align_corners),
45 };
46
47 permute_nhwc_to_nchw(output)
48}
49
50pub fn interpolate_backward<R: CubeRuntime>(
54 input: CubeTensor<R>,
55 out_grad: CubeTensor<R>,
56 _output_size: [usize; 2],
57 options: InterpolateOptions,
58) -> CubeTensor<R> {
59 let input = permute_nchw_to_nhwc(input);
60 let out_grad = permute_nchw_to_nhwc(out_grad);
61
62 let output_shape = input.shape();
63 let output = empty_device_dtype(
64 input.client.clone(),
65 input.device.clone(),
66 output_shape,
67 input.dtype,
68 );
69
70 let output = match options.mode {
71 InterpolateMode::Nearest => interpolate_nearest_backward_launch(out_grad, output),
72 InterpolateMode::Bilinear => {
73 panic!("bilinear interpolation backward is not supported by JIT backend")
74 }
75 InterpolateMode::Bicubic => {
76 panic!("bicubic interpolation backward is not supported by JIT backend")
77 }
78 InterpolateMode::Lanczos3 => {
79 panic!("lanczos3 interpolation backward is not supported by JIT backend")
80 }
81 };
82
83 permute_nhwc_to_nchw(output)
84}