1use std::{collections::BTreeMap, marker::PhantomData};
2
3use crate::prelude::{ArrayArg, TensorArg, TensorMapArg};
4use crate::{CubeScalar, KernelSettings};
5use crate::{MetadataBuilder, Runtime};
6use cubecl_ir::StorageType;
7use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
8use cubecl_runtime::{
9 client::ComputeClient,
10 kernel::{CubeKernel, KernelTask},
11 server::Bindings,
12};
13
14pub struct KernelLauncher<R: Runtime> {
16 tensors: TensorState<R>,
17 scalars: ScalarState,
18 pub settings: KernelSettings,
19 runtime: PhantomData<R>,
20}
21
22impl<R: Runtime> KernelLauncher<R> {
23 pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
25 self.tensors.push_tensor(tensor);
26 }
27
28 pub fn register_tensor_map(&mut self, tensor: &TensorMapArg<'_, R>) {
30 self.tensors.push_tensor_map(tensor);
31 }
32
33 pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
35 self.tensors.push_array(array);
36 }
37
38 pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
40 self.scalars.push(scalar);
41 }
42
43 pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
45 self.scalars.push_raw(bytes, dtype);
46 }
47
48 #[track_caller]
50 pub fn launch<K: CubeKernel>(
51 self,
52 cube_count: CubeCount,
53 kernel: K,
54 client: &ComputeClient<R>,
55 ) -> Result<(), LaunchError> {
56 let bindings = self.into_bindings();
57 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
58
59 client.launch(kernel, cube_count, bindings)
60 }
61
62 #[track_caller]
71 pub unsafe fn launch_unchecked<K: CubeKernel>(
72 self,
73 cube_count: CubeCount,
74 kernel: K,
75 client: &ComputeClient<R>,
76 ) -> Result<(), LaunchError> {
77 unsafe {
78 let bindings = self.into_bindings();
79 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
80
81 client.launch_unchecked(kernel, cube_count, bindings)
82 }
83 }
84
85 fn into_bindings(self) -> Bindings {
95 let mut bindings = Bindings::new();
96
97 self.tensors.register(&mut bindings);
98 self.scalars.register(&mut bindings);
99
100 bindings
101 }
102}
103
104pub enum TensorState<R: Runtime> {
106 Empty,
108 Some {
110 buffers: Vec<Binding>,
111 tensor_maps: Vec<TensorMapBinding>,
112 metadata: MetadataBuilder,
113 runtime: PhantomData<R>,
114 },
115}
116
117#[derive(Default, Clone)]
121pub struct ScalarState {
122 data: BTreeMap<StorageType, ScalarValues>,
123}
124
125pub type ScalarValues = Vec<u8>;
127
128impl<R: Runtime> TensorState<R> {
129 fn maybe_init(&mut self) {
130 if matches!(self, TensorState::Empty) {
131 *self = TensorState::Some {
132 buffers: Vec::new(),
133 tensor_maps: Vec::new(),
134 metadata: MetadataBuilder::default(),
135 runtime: PhantomData,
136 };
137 }
138 }
139
140 fn buffers(&mut self) -> &mut Vec<Binding> {
141 self.maybe_init();
142 let TensorState::Some { buffers, .. } = self else {
143 panic!("Should be init");
144 };
145 buffers
146 }
147
148 fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
149 self.maybe_init();
150 let TensorState::Some { tensor_maps, .. } = self else {
151 panic!("Should be init");
152 };
153 tensor_maps
154 }
155
156 fn metadata(&mut self) -> &mut MetadataBuilder {
157 self.maybe_init();
158 let TensorState::Some { metadata, .. } = self else {
159 panic!("Should be init");
160 };
161 metadata
162 }
163
164 pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
166 if let Some(tensor) = self.process_tensor(tensor) {
167 self.buffers().push(tensor);
168 }
169 }
170
171 fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
172 let (tensor, vectorization) = match tensor {
173 TensorArg::Handle {
174 handle,
175 line_size: vectorization_factor,
176 ..
177 } => (handle, vectorization_factor),
178 TensorArg::Alias { .. } => return None,
179 };
180
181 let elem_size = tensor.elem_size * *vectorization as usize;
182 let buffer_len = tensor.handle.size() / elem_size as u64;
183 let len = tensor.shape.iter().product::<usize>() / *vectorization as usize;
184 self.metadata().with_tensor(
185 tensor.strides.len() as u32,
186 buffer_len as u32,
187 len as u32,
188 tensor.shape.iter().map(|it| *it as u32).collect(),
189 tensor.strides.iter().map(|it| *it as u32).collect(),
190 );
191 Some(tensor.handle.clone().binding())
192 }
193
194 pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
196 if let Some(tensor) = self.process_array(array) {
197 self.buffers().push(tensor);
198 }
199 }
200
201 fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
202 let (array, vectorization) = match array {
203 ArrayArg::Handle {
204 handle,
205 line_size: vectorization_factor,
206 ..
207 } => (handle, vectorization_factor),
208 ArrayArg::Alias { .. } => return None,
209 };
210
211 let elem_size = array.elem_size * *vectorization as usize;
212 let buffer_len = array.handle.size() / elem_size as u64;
213 self.metadata().with_array(
214 buffer_len as u32,
215 array.length[0] as u32 / *vectorization as u32,
216 );
217 Some(array.handle.clone().binding())
218 }
219
220 pub fn push_tensor_map(&mut self, map: &TensorMapArg<'_, R>) {
222 let binding = self
223 .process_tensor(&map.tensor)
224 .expect("Can't use alias for TensorMap");
225
226 let map = map.metadata.clone();
227 self.tensor_maps().push(TensorMapBinding { binding, map });
228 }
229
230 fn register(self, bindings_global: &mut Bindings) {
231 if let Self::Some {
232 buffers,
233 tensor_maps,
234 metadata,
235 ..
236 } = self
237 {
238 let metadata = metadata.finish();
239
240 bindings_global.buffers = buffers;
241 bindings_global.tensor_maps = tensor_maps;
242 bindings_global.metadata = metadata;
243 }
244 }
245}
246
247impl ScalarState {
248 pub fn push<T: CubeScalar>(&mut self, val: T) {
250 let val = [val];
251 let bytes = T::as_bytes(&val);
252 self.data
253 .entry(T::cube_type())
254 .or_default()
255 .extend(bytes.iter().copied());
256 }
257
258 pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
260 self.data
261 .entry(dtype)
262 .or_default()
263 .extend(bytes.iter().copied());
264 }
265
266 fn register(&self, bindings: &mut Bindings) {
267 for (ty, values) in self.data.iter() {
268 let len = values.len() / ty.size();
269 let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
270
271 let mut data = vec![0; len_u64];
272 let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
273 slice[0..values.len()].copy_from_slice(values);
274 bindings
275 .scalars
276 .insert(*ty, ScalarBinding::new(*ty, len, data));
277 }
278 }
279}
280
281impl<R: Runtime> Default for KernelLauncher<R> {
282 fn default() -> Self {
283 Self {
284 tensors: TensorState::Empty,
285 scalars: Default::default(),
286 settings: Default::default(),
287 runtime: PhantomData,
288 }
289 }
290}