1use std::{collections::BTreeMap, marker::PhantomData};
2
3use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
4use crate::{CubeScalar, KernelSettings};
5use crate::{MetadataBuilder, Runtime};
6use cubecl_ir::StorageType;
7use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
8use cubecl_runtime::{
9 client::ComputeClient,
10 kernel::{CubeKernel, KernelTask},
11 server::Bindings,
12};
13
14pub struct KernelLauncher<R: Runtime> {
16 tensors: TensorState<R>,
17 scalars: ScalarState,
18 pub settings: KernelSettings,
19 runtime: PhantomData<R>,
20}
21
22impl<R: Runtime> KernelLauncher<R> {
23 pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
25 self.tensors.push_tensor(tensor);
26 }
27
28 pub fn register_tensor_map<K: TensorMapKind>(&mut self, tensor: &TensorMapArg<'_, R, K>) {
30 self.tensors.push_tensor_map(tensor);
31 }
32
33 pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
35 self.tensors.push_array(array);
36 }
37
38 pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
40 self.scalars.push(scalar);
41 }
42
43 pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
45 self.scalars.push_raw(bytes, dtype);
46 }
47
48 #[track_caller]
50 pub fn launch<K: CubeKernel>(
51 self,
52 cube_count: CubeCount,
53 kernel: K,
54 client: &ComputeClient<R>,
55 ) -> Result<(), LaunchError> {
56 let bindings = self.into_bindings();
57 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
58
59 client.launch(kernel, cube_count, bindings)
60 }
61
62 #[track_caller]
71 pub unsafe fn launch_unchecked<K: CubeKernel>(
72 self,
73 cube_count: CubeCount,
74 kernel: K,
75 client: &ComputeClient<R>,
76 ) -> Result<(), LaunchError> {
77 unsafe {
78 let bindings = self.into_bindings();
79 let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
80
81 client.launch_unchecked(kernel, cube_count, bindings)
82 }
83 }
84
85 fn into_bindings(self) -> Bindings {
95 let mut bindings = Bindings::new();
96
97 self.tensors.register(&mut bindings);
98 self.scalars.register(&mut bindings);
99
100 bindings
101 }
102}
103
104pub enum TensorState<R: Runtime> {
106 Empty { addr_type: StorageType },
108 Some {
110 buffers: Vec<Binding>,
111 tensor_maps: Vec<TensorMapBinding>,
112 metadata: MetadataBuilder,
113 runtime: PhantomData<R>,
114 },
115}
116
117#[derive(Default, Clone)]
121pub struct ScalarState {
122 data: BTreeMap<StorageType, ScalarValues>,
123}
124
125pub type ScalarValues = Vec<u8>;
127
128impl<R: Runtime> TensorState<R> {
129 fn maybe_init(&mut self) {
130 if let TensorState::Empty { addr_type } = self {
131 *self = TensorState::Some {
132 buffers: Vec::new(),
133 tensor_maps: Vec::new(),
134 metadata: MetadataBuilder::new(*addr_type),
135 runtime: PhantomData,
136 };
137 }
138 }
139
140 fn buffers(&mut self) -> &mut Vec<Binding> {
141 self.maybe_init();
142 let TensorState::Some { buffers, .. } = self else {
143 panic!("Should be init");
144 };
145 buffers
146 }
147
148 fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
149 self.maybe_init();
150 let TensorState::Some { tensor_maps, .. } = self else {
151 panic!("Should be init");
152 };
153 tensor_maps
154 }
155
156 fn metadata(&mut self) -> &mut MetadataBuilder {
157 self.maybe_init();
158 let TensorState::Some { metadata, .. } = self else {
159 panic!("Should be init");
160 };
161 metadata
162 }
163
164 pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
166 if let Some(tensor) = self.process_tensor(tensor) {
167 self.buffers().push(tensor);
168 }
169 }
170
171 fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
172 let (tensor, vectorization) = match tensor {
173 TensorArg::Handle {
174 handle,
175 line_size: vectorization_factor,
176 ..
177 } => (handle, vectorization_factor),
178 TensorArg::Alias { .. } => return None,
179 };
180
181 let elem_size = tensor.elem_size * *vectorization;
182 let buffer_len = tensor.handle.size() / elem_size as u64;
183 let len = tensor.shape.iter().product::<usize>() / *vectorization;
184 self.metadata().with_tensor(
185 tensor.strides.len() as u64,
186 buffer_len,
187 len as u64,
188 tensor.shape.iter().map(|it| *it as u64).collect(),
189 tensor.strides.iter().map(|it| *it as u64).collect(),
190 );
191 Some(tensor.handle.clone().binding())
192 }
193
194 pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
196 if let Some(tensor) = self.process_array(array) {
197 self.buffers().push(tensor);
198 }
199 }
200
201 fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
202 let (array, vectorization) = match array {
203 ArrayArg::Handle {
204 handle,
205 line_size: vectorization_factor,
206 ..
207 } => (handle, vectorization_factor),
208 ArrayArg::Alias { .. } => return None,
209 };
210
211 let elem_size = array.elem_size * *vectorization;
212 let buffer_len = array.handle.size() / elem_size as u64;
213 self.metadata()
214 .with_array(buffer_len, array.length[0] as u64 / *vectorization as u64);
215 Some(array.handle.clone().binding())
216 }
217
218 pub fn push_tensor_map<K: TensorMapKind>(&mut self, map: &TensorMapArg<'_, R, K>) {
220 let binding = self
221 .process_tensor(&map.tensor)
222 .expect("Can't use alias for TensorMap");
223
224 let map = map.metadata.clone();
225 self.tensor_maps().push(TensorMapBinding { binding, map });
226 }
227
228 fn register(self, bindings_global: &mut Bindings) {
229 if let Self::Some {
230 buffers,
231 tensor_maps,
232 metadata,
233 ..
234 } = self
235 {
236 let metadata = metadata.finish();
237
238 bindings_global.buffers = buffers;
239 bindings_global.tensor_maps = tensor_maps;
240 bindings_global.metadata = metadata;
241 }
242 }
243}
244
245impl ScalarState {
246 pub fn push<T: CubeScalar>(&mut self, val: T) {
248 let val = [val];
249 let bytes = T::as_bytes(&val);
250 self.data
251 .entry(T::cube_type())
252 .or_default()
253 .extend(bytes.iter().copied());
254 }
255
256 pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
258 self.data
259 .entry(dtype)
260 .or_default()
261 .extend(bytes.iter().copied());
262 }
263
264 fn register(&self, bindings: &mut Bindings) {
265 for (ty, values) in self.data.iter() {
266 let len = values.len() / ty.size();
267 let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
268
269 let mut data = vec![0; len_u64];
270 let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
271 slice[0..values.len()].copy_from_slice(values);
272 bindings
273 .scalars
274 .insert(*ty, ScalarBinding::new(*ty, len, data));
275 }
276 }
277}
278
279impl<R: Runtime> KernelLauncher<R> {
280 pub fn new(settings: KernelSettings) -> Self {
281 Self {
282 tensors: TensorState::Empty {
283 addr_type: settings.address_type.unsigned_type(),
284 },
285 scalars: Default::default(),
286 settings,
287 runtime: PhantomData,
288 }
289 }
290}