Skip to main content

burn_std/config/
autodiff.rs

1use cubecl_common::config::logger::{LogLevel, LoggerConfig};
2
3/// Configuration for autodiff in Burn.
4#[derive(Default, Clone, Debug, serde::Serialize, serde::Deserialize)]
5pub struct AutodiffConfig {
6    /// Logger configuration for autodiff logs.
7    #[serde(default)]
8    pub logger: LoggerConfig<AutodiffLogLevel>,
9}
10
11/// Log levels for autodiff logging.
12#[derive(
13    Default,
14    Clone,
15    Copy,
16    Debug,
17    PartialEq,
18    Eq,
19    PartialOrd,
20    Ord,
21    serde::Serialize,
22    serde::Deserialize,
23)]
24pub enum AutodiffLogLevel {
25    /// Autodiff logging is disabled.
26    #[default]
27    #[serde(rename = "disabled")]
28    Disabled,
29
30    /// Log backward graph size and the checkpoint strategy applied per forward pass.
31    #[serde(rename = "basic")]
32    Basic,
33
34    /// Additionally log each tensor that gets checkpointed or recomputed.
35    #[serde(rename = "medium")]
36    Medium,
37
38    /// Log every graph node traversal and recomputation event.
39    #[serde(rename = "full")]
40    Full,
41}
42
43impl LogLevel for AutodiffLogLevel {}