use crate::tensor::{TensorHandle, into_contiguous};
use cubecl::prelude::*;
use cubecl_core::{
self as cubecl, calculate_cube_count_elemwise, tensor_vector_size_parallel,
tensor_vector_size_perpendicular,
};
use std::cmp::min;
#[cube(launch_unchecked, address_type = "dynamic")]
fn copy_perpendicular<T: Numeric, N: Size>(
input: &Tensor<Vector<T, N>>,
output: &mut Tensor<Vector<T, N>>,
axis_vectorized: usize,
#[define(T)] _elem: StorageType,
) {
let vector_size = input.vector_size();
let last_axis = input.rank() - 1;
let num_batch = output.shape(last_axis) / vector_size;
let mut accumulators = Sequence::<Vector<T, N>>::new();
#[unroll]
for _ in 0..vector_size {
accumulators.push(Vector::empty());
}
let channel_input_stride_elem = input.stride(last_axis);
let channel_output_stride_elem = output.stride(axis_vectorized);
let channel_input_stride = channel_input_stride_elem / vector_size;
let channel_output_stride = channel_output_stride_elem / vector_size;
let num_runs = output.len() / (num_batch * vector_size);
if ABSOLUTE_POS >= num_runs {
terminate!()
}
let batch_index = ABSOLUTE_POS * num_batch;
let skip_interval = batch_index / channel_output_stride;
let skip_index = batch_index % channel_output_stride;
let skip_size = channel_output_stride_elem;
let global_index = (skip_interval * skip_size) + skip_index;
for b in 0..num_batch {
let offset_output = global_index + b;
let mut batch_offset = 0;
for axis in 0..input.rank() {
let coordinate = output.coordinate(offset_output * vector_size, axis);
batch_offset += coordinate * input.stride(axis);
}
let batch_offset = batch_offset / vector_size;
for i in 0..vector_size {
let index = batch_offset + i * channel_input_stride;
let batched = input[index];
#[unroll]
for o in 0..vector_size {
let vector = accumulators.index_mut(o);
vector[i] = batched[o];
}
}
#[unroll]
for o in 0..vector_size {
let index_out = offset_output + o * channel_output_stride;
let batched = accumulators[o];
output[index_out] = batched;
}
}
}
pub fn launch_into_contiguous_perpendicular<R: Runtime>(
client: &ComputeClient<R>,
input: TensorBinding<R>,
dtype: StorageType,
) -> TensorHandle<R> {
if input.shape.len() <= 1 {
return into_contiguous(client, input, dtype);
}
let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
launch_copy_perpendicular_ref(client, input, output.clone().binding(), dtype);
output
}
pub fn launch_copy_perpendicular_ref<R: Runtime>(
client: &ComputeClient<R>,
input: TensorBinding<R>,
output: TensorBinding<R>,
dtype: StorageType,
) {
let mut axis = 0;
for (i, stride) in input.strides.iter().enumerate() {
if *stride == 1 {
axis = i;
break;
}
}
let rank = output.shape.len();
let vector_size_perpendicular = tensor_vector_size_perpendicular(
client.io_optimized_vector_sizes(dtype.size()),
&input.shape,
&input.strides,
rank - 1,
);
let vector_size_parallel = tensor_vector_size_parallel(
client.io_optimized_vector_sizes(dtype.size()),
&output.shape,
&output.strides,
rank - 1,
);
let vector_size = min(vector_size_perpendicular, vector_size_parallel);
let num_elems = output.shape.iter().product::<usize>();
let working_units = num_elems / (vector_size as usize * output.shape[rank - 1]);
let cube_dim = CubeDim::new(client, working_units);
let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
let address_type = input
.required_address_type(dtype.size())
.max(output.required_address_type(dtype.size()));
unsafe {
copy_perpendicular::launch_unchecked::<R>(
client,
cube_count,
cube_dim,
address_type,
vector_size,
input.into_tensor_arg(),
output.into_tensor_arg(),
axis,
dtype,
);
}
}