use crate::{
kernel::{CompiledKernel, KernelDefinition, KernelMetadata},
server::ExecutionMode,
};
use alloc::string::String;
use cubecl_common::backtrace::BackTrace;
use cubecl_ir::{ElemType, StorageType};
use thiserror::Error;
pub trait CubeTask<C: Compiler>: KernelMetadata + Send + Sync {
fn compile(
&self,
compiler: &mut C,
compilation_options: &C::CompilationOptions,
mode: ExecutionMode,
address_type: StorageType,
) -> Result<CompiledKernel<C>, CompilationError>;
}
#[derive(Error, Clone)]
#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
pub enum CompilationError {
#[error(
"An unsupported instruction caused the compilation to fail\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
)]
UnsupportedInstruction {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error(
"An error caused the compilation to fail\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
)]
Generic {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
#[error(
"A validation error caused the compilation to fail\nCaused by:\n {reason}\nBacktrace:\n{backtrace}"
)]
Validation {
reason: String,
#[cfg_attr(std_io, serde(skip))]
backtrace: BackTrace,
},
}
impl core::fmt::Debug for CompilationError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{self}")
}
}
pub trait Compiler: Sync + Send + 'static + Clone + core::fmt::Debug {
type Representation: core::fmt::Display;
type CompilationOptions: Send + Default + core::fmt::Debug;
fn compile(
&mut self,
kernel: KernelDefinition,
compilation_options: &Self::CompilationOptions,
mode: ExecutionMode,
addr_type: StorageType,
) -> Result<Self::Representation, CompilationError>;
fn elem_size(&self, elem: ElemType) -> usize;
fn extension(&self) -> &'static str;
}