Skip to main content

cubecl_std/tensor/contiguous/
launch.rs

1use crate::tensor::{TensorHandle, copy_gpu_ref, launch_copy_perpendicular_ref};
2use cubecl_core::{
3    Runtime, client::ComputeClient, ir::StorageType, prelude::TensorHandleRef, server::LaunchError,
4};
5
6/// Make a jit tensor contiguous.
7pub fn into_contiguous_ref<R: Runtime>(
8    client: &ComputeClient<R>,
9    input: &TensorHandleRef<'_, R>,
10    dtype: StorageType,
11) -> Result<TensorHandle<R>, LaunchError> {
12    let num_elems: usize = input.shape.iter().product();
13
14    let handle = client.empty(num_elems * dtype.size());
15    let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle, dtype);
16
17    copy_into(client, input, &output.as_ref(), dtype)?;
18
19    Ok(output)
20}
21
22/// Make a jit tensor contiguous, using the pitched allocator if available.
23/// See [`create_tensor`](cubecl_runtime::client::ComputeClient::create_tensor).
24pub fn into_contiguous_pitched_ref<R: Runtime>(
25    client: &ComputeClient<R>,
26    input: &TensorHandleRef<'_, R>,
27    dtype: StorageType,
28) -> Result<TensorHandle<R>, LaunchError> {
29    if input.shape.len() <= 1 {
30        return into_contiguous_ref(client, input, dtype);
31    }
32
33    let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
34
35    copy_into(client, input, &output.as_ref(), dtype)?;
36
37    Ok(output)
38}
39
40/// Copies the input tensor into the output tensor following the strides.
41pub fn copy_into<R: Runtime>(
42    client: &ComputeClient<R>,
43    input: &TensorHandleRef<'_, R>,
44    output: &TensorHandleRef<'_, R>,
45    dtype: StorageType,
46) -> Result<(), LaunchError> {
47    let rank = input.strides.len();
48
49    // It's normally faster on all devices, but since it doesn't parallelize on an axis, it
50    // might be worst on GPU. Should tune at some point.
51    let is_cpu = client.properties().hardware.num_cpu_cores.is_some();
52    if input.strides[rank - 1] != 1 && is_cpu {
53        launch_copy_perpendicular_ref(client, input, output, dtype)?;
54    } else {
55        copy_gpu_ref(client, input, output, dtype)?;
56    };
57
58    Ok(())
59}