cubecl_core/compute/
launcher.rs

1use std::marker::PhantomData;
2
3use crate::prelude::{ArrayArg, TensorArg};
4use crate::KernelSettings;
5use crate::{compute::KernelTask, ir::UIntKind};
6use crate::{
7    ir::{Elem, FloatKind, IntKind},
8    MetadataBuilder,
9};
10use crate::{Kernel, Runtime};
11use bytemuck::NoUninit;
12use cubecl_runtime::client::ComputeClient;
13use cubecl_runtime::server::{Binding, CubeCount};
14
15/// Prepare a kernel for [launch](KernelLauncher::launch).
16pub struct KernelLauncher<R: Runtime> {
17    tensors: TensorState<R>,
18    scalar_bf16: ScalarState<half::bf16>,
19    scalar_f16: ScalarState<half::f16>,
20    scalar_f32: ScalarState<f32>,
21    scalar_f64: ScalarState<f64>,
22    scalar_u64: ScalarState<u64>,
23    scalar_u32: ScalarState<u32>,
24    scalar_u16: ScalarState<u16>,
25    scalar_u8: ScalarState<u8>,
26    scalar_i64: ScalarState<i64>,
27    scalar_i32: ScalarState<i32>,
28    scalar_i16: ScalarState<i16>,
29    scalar_i8: ScalarState<i8>,
30    scalar_order: Vec<Elem>,
31    pub settings: KernelSettings,
32    runtime: PhantomData<R>,
33}
34
35impl<R: Runtime> KernelLauncher<R> {
36    /// Register a tensor to be launched.
37    pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
38        self.tensors.push_tensor(tensor);
39    }
40
41    /// Register an array to be launched.
42    pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
43        self.tensors.push_array(array);
44    }
45
46    /// Register a u8 scalar to be launched.
47    pub fn register_u8(&mut self, scalar: u8) {
48        self.register_scalar(Elem::UInt(UIntKind::U8));
49        self.scalar_u8.push(scalar);
50    }
51
52    /// Register a u16 scalar to be launched.
53    pub fn register_u16(&mut self, scalar: u16) {
54        self.register_scalar(Elem::UInt(UIntKind::U16));
55        self.scalar_u16.push(scalar);
56    }
57
58    /// Register a u32 scalar to be launched.
59    pub fn register_u32(&mut self, scalar: u32) {
60        self.register_scalar(Elem::UInt(UIntKind::U32));
61        self.scalar_u32.push(scalar);
62    }
63
64    /// Register a u64 scalar to be launched.
65    pub fn register_u64(&mut self, scalar: u64) {
66        self.register_scalar(Elem::UInt(UIntKind::U64));
67        self.scalar_u64.push(scalar);
68    }
69
70    /// Register a i8 scalar to be launched.
71    pub fn register_i8(&mut self, scalar: i8) {
72        self.register_scalar(Elem::Int(IntKind::I8));
73        self.scalar_i8.push(scalar);
74    }
75
76    /// Register a i16 scalar to be launched.
77    pub fn register_i16(&mut self, scalar: i16) {
78        self.register_scalar(Elem::Int(IntKind::I16));
79        self.scalar_i16.push(scalar);
80    }
81
82    /// Register a i32 scalar to be launched.
83    pub fn register_i32(&mut self, scalar: i32) {
84        self.register_scalar(Elem::Int(IntKind::I32));
85        self.scalar_i32.push(scalar);
86    }
87
88    /// Register a i64 scalar to be launched.
89    pub fn register_i64(&mut self, scalar: i64) {
90        self.register_scalar(Elem::Int(IntKind::I64));
91        self.scalar_i64.push(scalar);
92    }
93
94    /// Register a bf16 scalar to be launched.
95    pub fn register_bf16(&mut self, scalar: half::bf16) {
96        self.register_scalar(Elem::Float(FloatKind::BF16));
97        self.scalar_bf16.push(scalar);
98    }
99
100    /// Register a f16 scalar to be launched.
101    pub fn register_f16(&mut self, scalar: half::f16) {
102        self.register_scalar(Elem::Float(FloatKind::F16));
103        self.scalar_f16.push(scalar);
104    }
105
106    /// Register a f32 scalar to be launched.
107    pub fn register_f32(&mut self, scalar: f32) {
108        self.register_scalar(Elem::Float(FloatKind::F32));
109        self.scalar_f32.push(scalar);
110    }
111
112    /// Register a f64 scalar to be launched.
113    pub fn register_f64(&mut self, scalar: f64) {
114        self.register_scalar(Elem::Float(FloatKind::F64));
115        self.scalar_f64.push(scalar);
116    }
117
118    /// Launch the kernel.
119    pub fn launch<K: Kernel>(
120        self,
121        cube_count: CubeCount,
122        kernel: K,
123        client: &ComputeClient<R::Server, R::Channel>,
124    ) {
125        let bindings = self.into_bindings(client);
126
127        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
128
129        client.execute(kernel, cube_count, bindings);
130    }
131
132    /// Launch the kernel without check bounds.
133    ///
134    /// # Safety
135    ///
136    /// Out-of-bounds reads and writes can happen.
137    pub unsafe fn launch_unchecked<K: Kernel>(
138        self,
139        cube_count: CubeCount,
140        kernel: K,
141        client: &ComputeClient<R::Server, R::Channel>,
142    ) {
143        let bindings = self.into_bindings(client);
144
145        let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
146
147        client.execute_unchecked(kernel, cube_count, bindings);
148    }
149
150    /// We need to create the bindings in the same order they are defined in the compilation step.
151    ///
152    /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed
153    /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars
154    /// are registered in the same order they are added. This is why we store the scalar data type
155    /// in the `scalar_order` vector, so that we can register them in the same order.
156    fn into_bindings(mut self, client: &ComputeClient<R::Server, R::Channel>) -> Vec<Binding> {
157        let mut bindings = Vec::new();
158
159        self.tensors.register(client, &mut bindings);
160
161        for elem in self.scalar_order.drain(..) {
162            match elem {
163                Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
164                    FloatKind::F16 => self.scalar_f16.register::<R>(client, &mut bindings),
165                    FloatKind::BF16 => self.scalar_bf16.register::<R>(client, &mut bindings),
166                    FloatKind::TF32 => self.scalar_f32.register::<R>(client, &mut bindings),
167                    FloatKind::Flex32 => self.scalar_f32.register::<R>(client, &mut bindings),
168                    FloatKind::F32 => self.scalar_f32.register::<R>(client, &mut bindings),
169                    FloatKind::F64 => self.scalar_f64.register::<R>(client, &mut bindings),
170                },
171                Elem::Int(kind) => match kind {
172                    IntKind::I8 => self.scalar_i8.register::<R>(client, &mut bindings),
173                    IntKind::I16 => self.scalar_i16.register::<R>(client, &mut bindings),
174                    IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
175                    IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
176                },
177                Elem::AtomicInt(kind) => match kind {
178                    IntKind::I8 => self.scalar_i8.register::<R>(client, &mut bindings),
179                    IntKind::I16 => self.scalar_i16.register::<R>(client, &mut bindings),
180                    IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
181                    IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
182                },
183                Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
184                    UIntKind::U8 => self.scalar_u8.register::<R>(client, &mut bindings),
185                    UIntKind::U16 => self.scalar_u16.register::<R>(client, &mut bindings),
186                    UIntKind::U32 => self.scalar_u32.register::<R>(client, &mut bindings),
187                    UIntKind::U64 => self.scalar_u64.register::<R>(client, &mut bindings),
188                },
189                Elem::Bool => panic!("Bool can't be passed as bindings."),
190            }
191        }
192
193        bindings
194    }
195
196    fn register_scalar(&mut self, elem: Elem) {
197        if !self.scalar_order.contains(&elem) {
198            self.scalar_order.push(elem);
199        }
200    }
201}
202
203/// Handles the tensor state.
204pub enum TensorState<R: Runtime> {
205    /// No tensor is registered yet.
206    Empty,
207    /// The registered tensors.
208    Some {
209        bindings: Vec<Binding>,
210        metadata: MetadataBuilder,
211        runtime: PhantomData<R>,
212    },
213}
214
215/// Handles the scalar state of an element type
216///
217/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device.
218pub enum ScalarState<T> {
219    /// No scalar of that type is registered yet.
220    Empty,
221    /// The registered scalars.
222    Some(Vec<T>),
223}
224
225impl<R: Runtime> TensorState<R> {
226    /// Push a new tensor to the state.
227    pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
228        let (tensor, vectorization) = match tensor {
229            TensorArg::Handle {
230                handle,
231                vectorization_factor,
232                ..
233            } => (handle, vectorization_factor),
234            TensorArg::Alias { .. } => return,
235        };
236
237        if let TensorState::Empty = self {
238            *self = TensorState::Some {
239                bindings: Vec::with_capacity(1),
240                metadata: MetadataBuilder::default(),
241                runtime: PhantomData,
242            };
243        };
244
245        let TensorState::Some {
246            bindings, metadata, ..
247        } = self
248        else {
249            panic!("Should be init")
250        };
251
252        let elem_size = tensor.elem_size * *vectorization as usize;
253        let buffer_len = tensor.handle.size() / elem_size as u64;
254        let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
255        bindings.push(tensor.handle.clone().binding());
256        metadata.with_tensor(
257            tensor.strides.len() as u32,
258            buffer_len as u32,
259            len as u32,
260            tensor.shape.iter().map(|it| *it as u32).collect(),
261            tensor.strides.iter().map(|it| *it as u32).collect(),
262        );
263    }
264
265    /// Push a new array to the state.
266    pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
267        let (array, vectorization) = match array {
268            ArrayArg::Handle {
269                handle,
270                vectorization_factor,
271                ..
272            } => (handle, vectorization_factor),
273            ArrayArg::Alias { .. } => return,
274        };
275
276        if let TensorState::Empty = self {
277            *self = TensorState::Some {
278                bindings: Vec::with_capacity(1),
279                metadata: MetadataBuilder::default(),
280                runtime: PhantomData,
281            };
282        };
283
284        let TensorState::Some {
285            bindings, metadata, ..
286        } = self
287        else {
288            panic!("Should be init")
289        };
290
291        let elem_size = array.elem_size * *vectorization as usize;
292        let buffer_len = array.handle.size() / elem_size as u64;
293        bindings.push(array.handle.clone().binding());
294        metadata.with_array(buffer_len as u32, array.length[0] as u32);
295    }
296
297    fn register(
298        self,
299        client: &ComputeClient<R::Server, R::Channel>,
300        bindings_global: &mut Vec<Binding>,
301    ) {
302        if let Self::Some {
303            bindings,
304            metadata,
305            runtime: _,
306        } = self
307        {
308            let metadata = metadata.finish();
309
310            bindings_global.extend(bindings);
311            bindings_global.push(client.create(bytemuck::cast_slice(&metadata)).binding());
312        }
313    }
314}
315
316impl<T: NoUninit> ScalarState<T> {
317    /// Add a new scalar value to the state.
318    pub fn push(&mut self, val: T) {
319        match self {
320            ScalarState::Empty => *self = Self::Some(vec![val]),
321            ScalarState::Some(values) => values.push(val),
322        }
323    }
324
325    fn register<R: Runtime>(
326        &self,
327        client: &ComputeClient<R::Server, R::Channel>,
328        bindings: &mut Vec<Binding>,
329    ) {
330        match self {
331            ScalarState::Empty => (),
332            ScalarState::Some(values) => {
333                let handle = client.create(bytemuck::cast_slice(values));
334                bindings.push(handle.binding());
335            }
336        }
337    }
338}
339
340impl<R: Runtime> Default for KernelLauncher<R> {
341    fn default() -> Self {
342        Self {
343            tensors: TensorState::Empty,
344            scalar_bf16: ScalarState::Empty,
345            scalar_f16: ScalarState::Empty,
346            scalar_f32: ScalarState::Empty,
347            scalar_f64: ScalarState::Empty,
348            scalar_u64: ScalarState::Empty,
349            scalar_u32: ScalarState::Empty,
350            scalar_u16: ScalarState::Empty,
351            scalar_u8: ScalarState::Empty,
352            scalar_i64: ScalarState::Empty,
353            scalar_i32: ScalarState::Empty,
354            scalar_i16: ScalarState::Empty,
355            scalar_i8: ScalarState::Empty,
356            scalar_order: Vec::new(),
357            settings: Default::default(),
358            runtime: PhantomData,
359        }
360    }
361}