cubecl_std/tensor/
handle.rs

1use core::marker::PhantomData;
2use cubecl_core::tensor_line_size_parallel;
3use cubecl_core::{Runtime, server};
4use cubecl_core::{calculate_cube_count_elemwise, server::Allocation};
5use cubecl_core::{prelude::*, server::CopyDescriptor};
6use cubecl_runtime::server::Handle;
7
8/// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata.,
9pub struct TensorHandle<R, E>
10where
11    R: Runtime,
12    E: CubePrimitive,
13{
14    /// The buffer where the data are stored.
15    pub handle: server::Handle,
16    pub shape: Vec<usize>,
17    pub strides: Vec<usize>,
18    elem: PhantomData<E>,
19    runtime: PhantomData<R>,
20}
21
22impl<R, E> core::fmt::Debug for TensorHandle<R, E>
23where
24    R: Runtime,
25    E: CubePrimitive,
26{
27    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        f.write_fmt(format_args!(
29            "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
30            self.shape,
31            self.strides,
32            core::any::type_name::<E>(),
33        ))
34    }
35}
36
37impl<R, E> Clone for TensorHandle<R, E>
38where
39    R: Runtime,
40    E: CubePrimitive,
41{
42    fn clone(&self) -> Self {
43        Self {
44            handle: self.handle.clone(),
45            shape: self.shape.clone(),
46            strides: self.strides.clone(),
47            elem: PhantomData,
48            runtime: PhantomData,
49        }
50    }
51}
52
53impl<R, E> TensorHandle<R, E>
54where
55    R: Runtime,
56    E: CubePrimitive,
57{
58    /// Create a new tensor.
59    pub fn new(handle: server::Handle, shape: Vec<usize>, strides: Vec<usize>) -> Self {
60        Self {
61            handle,
62            shape,
63            strides,
64            elem: PhantomData,
65            runtime: PhantomData,
66        }
67    }
68
69    pub fn empty(client: &ComputeClient<R::Server>, shape: Vec<usize>) -> Self {
70        let elem_size = E::size().expect("To be a native type");
71        let Allocation { handle, strides } = client.empty_tensor(&shape, elem_size);
72
73        Self::new(handle, shape, strides)
74    }
75
76    /// Create a new tensor.
77    pub fn from_ref(handle: &TensorHandleRef<'_, R>) -> Self {
78        Self {
79            handle: handle.handle.clone(),
80            shape: handle.shape.to_vec(),
81            strides: handle.strides.to_vec(),
82            elem: PhantomData,
83            runtime: PhantomData,
84        }
85    }
86
87    /// Create a new tensor with a contiguous memory layout.
88    pub fn new_contiguous(shape: Vec<usize>, handle: Handle) -> Self {
89        let strides = Self::contiguous_strides(&shape);
90
91        Self {
92            handle,
93            shape,
94            strides,
95            elem: PhantomData,
96            runtime: PhantomData,
97        }
98    }
99
100    /// Check if the tensor is safe to mutate.
101    pub fn can_mut(&self) -> bool {
102        self.handle.can_mut()
103    }
104
105    pub fn as_ref(&self) -> TensorHandleRef<'_, R> {
106        unsafe {
107            TensorHandleRef::from_raw_parts(
108                &self.handle,
109                &self.strides,
110                &self.shape,
111                size_of::<E>(),
112            )
113        }
114    }
115
116    /// Return the reference to a tensor argument.
117    pub fn as_arg<'a>(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
118        let handle: TensorHandleRef<'a, R> = self.as_ref();
119
120        unsafe {
121            TensorArg::from_raw_parts::<E>(
122                handle.handle,
123                handle.strides,
124                handle.shape,
125                vectorisation,
126            )
127        }
128    }
129
130    pub fn as_copy_descriptor<'a>(&'a self) -> CopyDescriptor<'a> {
131        CopyDescriptor {
132            binding: self.handle.clone().binding(),
133            shape: &self.shape,
134            strides: &self.strides,
135            elem_size: size_of::<E>(),
136        }
137    }
138
139    fn contiguous_strides(shape: &[usize]) -> Vec<usize> {
140        let mut strides = Vec::with_capacity(shape.len());
141
142        let mut current = 1;
143        shape.iter().enumerate().rev().for_each(|(_, val)| {
144            strides.push(current);
145            current *= val;
146        });
147        strides.reverse();
148        strides
149    }
150}
151impl<R, E> TensorHandle<R, E>
152where
153    R: Runtime,
154    E: Numeric,
155{
156    pub fn zeros(client: &ComputeClient<R::Server>, shape: Vec<usize>) -> Self {
157        let num_elements: usize = shape.iter().product();
158        let rank = shape.len();
159        let output = Self::empty(client, shape);
160
161        let line_size = tensor_line_size_parallel(
162            R::supported_line_sizes().iter().cloned(),
163            &output.shape,
164            &output.strides,
165            rank - 1,
166        );
167
168        let cube_dim = CubeDim::default();
169        let cube_count = calculate_cube_count_elemwise(num_elements / line_size as usize, cube_dim);
170        let array_len = output.handle.size() as usize / size_of::<E>();
171
172        unsafe {
173            init::zeros_array::launch_unchecked::<E, R>(
174                client,
175                cube_count,
176                cube_dim,
177                ArrayArg::from_raw_parts::<E>(&output.handle, array_len, line_size),
178            )
179        };
180
181        output
182    }
183}
184
185pub(crate) mod init {
186    use cubecl::prelude::*;
187    use cubecl_core as cubecl;
188
189    #[cube(launch_unchecked)]
190    pub fn zeros_array<C: Numeric>(output: &mut Array<Line<C>>) {
191        if ABSOLUTE_POS < output.len() {
192            output[ABSOLUTE_POS] = Line::cast_from(C::from_int(0));
193        }
194    }
195}