1use alloc::{boxed::Box, collections::BTreeMap, vec, vec::Vec};
2use core::marker::PhantomData;
3
4use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
5use crate::{CubeScalar, KernelSettings};
6use crate::{MetadataBuilder, Runtime};
7#[cfg(feature = "std")]
8use core::cell::RefCell;
9use cubecl_ir::{AddressType, StorageType};
10use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
11use cubecl_runtime::{
12 client::ComputeClient,
13 kernel::{CubeKernel, KernelTask},
14 server::Bindings,
15};
16
17pub struct KernelLauncher<R: Runtime> {
19 tensors: TensorState<R>,
20 scalars: ScalarState,
21 pub settings: KernelSettings,
22 runtime: PhantomData<R>,
23}
24
25impl<R: Runtime> KernelLauncher<R> {
26 pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
28 self.tensors.push_tensor(tensor);
29 }
30
31 pub fn register_tensor_map<K: TensorMapKind>(&mut self, tensor: &TensorMapArg<'_, R, K>) {
33 self.tensors.push_tensor_map(tensor);
34 }
35
36 pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
38 self.tensors.push_array(array);
39 }
40
41 pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
43 self.scalars.push(scalar);
44 }
45
46 pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
48 self.scalars.push_raw(bytes, dtype);
49 }
50
51 #[track_caller]
53 pub fn launch<K: CubeKernel>(
54 self,
55 cube_count: CubeCount,
56 kernel: K,
57 client: &ComputeClient<R>,
58 ) -> Result<(), LaunchError> {
59 let bindings = self.into_bindings();
60 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
61
62 client.launch(kernel, cube_count, bindings)
63 }
64
65 #[track_caller]
74 pub unsafe fn launch_unchecked<K: CubeKernel>(
75 self,
76 cube_count: CubeCount,
77 kernel: K,
78 client: &ComputeClient<R>,
79 ) -> Result<(), LaunchError> {
80 unsafe {
81 let bindings = self.into_bindings();
82 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
83
84 client.launch_unchecked(kernel, cube_count, bindings)
85 }
86 }
87
88 fn into_bindings(self) -> Bindings {
98 let mut bindings = Bindings::new();
99
100 self.tensors.register(&mut bindings);
101 self.scalars.register(&mut bindings);
102
103 bindings
104 }
105}
106
107#[cfg(feature = "std")]
108std::thread_local! {
109 static METADATA: RefCell<MetadataBuilder> = RefCell::new(MetadataBuilder::default());
110}
111
112pub enum TensorState<R: Runtime> {
114 Empty { addr_type: AddressType },
116 Some {
118 buffers: Vec<Binding>,
119 tensor_maps: Vec<TensorMapBinding>,
120 addr_type: AddressType,
121 runtime: PhantomData<R>,
122 #[cfg(not(feature = "std"))]
123 metadata: MetadataBuilder,
124 },
125}
126
127#[derive(Default, Clone)]
131pub struct ScalarState {
132 data: BTreeMap<StorageType, ScalarValues>,
133}
134
135pub type ScalarValues = Vec<u8>;
137
138impl<R: Runtime> TensorState<R> {
139 fn maybe_init(&mut self) {
140 if let TensorState::Empty { addr_type } = self {
141 *self = TensorState::Some {
142 buffers: Vec::new(),
143 tensor_maps: Vec::new(),
144 addr_type: *addr_type,
145 runtime: PhantomData,
146 #[cfg(not(feature = "std"))]
147 metadata: MetadataBuilder::default(),
148 };
149 }
150 }
151
152 #[cfg(feature = "std")]
153 fn with_metadata<T>(&mut self, fun: impl FnMut(&mut MetadataBuilder) -> T) -> T {
154 METADATA.with_borrow_mut(fun)
155 }
156
157 #[cfg(not(feature = "std"))]
158 fn with_metadata<T>(&mut self, mut fun: impl FnMut(&mut MetadataBuilder) -> T) -> T {
159 self.maybe_init();
160 let TensorState::Some { metadata, .. } = self else {
161 panic!("Should be init");
162 };
163 fun(metadata)
164 }
165
166 fn buffers(&mut self) -> &mut Vec<Binding> {
167 self.maybe_init();
168 let TensorState::Some { buffers, .. } = self else {
169 panic!("Should be init");
170 };
171 buffers
172 }
173
174 fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
175 self.maybe_init();
176 let TensorState::Some { tensor_maps, .. } = self else {
177 panic!("Should be init");
178 };
179 tensor_maps
180 }
181
182 fn address_type(&self) -> AddressType {
183 match self {
184 TensorState::Empty { addr_type } => *addr_type,
185 TensorState::Some { addr_type, .. } => *addr_type,
186 }
187 }
188
189 pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
191 if let Some(tensor) = self.process_tensor(tensor) {
192 self.buffers().push(tensor);
193 }
194 }
195
196 fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
197 let (tensor, vectorization) = match tensor {
198 TensorArg::Handle {
199 handle,
200 line_size: vectorization_factor,
201 ..
202 } => (handle, vectorization_factor),
203 TensorArg::Alias { .. } => return None,
204 };
205
206 let elem_size = tensor.elem_size * *vectorization;
207 let buffer_len = tensor.handle.size() / elem_size as u64;
208 let len = tensor.shape.iter().product::<usize>() / *vectorization;
209 let address_type = self.address_type();
210 self.with_metadata(|meta| {
211 meta.register_tensor(
212 tensor.strides.len() as u64,
213 buffer_len,
214 len as u64,
215 tensor.shape,
216 tensor.strides,
217 address_type,
218 )
219 });
220 Some(tensor.handle.clone().binding())
221 }
222
223 pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
225 if let Some(tensor) = self.process_array(array) {
226 self.buffers().push(tensor);
227 }
228 }
229
230 fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
231 let (array, vectorization) = match array {
232 ArrayArg::Handle {
233 handle,
234 line_size: vectorization_factor,
235 ..
236 } => (handle, vectorization_factor),
237 ArrayArg::Alias { .. } => return None,
238 };
239
240 let elem_size = array.elem_size * *vectorization;
241 let buffer_len = array.handle.size() / elem_size as u64;
242 let address_type = self.address_type();
243 self.with_metadata(|meta| {
244 meta.register_array(
245 buffer_len,
246 array.length[0] as u64 / *vectorization as u64,
247 address_type,
248 )
249 });
250 Some(array.handle.clone().binding())
251 }
252
253 pub fn push_tensor_map<K: TensorMapKind>(&mut self, map: &TensorMapArg<'_, R, K>) {
255 let binding = self
256 .process_tensor(&map.tensor)
257 .expect("Can't use alias for TensorMap");
258
259 let map = map.metadata.clone();
260 self.tensor_maps().push(TensorMapBinding { binding, map });
261 }
262
263 fn register(mut self, bindings_global: &mut Bindings) {
264 let metadata = matches!(self, Self::Some { .. }).then(|| {
265 let addr_type = self.address_type();
266 self.with_metadata(|meta| meta.finish(addr_type))
267 });
268 if let Self::Some {
269 buffers,
270 tensor_maps,
271 ..
272 } = self
273 {
274 let metadata = metadata.unwrap();
275
276 bindings_global.buffers = buffers;
277 bindings_global.tensor_maps = tensor_maps;
278 bindings_global.metadata = metadata;
279 }
280 }
281}
282
283impl ScalarState {
284 pub fn push<T: CubeScalar>(&mut self, val: T) {
286 let val = [val];
287 let bytes = T::as_bytes(&val);
288 self.data
289 .entry(T::cube_type())
290 .or_default()
291 .extend(bytes.iter().copied());
292 }
293
294 pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
296 self.data
297 .entry(dtype)
298 .or_default()
299 .extend(bytes.iter().copied());
300 }
301
302 fn register(&self, bindings: &mut Bindings) {
303 for (ty, values) in self.data.iter() {
304 let len = values.len() / ty.size();
305 let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
306
307 let mut data = vec![0; len_u64];
308 let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
309 slice[0..values.len()].copy_from_slice(values);
310 bindings
311 .scalars
312 .insert(*ty, ScalarBinding::new(*ty, len, data));
313 }
314 }
315}
316
317impl<R: Runtime> KernelLauncher<R> {
318 pub fn new(settings: KernelSettings) -> Self {
319 Self {
320 tensors: TensorState::Empty {
321 addr_type: settings.address_type,
322 },
323 scalars: Default::default(),
324 settings,
325 runtime: PhantomData,
326 }
327 }
328}