burn_jit/kernel/index/
flip.rs

1use crate::{
2    element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime,
3};
4use cubecl::{calculate_cube_count_elemwise, prelude::*};
5
6#[cube(launch_unchecked)]
7fn flip_kernel<E: CubePrimitive, Bool: Int>(
8    input: &Tensor<E>,
9    output: &mut Tensor<E>,
10    indices: Sequence<Bool>,
11    #[comptime] rank: u32,
12) {
13    if ABSOLUTE_POS >= output.len() {
14        return;
15    }
16
17    let mut offset_input = 0;
18
19    #[unroll]
20    for i in 0..rank {
21        let stride = input.stride(i);
22        let shape = output.shape(i);
23        let flip = *indices.index(i) == Bool::from_int(1);
24        let mut offset_local = ABSOLUTE_POS / stride % shape;
25
26        if flip {
27            offset_local = shape - offset_local - 1;
28        }
29
30        offset_input += offset_local * stride;
31    }
32
33    output[ABSOLUTE_POS] = input[offset_input];
34}
35
36pub(crate) fn flip<R: JitRuntime, E: JitElement, BT: BoolElement>(
37    tensor: JitTensor<R>,
38    indices: &[usize],
39) -> JitTensor<R> {
40    let output = empty_device::<R, E>(
41        tensor.client.clone(),
42        tensor.device.clone(),
43        tensor.shape.clone(),
44    );
45    flip_on_output::<R, E, BT>(tensor, output, indices)
46}
47
48pub(crate) fn flip_on_output<R: JitRuntime, E: JitElement, BT: BoolElement>(
49    tensor: JitTensor<R>,
50    output: JitTensor<R>,
51    indices: &[usize],
52) -> JitTensor<R> {
53    let ndims = tensor.shape.num_dims();
54    let mut indices_sequence = SequenceArg::<'_, R, BT>::new();
55
56    for i in 0..ndims {
57        indices_sequence.push(ScalarArg::new(BT::new_bool(indices.contains(&i))));
58    }
59
60    let cube_dim = CubeDim::default();
61    let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
62
63    unsafe {
64        flip_kernel::launch_unchecked::<E, BT, R>(
65            &tensor.client,
66            cube_count,
67            cube_dim,
68            tensor.as_tensor_arg::<E>(1),
69            output.as_tensor_arg::<E>(1),
70            indices_sequence,
71            ndims as u32,
72        );
73    }
74
75    output
76}