cubecl_std/tensor/contiguous/
perpendicular.rs

1use crate::tensor::{TensorHandle, into_contiguous_ref};
2use cubecl::prelude::*;
3use cubecl_core::{
4    self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel,
5    tensor_line_size_perpendicular,
6};
7use std::cmp::min;
8
9/// Kernel for converting a non-contiguous tensor into a contiguous one when
10/// the vectorization axis is perpendicular to the last dimension.
11///
12/// This kernel handles the case where memory is laid out such that the unit-stride
13/// is not on the last dimension, requiring a "gather-and-transpose" pattern
14/// to write out contiguous lines.
15#[cube(launch_unchecked)]
16fn into_contiguous_perpendicular<N: Numeric>(
17    input: &Tensor<Line<N>>,
18    output: &mut Tensor<Line<N>>,
19    axis_vectorized: usize,
20    #[define(N)] _elem: StorageType,
21) {
22    let line_size = input.line_size();
23    let last_axis = input.rank() - 1;
24
25    // Calculate how many vectorized lines fit into the last dimension's shape.
26    let num_batch = output.shape(last_axis) / line_size;
27
28    // Local registers to perform a small in-register transpose.
29    let mut accumulators = Sequence::<Line<N>>::new();
30
31    #[unroll]
32    for _ in 0..line_size {
33        accumulators.push(Line::empty(line_size));
34    }
35
36    let channel_input_stride_elem = input.stride(last_axis);
37    let channel_output_stride_elem = output.stride(axis_vectorized);
38
39    // Strides adjusted for vectorization (line_size).
40    let channel_input_stride = channel_input_stride_elem / line_size;
41    let channel_output_stride = channel_output_stride_elem / line_size;
42
43    // Total parallel units needed to cover the output space.
44    let num_runs = output.len() / (num_batch * line_size);
45
46    if ABSOLUTE_POS >= num_runs {
47        terminate!()
48    }
49
50    // Mapping the global worker ID to the specific tensor coordinates.
51    let batch_index = ABSOLUTE_POS * num_batch;
52    let skip_interval = batch_index / channel_output_stride;
53    let skip_index = batch_index % channel_output_stride;
54    let skip_size = channel_output_stride_elem;
55    let global_index = (skip_interval * skip_size) + skip_index;
56
57    for b in 0..num_batch {
58        let offset_output = global_index + b;
59
60        // Calculate the physical offset in the input tensor for the current output coordinate.
61        let mut batch_offset = 0;
62        for axis in 0..input.rank() {
63            let coordinate = output.coordinate(offset_output * line_size, axis);
64            batch_offset += coordinate * input.stride(axis);
65        }
66        let batch_offset = batch_offset / line_size;
67
68        // --- STEP 1: GATHER ---
69        // Load data from the input tensor. Since the data is "perpendicular",
70        // we read across the stride-1 axis to fill the accumulators.
71        for i in 0..line_size {
72            let index = batch_offset + i * channel_input_stride;
73            let batched = input[index];
74
75            // --- STEP 2: TRANSPOSE ---
76            // Rearrange the loaded vector components into the accumulators.
77            #[unroll]
78            for o in 0..line_size {
79                let line = accumulators.index_mut(o);
80                line[i] = batched[o];
81            }
82        }
83
84        // --- STEP 3: STORE ---
85        // Write the transposed lines to the output in a contiguous fashion.
86        #[unroll]
87        for o in 0..line_size {
88            let index_out = offset_output + o * channel_output_stride;
89            let batched = accumulators[o];
90
91            output[index_out] = batched;
92        }
93    }
94}
95
96/// Launches the perpendicular contiguous kernel.
97///
98/// This is used when the input tensor's memory layout is such that the last dimension
99/// is not the one with a stride of 1 (the vectorized dimension). It optimizes
100/// the copy by using hardware vectorization (Lines) and an in-register transpose.
101pub fn launch_into_contiguous_perpendicular<R: Runtime>(
102    client: &ComputeClient<R>,
103    input: &TensorHandleRef<'_, R>,
104    dtype: StorageType,
105) -> Result<TensorHandle<R>, LaunchError> {
106    // Fallback for 1D tensors where perpendicularity doesn't apply.
107    if input.shape.len() <= 1 {
108        return into_contiguous_ref(client, input, dtype);
109    }
110
111    let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
112    launch_into_contiguous_perpendicular_ref(client, input, &output.as_ref(), dtype)?;
113
114    Ok(output)
115}
116
117/// Launches the perpendicular contiguous kernel.
118///
119/// This is used when the input tensor's memory layout is such that the last dimension
120/// is not the one with a stride of 1 (the vectorized dimension). It optimizes
121/// the copy by using hardware vectorization (Lines) and an in-register transpose.
122pub fn launch_into_contiguous_perpendicular_ref<R: Runtime>(
123    client: &ComputeClient<R>,
124    input: &TensorHandleRef<'_, R>,
125    output: &TensorHandleRef<'_, R>,
126    dtype: StorageType,
127) -> Result<(), LaunchError> {
128    let mut axis = 0;
129
130    for (i, stride) in input.strides.iter().enumerate() {
131        if *stride == 1 {
132            axis = i;
133            break;
134        }
135    }
136    let rank = output.shape.len();
137
138    let line_size_perpendicular = tensor_line_size_perpendicular(
139        client.io_optimized_line_sizes(&dtype),
140        input.shape,
141        input.strides,
142        rank - 1,
143    );
144    let line_size_parallel = tensor_line_size_parallel(
145        client.io_optimized_line_sizes(&dtype),
146        output.shape,
147        output.strides,
148        rank - 1,
149    );
150    let line_size = min(line_size_perpendicular, line_size_parallel);
151
152    let num_elems = output.shape.iter().product::<usize>();
153    let working_units = num_elems / (line_size as usize * output.shape[rank - 1]);
154    let cube_dim = CubeDim::new(client, working_units);
155    let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
156
157    unsafe {
158        into_contiguous_perpendicular::launch_unchecked::<R>(
159            client,
160            cube_count,
161            cube_dim,
162            input.as_tensor_arg(line_size),
163            output.as_tensor_arg(line_size),
164            ScalarArg::new(axis),
165            dtype,
166        )?;
167    }
168
169    Ok(())
170}