use alloc::{string::ToString, vec::Vec};
use cubecl_ir::{Id, Scope, StorageType, Type};
use cubecl_runtime::{
kernel::{KernelArg, KernelDefinition, KernelOptions, ScalarKernelArg, Visibility},
server::CubeDim,
};
use crate::prelude::AddressType;
#[derive(Clone)]
pub struct KernelIntegrator {
expansion: KernelExpansion,
buffer_bindings: Vec<KernelArg>,
scalar_bindings: Vec<ScalarKernelArg>,
tensor_maps: Vec<KernelArg>,
}
#[derive(Clone)]
pub struct KernelExpansion {
pub buffers: Vec<BufferInfo>,
pub scalars: Vec<ScalarInfo>,
pub tensor_maps: Vec<BufferInfo>,
pub scope: Scope,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct KernelSettings {
pub cube_dim: CubeDim,
pub address_type: AddressType,
pub options: KernelOptions,
}
impl Default for KernelSettings {
fn default() -> Self {
Self {
cube_dim: CubeDim::new_1d(1),
address_type: AddressType::U32,
options: Default::default(),
}
}
}
impl KernelSettings {
pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self {
self.cube_dim = cube_dim;
self
}
pub fn address_type(mut self, ty: AddressType) -> Self {
self.address_type = ty;
self
}
pub fn kernel_name<S: AsRef<str>>(mut self, name: S) -> Self {
self.options.kernel_name = name.as_ref().to_string();
self
}
pub fn debug_symbols(mut self) -> Self {
self.options.debug_symbols = true;
self
}
pub fn cluster_dim(mut self, cluster_dim: CubeDim) -> Self {
self.options.cluster_dim = Some(cluster_dim);
self
}
}
#[derive(Clone, Debug)]
pub struct BufferInfo {
pub id: Id,
pub item: Type,
pub visibility: Visibility,
pub has_extended_meta: bool,
}
#[derive(Clone, Debug)]
pub struct ScalarInfo {
pub ty: StorageType,
pub count: usize,
}
impl KernelIntegrator {
pub fn new(info: KernelExpansion) -> Self {
Self {
expansion: info,
buffer_bindings: Default::default(),
scalar_bindings: Default::default(),
tensor_maps: Default::default(),
}
}
#[cfg_attr(feature = "tracing", tracing::instrument(level = "trace", skip(self)))]
pub fn integrate(mut self, settings: KernelSettings) -> KernelDefinition {
self.register_buffers();
self.register_scalars();
self.register_tensor_maps();
self.scalar_bindings.sort_by_key(|binding| binding.ty);
KernelDefinition {
buffers: self.buffer_bindings,
tensor_maps: self.tensor_maps,
scalars: self.scalar_bindings,
cube_dim: settings.cube_dim,
body: self.expansion.scope,
options: settings.options,
}
}
fn register_buffers(&mut self) {
for buffer in self.expansion.buffers.drain(..) {
self.buffer_bindings.push(KernelArg {
id: buffer.id,
ty: buffer.item,
visibility: buffer.visibility,
has_extended_meta: buffer.has_extended_meta,
size: None,
});
}
}
fn register_scalars(&mut self) {
for scalar in self.expansion.scalars.drain(..) {
self.scalar_bindings.push(ScalarKernelArg {
ty: scalar.ty,
count: scalar.count,
});
}
}
fn register_tensor_maps(&mut self) {
for buffer in self.expansion.tensor_maps.drain(..) {
self.tensor_maps.push(KernelArg {
id: buffer.id,
ty: buffer.item,
visibility: buffer.visibility,
has_extended_meta: buffer.has_extended_meta,
size: None,
});
}
}
}