use alloc::{
boxed::Box,
string::{String, ToString},
vec::Vec,
};
use core::{
fmt::Display,
marker::PhantomData,
sync::atomic::{AtomicI8, Ordering},
};
use cubecl_common::format::format_str;
use cubecl_ir::{Id, Scope, StorageType, Type};
use serde::{Deserialize, Serialize};
use crate::{
compiler::{CompilationError, Compiler, CubeTask},
config::{GlobalConfig, compilation::CompilationLogLevel},
id::KernelId,
server::{CubeDim, ExecutionMode},
};
pub trait KernelMetadata: Send + Sync + 'static {
fn name(&self) -> &'static str {
core::any::type_name::<Self>()
}
fn id(&self) -> KernelId;
fn address_type(&self) -> StorageType;
}
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub struct KernelDefinition {
pub buffers: Vec<KernelArg>,
pub tensor_maps: Vec<KernelArg>,
pub scalars: Vec<ScalarKernelArg>,
pub cube_dim: CubeDim,
pub body: Scope,
pub options: KernelOptions,
}
#[derive(Default, Clone, Debug, Hash, PartialEq, Eq)]
pub struct KernelOptions {
pub kernel_name: String,
pub debug_symbols: bool,
pub cluster_dim: Option<CubeDim>,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct KernelArg {
pub id: Id,
pub visibility: Visibility,
pub ty: Type,
pub size: Option<usize>,
pub has_extended_meta: bool,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct ScalarKernelArg {
pub ty: StorageType,
pub count: usize,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
#[allow(missing_docs)]
pub enum Visibility {
Read,
ReadWrite,
}
pub struct CompiledKernel<C: Compiler> {
pub entrypoint_name: String,
pub debug_name: Option<&'static str>,
pub source: String,
pub repr: Option<C::Representation>,
pub cube_dim: CubeDim,
pub debug_info: Option<DebugInformation>,
}
#[derive(new)]
pub struct DebugInformation {
pub lang_tag: &'static str,
pub id: KernelId,
}
pub trait CubeKernel: KernelMetadata {
fn define(&self) -> KernelDefinition;
}
pub struct KernelTask<C: Compiler, K: CubeKernel> {
kernel_definition: K,
_compiler: PhantomData<C>,
}
pub struct CubeTaskKernel<C: Compiler> {
pub task: Box<dyn CubeTask<C>>,
}
impl<C: Compiler, K: CubeKernel> KernelTask<C, K> {
pub fn new(kernel_definition: K) -> Self {
Self {
kernel_definition,
_compiler: PhantomData,
}
}
}
impl<C: Compiler, K: CubeKernel> CubeTask<C> for KernelTask<C, K> {
fn compile(
&self,
compiler: &mut C,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
addr_type: StorageType,
) -> Result<CompiledKernel<C>, CompilationError> {
let gpu_ir = self.kernel_definition.define();
let entrypoint_name = gpu_ir.options.kernel_name.clone();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = compiler.compile(gpu_ir, compilation_options, mode, addr_type)?;
Ok(CompiledKernel {
entrypoint_name,
debug_name: Some(core::any::type_name::<K>()),
source: lower_level_ir.to_string(),
repr: Some(lower_level_ir),
cube_dim,
debug_info: None,
})
}
}
impl<C: Compiler, K: CubeKernel> KernelMetadata for KernelTask<C, K> {
fn id(&self) -> KernelId {
self.kernel_definition.id()
}
fn name(&self) -> &'static str {
self.kernel_definition.name()
}
fn address_type(&self) -> StorageType {
self.kernel_definition.address_type()
}
}
impl<C: Compiler> KernelMetadata for Box<dyn CubeTask<C>> {
fn id(&self) -> KernelId {
self.as_ref().id()
}
fn name(&self) -> &'static str {
self.as_ref().name()
}
fn address_type(&self) -> StorageType {
self.as_ref().address_type()
}
}
static COMPILATION_LEVEL: AtomicI8 = AtomicI8::new(-1);
fn compilation_level() -> u8 {
let compilation_level = COMPILATION_LEVEL.load(Ordering::Relaxed);
if compilation_level == -1 {
let val = match GlobalConfig::get().compilation.logger.level {
CompilationLogLevel::Full => 2,
CompilationLogLevel::Disabled => 0,
CompilationLogLevel::Basic => 1,
};
COMPILATION_LEVEL.store(val, Ordering::Relaxed);
val as u8
} else {
compilation_level as u8
}
}
impl<C: Compiler> Display for CompiledKernel<C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match compilation_level() {
2 => self.format_full(f),
_ => self.format_basic(f),
}
}
}
impl<C: Compiler> CompiledKernel<C> {
fn format_basic(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("[Compiling kernel]")?;
if let Some(name) = self.debug_name {
if name.len() <= 32 {
f.write_fmt(format_args!(" {name}"))?;
} else {
f.write_fmt(format_args!(" {}", name.split('<').next().unwrap_or("")))?;
}
}
Ok(())
}
fn format_full(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("[START_KERNEL_COMPILATION]")?;
if let Some(name) = self.debug_name {
if name.len() <= 32 {
f.write_fmt(format_args!("\nname: {name}"))?;
} else {
let name = format_str(name, &[('<', '>')], false);
f.write_fmt(format_args!("\nname: {name}"))?;
}
}
if let Some(info) = &self.debug_info {
f.write_fmt(format_args!("\nid: {:#?}", info.id))?;
}
f.write_fmt(format_args!(
"
source:
```{}
{}
```
[END_KERNEL_COMPILATION]
",
self.debug_info
.as_ref()
.map(|info| info.lang_tag)
.unwrap_or(""),
self.source
))
}
}