cubecl_runtime/
compiler.rs

1use crate::kernel::{CompiledKernel, KernelDefinition, KernelMetadata};
2use alloc::string::String;
3use cubecl_common::{ExecutionMode, backtrace::BackTrace};
4use cubecl_ir::ElemType;
5use thiserror::Error;
6
7/// Kernel trait with the ComputeShader that will be compiled and cached based on the
8/// provided id.
9pub trait CubeTask<C: Compiler>: KernelMetadata + Send + Sync {
10    /// Compile a kernel and return the compiled form with an optional non-text representation
11    fn compile(
12        &self,
13        compiler: &mut C,
14        compilation_options: &C::CompilationOptions,
15        mode: ExecutionMode,
16    ) -> Result<CompiledKernel<C>, CompilationError>;
17}
18
19/// JIT compilation error.
20#[derive(Error, Clone)]
21#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
22pub enum CompilationError {
23    /// An instruction isn't supported.
24    #[error(
25        "An unsupported instruction caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
26    )]
27    UnsupportedInstruction {
28        /// The caused of the error.
29        reason: String,
30        /// The backtrace for this error.
31        #[cfg_attr(std_io, serde(skip))]
32        backtrace: BackTrace,
33    },
34
35    /// A generic compilation error.
36    #[error(
37        "An error caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
38    )]
39    Generic {
40        /// The error context.
41        reason: String,
42        /// The backtrace for this error.
43        #[cfg_attr(std_io, serde(skip))]
44        backtrace: BackTrace,
45    },
46    /// A generic compilation error.
47    #[error(
48        "A validation error caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
49    )]
50    Validation {
51        /// The error context.
52        reason: String,
53        /// The backtrace for this error.
54        #[cfg_attr(std_io, serde(skip))]
55        backtrace: BackTrace,
56    },
57}
58
59impl core::fmt::Debug for CompilationError {
60    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
61        f.write_fmt(format_args!("{self}"))
62    }
63}
64
65/// Compiles the representation into its own representation that can be formatted into tokens.
66pub trait Compiler: Sync + Send + 'static + Clone + core::fmt::Debug {
67    /// The representation for the compiled code.
68    type Representation: core::fmt::Display;
69    /// The compilation options used to configure the compiler
70    type CompilationOptions: Send + Default + core::fmt::Debug;
71
72    /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
73    fn compile(
74        &mut self,
75        kernel: KernelDefinition,
76        compilation_options: &Self::CompilationOptions,
77        mode: ExecutionMode,
78    ) -> Result<Self::Representation, CompilationError>;
79
80    /// The size of the given element in bytes.
81    fn elem_size(&self, elem: ElemType) -> usize;
82
83    /// The default extension for the runtime's kernel/shader code.
84    /// Might change based on which compiler is used.
85    fn extension(&self) -> &'static str;
86}