burn_jit/kernel/index/
flip.rs1use crate::{
2 element::JitElement, ops::numeric::empty_device, tensor::JitTensor, BoolElement, JitRuntime,
3};
4use cubecl::{calculate_cube_count_elemwise, prelude::*};
5
6#[cube(launch_unchecked)]
7fn flip_kernel<E: CubePrimitive, Bool: Int>(
8 input: &Tensor<E>,
9 output: &mut Tensor<E>,
10 indices: Sequence<Bool>,
11 #[comptime] rank: u32,
12) {
13 if ABSOLUTE_POS >= output.len() {
14 return;
15 }
16
17 let mut offset_input = 0;
18
19 #[unroll]
20 for i in 0..rank {
21 let stride = input.stride(i);
22 let shape = output.shape(i);
23 let flip = *indices.index(i) == Bool::from_int(1);
24 let mut offset_local = ABSOLUTE_POS / stride % shape;
25
26 if flip {
27 offset_local = shape - offset_local - 1;
28 }
29
30 offset_input += offset_local * stride;
31 }
32
33 output[ABSOLUTE_POS] = input[offset_input];
34}
35
36pub(crate) fn flip<R: JitRuntime, E: JitElement, BT: BoolElement>(
37 tensor: JitTensor<R>,
38 indices: &[usize],
39) -> JitTensor<R> {
40 let output = empty_device::<R, E>(
41 tensor.client.clone(),
42 tensor.device.clone(),
43 tensor.shape.clone(),
44 );
45 flip_on_output::<R, E, BT>(tensor, output, indices)
46}
47
48pub(crate) fn flip_on_output<R: JitRuntime, E: JitElement, BT: BoolElement>(
49 tensor: JitTensor<R>,
50 output: JitTensor<R>,
51 indices: &[usize],
52) -> JitTensor<R> {
53 let ndims = tensor.shape.num_dims();
54 let mut indices_sequence = SequenceArg::<'_, R, BT>::new();
55
56 for i in 0..ndims {
57 indices_sequence.push(ScalarArg::new(BT::new_bool(indices.contains(&i))));
58 }
59
60 let cube_dim = CubeDim::default();
61 let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
62
63 unsafe {
64 flip_kernel::launch_unchecked::<E, BT, R>(
65 &tensor.client,
66 cube_count,
67 cube_dim,
68 tensor.as_tensor_arg::<E>(1),
69 output.as_tensor_arg::<E>(1),
70 indices_sequence,
71 ndims as u32,
72 );
73 }
74
75 output
76}