1use std::marker::PhantomData;
2
3use crate::MetadataBuilder;
4use crate::compute::KernelTask;
5use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
6use crate::{Kernel, Runtime};
7use crate::{KernelSettings, prelude::CubePrimitive};
8use bytemuck::{AnyBitPattern, NoUninit};
9use cubecl_runtime::server::{Binding, CubeCount, ScalarBinding, TensorMapBinding};
10use cubecl_runtime::{client::ComputeClient, server::Bindings};
11
12pub struct KernelLauncher<R: Runtime> {
14 tensors: TensorState<R>,
15 scalar_bf16: ScalarState<half::bf16>,
16 scalar_f16: ScalarState<half::f16>,
17 scalar_f32: ScalarState<f32>,
18 scalar_f64: ScalarState<f64>,
19 scalar_u64: ScalarState<u64>,
20 scalar_u32: ScalarState<u32>,
21 scalar_u16: ScalarState<u16>,
22 scalar_u8: ScalarState<u8>,
23 scalar_i64: ScalarState<i64>,
24 scalar_i32: ScalarState<i32>,
25 scalar_i16: ScalarState<i16>,
26 scalar_i8: ScalarState<i8>,
27 pub settings: KernelSettings,
28 runtime: PhantomData<R>,
29}
30
31impl<R: Runtime> KernelLauncher<R> {
32 pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
34 self.tensors.push_tensor(tensor);
35 }
36
37 pub fn register_tensor_map(&mut self, tensor: &TensorMapArg<'_, R>) {
39 self.tensors.push_tensor_map(tensor);
40 }
41
42 pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
44 self.tensors.push_array(array);
45 }
46
47 pub fn register_u8(&mut self, scalar: u8) {
49 self.scalar_u8.push(scalar);
50 }
51
52 pub fn register_u16(&mut self, scalar: u16) {
54 self.scalar_u16.push(scalar);
55 }
56
57 pub fn register_u32(&mut self, scalar: u32) {
59 self.scalar_u32.push(scalar);
60 }
61
62 pub fn register_u64(&mut self, scalar: u64) {
64 self.scalar_u64.push(scalar);
65 }
66
67 pub fn register_i8(&mut self, scalar: i8) {
69 self.scalar_i8.push(scalar);
70 }
71
72 pub fn register_i16(&mut self, scalar: i16) {
74 self.scalar_i16.push(scalar);
75 }
76
77 pub fn register_i32(&mut self, scalar: i32) {
79 self.scalar_i32.push(scalar);
80 }
81
82 pub fn register_i64(&mut self, scalar: i64) {
84 self.scalar_i64.push(scalar);
85 }
86
87 pub fn register_bf16(&mut self, scalar: half::bf16) {
89 self.scalar_bf16.push(scalar);
90 }
91
92 pub fn register_f16(&mut self, scalar: half::f16) {
94 self.scalar_f16.push(scalar);
95 }
96
97 pub fn register_f32(&mut self, scalar: f32) {
99 self.scalar_f32.push(scalar);
100 }
101
102 pub fn register_f64(&mut self, scalar: f64) {
104 self.scalar_f64.push(scalar);
105 }
106
107 pub fn launch<K: Kernel>(
109 self,
110 cube_count: CubeCount,
111 kernel: K,
112 client: &ComputeClient<R::Server, R::Channel>,
113 ) {
114 let bindings = self.into_bindings();
115 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
116
117 client.execute(kernel, cube_count, bindings);
118 }
119
120 pub unsafe fn launch_unchecked<K: Kernel>(
129 self,
130 cube_count: CubeCount,
131 kernel: K,
132 client: &ComputeClient<R::Server, R::Channel>,
133 ) {
134 unsafe {
135 let bindings = self.into_bindings();
136 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
137
138 client.execute_unchecked(kernel, cube_count, bindings);
139 }
140 }
141
142 fn into_bindings(self) -> Bindings {
152 let mut bindings = Bindings::new();
153
154 self.tensors.register(&mut bindings);
155
156 self.scalar_u8.register(&mut bindings);
157 self.scalar_u16.register(&mut bindings);
158 self.scalar_u32.register(&mut bindings);
159 self.scalar_u64.register(&mut bindings);
160 self.scalar_i8.register(&mut bindings);
161 self.scalar_i16.register(&mut bindings);
162 self.scalar_i32.register(&mut bindings);
163 self.scalar_i64.register(&mut bindings);
164 self.scalar_f16.register(&mut bindings);
165 self.scalar_bf16.register(&mut bindings);
166 self.scalar_f32.register(&mut bindings);
167 self.scalar_f64.register(&mut bindings);
168
169 bindings
170 }
171}
172
173pub enum TensorState<R: Runtime> {
175 Empty,
177 Some {
179 buffers: Vec<Binding>,
180 tensor_maps: Vec<TensorMapBinding>,
181 metadata: MetadataBuilder,
182 runtime: PhantomData<R>,
183 },
184}
185
186pub enum ScalarState<T> {
190 Empty,
192 Some(Vec<T>),
194}
195
196impl<R: Runtime> TensorState<R> {
197 fn maybe_init(&mut self) {
198 if matches!(self, TensorState::Empty) {
199 *self = TensorState::Some {
200 buffers: Vec::new(),
201 tensor_maps: Vec::new(),
202 metadata: MetadataBuilder::default(),
203 runtime: PhantomData,
204 };
205 }
206 }
207
208 fn buffers(&mut self) -> &mut Vec<Binding> {
209 self.maybe_init();
210 let TensorState::Some { buffers, .. } = self else {
211 panic!("Should be init");
212 };
213 buffers
214 }
215
216 fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
217 self.maybe_init();
218 let TensorState::Some { tensor_maps, .. } = self else {
219 panic!("Should be init");
220 };
221 tensor_maps
222 }
223
224 fn metadata(&mut self) -> &mut MetadataBuilder {
225 self.maybe_init();
226 let TensorState::Some { metadata, .. } = self else {
227 panic!("Should be init");
228 };
229 metadata
230 }
231
232 pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
234 if let Some(tensor) = self.process_tensor(tensor) {
235 self.buffers().push(tensor);
236 }
237 }
238
239 fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
240 let (tensor, vectorization) = match tensor {
241 TensorArg::Handle {
242 handle,
243 vectorization_factor,
244 ..
245 } => (handle, vectorization_factor),
246 TensorArg::Alias { .. } => return None,
247 };
248
249 let elem_size = tensor.elem_size * *vectorization as usize;
250 let buffer_len = tensor.handle.size() / elem_size as u64;
251 let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
252 self.metadata().with_tensor(
253 tensor.strides.len() as u32,
254 buffer_len as u32,
255 len as u32,
256 tensor.shape.iter().map(|it| *it as u32).collect(),
257 tensor.strides.iter().map(|it| *it as u32).collect(),
258 );
259 Some(tensor.handle.clone().binding())
260 }
261
262 pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
264 if let Some(tensor) = self.process_array(array) {
265 self.buffers().push(tensor);
266 }
267 }
268
269 fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
270 let (array, vectorization) = match array {
271 ArrayArg::Handle {
272 handle,
273 vectorization_factor,
274 ..
275 } => (handle, vectorization_factor),
276 ArrayArg::Alias { .. } => return None,
277 };
278
279 let elem_size = array.elem_size * *vectorization as usize;
280 let buffer_len = array.handle.size() / elem_size as u64;
281 self.metadata()
282 .with_array(buffer_len as u32, array.length[0] as u32);
283 Some(array.handle.clone().binding())
284 }
285
286 pub fn push_tensor_map(&mut self, map: &TensorMapArg<'_, R>) {
288 let binding = self
289 .process_tensor(&map.tensor)
290 .expect("Can't use alias for TensorMap");
291
292 let map = map.metadata.clone();
293 self.tensor_maps().push(TensorMapBinding { binding, map });
294 }
295
296 fn register(self, bindings_global: &mut Bindings) {
297 if let Self::Some {
298 buffers,
299 tensor_maps,
300 metadata,
301 ..
302 } = self
303 {
304 let metadata = metadata.finish();
305
306 bindings_global.buffers = buffers;
307 bindings_global.tensor_maps = tensor_maps;
308 bindings_global.metadata = metadata;
309 }
310 }
311}
312
313impl<T: NoUninit + AnyBitPattern + CubePrimitive> ScalarState<T> {
314 pub fn push(&mut self, val: T) {
316 match self {
317 ScalarState::Empty => *self = Self::Some(vec![val]),
318 ScalarState::Some(values) => values.push(val),
319 }
320 }
321
322 fn register(&self, bindings: &mut Bindings) {
323 if let ScalarState::Some(values) = self {
324 let len = values.len();
325 let len_u64 = len.div_ceil(size_of::<u64>() / size_of::<T>());
326 let mut data = vec![0; len_u64];
327 let slice = bytemuck::cast_slice_mut::<u64, T>(&mut data);
328 slice[0..values.len()].copy_from_slice(values);
329 let elem = T::as_elem_native_unchecked();
330 bindings
331 .scalars
332 .insert(elem, ScalarBinding::new(elem, len, data));
333 }
334 }
335}
336
337impl<R: Runtime> Default for KernelLauncher<R> {
338 fn default() -> Self {
339 Self {
340 tensors: TensorState::Empty,
341 scalar_bf16: ScalarState::Empty,
342 scalar_f16: ScalarState::Empty,
343 scalar_f32: ScalarState::Empty,
344 scalar_f64: ScalarState::Empty,
345 scalar_u64: ScalarState::Empty,
346 scalar_u32: ScalarState::Empty,
347 scalar_u16: ScalarState::Empty,
348 scalar_u8: ScalarState::Empty,
349 scalar_i64: ScalarState::Empty,
350 scalar_i32: ScalarState::Empty,
351 scalar_i16: ScalarState::Empty,
352 scalar_i8: ScalarState::Empty,
353 settings: Default::default(),
354 runtime: PhantomData,
355 }
356 }
357}