cubecl_core/compute/
launcher.rs

1use std::{collections::BTreeMap, marker::PhantomData};
2
3use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
4use crate::{CubeScalar, KernelSettings};
5use crate::{MetadataBuilder, Runtime};
6use cubecl_ir::StorageType;
7use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
8use cubecl_runtime::{
9    client::ComputeClient,
10    kernel::{CubeKernel, KernelTask},
11    server::Bindings,
12};
13
14/// Prepare a kernel for [launch](KernelLauncher::launch).
15pub struct KernelLauncher<R: Runtime> {
16    tensors: TensorState<R>,
17    scalars: ScalarState,
18    pub settings: KernelSettings,
19    runtime: PhantomData<R>,
20}
21
22impl<R: Runtime> KernelLauncher<R> {
23    /// Register a tensor to be launched.
24    pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
25        self.tensors.push_tensor(tensor);
26    }
27
28    /// Register a mapped tensor to be launched.
29    pub fn register_tensor_map(&mut self, tensor: &TensorMapArg<'_, R>) {
30        self.tensors.push_tensor_map(tensor);
31    }
32
33    /// Register an input array to be launched.
34    pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
35        self.tensors.push_array(array);
36    }
37
38    /// Register a scalar to be launched.
39    pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
40        self.scalars.push(scalar);
41    }
42
43    /// Register a scalar to be launched from raw data.
44    pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
45        self.scalars.push_raw(bytes, dtype);
46    }
47
48    /// Launch the kernel.
49    #[track_caller]
50    pub fn launch<K: CubeKernel>(
51        self,
52        cube_count: CubeCount,
53        kernel: K,
54        client: &ComputeClient<R>,
55    ) -> Result<(), LaunchError> {
56        let bindings = self.into_bindings();
57        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
58
59        client.launch(kernel, cube_count, bindings)
60    }
61
62    /// Launch the kernel without check bounds.
63    ///
64    /// # Safety
65    ///
66    /// The kernel must not:
67    /// - Contain any out of bounds reads or writes. Doing so is immediate UB.
68    /// - Contain any loops that never terminate. These may be optimized away entirely or cause
69    ///   other unpredictable behaviour.
70    #[track_caller]
71    pub unsafe fn launch_unchecked<K: CubeKernel>(
72        self,
73        cube_count: CubeCount,
74        kernel: K,
75        client: &ComputeClient<R>,
76    ) -> Result<(), LaunchError> {
77        unsafe {
78            let bindings = self.into_bindings();
79            let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
80
81            client.launch_unchecked(kernel, cube_count, bindings)
82        }
83    }
84
85    /// We need to create the bindings in the same order they are defined in the compilation step.
86    ///
87    /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
88    /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
89    /// are registered in the same order they are added. This is why we store the scalar data type
90    /// in the `scalar_order` vector, so that we can register them in the same order.
91    ///
92    /// Also returns an ordered list of constant bindings. The ordering between constants and tensors
93    /// is up to the runtime.
94    fn into_bindings(self) -> Bindings {
95        let mut bindings = Bindings::new();
96
97        self.tensors.register(&mut bindings);
98        self.scalars.register(&mut bindings);
99
100        bindings
101    }
102}
103
104/// Handles the tensor state.
105pub enum TensorState<R: Runtime> {
106    /// No tensor is registered yet.
107    Empty,
108    /// The registered tensors.
109    Some {
110        buffers: Vec<Binding>,
111        tensor_maps: Vec<TensorMapBinding>,
112        metadata: MetadataBuilder,
113        runtime: PhantomData<R>,
114    },
115}
116
117/// Handles the scalar state of an element type
118///
119/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
120#[derive(Default, Clone)]
121pub struct ScalarState {
122    data: BTreeMap<StorageType, ScalarValues>,
123}
124
125/// Stores the data and type for a scalar arg
126pub type ScalarValues = Vec<u8>;
127
128impl<R: Runtime> TensorState<R> {
129    fn maybe_init(&mut self) {
130        if matches!(self, TensorState::Empty) {
131            *self = TensorState::Some {
132                buffers: Vec::new(),
133                tensor_maps: Vec::new(),
134                metadata: MetadataBuilder::default(),
135                runtime: PhantomData,
136            };
137        }
138    }
139
140    fn buffers(&mut self) -> &mut Vec<Binding> {
141        self.maybe_init();
142        let TensorState::Some { buffers, .. } = self else {
143            panic!("Should be init");
144        };
145        buffers
146    }
147
148    fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
149        self.maybe_init();
150        let TensorState::Some { tensor_maps, .. } = self else {
151            panic!("Should be init");
152        };
153        tensor_maps
154    }
155
156    fn metadata(&mut self) -> &mut MetadataBuilder {
157        self.maybe_init();
158        let TensorState::Some { metadata, .. } = self else {
159            panic!("Should be init");
160        };
161        metadata
162    }
163
164    /// Push a new input tensor to the state.
165    pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
166        if let Some(tensor) = self.process_tensor(tensor) {
167            self.buffers().push(tensor);
168        }
169    }
170
171    fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
172        let (tensor, vectorization) = match tensor {
173            TensorArg::Handle {
174                handle,
175                line_size: vectorization_factor,
176                ..
177            } => (handle, vectorization_factor),
178            TensorArg::Alias { .. } => return None,
179        };
180
181        let elem_size = tensor.elem_size * *vectorization as usize;
182        let buffer_len = tensor.handle.size() / elem_size as u64;
183        let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
184        self.metadata().with_tensor(
185            tensor.strides.len() as u32,
186            buffer_len as u32,
187            len as u32,
188            tensor.shape.iter().map(|it| *it as u32).collect(),
189            tensor.strides.iter().map(|it| *it as u32).collect(),
190        );
191        Some(tensor.handle.clone().binding())
192    }
193
194    /// Push a new input array to the state.
195    pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
196        if let Some(tensor) = self.process_array(array) {
197            self.buffers().push(tensor);
198        }
199    }
200
201    fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
202        let (array, vectorization) = match array {
203            ArrayArg::Handle {
204                handle,
205                line_size: vectorization_factor,
206                ..
207            } => (handle, vectorization_factor),
208            ArrayArg::Alias { .. } => return None,
209        };
210
211        let elem_size = array.elem_size * *vectorization as usize;
212        let buffer_len = array.handle.size() / elem_size as u64;
213        self.metadata().with_array(
214            buffer_len as u32,
215            array.length[0] as u32 / *vectorization as u32,
216        );
217        Some(array.handle.clone().binding())
218    }
219
220    /// Push a new tensor to the state.
221    pub fn push_tensor_map(&mut self, map: &TensorMapArg<'_, R>) {
222        let binding = self
223            .process_tensor(&map.tensor)
224            .expect("Can't use alias for TensorMap");
225
226        let map = map.metadata.clone();
227        self.tensor_maps().push(TensorMapBinding { binding, map });
228    }
229
230    fn register(self, bindings_global: &mut Bindings) {
231        if let Self::Some {
232            buffers,
233            tensor_maps,
234            metadata,
235            ..
236        } = self
237        {
238            let metadata = metadata.finish();
239
240            bindings_global.buffers = buffers;
241            bindings_global.tensor_maps = tensor_maps;
242            bindings_global.metadata = metadata;
243        }
244    }
245}
246
247impl ScalarState {
248    /// Add a new scalar value to the state.
249    pub fn push<T: CubeScalar>(&mut self, val: T) {
250        let val = [val];
251        let bytes = T::as_bytes(&val);
252        self.data
253            .entry(T::cube_type())
254            .or_default()
255            .extend(bytes.iter().copied());
256    }
257
258    /// Add a new raw value to the state.
259    pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
260        self.data
261            .entry(dtype)
262            .or_default()
263            .extend(bytes.iter().copied());
264    }
265
266    fn register(&self, bindings: &mut Bindings) {
267        for (ty, values) in self.data.iter() {
268            let len = values.len() / ty.size();
269            let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
270
271            let mut data = vec![0; len_u64];
272            let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
273            slice[0..values.len()].copy_from_slice(values);
274            bindings
275                .scalars
276                .insert(*ty, ScalarBinding::new(*ty, len, data));
277        }
278    }
279}
280
281impl<R: Runtime> Default for KernelLauncher<R> {
282    fn default() -> Self {
283        Self {
284            tensors: TensorState::Empty,
285            scalars: Default::default(),
286            settings: Default::default(),
287            runtime: PhantomData,
288        }
289    }
290}