1use std::marker::PhantomData;
2
3use crate::prelude::{ArrayArg, TensorArg};
4use crate::KernelSettings;
5use crate::{compute::KernelTask, ir::UIntKind};
6use crate::{
7 ir::{Elem, FloatKind, IntKind},
8 MetadataBuilder,
9};
10use crate::{Kernel, Runtime};
11use bytemuck::NoUninit;
12use cubecl_runtime::client::ComputeClient;
13use cubecl_runtime::server::{Binding, CubeCount};
14
15pub struct KernelLauncher<R: Runtime> {
17 tensors: TensorState<R>,
18 scalar_bf16: ScalarState<half::bf16>,
19 scalar_f16: ScalarState<half::f16>,
20 scalar_f32: ScalarState<f32>,
21 scalar_f64: ScalarState<f64>,
22 scalar_u64: ScalarState<u64>,
23 scalar_u32: ScalarState<u32>,
24 scalar_u16: ScalarState<u16>,
25 scalar_u8: ScalarState<u8>,
26 scalar_i64: ScalarState<i64>,
27 scalar_i32: ScalarState<i32>,
28 scalar_i16: ScalarState<i16>,
29 scalar_i8: ScalarState<i8>,
30 scalar_order: Vec<Elem>,
31 pub settings: KernelSettings,
32 runtime: PhantomData<R>,
33}
34
35impl<R: Runtime> KernelLauncher<R> {
36 pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
38 self.tensors.push_tensor(tensor);
39 }
40
41 pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
43 self.tensors.push_array(array);
44 }
45
46 pub fn register_u8(&mut self, scalar: u8) {
48 self.register_scalar(Elem::UInt(UIntKind::U8));
49 self.scalar_u8.push(scalar);
50 }
51
52 pub fn register_u16(&mut self, scalar: u16) {
54 self.register_scalar(Elem::UInt(UIntKind::U16));
55 self.scalar_u16.push(scalar);
56 }
57
58 pub fn register_u32(&mut self, scalar: u32) {
60 self.register_scalar(Elem::UInt(UIntKind::U32));
61 self.scalar_u32.push(scalar);
62 }
63
64 pub fn register_u64(&mut self, scalar: u64) {
66 self.register_scalar(Elem::UInt(UIntKind::U64));
67 self.scalar_u64.push(scalar);
68 }
69
70 pub fn register_i8(&mut self, scalar: i8) {
72 self.register_scalar(Elem::Int(IntKind::I8));
73 self.scalar_i8.push(scalar);
74 }
75
76 pub fn register_i16(&mut self, scalar: i16) {
78 self.register_scalar(Elem::Int(IntKind::I16));
79 self.scalar_i16.push(scalar);
80 }
81
82 pub fn register_i32(&mut self, scalar: i32) {
84 self.register_scalar(Elem::Int(IntKind::I32));
85 self.scalar_i32.push(scalar);
86 }
87
88 pub fn register_i64(&mut self, scalar: i64) {
90 self.register_scalar(Elem::Int(IntKind::I64));
91 self.scalar_i64.push(scalar);
92 }
93
94 pub fn register_bf16(&mut self, scalar: half::bf16) {
96 self.register_scalar(Elem::Float(FloatKind::BF16));
97 self.scalar_bf16.push(scalar);
98 }
99
100 pub fn register_f16(&mut self, scalar: half::f16) {
102 self.register_scalar(Elem::Float(FloatKind::F16));
103 self.scalar_f16.push(scalar);
104 }
105
106 pub fn register_f32(&mut self, scalar: f32) {
108 self.register_scalar(Elem::Float(FloatKind::F32));
109 self.scalar_f32.push(scalar);
110 }
111
112 pub fn register_f64(&mut self, scalar: f64) {
114 self.register_scalar(Elem::Float(FloatKind::F64));
115 self.scalar_f64.push(scalar);
116 }
117
118 pub fn launch<K: Kernel>(
120 self,
121 cube_count: CubeCount,
122 kernel: K,
123 client: &ComputeClient<R::Server, R::Channel>,
124 ) {
125 let bindings = self.into_bindings(client);
126
127 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
128
129 client.execute(kernel, cube_count, bindings);
130 }
131
132 pub unsafe fn launch_unchecked<K: Kernel>(
138 self,
139 cube_count: CubeCount,
140 kernel: K,
141 client: &ComputeClient<R::Server, R::Channel>,
142 ) {
143 let bindings = self.into_bindings(client);
144
145 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
146
147 client.execute_unchecked(kernel, cube_count, bindings);
148 }
149
150 fn into_bindings(mut self, client: &ComputeClient<R::Server, R::Channel>) -> Vec<Binding> {
157 let mut bindings = Vec::new();
158
159 self.tensors.register(client, &mut bindings);
160
161 for elem in self.scalar_order.drain(..) {
162 match elem {
163 Elem::Float(kind) | Elem::AtomicFloat(kind) => match kind {
164 FloatKind::F16 => self.scalar_f16.register::<R>(client, &mut bindings),
165 FloatKind::BF16 => self.scalar_bf16.register::<R>(client, &mut bindings),
166 FloatKind::TF32 => self.scalar_f32.register::<R>(client, &mut bindings),
167 FloatKind::Flex32 => self.scalar_f32.register::<R>(client, &mut bindings),
168 FloatKind::F32 => self.scalar_f32.register::<R>(client, &mut bindings),
169 FloatKind::F64 => self.scalar_f64.register::<R>(client, &mut bindings),
170 },
171 Elem::Int(kind) => match kind {
172 IntKind::I8 => self.scalar_i8.register::<R>(client, &mut bindings),
173 IntKind::I16 => self.scalar_i16.register::<R>(client, &mut bindings),
174 IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
175 IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
176 },
177 Elem::AtomicInt(kind) => match kind {
178 IntKind::I8 => self.scalar_i8.register::<R>(client, &mut bindings),
179 IntKind::I16 => self.scalar_i16.register::<R>(client, &mut bindings),
180 IntKind::I32 => self.scalar_i32.register::<R>(client, &mut bindings),
181 IntKind::I64 => self.scalar_i64.register::<R>(client, &mut bindings),
182 },
183 Elem::UInt(kind) | Elem::AtomicUInt(kind) => match kind {
184 UIntKind::U8 => self.scalar_u8.register::<R>(client, &mut bindings),
185 UIntKind::U16 => self.scalar_u16.register::<R>(client, &mut bindings),
186 UIntKind::U32 => self.scalar_u32.register::<R>(client, &mut bindings),
187 UIntKind::U64 => self.scalar_u64.register::<R>(client, &mut bindings),
188 },
189 Elem::Bool => panic!("Bool can't be passed as bindings."),
190 }
191 }
192
193 bindings
194 }
195
196 fn register_scalar(&mut self, elem: Elem) {
197 if !self.scalar_order.contains(&elem) {
198 self.scalar_order.push(elem);
199 }
200 }
201}
202
203pub enum TensorState<R: Runtime> {
205 Empty,
207 Some {
209 bindings: Vec<Binding>,
210 metadata: MetadataBuilder,
211 runtime: PhantomData<R>,
212 },
213}
214
215pub enum ScalarState<T> {
219 Empty,
221 Some(Vec<T>),
223}
224
225impl<R: Runtime> TensorState<R> {
226 pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
228 let (tensor, vectorization) = match tensor {
229 TensorArg::Handle {
230 handle,
231 vectorization_factor,
232 ..
233 } => (handle, vectorization_factor),
234 TensorArg::Alias { .. } => return,
235 };
236
237 if let TensorState::Empty = self {
238 *self = TensorState::Some {
239 bindings: Vec::with_capacity(1),
240 metadata: MetadataBuilder::default(),
241 runtime: PhantomData,
242 };
243 };
244
245 let TensorState::Some {
246 bindings, metadata, ..
247 } = self
248 else {
249 panic!("Should be init")
250 };
251
252 let elem_size = tensor.elem_size * *vectorization as usize;
253 let buffer_len = tensor.handle.size() / elem_size as u64;
254 let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
255 bindings.push(tensor.handle.clone().binding());
256 metadata.with_tensor(
257 tensor.strides.len() as u32,
258 buffer_len as u32,
259 len as u32,
260 tensor.shape.iter().map(|it| *it as u32).collect(),
261 tensor.strides.iter().map(|it| *it as u32).collect(),
262 );
263 }
264
265 pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
267 let (array, vectorization) = match array {
268 ArrayArg::Handle {
269 handle,
270 vectorization_factor,
271 ..
272 } => (handle, vectorization_factor),
273 ArrayArg::Alias { .. } => return,
274 };
275
276 if let TensorState::Empty = self {
277 *self = TensorState::Some {
278 bindings: Vec::with_capacity(1),
279 metadata: MetadataBuilder::default(),
280 runtime: PhantomData,
281 };
282 };
283
284 let TensorState::Some {
285 bindings, metadata, ..
286 } = self
287 else {
288 panic!("Should be init")
289 };
290
291 let elem_size = array.elem_size * *vectorization as usize;
292 let buffer_len = array.handle.size() / elem_size as u64;
293 bindings.push(array.handle.clone().binding());
294 metadata.with_array(buffer_len as u32, array.length[0] as u32);
295 }
296
297 fn register(
298 self,
299 client: &ComputeClient<R::Server, R::Channel>,
300 bindings_global: &mut Vec<Binding>,
301 ) {
302 if let Self::Some {
303 bindings,
304 metadata,
305 runtime: _,
306 } = self
307 {
308 let metadata = metadata.finish();
309
310 bindings_global.extend(bindings);
311 bindings_global.push(client.create(bytemuck::cast_slice(&metadata)).binding());
312 }
313 }
314}
315
316impl<T: NoUninit> ScalarState<T> {
317 pub fn push(&mut self, val: T) {
319 match self {
320 ScalarState::Empty => *self = Self::Some(vec![val]),
321 ScalarState::Some(values) => values.push(val),
322 }
323 }
324
325 fn register<R: Runtime>(
326 &self,
327 client: &ComputeClient<R::Server, R::Channel>,
328 bindings: &mut Vec<Binding>,
329 ) {
330 match self {
331 ScalarState::Empty => (),
332 ScalarState::Some(values) => {
333 let handle = client.create(bytemuck::cast_slice(values));
334 bindings.push(handle.binding());
335 }
336 }
337 }
338}
339
340impl<R: Runtime> Default for KernelLauncher<R> {
341 fn default() -> Self {
342 Self {
343 tensors: TensorState::Empty,
344 scalar_bf16: ScalarState::Empty,
345 scalar_f16: ScalarState::Empty,
346 scalar_f32: ScalarState::Empty,
347 scalar_f64: ScalarState::Empty,
348 scalar_u64: ScalarState::Empty,
349 scalar_u32: ScalarState::Empty,
350 scalar_u16: ScalarState::Empty,
351 scalar_u8: ScalarState::Empty,
352 scalar_i64: ScalarState::Empty,
353 scalar_i32: ScalarState::Empty,
354 scalar_i16: ScalarState::Empty,
355 scalar_i8: ScalarState::Empty,
356 scalar_order: Vec::new(),
357 settings: Default::default(),
358 runtime: PhantomData,
359 }
360 }
361}