use super::SourceTemplate;
use crate::{CubeRuntime, element::CubeElement, tensor::CubeTensor};
use cubecl::{CompilationError, Compiler, CubeTask, prelude::*};
pub trait KernelSource: Send + 'static + Sync {
fn source(&self) -> SourceTemplate;
fn id(&self) -> KernelId;
}
#[derive(new)]
pub struct SourceKernel<K> {
kernel_source: K,
cube_dim: CubeDim,
}
impl<C: Compiler, K: KernelSource> CubeTask<C> for SourceKernel<K> {
fn compile(
&self,
_compiler: &mut C,
_options: &C::CompilationOptions,
_mode: ExecutionMode,
_address_type: StorageType,
) -> Result<CompiledKernel<C>, CompilationError> {
let source_template = self.kernel_source.source();
let source = source_template.complete();
Ok(CompiledKernel {
entrypoint_name: "main".to_string(),
debug_name: Some(core::any::type_name::<K>()),
source,
cube_dim: self.cube_dim,
debug_info: None,
repr: None,
})
}
}
impl<K: KernelSource> KernelMetadata for SourceKernel<K> {
fn id(&self) -> KernelId {
self.kernel_source.id()
}
fn address_type(&self) -> StorageType {
u32::as_type_native_unchecked().storage_type()
}
}
#[macro_export]
macro_rules! kernel_source {
(
$struct:ident,
$file:expr
) => {
#[derive(new)]
pub struct $struct;
impl $struct {
fn source(&self) -> $crate::template::SourceTemplate {
$crate::template::SourceTemplate::new(include_str!($file))
}
}
};
}
pub fn build_info<R: CubeRuntime, E: CubeElement>(tensors: &[&CubeTensor<R>]) -> Vec<u32> {
let ndims = tensors[0].meta.num_dims();
let mut info: Vec<u32> = vec![0; tensors.len() * 2 * ndims + 1];
info[0] = ndims as u32;
let mut current = 1;
for tensor in tensors.iter() {
for d in 0..ndims {
info[current] = tensor.meta.strides()[d] as u32;
current += 1;
}
}
for tensor in tensors.iter() {
for d in 0..ndims {
info[current] = tensor.meta.shape()[d] as u32;
current += 1;
}
}
info
}