cubecl_std/tensor/contiguous/
launch.rs

1use crate::tensor::{
2    TensorHandle, into_contiguous_gpu_ref, launch_into_contiguous_perpendicular_ref,
3};
4use cubecl_core::{
5    Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef, server::LaunchError,
6};
7
8/// Make a jit tensor contiguous.
9pub fn into_contiguous_ref<R: Runtime>(
10    client: &ComputeClient<R>,
11    input: &TensorHandleRef<'_, R>,
12    dtype: StorageType,
13) -> Result<TensorHandle<R>, LaunchError> {
14    let num_elems: usize = input.shape.iter().product();
15
16    let handle = client.empty(num_elems * dtype.size());
17    let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle, dtype);
18
19    copy_into(client, input, &output.as_ref(), dtype)?;
20
21    Ok(output)
22}
23
24/// Make a jit tensor contiguous, using the pitched allocator if available.
25/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
26pub fn into_contiguous_pitched_ref<R: Runtime>(
27    client: &ComputeClient<R>,
28    input: &TensorHandleRef<'_, R>,
29    dtype: StorageType,
30) -> Result<TensorHandle<R>, LaunchError> {
31    if input.shape.len() <= 1 {
32        return into_contiguous_ref(client, input, dtype);
33    }
34
35    let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
36
37    copy_into(client, input, &output.as_ref(), dtype)?;
38
39    Ok(output)
40}
41
42/// Copies the input tensor into the output tensor following the strides.
43pub fn copy_into<R: Runtime>(
44    client: &ComputeClient<R>,
45    input: &TensorHandleRef<'_, R>,
46    output: &TensorHandleRef<'_, R>,
47    dtype: StorageType,
48) -> Result<(), LaunchError> {
49    let rank = input.strides.len();
50
51    // It's normally faster on all devices, but since it doesn't parallelize on an axis, it
52    // might be worst on GPU. Should tune at some point.
53    let is_cpu = client.properties().hardware.num_cpu_cores.is_some();
54    if input.strides[rank - 1] != 1 && is_cpu {
55        launch_into_contiguous_perpendicular_ref(client, input, output, dtype)?;
56    } else {
57        into_contiguous_gpu_ref(client, input, output, dtype)?;
58    };
59
60    Ok(())
61}