cubecl_std/tensor/contiguous/
perpendicular.rs1use crate::tensor::{TensorHandle, into_contiguous_ref};
2use cubecl::prelude::*;
3use cubecl_core::{
4 self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel,
5 tensor_line_size_perpendicular,
6};
7use std::cmp::min;
8
9#[cube(launch_unchecked)]
16fn into_contiguous_perpendicular<N: Numeric>(
17 input: &Tensor<Line<N>>,
18 output: &mut Tensor<Line<N>>,
19 axis_vectorized: usize,
20 #[define(N)] _elem: StorageType,
21) {
22 let line_size = input.line_size();
23 let last_axis = input.rank() - 1;
24
25 let num_batch = output.shape(last_axis) / line_size;
27
28 let mut accumulators = Sequence::<Line<N>>::new();
30
31 #[unroll]
32 for _ in 0..line_size {
33 accumulators.push(Line::empty(line_size));
34 }
35
36 let channel_input_stride_elem = input.stride(last_axis);
37 let channel_output_stride_elem = output.stride(axis_vectorized);
38
39 let channel_input_stride = channel_input_stride_elem / line_size;
41 let channel_output_stride = channel_output_stride_elem / line_size;
42
43 let num_runs = output.len() / (num_batch * line_size);
45
46 if ABSOLUTE_POS >= num_runs {
47 terminate!()
48 }
49
50 let batch_index = ABSOLUTE_POS * num_batch;
52 let skip_interval = batch_index / channel_output_stride;
53 let skip_index = batch_index % channel_output_stride;
54 let skip_size = channel_output_stride_elem;
55 let global_index = (skip_interval * skip_size) + skip_index;
56
57 for b in 0..num_batch {
58 let offset_output = global_index + b;
59
60 let mut batch_offset = 0;
62 for axis in 0..input.rank() {
63 let coordinate = output.coordinate(offset_output * line_size, axis);
64 batch_offset += coordinate * input.stride(axis);
65 }
66 let batch_offset = batch_offset / line_size;
67
68 for i in 0..line_size {
72 let index = batch_offset + i * channel_input_stride;
73 let batched = input[index];
74
75 #[unroll]
78 for o in 0..line_size {
79 let line = accumulators.index_mut(o);
80 line[i] = batched[o];
81 }
82 }
83
84 #[unroll]
87 for o in 0..line_size {
88 let index_out = offset_output + o * channel_output_stride;
89 let batched = accumulators[o];
90
91 output[index_out] = batched;
92 }
93 }
94}
95
96pub fn launch_into_contiguous_perpendicular<R: Runtime>(
102 client: &ComputeClient<R>,
103 input: &TensorHandleRef<'_, R>,
104 dtype: StorageType,
105) -> Result<TensorHandle<R>, LaunchError> {
106 if input.shape.len() <= 1 {
108 return into_contiguous_ref(client, input, dtype);
109 }
110
111 let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
112 launch_into_contiguous_perpendicular_ref(client, input, &output.as_ref(), dtype)?;
113
114 Ok(output)
115}
116
117pub fn launch_into_contiguous_perpendicular_ref<R: Runtime>(
123 client: &ComputeClient<R>,
124 input: &TensorHandleRef<'_, R>,
125 output: &TensorHandleRef<'_, R>,
126 dtype: StorageType,
127) -> Result<(), LaunchError> {
128 let mut axis = 0;
129
130 for (i, stride) in input.strides.iter().enumerate() {
131 if *stride == 1 {
132 axis = i;
133 break;
134 }
135 }
136 let rank = output.shape.len();
137
138 let line_size_perpendicular = tensor_line_size_perpendicular(
139 client.io_optimized_line_sizes(&dtype),
140 input.shape,
141 input.strides,
142 rank - 1,
143 );
144 let line_size_parallel = tensor_line_size_parallel(
145 client.io_optimized_line_sizes(&dtype),
146 output.shape,
147 output.strides,
148 rank - 1,
149 );
150 let line_size = min(line_size_perpendicular, line_size_parallel);
151
152 let num_elems = output.shape.iter().product::<usize>();
153 let working_units = num_elems / (line_size as usize * output.shape[rank - 1]);
154 let cube_dim = CubeDim::new(client, working_units);
155 let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
156
157 unsafe {
158 into_contiguous_perpendicular::launch_unchecked::<R>(
159 client,
160 cube_count,
161 cube_dim,
162 input.as_tensor_arg(line_size),
163 output.as_tensor_arg(line_size),
164 ScalarArg::new(axis),
165 dtype,
166 )?;
167 }
168
169 Ok(())
170}