Skip to main content

burn_std/config/
fusion.rs

1use cubecl_common::config::logger::{LogLevel, LoggerConfig};
2
3/// Configuration for operation fusion in Burn.
4#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
5pub struct FusionConfig {
6    /// Logger configuration for fusion logs.
7    #[serde(default)]
8    pub logger: LoggerConfig<FusionLogLevel>,
9
10    /// Beam search configuration used when exploring fusion opportunities.
11    #[serde(default)]
12    pub beam_search: BeamSearchConfig,
13}
14
15/// Beam search configuration controlling how the fusion optimizer explores independent blocks
16/// of operations.
17#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
18pub struct BeamSearchConfig {
19    /// Maximum number of independent blocks explored during the fusion search.
20    ///
21    /// Higher values can find better fusion opportunities at the cost of more cache misses
22    /// in the fusion cache.
23    #[serde(default = "default_max_blocks")]
24    pub max_blocks: usize,
25}
26
27impl Default for BeamSearchConfig {
28    fn default() -> Self {
29        Self {
30            max_blocks: default_max_blocks(),
31        }
32    }
33}
34
35fn default_max_blocks() -> usize {
36    5
37}
38
39/// Log levels for fusion logging.
40#[derive(
41    Default,
42    Clone,
43    Copy,
44    Debug,
45    PartialEq,
46    Eq,
47    PartialOrd,
48    Ord,
49    serde::Serialize,
50    serde::Deserialize,
51)]
52pub enum FusionLogLevel {
53    /// Fusion logging is disabled.
54    #[default]
55    #[serde(rename = "disabled")]
56    Disabled,
57
58    /// Log the final execution strategy selected per stream (single vs composed).
59    #[serde(rename = "basic")]
60    Basic,
61
62    /// Log block merge/split decisions and cache hit/miss events.
63    #[serde(rename = "medium")]
64    Medium,
65
66    /// Log every registration, rejection and scoring decision.
67    #[serde(rename = "full")]
68    Full,
69}
70
71impl LogLevel for FusionLogLevel {}