Skip to main content

cubecl_core/frontend/container/tensor/
launch.rs

1use core::marker::PhantomData;
2
3use cubecl_ir::AddressType;
4use cubecl_runtime::{runtime::Runtime, server::CopyDescriptor};
5use cubecl_zspace::{Shape, Strides};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    compute::{KernelBuilder, KernelLauncher},
10    ir::Id,
11    prelude::{ArrayArg, ArrayBinding, CubePrimitive, LaunchArg, NativeExpand},
12};
13
14use super::Tensor;
15
16/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
17#[derive(Debug)]
18pub enum TensorArg<R: Runtime> {
19    /// The tensor is passed with a tensor handle.
20    Handle {
21        /// The tensor handle.
22        handle: TensorBinding<R>,
23    },
24    /// The tensor is aliasing another input tensor.
25    Alias {
26        /// The position of the input tensor.
27        input_pos: usize,
28        strides: Strides,
29        shape: Shape,
30    },
31}
32
33/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle),
34/// the strides and the shape.
35pub struct TensorBinding<R: Runtime> {
36    pub handle: cubecl_runtime::server::Binding,
37    pub strides: Strides,
38    pub shape: Shape,
39    pub runtime: PhantomData<R>,
40}
41
42impl<R: Runtime> Clone for TensorBinding<R> {
43    fn clone(&self) -> Self {
44        Self {
45            handle: self.handle.clone(),
46            strides: self.strides.clone(),
47            shape: self.shape.clone(),
48            runtime: PhantomData,
49        }
50    }
51}
52
53impl<R: Runtime> TensorBinding<R> {
54    pub fn size(&self) -> usize {
55        self.shape.iter().product()
56    }
57
58    /// Address type required to fully index this tensor handle, assuming scalar access.
59    pub fn required_address_type(&self, elem_size: usize) -> AddressType {
60        AddressType::from_len(self.handle.size() as usize / elem_size)
61    }
62}
63
64impl<R: Runtime> core::fmt::Debug for TensorBinding<R> {
65    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
66        writeln!(
67            f,
68            "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
69            self.strides, self.shape
70        )
71    }
72}
73
74/// Compilation argument for a [tensor](Tensor).
75#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
76pub struct TensorCompilationArg {
77    pub inplace: Option<Id>,
78}
79
80impl<C: CubePrimitive> LaunchArg for Tensor<C> {
81    type RuntimeArg<R: Runtime> = TensorArg<R>;
82    type CompilationArg = TensorCompilationArg;
83
84    fn register<R: Runtime>(
85        arg: Self::RuntimeArg<R>,
86        launcher: &mut KernelLauncher<R>,
87    ) -> Self::CompilationArg {
88        let ty = launcher.with_scope(|scope| C::as_type(scope));
89        let compilation_arg = match &arg {
90            TensorArg::Handle { .. } => TensorCompilationArg { inplace: None },
91            TensorArg::Alias { input_pos, .. } => TensorCompilationArg {
92                inplace: Some(*input_pos as Id),
93            },
94        };
95        launcher.register_tensor(arg, ty);
96        compilation_arg
97    }
98
99    fn expand(_arg: &Self::CompilationArg, builder: &mut KernelBuilder) -> NativeExpand<Tensor<C>> {
100        builder.input_tensor(C::as_type(&builder.scope)).into()
101    }
102    fn expand_output(
103        arg: &Self::CompilationArg,
104        builder: &mut KernelBuilder,
105    ) -> NativeExpand<Tensor<C>> {
106        match arg.inplace {
107            Some(id) => builder.inplace_output(id).into(),
108            None => builder.output_tensor(C::as_type(&builder.scope)).into(),
109        }
110    }
111}
112
113impl<R: Runtime> TensorArg<R> {
114    /// Create a new tensor argument specified with its vectorization factor.
115    ///
116    /// # Safety
117    ///
118    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
119    /// out-of-bound reads and writes.
120    pub unsafe fn from_raw_parts(
121        handle: cubecl_runtime::server::Handle,
122        strides: Strides,
123        shape: Shape,
124    ) -> Self {
125        unsafe { Self::from_raw_parts_binding(handle.binding(), strides, shape) }
126    }
127
128    pub(crate) unsafe fn from_raw_parts_binding(
129        handle: cubecl_runtime::server::Binding,
130        strides: Strides,
131        shape: Shape,
132    ) -> Self {
133        unsafe {
134            Self::Handle {
135                handle: TensorBinding::from_raw_parts_binding(handle, strides, shape),
136            }
137        }
138    }
139
140    /// Create an alias argument.
141    pub fn into_alias(self, position: usize) -> Self {
142        match self {
143            TensorArg::Handle { handle } => handle.into_alias(position),
144            alias @ TensorArg::Alias { .. } => alias,
145        }
146    }
147
148    pub fn size(&self) -> usize {
149        match self {
150            TensorArg::Handle { handle } => handle.size(),
151            TensorArg::Alias { shape, .. } => shape.iter().product(),
152        }
153    }
154
155    pub fn shape(&self) -> &[usize] {
156        match self {
157            TensorArg::Handle { handle } => &handle.shape,
158            TensorArg::Alias { shape, .. } => shape,
159        }
160    }
161
162    pub fn strides(&self) -> &[usize] {
163        match self {
164            TensorArg::Handle { handle } => &handle.strides,
165            TensorArg::Alias { strides, .. } => strides,
166        }
167    }
168}
169
170impl<R: Runtime> TensorArg<R> {
171    pub fn into_array_arg(self) -> ArrayArg<R> {
172        match self {
173            TensorArg::Handle { handle } => {
174                let handle = unsafe {
175                    let size = handle.size();
176                    ArrayBinding::from_raw_parts_binding(handle.handle, size)
177                };
178                ArrayArg::Handle { handle }
179            }
180            TensorArg::Alias {
181                input_pos, shape, ..
182            } => ArrayArg::Alias {
183                input_pos,
184                length: [shape.iter().product()],
185            },
186        }
187    }
188}
189
190impl<R: Runtime> TensorBinding<R> {
191    /// Convert the handle into a [tensor argument](TensorArg).
192    pub fn into_tensor_arg(self) -> TensorArg<R> {
193        unsafe { TensorArg::from_raw_parts_binding(self.handle, self.strides, self.shape) }
194    }
195    /// Convert the handle into a [tensor argument](TensorArg).
196    pub fn into_alias(self, index: usize) -> TensorArg<R> {
197        TensorArg::Alias {
198            input_pos: index,
199            strides: self.strides,
200            shape: self.shape,
201        }
202    }
203    /// Convert the handle into a [tensor argument](TensorArg).
204    pub fn as_alias(&self, index: usize) -> TensorArg<R> {
205        TensorArg::Alias {
206            input_pos: index,
207            strides: self.strides.clone(),
208            shape: self.shape.clone(),
209        }
210    }
211    /// Convert the handle into an [array argument](ArrayArg).
212    pub fn into_array_arg(self) -> ArrayArg<R> {
213        let length = self.shape.iter().product();
214        unsafe { ArrayArg::from_raw_parts_binding(self.handle, length) }
215    }
216
217    /// Create a handle from raw parts.
218    ///
219    /// # Safety
220    ///
221    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
222    /// out-of-bounds reads and writes.
223    pub unsafe fn from_raw_parts(
224        handle: cubecl_runtime::server::Handle,
225        strides: Strides,
226        shape: Shape,
227    ) -> Self {
228        unsafe { Self::from_raw_parts_binding(handle.binding(), strides, shape) }
229    }
230
231    /// Create a handle from raw parts.
232    ///
233    /// # Safety
234    ///
235    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
236    /// out-of-bounds reads and writes.
237    pub unsafe fn from_raw_parts_binding(
238        handle: cubecl_runtime::server::Binding,
239        strides: Strides,
240        shape: Shape,
241    ) -> Self {
242        Self {
243            handle,
244            strides,
245            shape,
246            runtime: PhantomData,
247        }
248    }
249
250    pub fn into_copy_descriptor(self, elem_size: usize) -> CopyDescriptor {
251        CopyDescriptor {
252            handle: self.handle,
253            shape: self.shape,
254            strides: self.strides,
255            elem_size,
256        }
257    }
258}