cubecl_core/compute/
launcher.rs

1use std::marker::PhantomData;
2
3use crate::MetadataBuilder;
4use crate::compute::KernelTask;
5use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
6use crate::{Kernel, Runtime};
7use crate::{KernelSettings, prelude::CubePrimitive};
8use bytemuck::{AnyBitPattern, NoUninit};
9use cubecl_runtime::server::{Binding, CubeCount, ScalarBinding, TensorMapBinding};
10use cubecl_runtime::{client::ComputeClient, server::Bindings};
11
12/// Prepare a kernel for [launch](KernelLauncher::launch).
13pub struct KernelLauncher<R: Runtime> {
14    tensors: TensorState<R>,
15    scalar_bf16: ScalarState<half::bf16>,
16    scalar_f16: ScalarState<half::f16>,
17    scalar_f32: ScalarState<f32>,
18    scalar_f64: ScalarState<f64>,
19    scalar_u64: ScalarState<u64>,
20    scalar_u32: ScalarState<u32>,
21    scalar_u16: ScalarState<u16>,
22    scalar_u8: ScalarState<u8>,
23    scalar_i64: ScalarState<i64>,
24    scalar_i32: ScalarState<i32>,
25    scalar_i16: ScalarState<i16>,
26    scalar_i8: ScalarState<i8>,
27    pub settings: KernelSettings,
28    runtime: PhantomData<R>,
29}
30
31impl<R: Runtime> KernelLauncher<R> {
32    /// Register a tensor to be launched.
33    pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
34        self.tensors.push_tensor(tensor);
35    }
36
37    /// Register a mapped tensor to be launched.
38    pub fn register_tensor_map(&mut self, tensor: &TensorMapArg<'_, R>) {
39        self.tensors.push_tensor_map(tensor);
40    }
41
42    /// Register an input array to be launched.
43    pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
44        self.tensors.push_array(array);
45    }
46
47    /// Register a u8 scalar to be launched.
48    pub fn register_u8(&mut self, scalar: u8) {
49        self.scalar_u8.push(scalar);
50    }
51
52    /// Register a u16 scalar to be launched.
53    pub fn register_u16(&mut self, scalar: u16) {
54        self.scalar_u16.push(scalar);
55    }
56
57    /// Register a u32 scalar to be launched.
58    pub fn register_u32(&mut self, scalar: u32) {
59        self.scalar_u32.push(scalar);
60    }
61
62    /// Register a u64 scalar to be launched.
63    pub fn register_u64(&mut self, scalar: u64) {
64        self.scalar_u64.push(scalar);
65    }
66
67    /// Register a i8 scalar to be launched.
68    pub fn register_i8(&mut self, scalar: i8) {
69        self.scalar_i8.push(scalar);
70    }
71
72    /// Register a i16 scalar to be launched.
73    pub fn register_i16(&mut self, scalar: i16) {
74        self.scalar_i16.push(scalar);
75    }
76
77    /// Register a i32 scalar to be launched.
78    pub fn register_i32(&mut self, scalar: i32) {
79        self.scalar_i32.push(scalar);
80    }
81
82    /// Register a i64 scalar to be launched.
83    pub fn register_i64(&mut self, scalar: i64) {
84        self.scalar_i64.push(scalar);
85    }
86
87    /// Register a bf16 scalar to be launched.
88    pub fn register_bf16(&mut self, scalar: half::bf16) {
89        self.scalar_bf16.push(scalar);
90    }
91
92    /// Register a f16 scalar to be launched.
93    pub fn register_f16(&mut self, scalar: half::f16) {
94        self.scalar_f16.push(scalar);
95    }
96
97    /// Register a f32 scalar to be launched.
98    pub fn register_f32(&mut self, scalar: f32) {
99        self.scalar_f32.push(scalar);
100    }
101
102    /// Register a f64 scalar to be launched.
103    pub fn register_f64(&mut self, scalar: f64) {
104        self.scalar_f64.push(scalar);
105    }
106
107    /// Launch the kernel.
108    pub fn launch<K: Kernel>(
109        self,
110        cube_count: CubeCount,
111        kernel: K,
112        client: &ComputeClient<R::Server, R::Channel>,
113    ) {
114        let bindings = self.into_bindings();
115        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
116
117        client.execute(kernel, cube_count, bindings);
118    }
119
120    /// Launch the kernel without check bounds.
121    ///
122    /// # Safety
123    ///
124    /// The kernel must not:
125    /// - Contain any out of bounds reads or writes. Doing so is immediate UB.
126    /// - Contain any loops that never terminate. These may be optimized away entirely or cause
127    ///   other unpredictable behaviour.
128    pub unsafe fn launch_unchecked<K: Kernel>(
129        self,
130        cube_count: CubeCount,
131        kernel: K,
132        client: &ComputeClient<R::Server, R::Channel>,
133    ) {
134        unsafe {
135            let bindings = self.into_bindings();
136            let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
137
138            client.execute_unchecked(kernel, cube_count, bindings);
139        }
140    }
141
142    /// We need to create the bindings in the same order they are defined in the compilation step.
143    ///
144    /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
145    /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
146    /// are registered in the same order they are added. This is why we store the scalar data type
147    /// in the `scalar_order` vector, so that we can register them in the same order.
148    ///
149    /// Also returns an ordered list of constant bindings. The ordering between constants and tensors
150    /// is up to the runtime.
151    fn into_bindings(self) -> Bindings {
152        let mut bindings = Bindings::new();
153
154        self.tensors.register(&mut bindings);
155
156        self.scalar_u8.register(&mut bindings);
157        self.scalar_u16.register(&mut bindings);
158        self.scalar_u32.register(&mut bindings);
159        self.scalar_u64.register(&mut bindings);
160        self.scalar_i8.register(&mut bindings);
161        self.scalar_i16.register(&mut bindings);
162        self.scalar_i32.register(&mut bindings);
163        self.scalar_i64.register(&mut bindings);
164        self.scalar_f16.register(&mut bindings);
165        self.scalar_bf16.register(&mut bindings);
166        self.scalar_f32.register(&mut bindings);
167        self.scalar_f64.register(&mut bindings);
168
169        bindings
170    }
171}
172
173/// Handles the tensor state.
174pub enum TensorState<R: Runtime> {
175    /// No tensor is registered yet.
176    Empty,
177    /// The registered tensors.
178    Some {
179        buffers: Vec<Binding>,
180        tensor_maps: Vec<TensorMapBinding>,
181        metadata: MetadataBuilder,
182        runtime: PhantomData<R>,
183    },
184}
185
186/// Handles the scalar state of an element type
187///
188/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
189pub enum ScalarState<T> {
190    /// No scalar of that type is registered yet.
191    Empty,
192    /// The registered scalars.
193    Some(Vec<T>),
194}
195
196impl<R: Runtime> TensorState<R> {
197    fn maybe_init(&mut self) {
198        if matches!(self, TensorState::Empty) {
199            *self = TensorState::Some {
200                buffers: Vec::new(),
201                tensor_maps: Vec::new(),
202                metadata: MetadataBuilder::default(),
203                runtime: PhantomData,
204            };
205        }
206    }
207
208    fn buffers(&mut self) -> &mut Vec<Binding> {
209        self.maybe_init();
210        let TensorState::Some { buffers, .. } = self else {
211            panic!("Should be init");
212        };
213        buffers
214    }
215
216    fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
217        self.maybe_init();
218        let TensorState::Some { tensor_maps, .. } = self else {
219            panic!("Should be init");
220        };
221        tensor_maps
222    }
223
224    fn metadata(&mut self) -> &mut MetadataBuilder {
225        self.maybe_init();
226        let TensorState::Some { metadata, .. } = self else {
227            panic!("Should be init");
228        };
229        metadata
230    }
231
232    /// Push a new input tensor to the state.
233    pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
234        if let Some(tensor) = self.process_tensor(tensor) {
235            self.buffers().push(tensor);
236        }
237    }
238
239    fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
240        let (tensor, vectorization) = match tensor {
241            TensorArg::Handle {
242                handle,
243                vectorization_factor,
244                ..
245            } => (handle, vectorization_factor),
246            TensorArg::Alias { .. } => return None,
247        };
248
249        let elem_size = tensor.elem_size * *vectorization as usize;
250        let buffer_len = tensor.handle.size() / elem_size as u64;
251        let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
252        self.metadata().with_tensor(
253            tensor.strides.len() as u32,
254            buffer_len as u32,
255            len as u32,
256            tensor.shape.iter().map(|it| *it as u32).collect(),
257            tensor.strides.iter().map(|it| *it as u32).collect(),
258        );
259        Some(tensor.handle.clone().binding())
260    }
261
262    /// Push a new input array to the state.
263    pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
264        if let Some(tensor) = self.process_array(array) {
265            self.buffers().push(tensor);
266        }
267    }
268
269    fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
270        let (array, vectorization) = match array {
271            ArrayArg::Handle {
272                handle,
273                vectorization_factor,
274                ..
275            } => (handle, vectorization_factor),
276            ArrayArg::Alias { .. } => return None,
277        };
278
279        let elem_size = array.elem_size * *vectorization as usize;
280        let buffer_len = array.handle.size() / elem_size as u64;
281        self.metadata()
282            .with_array(buffer_len as u32, array.length[0] as u32);
283        Some(array.handle.clone().binding())
284    }
285
286    /// Push a new tensor to the state.
287    pub fn push_tensor_map(&mut self, map: &TensorMapArg<'_, R>) {
288        let binding = self
289            .process_tensor(&map.tensor)
290            .expect("Can't use alias for TensorMap");
291
292        let map = map.metadata.clone();
293        self.tensor_maps().push(TensorMapBinding { binding, map });
294    }
295
296    fn register(self, bindings_global: &mut Bindings) {
297        if let Self::Some {
298            buffers,
299            tensor_maps,
300            metadata,
301            ..
302        } = self
303        {
304            let metadata = metadata.finish();
305
306            bindings_global.buffers = buffers;
307            bindings_global.tensor_maps = tensor_maps;
308            bindings_global.metadata = metadata;
309        }
310    }
311}
312
313impl<T: NoUninit + AnyBitPattern + CubePrimitive> ScalarState<T> {
314    /// Add a new scalar value to the state.
315    pub fn push(&mut self, val: T) {
316        match self {
317            ScalarState::Empty => *self = Self::Some(vec![val]),
318            ScalarState::Some(values) => values.push(val),
319        }
320    }
321
322    fn register(&self, bindings: &mut Bindings) {
323        if let ScalarState::Some(values) = self {
324            let len = values.len();
325            let len_u64 = len.div_ceil(size_of::<u64>() / size_of::<T>());
326            let mut data = vec![0; len_u64];
327            let slice = bytemuck::cast_slice_mut::<u64, T>(&mut data);
328            slice[0..values.len()].copy_from_slice(values);
329            let elem = T::as_elem_native_unchecked();
330            bindings
331                .scalars
332                .insert(elem, ScalarBinding::new(elem, len, data));
333        }
334    }
335}
336
337impl<R: Runtime> Default for KernelLauncher<R> {
338    fn default() -> Self {
339        Self {
340            tensors: TensorState::Empty,
341            scalar_bf16: ScalarState::Empty,
342            scalar_f16: ScalarState::Empty,
343            scalar_f32: ScalarState::Empty,
344            scalar_f64: ScalarState::Empty,
345            scalar_u64: ScalarState::Empty,
346            scalar_u32: ScalarState::Empty,
347            scalar_u16: ScalarState::Empty,
348            scalar_u8: ScalarState::Empty,
349            scalar_i64: ScalarState::Empty,
350            scalar_i32: ScalarState::Empty,
351            scalar_i16: ScalarState::Empty,
352            scalar_i8: ScalarState::Empty,
353            settings: Default::default(),
354            runtime: PhantomData,
355        }
356    }
357}