use crate::compute::{CubeCount, KernelTask};
use crate::frontend::TensorHandleRef;
use crate::ir::Elem;
use crate::pod::CubeElement;
use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX};
use cubecl_runtime::client::ComputeClient;
use cubecl_runtime::server::{Binding, ComputeServer, Handle};
pub enum CubeCountSettings<S: ComputeServer> {
Input { pos: usize },
Output { pos: usize },
Custom(CubeCount<S>),
}
pub struct Execution<'h, K, R: Runtime, Scalars> {
scalars: Scalars,
client: ComputeClient<R::Server, R::Channel>,
kernel: K,
inputs: &'h [TensorHandleRef<'h, R>],
outputs: &'h [TensorHandleRef<'h, R>],
}
impl<'h, K, R: Runtime> Execution<'h, K, R, ()> {
pub fn start(
kernel: K,
client: ComputeClient<R::Server, R::Channel>,
) -> Execution<'h, K, R, ()> {
Execution {
scalars: (),
client,
kernel,
inputs: &[],
outputs: &[],
}
}
#[allow(unused)]
pub fn inputs(self, inputs: &'h [TensorHandleRef<'h, R>]) -> Execution<'h, K, R, ()> {
Execution {
scalars: self.scalars,
client: self.client,
kernel: self.kernel,
inputs,
outputs: self.outputs,
}
}
pub fn outputs(self, outputs: &'h [TensorHandleRef<'h, R>]) -> Execution<'h, K, R, ()> {
Execution {
scalars: self.scalars,
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs,
}
}
}
impl<'h, K, R> Execution<'h, K, R, ()>
where
K: Kernel + 'static,
R: Runtime,
{
pub fn with_scalars<E>(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> {
Execution {
scalars: (scalars,),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, f32, f32, f32>(
self.inputs,
self.outputs,
None,
None,
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)>
where
K: Kernel + 'static,
R: Runtime,
E: CubeElement,
{
pub fn with_scalars<'b, E2>(
self,
scalars: &'b [E2],
) -> Execution<'h, K, R, (&'a [E], &'b [E2])> {
Execution {
scalars: (self.scalars.0, scalars),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E, f32, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
None,
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])>
where
K: Kernel + 'static,
R: Runtime,
E1: CubeElement,
E2: CubeElement,
{
#[allow(unused, clippy::type_complexity)]
pub fn with_scalars<'c, E3>(
self,
scalars: &'c [E3],
) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> {
Execution {
scalars: (self.scalars.0, self.scalars.1, scalars),
client: self.client,
kernel: self.kernel,
inputs: self.inputs,
outputs: self.outputs,
}
}
#[allow(clippy::too_many_arguments)]
pub fn execute(self, launch: CubeCountSettings<R::Server>)
where
K: Kernel + 'static,
R: Runtime,
{
execute_dynamic::<R, K, E1, E2, f32>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Some(self.scalars.1),
None,
self.kernel,
launch,
self.client,
)
}
}
impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])>
where
K: Kernel + 'static,
R: Runtime,
E1: CubeElement,
E2: CubeElement,
E3: CubeElement,
{
#[allow(unused)]
pub fn execute(self, launch: CubeCountSettings<R::Server>) {
execute_dynamic::<R, K, E1, E2, E3>(
self.inputs,
self.outputs,
Some(self.scalars.0),
Some(self.scalars.1),
Some(self.scalars.2),
self.kernel,
launch,
self.client,
)
}
}
#[allow(clippy::too_many_arguments)]
fn execute_dynamic<R, K, E1, E2, E3>(
inputs: &[TensorHandleRef<R>],
outputs: &[TensorHandleRef<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
kernel: K,
launch: CubeCountSettings<R::Server>,
client: ComputeClient<R::Server, R::Channel>,
) where
K: Kernel + 'static,
R: Runtime,
E1: CubeElement,
E2: CubeElement,
E3: CubeElement,
{
let settings = execute_settings(
inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client,
);
let mut handles = settings.handles_tensors;
handles.push(settings.handle_info.binding());
for handle in settings.handles_scalars.into_iter() {
handles.push(handle.binding());
}
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.execute(kernel, settings.cube_count, handles);
}
struct ExecuteSettings<R: Runtime> {
handles_tensors: Vec<Binding<R::Server>>,
handle_info: Handle<R::Server>,
handles_scalars: Vec<Handle<R::Server>>,
cube_count: CubeCount<R::Server>,
}
fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
inputs: &'a [TensorHandleRef<R>],
outputs: &'a [TensorHandleRef<R>],
scalars_1: Option<&[E1]>,
scalars_2: Option<&[E2]>,
scalars_3: Option<&[E3]>,
launch: CubeCountSettings<R::Server>,
client: &ComputeClient<R::Server, R::Channel>,
) -> ExecuteSettings<R> {
let mut info = Vec::new();
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
if info.is_empty() {
info.push(strides.len() as u32);
}
for s in strides.iter() {
info.push(*s as u32);
}
for s in shape.iter() {
info.push(*s as u32);
}
};
let mut num_elems_output = 0;
for (i, input) in inputs.iter().enumerate() {
if let CubeCountSettings::Input { pos } = &launch {
if i == *pos {
num_elems_output = calculate_num_elems_dyn_rank(input.shape);
}
};
register_info_tensor(input.strides, input.shape);
handles.push(input.handle.clone().binding());
}
for (i, output) in outputs.iter().enumerate() {
if let CubeCountSettings::Output { pos } = &launch {
if i == *pos {
num_elems_output = calculate_num_elems_dyn_rank(output.shape);
}
};
register_info_tensor(output.strides, output.shape);
handles.push(output.handle.clone().binding());
}
if R::require_array_lengths() {
for input in inputs.iter() {
let len = calculate_num_elems_dyn_rank(input.shape);
info.push(len as u32);
}
for output in outputs.iter() {
let len = calculate_num_elems_dyn_rank(output.shape);
info.push(len as u32);
}
}
let info = client.create(bytemuck::cast_slice(&info));
let handles_scalars =
create_scalar_handles::<R, E1, E2, E3>(scalars_1, scalars_2, scalars_3, client);
let cube_count = match launch {
CubeCountSettings::Custom(count) => count,
_ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX),
};
ExecuteSettings {
handles_tensors: handles,
handle_info: info,
handles_scalars,
cube_count,
}
}
fn create_scalar_handles<R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>(
scalars_0: Option<&[E1]>,
scalars_1: Option<&[E2]>,
scalars_2: Option<&[E3]>,
client: &ComputeClient<R::Server, R::Channel>,
) -> Vec<Handle<R::Server>> {
let element_priority = |elem: Elem| match elem {
Elem::Float(_) => 0,
Elem::Int(_) => 1,
Elem::UInt => 2,
Elem::Bool => panic!("Bool scalars are not supported"),
};
let scalar_priorities: [usize; 3] = [
element_priority(E1::cube_elem()),
element_priority(E2::cube_elem()),
element_priority(E3::cube_elem()),
];
let mut handles_scalars = Vec::new();
for i in 0..3 {
for (j, scalar_priority) in scalar_priorities.iter().enumerate() {
if scalar_priority == &i {
if j == 0 {
if let Some(values) = &scalars_0 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
} else if j == 1 {
if let Some(values) = &scalars_1 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
} else if j == 2 {
if let Some(values) = &scalars_2 {
handles_scalars.push(client.create(bytemuck::cast_slice(values)));
}
}
}
}
}
handles_scalars
}
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
let mut num_elems = 1;
for i in shape.iter() {
num_elems *= i;
}
num_elems
}