cubecl_core/compute/
launcher.rs

1use std::marker::PhantomData;
2
3use crate::MetadataBuilder;
4use crate::Runtime;
5use crate::compute::KernelTask;
6use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
7use crate::{KernelSettings, prelude::CubePrimitive};
8use bytemuck::{AnyBitPattern, NoUninit};
9use cubecl_runtime::server::{Binding, CubeCount, ScalarBinding, TensorMapBinding};
10use cubecl_runtime::{client::ComputeClient, server::Bindings};
11
12use super::CubeKernel;
13
14/// Prepare a kernel for [launch](KernelLauncher::launch).
15pub struct KernelLauncher<R: Runtime> {
16    tensors: TensorState<R>,
17    scalar_bf16: ScalarState<half::bf16>,
18    scalar_f16: ScalarState<half::f16>,
19    scalar_f32: ScalarState<f32>,
20    scalar_f64: ScalarState<f64>,
21    scalar_u64: ScalarState<u64>,
22    scalar_u32: ScalarState<u32>,
23    scalar_u16: ScalarState<u16>,
24    scalar_u8: ScalarState<u8>,
25    scalar_i64: ScalarState<i64>,
26    scalar_i32: ScalarState<i32>,
27    scalar_i16: ScalarState<i16>,
28    scalar_i8: ScalarState<i8>,
29    pub settings: KernelSettings,
30    runtime: PhantomData<R>,
31}
32
33impl<R: Runtime> KernelLauncher<R> {
34    /// Register a tensor to be launched.
35    pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
36        self.tensors.push_tensor(tensor);
37    }
38
39    /// Register a mapped tensor to be launched.
40    pub fn register_tensor_map(&mut self, tensor: &TensorMapArg<'_, R>) {
41        self.tensors.push_tensor_map(tensor);
42    }
43
44    /// Register an input array to be launched.
45    pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
46        self.tensors.push_array(array);
47    }
48
49    /// Register a u8 scalar to be launched.
50    pub fn register_u8(&mut self, scalar: u8) {
51        self.scalar_u8.push(scalar);
52    }
53
54    /// Register a u16 scalar to be launched.
55    pub fn register_u16(&mut self, scalar: u16) {
56        self.scalar_u16.push(scalar);
57    }
58
59    /// Register a u32 scalar to be launched.
60    pub fn register_u32(&mut self, scalar: u32) {
61        self.scalar_u32.push(scalar);
62    }
63
64    /// Register a u64 scalar to be launched.
65    pub fn register_u64(&mut self, scalar: u64) {
66        self.scalar_u64.push(scalar);
67    }
68
69    /// Register a i8 scalar to be launched.
70    pub fn register_i8(&mut self, scalar: i8) {
71        self.scalar_i8.push(scalar);
72    }
73
74    /// Register a i16 scalar to be launched.
75    pub fn register_i16(&mut self, scalar: i16) {
76        self.scalar_i16.push(scalar);
77    }
78
79    /// Register a i32 scalar to be launched.
80    pub fn register_i32(&mut self, scalar: i32) {
81        self.scalar_i32.push(scalar);
82    }
83
84    /// Register a i64 scalar to be launched.
85    pub fn register_i64(&mut self, scalar: i64) {
86        self.scalar_i64.push(scalar);
87    }
88
89    /// Register a bf16 scalar to be launched.
90    pub fn register_bf16(&mut self, scalar: half::bf16) {
91        self.scalar_bf16.push(scalar);
92    }
93
94    /// Register a f16 scalar to be launched.
95    pub fn register_f16(&mut self, scalar: half::f16) {
96        self.scalar_f16.push(scalar);
97    }
98
99    /// Register a f32 scalar to be launched.
100    pub fn register_f32(&mut self, scalar: f32) {
101        self.scalar_f32.push(scalar);
102    }
103
104    /// Register a f64 scalar to be launched.
105    pub fn register_f64(&mut self, scalar: f64) {
106        self.scalar_f64.push(scalar);
107    }
108
109    /// Launch the kernel.
110    #[track_caller]
111    pub fn launch<K: CubeKernel>(
112        self,
113        cube_count: CubeCount,
114        kernel: K,
115        client: &ComputeClient<R::Server, R::Channel>,
116    ) {
117        let bindings = self.into_bindings();
118        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
119
120        client.execute(kernel, cube_count, bindings);
121    }
122
123    /// Launch the kernel without check bounds.
124    ///
125    /// # Safety
126    ///
127    /// The kernel must not:
128    /// - Contain any out of bounds reads or writes. Doing so is immediate UB.
129    /// - Contain any loops that never terminate. These may be optimized away entirely or cause
130    ///   other unpredictable behaviour.
131    #[track_caller]
132    pub unsafe fn launch_unchecked<K: CubeKernel>(
133        self,
134        cube_count: CubeCount,
135        kernel: K,
136        client: &ComputeClient<R::Server, R::Channel>,
137    ) {
138        unsafe {
139            let bindings = self.into_bindings();
140            let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
141
142            client.execute_unchecked(kernel, cube_count, bindings);
143        }
144    }
145
146    /// We need to create the bindings in the same order they are defined in the compilation step.
147    ///
148    /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
149    /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
150    /// are registered in the same order they are added. This is why we store the scalar data type
151    /// in the `scalar_order` vector, so that we can register them in the same order.
152    ///
153    /// Also returns an ordered list of constant bindings. The ordering between constants and tensors
154    /// is up to the runtime.
155    fn into_bindings(self) -> Bindings {
156        let mut bindings = Bindings::new();
157
158        self.tensors.register(&mut bindings);
159
160        self.scalar_u8.register(&mut bindings);
161        self.scalar_u16.register(&mut bindings);
162        self.scalar_u32.register(&mut bindings);
163        self.scalar_u64.register(&mut bindings);
164        self.scalar_i8.register(&mut bindings);
165        self.scalar_i16.register(&mut bindings);
166        self.scalar_i32.register(&mut bindings);
167        self.scalar_i64.register(&mut bindings);
168        self.scalar_f16.register(&mut bindings);
169        self.scalar_bf16.register(&mut bindings);
170        self.scalar_f32.register(&mut bindings);
171        self.scalar_f64.register(&mut bindings);
172
173        bindings
174    }
175}
176
177/// Handles the tensor state.
178pub enum TensorState<R: Runtime> {
179    /// No tensor is registered yet.
180    Empty,
181    /// The registered tensors.
182    Some {
183        buffers: Vec<Binding>,
184        tensor_maps: Vec<TensorMapBinding>,
185        metadata: MetadataBuilder,
186        runtime: PhantomData<R>,
187    },
188}
189
190/// Handles the scalar state of an element type
191///
192/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
193pub enum ScalarState<T> {
194    /// No scalar of that type is registered yet.
195    Empty,
196    /// The registered scalars.
197    Some(Vec<T>),
198}
199
200impl<R: Runtime> TensorState<R> {
201    fn maybe_init(&mut self) {
202        if matches!(self, TensorState::Empty) {
203            *self = TensorState::Some {
204                buffers: Vec::new(),
205                tensor_maps: Vec::new(),
206                metadata: MetadataBuilder::default(),
207                runtime: PhantomData,
208            };
209        }
210    }
211
212    fn buffers(&mut self) -> &mut Vec<Binding> {
213        self.maybe_init();
214        let TensorState::Some { buffers, .. } = self else {
215            panic!("Should be init");
216        };
217        buffers
218    }
219
220    fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
221        self.maybe_init();
222        let TensorState::Some { tensor_maps, .. } = self else {
223            panic!("Should be init");
224        };
225        tensor_maps
226    }
227
228    fn metadata(&mut self) -> &mut MetadataBuilder {
229        self.maybe_init();
230        let TensorState::Some { metadata, .. } = self else {
231            panic!("Should be init");
232        };
233        metadata
234    }
235
236    /// Push a new input tensor to the state.
237    pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
238        if let Some(tensor) = self.process_tensor(tensor) {
239            self.buffers().push(tensor);
240        }
241    }
242
243    fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
244        let (tensor, vectorization) = match tensor {
245            TensorArg::Handle {
246                handle,
247                vectorization_factor,
248                ..
249            } => (handle, vectorization_factor),
250            TensorArg::Alias { .. } => return None,
251        };
252
253        let elem_size = tensor.elem_size * *vectorization as usize;
254        let buffer_len = tensor.handle.size() / elem_size as u64;
255        let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
256        self.metadata().with_tensor(
257            tensor.strides.len() as u32,
258            buffer_len as u32,
259            len as u32,
260            tensor.shape.iter().map(|it| *it as u32).collect(),
261            tensor.strides.iter().map(|it| *it as u32).collect(),
262        );
263        Some(tensor.handle.clone().binding())
264    }
265
266    /// Push a new input array to the state.
267    pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
268        if let Some(tensor) = self.process_array(array) {
269            self.buffers().push(tensor);
270        }
271    }
272
273    fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
274        let (array, vectorization) = match array {
275            ArrayArg::Handle {
276                handle,
277                vectorization_factor,
278                ..
279            } => (handle, vectorization_factor),
280            ArrayArg::Alias { .. } => return None,
281        };
282
283        let elem_size = array.elem_size * *vectorization as usize;
284        let buffer_len = array.handle.size() / elem_size as u64;
285        self.metadata()
286            .with_array(buffer_len as u32, array.length[0] as u32);
287        Some(array.handle.clone().binding())
288    }
289
290    /// Push a new tensor to the state.
291    pub fn push_tensor_map(&mut self, map: &TensorMapArg<'_, R>) {
292        let binding = self
293            .process_tensor(&map.tensor)
294            .expect("Can't use alias for TensorMap");
295
296        let map = map.metadata.clone();
297        self.tensor_maps().push(TensorMapBinding { binding, map });
298    }
299
300    fn register(self, bindings_global: &mut Bindings) {
301        if let Self::Some {
302            buffers,
303            tensor_maps,
304            metadata,
305            ..
306        } = self
307        {
308            let metadata = metadata.finish();
309
310            bindings_global.buffers = buffers;
311            bindings_global.tensor_maps = tensor_maps;
312            bindings_global.metadata = metadata;
313        }
314    }
315}
316
317impl<T: NoUninit + AnyBitPattern + CubePrimitive> ScalarState<T> {
318    /// Add a new scalar value to the state.
319    pub fn push(&mut self, val: T) {
320        match self {
321            ScalarState::Empty => *self = Self::Some(vec![val]),
322            ScalarState::Some(values) => values.push(val),
323        }
324    }
325
326    fn register(&self, bindings: &mut Bindings) {
327        if let ScalarState::Some(values) = self {
328            let len = values.len();
329            let len_u64 = len.div_ceil(size_of::<u64>() / size_of::<T>());
330            let mut data = vec![0; len_u64];
331            let slice = bytemuck::cast_slice_mut::<u64, T>(&mut data);
332            slice[0..values.len()].copy_from_slice(values);
333            let elem = T::as_elem_native_unchecked();
334            bindings
335                .scalars
336                .insert(elem, ScalarBinding::new(elem, len, data));
337        }
338    }
339}
340
341impl<R: Runtime> Default for KernelLauncher<R> {
342    fn default() -> Self {
343        Self {
344            tensors: TensorState::Empty,
345            scalar_bf16: ScalarState::Empty,
346            scalar_f16: ScalarState::Empty,
347            scalar_f32: ScalarState::Empty,
348            scalar_f64: ScalarState::Empty,
349            scalar_u64: ScalarState::Empty,
350            scalar_u32: ScalarState::Empty,
351            scalar_u16: ScalarState::Empty,
352            scalar_u8: ScalarState::Empty,
353            scalar_i64: ScalarState::Empty,
354            scalar_i32: ScalarState::Empty,
355            scalar_i16: ScalarState::Empty,
356            scalar_i8: ScalarState::Empty,
357            settings: Default::default(),
358            runtime: PhantomData,
359        }
360    }
361}