cubecl_core/codegen/
execution.rs

1use crate::compute::KernelTask;
2use crate::frontend::TensorHandleRef;
3use crate::ir::Elem;
4use crate::pod::CubeElement;
5use crate::{calculate_cube_count_elemwise, CubeDim, Kernel, Runtime};
6use cubecl_runtime::client::ComputeClient;
7use cubecl_runtime::server::{Binding, CubeCount, Handle};
8
9/// The position of the input or output to calculate the number of cubes to launch.
10pub enum CubeCountSettings {
11    Input { pos: usize },
12    Output { pos: usize },
13    Custom(CubeCount),
14}
15
16pub struct Execution<'h, K, R: Runtime, Scalars> {
17    scalars: Scalars,
18    client: ComputeClient<R::Server, R::Channel>,
19    kernel: K,
20    inputs: &'h [TensorHandleRef<'h, R>],
21    outputs: &'h [TensorHandleRef<'h, R>],
22}
23
24impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
25    pub fn start(
26        kernel: K,
27        client: ComputeClient<R::Server, R::Channel>,
28    ) -> Execution<'h, K, R, ()> {
29        Execution {
30            scalars: (),
31            client,
32            kernel,
33            inputs: &[],
34            outputs: &[],
35        }
36    }
37
38    #[allow(unused)]
39    pub fn inputs(self, inputs: &'h [TensorHandleRef<'h, R>]) -> Execution<'h, K, R, ()> {
40        Execution {
41            scalars: self.scalars,
42            client: self.client,
43            kernel: self.kernel,
44            inputs,
45            outputs: self.outputs,
46        }
47    }
48
49    pub fn outputs(self, outputs: &'h [TensorHandleRef<'h, R>]) -> Execution<'h, K, R, ()> {
50        Execution {
51            scalars: self.scalars,
52            client: self.client,
53            kernel: self.kernel,
54            inputs: self.inputs,
55            outputs,
56        }
57    }
58}
59
60impl<'h, K, R> Execution<'h, K, R, ()>
61where
62    K: Kernel + 'static,
63    R: Runtime,
64{
65    pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
66        Execution {
67            scalars: (scalars,),
68            client: self.client,
69            kernel: self.kernel,
70            inputs: self.inputs,
71            outputs: self.outputs,
72        }
73    }
74    /// Execute a dynamic kernel.
75    #[allow(unused)]
76    pub fn execute(self, launch: CubeCountSettings) {
77        execute_dynamic::<R, K, f32, f32, f32>(
78            self.inputs,
79            self.outputs,
80            None,
81            None,
82            None,
83            self.kernel,
84            launch,
85            self.client,
86        )
87    }
88}
89
90impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
91where
92    K: Kernel + 'static,
93    R: Runtime,
94    E: CubeElement,
95{
96    pub fn with_scalars<'b, E2>(
97        self,
98        scalars: &'b [E2],
99    ) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
100        Execution {
101            scalars: (self.scalars.0, scalars),
102            client: self.client,
103            kernel: self.kernel,
104            inputs: self.inputs,
105            outputs: self.outputs,
106        }
107    }
108
109    /// Execute a dynamic kernel.
110    #[allow(unused)]
111    pub fn execute(self, launch: CubeCountSettings) {
112        execute_dynamic::<R, K, E, f32, f32>(
113            self.inputs,
114            self.outputs,
115            Some(self.scalars.0),
116            None,
117            None,
118            self.kernel,
119            launch,
120            self.client,
121        )
122    }
123}
124
125impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
126where
127    K: Kernel + 'static,
128    R: Runtime,
129    E1: CubeElement,
130    E2: CubeElement,
131{
132    #[allow(unused, clippy::type_complexity)]
133    pub fn with_scalars<'c, E3>(
134        self,
135        scalars: &'c [E3],
136    ) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
137        Execution {
138            scalars: (self.scalars.0, self.scalars.1, scalars),
139            client: self.client,
140            kernel: self.kernel,
141            inputs: self.inputs,
142            outputs: self.outputs,
143        }
144    }
145    /// Execute a dynamic kernel.
146    #[allow(clippy::too_many_arguments)]
147    pub fn execute(self, launch: CubeCountSettings)
148    where
149        K: Kernel + 'static,
150        R: Runtime,
151    {
152        execute_dynamic::<R, K, E1, E2, f32>(
153            self.inputs,
154            self.outputs,
155            Some(self.scalars.0),
156            Some(self.scalars.1),
157            None,
158            self.kernel,
159            launch,
160            self.client,
161        )
162    }
163}
164
165impl<K, R, E1, E2, E3> Execution<'_, K, R, (&[E1], &[E2], &[E3])>
166where
167    K: Kernel + 'static,
168    R: Runtime,
169    E1: CubeElement,
170    E2: CubeElement,
171    E3: CubeElement,
172{
173    /// Execute a dynamic kernel.
174    #[allow(unused)]
175    pub fn execute(self, launch: CubeCountSettings) {
176        execute_dynamic::<R, K, E1, E2, E3>(
177            self.inputs,
178            self.outputs,
179            Some(self.scalars.0),
180            Some(self.scalars.1),
181            Some(self.scalars.2),
182            self.kernel,
183            launch,
184            self.client,
185        )
186    }
187}
188
189#[allow(clippy::too_many_arguments)]
190fn execute_dynamic<R, K, E1, E2, E3>(
191    inputs: &[TensorHandleRef<R>],
192    outputs: &[TensorHandleRef<R>],
193    scalars_1: Option<&[E1]>,
194    scalars_2: Option<&[E2]>,
195    scalars_3: Option<&[E3]>,
196    kernel: K,
197    launch: CubeCountSettings,
198    client: ComputeClient<R::Server, R::Channel>,
199) where
200    K: Kernel + 'static,
201    R: Runtime,
202    E1: CubeElement,
203    E2: CubeElement,
204    E3: CubeElement,
205{
206    let settings = execute_settings::<R, E1, E2, E3>(
207        inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
208    );
209
210    let mut handles = settings.handles_tensors;
211
212    handles.push(settings.handle_info.binding());
213    for handle in settings.handles_scalars.into_iter() {
214        handles.push(handle.binding());
215    }
216
217    let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
218    client.execute(kernel, settings.cube_count, handles);
219}
220
221struct ExecuteSettings {
222    handles_tensors: Vec<Binding>,
223    handle_info: Handle,
224    handles_scalars: Vec<Handle>,
225    cube_count: CubeCount,
226}
227
228#[allow(clippy::too_many_arguments)]
229fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
230    inputs: &'a [TensorHandleRef<R>],
231    outputs: &'a [TensorHandleRef<R>],
232    scalars_1: Option<&[E1]>,
233    scalars_2: Option<&[E2]>,
234    scalars_3: Option<&[E3]>,
235    launch: CubeCountSettings,
236    client: &ComputeClient<R::Server, R::Channel>,
237) -> ExecuteSettings {
238    let mut info = Vec::new();
239    let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
240
241    // Inner function to fill the info buffer.
242    let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
243        if info.is_empty() {
244            info.push(strides.len() as u32);
245        }
246
247        for s in strides.iter() {
248            info.push(*s as u32);
249        }
250        for s in shape.iter() {
251            info.push(*s as u32);
252        }
253    };
254
255    let mut num_elems_output = 0;
256
257    // We start by registering the inputs.
258    for (i, input) in inputs.iter().enumerate() {
259        if let CubeCountSettings::Input { pos } = &launch {
260            if i == *pos {
261                num_elems_output = calculate_num_elems_dyn_rank(input.shape);
262            }
263        };
264        register_info_tensor(input.strides, input.shape);
265        handles.push(input.handle.clone().binding());
266    }
267
268    // Then we follow with the outputs.
269    for (i, output) in outputs.iter().enumerate() {
270        if let CubeCountSettings::Output { pos } = &launch {
271            if i == *pos {
272                num_elems_output = calculate_num_elems_dyn_rank(output.shape);
273            }
274        };
275        register_info_tensor(output.strides, output.shape);
276        handles.push(output.handle.clone().binding());
277    }
278
279    // [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0...,  I0len, I1len1, O0len]
280    if R::require_array_lengths() {
281        for input in inputs.iter() {
282            let len = calculate_num_elems_dyn_rank(input.shape);
283            info.push(len as u32);
284        }
285
286        for output in outputs.iter() {
287            let len = calculate_num_elems_dyn_rank(output.shape);
288            info.push(len as u32);
289        }
290    }
291
292    let info = client.create(bytemuck::cast_slice(&info));
293
294    // Finally we finish with the named bindings.
295    let handles_scalars =
296        create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
297
298    let cube_count = match launch {
299        CubeCountSettings::Custom(count) => count,
300        _ => calculate_cube_count_elemwise(num_elems_output, CubeDim::default()),
301    };
302
303    ExecuteSettings {
304        handles_tensors: handles,
305        handle_info: info,
306        handles_scalars,
307        cube_count,
308    }
309}
310
311fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
312    scalars_0: Option<&[E1]>,
313    scalars_1: Option<&[E2]>,
314    scalars_2: Option<&[E3]>,
315    client: &ComputeClient<R::Server, R::Channel>,
316) -> Vec<Handle> {
317    // It is crucial that scalars follow this order: float, int, uint
318    let element_priority = |elem: Elem| match elem {
319        Elem::Float(_) | Elem::AtomicFloat(_) => 0,
320        Elem::Int(_) | Elem::AtomicInt(_) => 1,
321        Elem::UInt(_) | Elem::AtomicUInt(_) => 2,
322        Elem::Bool => panic!("Bool scalars are not supported"),
323    };
324    let scalar_priorities: [usize; 3] = [
325        element_priority(E1::cube_elem()),
326        element_priority(E2::cube_elem()),
327        element_priority(E3::cube_elem()),
328    ];
329
330    let mut handles_scalars = Vec::new();
331    for i in 0..3 {
332        for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
333            if scalar_priority == &i {
334                if j == 0 {
335                    if let Some(values) = &scalars_0 {
336                        handles_scalars.push(client.create(bytemuck::cast_slice(values)));
337                    }
338                } else if j == 1 {
339                    if let Some(values) = &scalars_1 {
340                        handles_scalars.push(client.create(bytemuck::cast_slice(values)));
341                    }
342                } else if j == 2 {
343                    if let Some(values) = &scalars_2 {
344                        handles_scalars.push(client.create(bytemuck::cast_slice(values)));
345                    }
346                }
347            }
348        }
349    }
350
351    handles_scalars
352}
353
354pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
355    let mut num_elems = 1;
356    for i in shape.iter() {
357        num_elems *= i;
358    }
359    num_elems
360}