use std::{collections::BTreeMap, marker::PhantomData};
use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
use crate::{CubeScalar, KernelSettings};
use crate::{MetadataBuilder, Runtime};
use cubecl_ir::StorageType;
use cubecl_runtime::server::{Binding, CubeCount, LaunchError, ScalarBinding, TensorMapBinding};
use cubecl_runtime::{
client::ComputeClient,
kernel::{CubeKernel, KernelTask},
server::Bindings,
};
pub struct KernelLauncher<R: Runtime> {
tensors: TensorState<R>,
scalars: ScalarState,
pub settings: KernelSettings,
runtime: PhantomData<R>,
}
impl<R: Runtime> KernelLauncher<R> {
pub fn register_tensor(&mut self, tensor: &TensorArg<'_, R>) {
self.tensors.push_tensor(tensor);
}
pub fn register_tensor_map<K: TensorMapKind>(&mut self, tensor: &TensorMapArg<'_, R, K>) {
self.tensors.push_tensor_map(tensor);
}
pub fn register_array(&mut self, array: &ArrayArg<'_, R>) {
self.tensors.push_array(array);
}
pub fn register_scalar<C: CubeScalar>(&mut self, scalar: C) {
self.scalars.push(scalar);
}
pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
self.scalars.push_raw(bytes, dtype);
}
#[track_caller]
pub fn launch<K: CubeKernel>(
self,
cube_count: CubeCount,
kernel: K,
client: &ComputeClient<R>,
) -> Result<(), LaunchError> {
let bindings = self.into_bindings();
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.launch(kernel, cube_count, bindings)
}
#[track_caller]
pub unsafe fn launch_unchecked<K: CubeKernel>(
self,
cube_count: CubeCount,
kernel: K,
client: &ComputeClient<R>,
) -> Result<(), LaunchError> {
unsafe {
let bindings = self.into_bindings();
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.launch_unchecked(kernel, cube_count, bindings)
}
}
fn into_bindings(self) -> Bindings {
let mut bindings = Bindings::new();
self.tensors.register(&mut bindings);
self.scalars.register(&mut bindings);
bindings
}
}
pub enum TensorState<R: Runtime> {
Empty { addr_type: StorageType },
Some {
buffers: Vec<Binding>,
tensor_maps: Vec<TensorMapBinding>,
metadata: MetadataBuilder,
runtime: PhantomData<R>,
},
}
#[derive(Default, Clone)]
pub struct ScalarState {
data: BTreeMap<StorageType, ScalarValues>,
}
pub type ScalarValues = Vec<u8>;
impl<R: Runtime> TensorState<R> {
fn maybe_init(&mut self) {
if let TensorState::Empty { addr_type } = self {
*self = TensorState::Some {
buffers: Vec::new(),
tensor_maps: Vec::new(),
metadata: MetadataBuilder::new(*addr_type),
runtime: PhantomData,
};
}
}
fn buffers(&mut self) -> &mut Vec<Binding> {
self.maybe_init();
let TensorState::Some { buffers, .. } = self else {
panic!("Should be init");
};
buffers
}
fn tensor_maps(&mut self) -> &mut Vec<TensorMapBinding> {
self.maybe_init();
let TensorState::Some { tensor_maps, .. } = self else {
panic!("Should be init");
};
tensor_maps
}
fn metadata(&mut self) -> &mut MetadataBuilder {
self.maybe_init();
let TensorState::Some { metadata, .. } = self else {
panic!("Should be init");
};
metadata
}
pub fn push_tensor(&mut self, tensor: &TensorArg<'_, R>) {
if let Some(tensor) = self.process_tensor(tensor) {
self.buffers().push(tensor);
}
}
fn process_tensor(&mut self, tensor: &TensorArg<'_, R>) -> Option<Binding> {
let (tensor, vectorization) = match tensor {
TensorArg::Handle {
handle,
line_size: vectorization_factor,
..
} => (handle, vectorization_factor),
TensorArg::Alias { .. } => return None,
};
let elem_size = tensor.elem_size * *vectorization;
let buffer_len = tensor.handle.size() / elem_size as u64;
let len = tensor.shape.iter().product::<usize>() / *vectorization;
self.metadata().with_tensor(
tensor.strides.len() as u64,
buffer_len,
len as u64,
tensor.shape.iter().map(|it| *it as u64).collect(),
tensor.strides.iter().map(|it| *it as u64).collect(),
);
Some(tensor.handle.clone().binding())
}
pub fn push_array(&mut self, array: &ArrayArg<'_, R>) {
if let Some(tensor) = self.process_array(array) {
self.buffers().push(tensor);
}
}
fn process_array(&mut self, array: &ArrayArg<'_, R>) -> Option<Binding> {
let (array, vectorization) = match array {
ArrayArg::Handle {
handle,
line_size: vectorization_factor,
..
} => (handle, vectorization_factor),
ArrayArg::Alias { .. } => return None,
};
let elem_size = array.elem_size * *vectorization;
let buffer_len = array.handle.size() / elem_size as u64;
self.metadata()
.with_array(buffer_len, array.length[0] as u64 / *vectorization as u64);
Some(array.handle.clone().binding())
}
pub fn push_tensor_map<K: TensorMapKind>(&mut self, map: &TensorMapArg<'_, R, K>) {
let binding = self
.process_tensor(&map.tensor)
.expect("Can't use alias for TensorMap");
let map = map.metadata.clone();
self.tensor_maps().push(TensorMapBinding { binding, map });
}
fn register(self, bindings_global: &mut Bindings) {
if let Self::Some {
buffers,
tensor_maps,
metadata,
..
} = self
{
let metadata = metadata.finish();
bindings_global.buffers = buffers;
bindings_global.tensor_maps = tensor_maps;
bindings_global.metadata = metadata;
}
}
}
impl ScalarState {
pub fn push<T: CubeScalar>(&mut self, val: T) {
let val = [val];
let bytes = T::as_bytes(&val);
self.data
.entry(T::cube_type())
.or_default()
.extend(bytes.iter().copied());
}
pub fn push_raw(&mut self, bytes: &[u8], dtype: StorageType) {
self.data
.entry(dtype)
.or_default()
.extend(bytes.iter().copied());
}
fn register(&self, bindings: &mut Bindings) {
for (ty, values) in self.data.iter() {
let len = values.len() / ty.size();
let len_u64 = len.div_ceil(size_of::<u64>() / ty.size());
let mut data = vec![0; len_u64];
let slice = bytemuck::cast_slice_mut::<u64, u8>(&mut data);
slice[0..values.len()].copy_from_slice(values);
bindings
.scalars
.insert(*ty, ScalarBinding::new(*ty, len, data));
}
}
}
impl<R: Runtime> KernelLauncher<R> {
pub fn new(settings: KernelSettings) -> Self {
Self {
tensors: TensorState::Empty {
addr_type: settings.address_type.unsigned_type(),
},
scalars: Default::default(),
settings,
runtime: PhantomData,
}
}
}