cubecl_runtime/
compiler.rs

1use crate::{
2    kernel::{CompiledKernel, KernelDefinition, KernelMetadata},
3    server::ExecutionMode,
4};
5use alloc::string::String;
6use cubecl_common::backtrace::BackTrace;
7use cubecl_ir::{ElemType, StorageType};
8use thiserror::Error;
9
10/// Kernel trait with the ComputeShader that will be compiled and cached based on the
11/// provided id.
12pub trait CubeTask<C: Compiler>: KernelMetadata + Send + Sync {
13    /// Compile a kernel and return the compiled form with an optional non-text representation
14    fn compile(
15        &self,
16        compiler: &mut C,
17        compilation_options: &C::CompilationOptions,
18        mode: ExecutionMode,
19        address_type: StorageType,
20    ) -> Result<CompiledKernel<C>, CompilationError>;
21}
22
23/// JIT compilation error.
24#[derive(Error, Clone)]
25#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
26pub enum CompilationError {
27    /// An instruction isn't supported.
28    #[error(
29        "An unsupported instruction caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
30    )]
31    UnsupportedInstruction {
32        /// The caused of the error.
33        reason: String,
34        /// The backtrace for this error.
35        #[cfg_attr(std_io, serde(skip))]
36        backtrace: BackTrace,
37    },
38
39    /// A generic compilation error.
40    #[error(
41        "An error caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
42    )]
43    Generic {
44        /// The error context.
45        reason: String,
46        /// The backtrace for this error.
47        #[cfg_attr(std_io, serde(skip))]
48        backtrace: BackTrace,
49    },
50    /// A generic compilation error.
51    #[error(
52        "A validation error caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
53    )]
54    Validation {
55        /// The error context.
56        reason: String,
57        /// The backtrace for this error.
58        #[cfg_attr(std_io, serde(skip))]
59        backtrace: BackTrace,
60    },
61}
62
63impl core::fmt::Debug for CompilationError {
64    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
65        write!(f, "{self}")
66    }
67}
68
69/// Compiles the representation into its own representation that can be formatted into tokens.
70pub trait Compiler: Sync + Send + 'static + Clone + core::fmt::Debug {
71    /// The representation for the compiled code.
72    type Representation: core::fmt::Display;
73    /// The compilation options used to configure the compiler
74    type CompilationOptions: Send + Default + core::fmt::Debug;
75
76    /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
77    fn compile(
78        &mut self,
79        kernel: KernelDefinition,
80        compilation_options: &Self::CompilationOptions,
81        mode: ExecutionMode,
82        addr_type: StorageType,
83    ) -> Result<Self::Representation, CompilationError>;
84
85    /// The size of the given element in bytes.
86    fn elem_size(&self, elem: ElemType) -> usize;
87
88    /// The default extension for the runtime's kernel/shader code.
89    /// Might change based on which compiler is used.
90    fn extension(&self) -> &'static str;
91}