use std::marker::PhantomData;
use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel};
use alloc::sync::Arc;
use cubecl_runtime::server::{Binding, ComputeServer};
pub struct CompiledKernel {
pub source: String,
pub cube_dim: CubeDim,
pub shared_mem_bytes: usize,
}
pub trait CubeTask: Send + Sync {
fn id(&self) -> String;
fn compile(&self) -> CompiledKernel;
}
#[derive(new)]
pub struct KernelTask<C: Compiler, K: Kernel> {
kernel_definition: K,
_compiler: PhantomData<C>,
}
impl<C: Compiler, K: Kernel> CubeTask for KernelTask<C, K> {
fn compile(&self) -> CompiledKernel {
let gpu_ir = self.kernel_definition.define();
let cube_dim = gpu_ir.cube_dim;
let lower_level_ir = C::compile(gpu_ir);
let shared_mem_bytes = lower_level_ir.shared_memory_size();
let source = lower_level_ir.to_string();
CompiledKernel {
source,
cube_dim,
shared_mem_bytes,
}
}
fn id(&self) -> String {
self.kernel_definition.id().clone()
}
}
impl CubeTask for Arc<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
}
fn id(&self) -> String {
self.as_ref().id()
}
}
impl CubeTask for Box<dyn CubeTask> {
fn compile(&self) -> CompiledKernel {
self.as_ref().compile()
}
fn id(&self) -> String {
self.as_ref().id()
}
}
pub enum CubeCount<S: ComputeServer> {
Static(u32, u32, u32),
Dynamic(Binding<S>),
}
impl<S: ComputeServer> Clone for CubeCount<S> {
fn clone(&self) -> Self {
match self {
Self::Static(x, y, z) => Self::Static(*x, *y, *z),
Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
}
}
}