1use std::marker::PhantomData;
2
3use crate::MetadataBuilder;
4use crate::Runtime;
5use crate::compute::KernelTask;
6use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
7use crate::{KernelSettings, prelude::CubePrimitive};
8use bytemuck::{AnyBitPattern, NoUninit};
9use cubecl_runtime::server::{Binding, CubeCount, ScalarBinding, TensorMapBinding};
10use cubecl_runtime::{client::ComputeClient, server::Bindings};
11
12use super::CubeKernel;
13
14pub struct KernelLauncher<R: Runtime> {
16 tensors: TensorState<R>,
17 scalar_bf16: ScalarState<half::bf16>,
18 scalar_f16: ScalarState<half::f16>,
19 scalar_f32: ScalarState<f32>,
20 scalar_f64: ScalarState<f64>,
21 scalar_u64: ScalarState<u64>,
22 scalar_u32: ScalarState<u32>,
23 scalar_u16: ScalarState<u16>,
24 scalar_u8: ScalarState<u8>,
25 scalar_i64: ScalarState<i64>,
26 scalar_i32: ScalarState<i32>,
27 scalar_i16: ScalarState<i16>,
28 scalar_i8: ScalarState<i8>,
29 pub settings: KernelSettings,
30 runtime: PhantomData<R>,
31}
32
33impl<R: Runtime> KernelLauncher<R> {
34 pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
36 self.tensors.push_tensor(tensor);
37 }
38
39 pub fn register_tensor_map(&mut self, tensor: &TensorMapArg<'_, R>) {
41 self.tensors.push_tensor_map(tensor);
42 }
43
44 pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
46 self.tensors.push_array(array);
47 }
48
49 pub fn register_u8(&mut self, scalar: u8) {
51 self.scalar_u8.push(scalar);
52 }
53
54 pub fn register_u16(&mut self, scalar: u16) {
56 self.scalar_u16.push(scalar);
57 }
58
59 pub fn register_u32(&mut self, scalar: u32) {
61 self.scalar_u32.push(scalar);
62 }
63
64 pub fn register_u64(&mut self, scalar: u64) {
66 self.scalar_u64.push(scalar);
67 }
68
69 pub fn register_i8(&mut self, scalar: i8) {
71 self.scalar_i8.push(scalar);
72 }
73
74 pub fn register_i16(&mut self, scalar: i16) {
76 self.scalar_i16.push(scalar);
77 }
78
79 pub fn register_i32(&mut self, scalar: i32) {
81 self.scalar_i32.push(scalar);
82 }
83
84 pub fn register_i64(&mut self, scalar: i64) {
86 self.scalar_i64.push(scalar);
87 }
88
89 pub fn register_bf16(&mut self, scalar: half::bf16) {
91 self.scalar_bf16.push(scalar);
92 }
93
94 pub fn register_f16(&mut self, scalar: half::f16) {
96 self.scalar_f16.push(scalar);
97 }
98
99 pub fn register_f32(&mut self, scalar: f32) {
101 self.scalar_f32.push(scalar);
102 }
103
104 pub fn register_f64(&mut self, scalar: f64) {
106 self.scalar_f64.push(scalar);
107 }
108
109 #[track_caller]
111 pub fn launch<K: CubeKernel>(
112 self,
113 cube_count: CubeCount,
114 kernel: K,
115 client: &ComputeClient<R::Server, R::Channel>,
116 ) {
117 let bindings = self.into_bindings();
118 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
119
120 client.execute(kernel, cube_count, bindings);
121 }
122
123 #[track_caller]
132 pub unsafe fn launch_unchecked<K: CubeKernel>(
133 self,
134 cube_count: CubeCount,
135 kernel: K,
136 client: &ComputeClient<R::Server, R::Channel>,
137 ) {
138 unsafe {
139 let bindings = self.into_bindings();
140 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
141
142 client.execute_unchecked(kernel, cube_count, bindings);
143 }
144 }
145
146 fn into_bindings(self) -> Bindings {
156 let mut bindings = Bindings::new();
157
158 self.tensors.register(&mut bindings);
159
160 self.scalar_u8.register(&mut bindings);
161 self.scalar_u16.register(&mut bindings);
162 self.scalar_u32.register(&mut bindings);
163 self.scalar_u64.register(&mut bindings);
164 self.scalar_i8.register(&mut bindings);
165 self.scalar_i16.register(&mut bindings);
166 self.scalar_i32.register(&mut bindings);
167 self.scalar_i64.register(&mut bindings);
168 self.scalar_f16.register(&mut bindings);
169 self.scalar_bf16.register(&mut bindings);
170 self.scalar_f32.register(&mut bindings);
171 self.scalar_f64.register(&mut bindings);
172
173 bindings
174 }
175}
176
177pub enum TensorState<R: Runtime> {
179 Empty,
181 Some {
183 buffers: Vec<Binding>,
184 tensor_maps: Vec<TensorMapBinding>,
185 metadata: MetadataBuilder,
186 runtime: PhantomData<R>,
187 },
188}
189
190pub enum ScalarState<T> {
194 Empty,
196 Some(Vec<T>),
198}
199
200impl<R: Runtime> TensorState<R> {
201 fn maybe_init(&mut self) {
202 if matches!(self, TensorState::Empty) {
203 *self = TensorState::Some {
204 buffers: Vec::new(),
205 tensor_maps: Vec::new(),
206 metadata: MetadataBuilder::default(),
207 runtime: PhantomData,
208 };
209 }
210 }
211
212 fn buffers(&mut self) -> &mut Vec<Binding> {
213 self.maybe_init();
214 let TensorState::Some { buffers, .. } = self else {
215 panic!("Should be init");
216 };
217 buffers
218 }
219
220 fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
221 self.maybe_init();
222 let TensorState::Some { tensor_maps, .. } = self else {
223 panic!("Should be init");
224 };
225 tensor_maps
226 }
227
228 fn metadata(&mut self) -> &mut MetadataBuilder {
229 self.maybe_init();
230 let TensorState::Some { metadata, .. } = self else {
231 panic!("Should be init");
232 };
233 metadata
234 }
235
236 pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
238 if let Some(tensor) = self.process_tensor(tensor) {
239 self.buffers().push(tensor);
240 }
241 }
242
243 fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
244 let (tensor, vectorization) = match tensor {
245 TensorArg::Handle {
246 handle,
247 vectorization_factor,
248 ..
249 } => (handle, vectorization_factor),
250 TensorArg::Alias { .. } => return None,
251 };
252
253 let elem_size = tensor.elem_size * *vectorization as usize;
254 let buffer_len = tensor.handle.size() / elem_size as u64;
255 let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
256 self.metadata().with_tensor(
257 tensor.strides.len() as u32,
258 buffer_len as u32,
259 len as u32,
260 tensor.shape.iter().map(|it| *it as u32).collect(),
261 tensor.strides.iter().map(|it| *it as u32).collect(),
262 );
263 Some(tensor.handle.clone().binding())
264 }
265
266 pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
268 if let Some(tensor) = self.process_array(array) {
269 self.buffers().push(tensor);
270 }
271 }
272
273 fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
274 let (array, vectorization) = match array {
275 ArrayArg::Handle {
276 handle,
277 vectorization_factor,
278 ..
279 } => (handle, vectorization_factor),
280 ArrayArg::Alias { .. } => return None,
281 };
282
283 let elem_size = array.elem_size * *vectorization as usize;
284 let buffer_len = array.handle.size() / elem_size as u64;
285 self.metadata()
286 .with_array(buffer_len as u32, array.length[0] as u32);
287 Some(array.handle.clone().binding())
288 }
289
290 pub fn push_tensor_map(&mut self, map: &TensorMapArg<'_, R>) {
292 let binding = self
293 .process_tensor(&map.tensor)
294 .expect("Can't use alias for TensorMap");
295
296 let map = map.metadata.clone();
297 self.tensor_maps().push(TensorMapBinding { binding, map });
298 }
299
300 fn register(self, bindings_global: &mut Bindings) {
301 if let Self::Some {
302 buffers,
303 tensor_maps,
304 metadata,
305 ..
306 } = self
307 {
308 let metadata = metadata.finish();
309
310 bindings_global.buffers = buffers;
311 bindings_global.tensor_maps = tensor_maps;
312 bindings_global.metadata = metadata;
313 }
314 }
315}
316
317impl<T: NoUninit + AnyBitPattern + CubePrimitive> ScalarState<T> {
318 pub fn push(&mut self, val: T) {
320 match self {
321 ScalarState::Empty => *self = Self::Some(vec![val]),
322 ScalarState::Some(values) => values.push(val),
323 }
324 }
325
326 fn register(&self, bindings: &mut Bindings) {
327 if let ScalarState::Some(values) = self {
328 let len = values.len();
329 let len_u64 = len.div_ceil(size_of::<u64>() / size_of::<T>());
330 let mut data = vec![0; len_u64];
331 let slice = bytemuck::cast_slice_mut::<u64, T>(&mut data);
332 slice[0..values.len()].copy_from_slice(values);
333 let elem = T::as_elem_native_unchecked();
334 bindings
335 .scalars
336 .insert(elem, ScalarBinding::new(elem, len, data));
337 }
338 }
339}
340
341impl<R: Runtime> Default for KernelLauncher<R> {
342 fn default() -> Self {
343 Self {
344 tensors: TensorState::Empty,
345 scalar_bf16: ScalarState::Empty,
346 scalar_f16: ScalarState::Empty,
347 scalar_f32: ScalarState::Empty,
348 scalar_f64: ScalarState::Empty,
349 scalar_u64: ScalarState::Empty,
350 scalar_u32: ScalarState::Empty,
351 scalar_u16: ScalarState::Empty,
352 scalar_u8: ScalarState::Empty,
353 scalar_i64: ScalarState::Empty,
354 scalar_i32: ScalarState::Empty,
355 scalar_i16: ScalarState::Empty,
356 scalar_i8: ScalarState::Empty,
357 settings: Default::default(),
358 runtime: PhantomData,
359 }
360 }
361}