cubecl_core/frontend/container/tensor/
launch.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    Runtime,
7    compute::{KernelBuilder, KernelLauncher},
8    ir::{Id, Item, Vectorization},
9    prelude::{
10        ArgSettings, CompilationArg, CubePrimitive, ExpandElementTyped, LaunchArg, LaunchArgExpand,
11    },
12};
13
14use super::Tensor;
15
16/// Argument to be used for [tensors](Tensor) passed as arguments to kernels.
17#[derive(Debug)]
18pub enum TensorArg<'a, R: Runtime> {
19    /// The tensor is passed with a tensor handle.
20    Handle {
21        /// The tensor handle.
22        handle: TensorHandleRef<'a, R>,
23        /// The vectorization factor.
24        vectorization_factor: u8,
25    },
26    /// The tensor is aliasing another input tensor.
27    Alias {
28        /// The position of the input tensor.
29        input_pos: usize,
30    },
31}
32
33/// Tensor representation with a reference to the [server handle](cubecl_runtime::server::Handle),
34/// the strides and the shape.
35pub struct TensorHandleRef<'a, R: Runtime> {
36    pub handle: &'a cubecl_runtime::server::Handle,
37    pub strides: &'a [usize],
38    pub shape: &'a [usize],
39    pub elem_size: usize,
40    pub runtime: PhantomData<R>,
41}
42
43impl<R: Runtime> TensorHandleRef<'_, R> {
44    pub fn size(&self) -> usize {
45        self.shape.iter().product()
46    }
47}
48
49impl<R: Runtime> core::fmt::Debug for TensorHandleRef<'_, R> {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        writeln!(
52            f,
53            "TensorHandleRef {{ strides: {:?}, shape: {:?} }}",
54            self.strides, self.shape
55        )
56    }
57}
58
59/// Compilation argument for a [tensor](Tensor).
60#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize, Deserialize)]
61pub struct TensorCompilationArg {
62    pub inplace: Option<Id>,
63    pub vectorisation: Vectorization,
64}
65
66impl CompilationArg for TensorCompilationArg {}
67
68impl<C: CubePrimitive> LaunchArgExpand for Tensor<C> {
69    type CompilationArg = TensorCompilationArg;
70
71    fn expand(
72        arg: &Self::CompilationArg,
73        builder: &mut KernelBuilder,
74    ) -> ExpandElementTyped<Tensor<C>> {
75        builder
76            .input_tensor(Item::vectorized(
77                C::as_elem(&builder.context),
78                arg.vectorisation,
79            ))
80            .into()
81    }
82    fn expand_output(
83        arg: &Self::CompilationArg,
84        builder: &mut KernelBuilder,
85    ) -> ExpandElementTyped<Tensor<C>> {
86        match arg.inplace {
87            Some(id) => builder.inplace_output(id).into(),
88            None => builder
89                .output_tensor(Item::vectorized(
90                    C::as_elem(&builder.context),
91                    arg.vectorisation,
92                ))
93                .into(),
94        }
95    }
96}
97
98impl<C: CubePrimitive> LaunchArg for Tensor<C> {
99    type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>;
100
101    fn compilation_arg<R: Runtime>(runtime_arg: &Self::RuntimeArg<'_, R>) -> Self::CompilationArg {
102        match runtime_arg {
103            TensorArg::Handle {
104                vectorization_factor,
105                ..
106            } => TensorCompilationArg {
107                inplace: None,
108                vectorisation: Vectorization::Some(NonZero::new(*vectorization_factor).unwrap()),
109            },
110            TensorArg::Alias { input_pos } => TensorCompilationArg {
111                inplace: Some(*input_pos as Id),
112                vectorisation: Vectorization::None,
113            },
114        }
115    }
116}
117
118impl<'a, R: Runtime> TensorArg<'a, R> {
119    /// Create a new tensor argument specified with its vectorization factor.
120    ///
121    /// # Safety
122    ///
123    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
124    /// out-of-bound reads and writes.
125    pub unsafe fn from_raw_parts<E: CubePrimitive>(
126        handle: &'a cubecl_runtime::server::Handle,
127        strides: &'a [usize],
128        shape: &'a [usize],
129        factor: u8,
130    ) -> Self {
131        unsafe {
132            Self::Handle {
133                handle: TensorHandleRef::from_raw_parts(
134                    handle,
135                    strides,
136                    shape,
137                    E::size().expect("Element should have a size"),
138                ),
139                vectorization_factor: factor,
140            }
141        }
142    }
143
144    /// Create a new tensor argument specified with its vectorization factor with a manual element
145    /// size in bytes.
146    ///
147    /// # Safety
148    ///
149    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
150    /// out-of-bound reads and writes.
151    pub unsafe fn from_raw_parts_and_size(
152        handle: &'a cubecl_runtime::server::Handle,
153        strides: &'a [usize],
154        shape: &'a [usize],
155        factor: u8,
156        elem_size: usize,
157    ) -> Self {
158        unsafe {
159            Self::Handle {
160                handle: TensorHandleRef::from_raw_parts(handle, strides, shape, elem_size),
161                vectorization_factor: factor,
162            }
163        }
164    }
165
166    /// Create an alias argument.
167    pub fn alias(position: usize) -> Self {
168        Self::Alias {
169            input_pos: position,
170        }
171    }
172}
173
174impl<R: Runtime> ArgSettings<R> for TensorArg<'_, R> {
175    fn register(&self, launcher: &mut KernelLauncher<R>) {
176        launcher.register_tensor(self);
177    }
178}
179
180impl<'a, R: Runtime> TensorHandleRef<'a, R> {
181    /// Convert the handle into a [tensor argument](TensorArg).
182    pub fn as_tensor_arg(&'a self, vectorisation: u8) -> TensorArg<'a, R> {
183        unsafe {
184            TensorArg::from_raw_parts_and_size(
185                self.handle,
186                self.strides,
187                self.shape,
188                vectorisation,
189                self.elem_size,
190            )
191        }
192    }
193    /// Create a handle from raw parts.
194    ///
195    /// # Safety
196    ///
197    /// If you provide wrong strides or shapes, it might create undefined behavior caused by
198    /// out-of-bounds reads and writes.
199    pub unsafe fn from_raw_parts(
200        handle: &'a cubecl_runtime::server::Handle,
201        strides: &'a [usize],
202        shape: &'a [usize],
203        elem_size: usize,
204    ) -> Self {
205        Self {
206            handle,
207            strides,
208            shape,
209            elem_size,
210            runtime: PhantomData,
211        }
212    }
213}