Skip to main content

cubecl_std/tensor/
handle.rs

1use core::marker::PhantomData;
2use cubecl_core::{Runtime, server, zspace::strides};
3use cubecl_core::{calculate_cube_count_elemwise, server::MemoryLayout};
4use cubecl_core::{ir::StorageType, zspace::metadata::Metadata};
5use cubecl_core::{prelude::*, server::CopyDescriptor};
6use cubecl_core::{
7    tensor_vector_size_parallel,
8    zspace::{Shape, Strides},
9};
10use cubecl_runtime::server::Handle;
11
12/// Tensor representation containing a [server handle](Handle) as well as basic tensor metadata.,
13pub struct TensorHandle<R>
14where
15    R: Runtime,
16{
17    /// The buffer where the data are stored.
18    pub handle: server::Handle,
19    pub metadata: Box<Metadata>,
20    /// The type used as storage.
21    pub dtype: StorageType,
22    runtime: PhantomData<R>,
23}
24
25impl<R> core::fmt::Debug for TensorHandle<R>
26where
27    R: Runtime,
28{
29    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30        f.write_fmt(format_args!(
31            "Tensor {{ shape: {:?}, strides: {:?}, dtype: {}}}",
32            self.shape(),
33            self.strides(),
34            self.dtype,
35        ))
36    }
37}
38
39impl<R> Clone for TensorHandle<R>
40where
41    R: Runtime,
42{
43    fn clone(&self) -> Self {
44        Self {
45            handle: self.handle.clone(),
46            metadata: self.metadata.clone(),
47            dtype: self.dtype,
48            runtime: PhantomData,
49        }
50    }
51}
52
53impl<R> TensorHandle<R>
54where
55    R: Runtime,
56{
57    /// Create a new tensor.
58    pub fn new(
59        handle: server::Handle,
60        shape: impl Into<Shape>,
61        strides: impl Into<Strides>,
62        storage: impl Into<Type>,
63    ) -> Self {
64        Self {
65            handle,
66            metadata: Box::new(Metadata::new(shape, strides)),
67            dtype: storage.into().storage_type(),
68            runtime: PhantomData,
69        }
70    }
71
72    pub fn empty(
73        client: &ComputeClient<R>,
74        shape: impl Into<Shape>,
75        storage: impl Into<Type>,
76    ) -> Self {
77        let storage = storage.into();
78        let shape: Shape = shape.into();
79        let elem_size = storage.storage_type().size();
80        let MemoryLayout {
81            memory: handle,
82            strides,
83        } = client.empty_tensor(shape.clone(), elem_size);
84
85        Self::new(handle, shape, strides, storage)
86    }
87
88    /// Create a new tensor with a contiguous memory layout.
89    pub fn new_contiguous(shape: impl Into<Shape>, handle: Handle, storage: StorageType) -> Self {
90        let shape = shape.into();
91        let strides = Self::contiguous_strides(&shape);
92
93        Self {
94            handle,
95            metadata: Box::new(Metadata::new(shape, strides)),
96            dtype: storage,
97            runtime: PhantomData,
98        }
99    }
100
101    /// Check if the tensor is safe to mutate.
102    pub fn can_mut(&self) -> bool {
103        self.handle.can_mut()
104    }
105
106    pub fn binding(self) -> TensorBinding<R> {
107        unsafe {
108            TensorBinding::from_raw_parts(self.handle, self.metadata.strides, self.metadata.shape)
109        }
110    }
111
112    /// Return the reference to a tensor argument.
113    pub fn into_arg(self) -> TensorArg<R> {
114        self.binding().into_tensor_arg()
115    }
116
117    pub fn into_copy_descriptor(self) -> CopyDescriptor {
118        CopyDescriptor {
119            handle: self.handle.binding(),
120            shape: self.metadata.shape,
121            strides: self.metadata.strides,
122            elem_size: self.dtype.size(),
123        }
124    }
125
126    pub fn required_address_type(&self) -> AddressType {
127        let len = self.handle.size() / self.dtype.size() as u64;
128        AddressType::from_len(len as usize)
129    }
130
131    pub fn shape(&self) -> &Shape {
132        self.metadata.shape()
133    }
134
135    pub fn strides(&self) -> &Strides {
136        self.metadata.strides()
137    }
138
139    fn contiguous_strides(shape: &[usize]) -> Strides {
140        let mut strides = strides![1; shape.len()];
141
142        let mut current = 1;
143        shape.iter().rev().enumerate().for_each(|(i, val)| {
144            strides[i] = current;
145            current *= val;
146        });
147        strides.reverse();
148        strides
149    }
150}
151impl<R> TensorHandle<R>
152where
153    R: Runtime,
154{
155    pub fn zeros(
156        client: &ComputeClient<R>,
157        shape: impl Into<Shape>,
158        dtype: impl Into<Type>,
159    ) -> Self {
160        let dtype = dtype.into();
161        let shape = shape.into();
162        let num_elements: usize = shape.iter().product();
163        let rank = shape.len();
164        let output = Self::empty(client, shape, dtype);
165        let dtype = dtype.storage_type();
166
167        let vector_size = tensor_vector_size_parallel(
168            client.io_optimized_vector_sizes(dtype.size()),
169            output.shape(),
170            output.strides(),
171            rank - 1,
172        );
173
174        let working_units = num_elements / vector_size as usize;
175        let cube_dim = CubeDim::new(client, working_units);
176        let cube_count = calculate_cube_count_elemwise(client, working_units, cube_dim);
177        let array_len = output.handle.size_in_used() as usize / dtype.size();
178
179        unsafe {
180            init::zeros_array::launch_unchecked(
181                client,
182                cube_count,
183                cube_dim,
184                output.required_address_type(),
185                vector_size,
186                ArrayArg::from_raw_parts(output.handle.clone(), array_len),
187                dtype,
188            )
189        };
190
191        output
192    }
193}
194
195pub(crate) mod init {
196    use cubecl::prelude::*;
197    use cubecl_core::{self as cubecl, ir::StorageType};
198
199    #[cube(launch_unchecked, address_type = "dynamic")]
200    pub fn zeros_array<C: Numeric, N: Size>(
201        output: &mut Array<Vector<C, N>>,
202        #[define(C)] _elem: StorageType,
203    ) {
204        if ABSOLUTE_POS < output.len() {
205            output[ABSOLUTE_POS] = Vector::cast_from(C::from_int(0));
206        }
207    }
208}