Skip to main content

cubecl_core/frontend/container/tensor/
launch.rs

1use core::marker::PhantomData;
2
3use cubecl_ir::AddressType;
4use cubecl_runtime::{runtime::Runtime, server::CopyDescriptor};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    compute::{KernelBuilder, KernelLauncher},
9    ir::{Id, LineSize, Type},
10    prelude::{
11        ArgSettings, ArrayArg, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg,
12    },
13};
14
15use super::Tensor;
16
17/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
18#[derive(Debug)]
19pub enum TensorArg<'a, R: Runtime> {
20    /// The tensor is passed with a tensor handle.
21    Handle {
22        /// The tensor handle.
23        handle: TensorHandleRef<'a, R>,
24        /// The vectorization factor.
25        line_size: LineSize,
26    },
27    /// The tensor is aliasing another input tensor.
28    Alias {
29        /// The position of the input tensor.
30        input_pos: usize,
31    },
32}
33
34/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle),
35/// the strides and the shape.
36pub struct TensorHandleRef<'a, R: Runtime> {
37    pub handle: &'a cubecl_runtime::server::Handle,
38    pub strides: &'a [usize],
39    pub shape: &'a [usize],
40    pub elem_size: usize,
41    pub runtime: PhantomData<R>,
42}
43
44impl<'a, R: Runtime> Clone for TensorHandleRef<'a, R> {
45    fn clone(&self) -> Self {
46        *self
47    }
48}
49
50impl<'a, R: Runtime> Copy for TensorHandleRef<'a, R> {}
51
52impl<R: Runtime> TensorHandleRef<'_, R> {
53    pub fn size(&self) -> usize {
54        self.shape.iter().product()
55    }
56
57    /// Address type required to fully index this tensor handle, assuming scalar access.
58    pub fn required_address_type(&self) -> AddressType {
59        AddressType::from_len(self.handle.size() as usize / self.elem_size)
60    }
61}
62
63impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
64    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65        writeln!(
66            f,
67            "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
68            self.strides, self.shape
69        )
70    }
71}
72
73/// Compilation argument for a [tensor](Tensor).
74#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
75pub struct TensorCompilationArg {
76    pub inplace: Option<Id>,
77    pub line_size: LineSize,
78}
79
80impl CompilationArg for TensorCompilationArg {}
81
82impl<C: CubePrimitive> LaunchArg for Tensor<C> {
83    type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
84    type CompilationArg = TensorCompilationArg;
85
86    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
87        match runtime_arg {
88            TensorArg::Handle { line_size, .. } => TensorCompilationArg {
89                inplace: None,
90                line_size: *line_size as LineSize,
91            },
92            TensorArg::Alias { input_pos } => TensorCompilationArg {
93                inplace: Some(*input_pos as Id),
94                line_size: 0,
95            },
96        }
97    }
98
99    fn expand(
100        arg: &Self::CompilationArg,
101        builder: &mut KernelBuilder,
102    ) -> ExpandElementTyped<Tensor<C>> {
103        builder
104            .input_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
105            .into()
106    }
107    fn expand_output(
108        arg: &Self::CompilationArg,
109        builder: &mut KernelBuilder,
110    ) -> ExpandElementTyped<Tensor<C>> {
111        match arg.inplace {
112            Some(id) => builder.inplace_output(id).into(),
113            None => builder
114                .output_tensor(Type::new(C::as_type(&builder.scope)).line(arg.line_size))
115                .into(),
116        }
117    }
118}
119
120impl<'a, R: Runtime> TensorArg<'a, R> {
121    /// Create a new tensor argument specified with its vectorization factor.
122    ///
123    /// # Safety
124    ///
125    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
126    /// out-of-bound reads and writes.
127    pub unsafe fn from_raw_parts<E: CubePrimitive>(
128        handle: &'a cubecl_runtime::server::Handle,
129        strides: &'a [usize],
130        shape: &'a [usize],
131        factor: LineSize,
132    ) -> Self {
133        unsafe {
134            Self::Handle {
135                handle: TensorHandleRef::from_raw_parts(
136                    handle,
137                    strides,
138                    shape,
139                    E::size().expect("Element should have a size"),
140                ),
141                line_size: factor,
142            }
143        }
144    }
145
146    /// Create a new tensor argument specified with its vectorization factor with a manual element
147    /// size in bytes.
148    ///
149    /// # Safety
150    ///
151    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
152    /// out-of-bound reads and writes.
153    pub unsafe fn from_raw_parts_and_size(
154        handle: &'a cubecl_runtime::server::Handle,
155        strides: &'a [usize],
156        shape: &'a [usize],
157        factor: LineSize,
158        elem_size: usize,
159    ) -> Self {
160        unsafe {
161            Self::Handle {
162                handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
163                line_size: factor,
164            }
165        }
166    }
167
168    /// Create an alias argument.
169    pub fn alias(position: usize) -> Self {
170        Self::Alias {
171            input_pos: position,
172        }
173    }
174}
175
176impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
177    fn register(&self, launcher: &mut KernelLauncher<R>) {
178        launcher.register_tensor(self);
179    }
180}
181
182impl<'a, R: Runtime> TensorHandleRef<'a, R> {
183    /// Convert the handle into a [tensor argument](TensorArg).
184    pub fn as_tensor_arg(&'a self, line_size: LineSize) -> TensorArg<'a, R> {
185        unsafe {
186            TensorArg::from_raw_parts_and_size(
187                self.handle,
188                self.strides,
189                self.shape,
190                line_size,
191                self.elem_size,
192            )
193        }
194    }
195    /// Convert the handle into an [array argument](ArrayArg).
196    pub fn as_array_arg(&'a self, line_size: LineSize) -> ArrayArg<'a, R> {
197        let length = self.shape.iter().product();
198        unsafe { ArrayArg::from_raw_parts_and_size(self.handle, length, line_size, self.elem_size) }
199    }
200    /// Create a handle from raw parts.
201    ///
202    /// # Safety
203    ///
204    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
205    /// out-of-bounds reads and writes.
206    pub unsafe fn from_raw_parts(
207        handle: &'a cubecl_runtime::server::Handle,
208        strides: &'a [usize],
209        shape: &'a [usize],
210        elem_size: usize,
211    ) -> Self {
212        Self {
213            handle,
214            strides,
215            shape,
216            elem_size,
217            runtime: PhantomData,
218        }
219    }
220
221    pub fn as_copy_descriptor(&self) -> CopyDescriptor<'_> {
222        CopyDescriptor {
223            binding: self.handle.clone().binding(),
224            shape: self.shape,
225            strides: self.strides,
226            elem_size: self.elem_size,
227        }
228    }
229}