cubecl_runtime/
compiler.rs

1use crate::{
2    kernel::{CompiledKernel, KernelDefinition, KernelMetadata},
3    server::ExecutionMode,
4};
5use alloc::string::String;
6use cubecl_common::backtrace::BackTrace;
7use cubecl_ir::ElemType;
8use thiserror::Error;
9
10/// Kernel trait with the ComputeShader that will be compiled and cached based on the
11/// provided id.
12pub trait CubeTask<C: Compiler>: KernelMetadata + Send + Sync {
13    /// Compile a kernel and return the compiled form with an optional non-text representation
14    fn compile(
15        &self,
16        compiler: &mut C,
17        compilation_options: &C::CompilationOptions,
18        mode: ExecutionMode,
19    ) -> Result<CompiledKernel<C>, CompilationError>;
20}
21
22/// JIT compilation error.
23#[derive(Error, Clone)]
24#[cfg_attr(std_io, derive(serde::Serialize, serde::Deserialize))]
25pub enum CompilationError {
26    /// An instruction isn't supported.
27    #[error(
28        "An unsupported instruction caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
29    )]
30    UnsupportedInstruction {
31        /// The caused of the error.
32        reason: String,
33        /// The backtrace for this error.
34        #[cfg_attr(std_io, serde(skip))]
35        backtrace: BackTrace,
36    },
37
38    /// A generic compilation error.
39    #[error(
40        "An error caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
41    )]
42    Generic {
43        /// The error context.
44        reason: String,
45        /// The backtrace for this error.
46        #[cfg_attr(std_io, serde(skip))]
47        backtrace: BackTrace,
48    },
49    /// A generic compilation error.
50    #[error(
51        "A validation error caused the compilation to fail\nCaused by:\n  {reason}\nBacktrace:\n{backtrace}"
52    )]
53    Validation {
54        /// The error context.
55        reason: String,
56        /// The backtrace for this error.
57        #[cfg_attr(std_io, serde(skip))]
58        backtrace: BackTrace,
59    },
60}
61
62impl core::fmt::Debug for CompilationError {
63    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64        f.write_fmt(format_args!("{self}"))
65    }
66}
67
68/// Compiles the representation into its own representation that can be formatted into tokens.
69pub trait Compiler: Sync + Send + 'static + Clone + core::fmt::Debug {
70    /// The representation for the compiled code.
71    type Representation: core::fmt::Display;
72    /// The compilation options used to configure the compiler
73    type CompilationOptions: Send + Default + core::fmt::Debug;
74
75    /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation.
76    fn compile(
77        &mut self,
78        kernel: KernelDefinition,
79        compilation_options: &Self::CompilationOptions,
80        mode: ExecutionMode,
81    ) -> Result<Self::Representation, CompilationError>;
82
83    /// The size of the given element in bytes.
84    fn elem_size(&self, elem: ElemType) -> usize;
85
86    /// The default extension for the runtime's kernel/shader code.
87    /// Might change based on which compiler is used.
88    fn extension(&self) -> &'static str;
89}