use alloc::{boxed::Box, vec::Vec};
use core::marker::PhantomData;
use crate::Runtime;
use crate::prelude::{ArrayArg, TensorArg, TensorMapArg, TensorMapKind};
use crate::{InfoBuilder, KernelSettings, ScalarArgType};
#[cfg(feature = "std")]
use core::cell::RefCell;
use cubecl_ir::{AddressType, Scope, StorageType, Type};
use cubecl_runtime::server::{Binding, CubeCount, TensorMapBinding};
use cubecl_runtime::{
client::ComputeClient,
kernel::{CubeKernel, KernelTask},
server::KernelArguments,
};
#[cfg(feature = "std")]
std::thread_local! {
static INFO: RefCell<InfoBuilder> = RefCell::new(InfoBuilder::default());
static SCOPE: RefCell<Scope> = RefCell::new(Scope::root(false));
}
pub struct KernelLauncher<R: Runtime> {
buffers: Vec<Binding>,
tensor_maps: Vec<TensorMapBinding>,
address_type: AddressType,
pub settings: KernelSettings,
#[cfg(not(feature = "std"))]
info: InfoBuilder,
#[cfg(not(feature = "std"))]
pub scope: Scope,
_runtime: PhantomData<R>,
}
impl<R: Runtime> KernelLauncher<R> {
#[cfg(feature = "std")]
pub fn with_scope<T>(&mut self, fun: impl FnMut(&mut Scope) -> T) -> T {
SCOPE.with_borrow_mut(fun)
}
#[cfg(not(feature = "std"))]
pub fn with_scope<T>(&mut self, mut fun: impl FnMut(&mut Scope) -> T) -> T {
fun(&mut self.scope)
}
#[cfg(feature = "std")]
fn with_info<T>(&mut self, fun: impl FnMut(&mut InfoBuilder) -> T) -> T {
INFO.with_borrow_mut(fun)
}
#[cfg(not(feature = "std"))]
fn with_info<T>(&mut self, mut fun: impl FnMut(&mut InfoBuilder) -> T) -> T {
fun(&mut self.info)
}
pub fn register_scalar<C: ScalarArgType>(&mut self, scalar: C) {
self.with_info(|info| info.scalars.push(scalar));
}
pub fn register_scalar_raw(&mut self, bytes: &[u8], dtype: StorageType) {
self.with_info(|info| info.scalars.push_raw(bytes, dtype));
}
#[track_caller]
pub fn launch<K: CubeKernel>(
self,
cube_count: CubeCount,
kernel: K,
client: &ComputeClient<R>,
) {
let bindings = self.into_bindings();
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.launch(kernel, cube_count, bindings)
}
#[track_caller]
pub unsafe fn launch_unchecked<K: CubeKernel>(
self,
cube_count: CubeCount,
kernel: K,
client: &ComputeClient<R>,
) {
unsafe {
let bindings = self.into_bindings();
let kernel = Box::new(KernelTask::<R::Compiler, K>::new(kernel));
client.launch_unchecked(kernel, cube_count, bindings)
}
}
fn into_bindings(mut self) -> KernelArguments {
let mut bindings = KernelArguments::new();
let address_type = self.address_type;
let info = self.with_info(|info| info.finish(address_type));
bindings.buffers = self.buffers;
bindings.tensor_maps = self.tensor_maps;
bindings.info = info;
bindings
}
}
impl<R: Runtime> KernelLauncher<R> {
pub fn register_tensor(&mut self, tensor: TensorArg<R>, ty: Type) {
if let Some(tensor) = self.process_tensor(tensor, ty) {
self.buffers.push(tensor);
}
}
fn process_tensor(&mut self, tensor: TensorArg<R>, ty: Type) -> Option<Binding> {
let tensor = match tensor {
TensorArg::Handle { handle, .. } => handle,
TensorArg::Alias { .. } => return None,
};
let elem_size = ty.size();
let vectorization = ty.vector_size();
let buffer_len = tensor.handle.size_in_used() / elem_size as u64;
let len = tensor.shape.iter().product::<usize>() / vectorization;
let address_type = self.address_type;
self.with_info(|info| {
info.metadata.register_tensor(
tensor.strides.len() as u64,
buffer_len,
len as u64,
tensor.shape.clone(),
tensor.strides.clone(),
address_type,
)
});
Some(tensor.handle)
}
pub fn register_array(&mut self, array: ArrayArg<R>, ty: Type) {
if let Some(tensor) = self.process_array(array, ty) {
self.buffers.push(tensor);
}
}
fn process_array(&mut self, array: ArrayArg<R>, ty: Type) -> Option<Binding> {
let array = match array {
ArrayArg::Handle { handle, .. } => handle,
ArrayArg::Alias { .. } => return None,
};
let elem_size = ty.size();
let vectorization = ty.vector_size();
let buffer_len = array.handle.size_in_used() / elem_size as u64;
let address_type = self.address_type;
self.with_info(|info| {
info.metadata.register_array(
buffer_len,
array.length[0] as u64 / vectorization as u64,
address_type,
)
});
Some(array.handle)
}
pub fn register_tensor_map<K: TensorMapKind>(&mut self, map: TensorMapArg<R, K>, ty: Type) {
let binding = self
.process_tensor(map.tensor, ty)
.expect("Can't use alias for TensorMap");
let map = map.metadata.clone();
self.tensor_maps.push(TensorMapBinding { binding, map });
}
}
impl<R: Runtime> KernelLauncher<R> {
pub fn new(settings: KernelSettings) -> Self {
Self {
address_type: settings.address_type,
settings,
buffers: Vec::new(),
tensor_maps: Vec::new(),
_runtime: PhantomData,
#[cfg(not(feature = "std"))]
info: InfoBuilder::default(),
#[cfg(not(feature = "std"))]
scope: Scope::root(false),
}
}
}