Skip to main content

cubecl_common/config/
mod.rs

1/// Reusable logger configuration and sink management.
2pub mod logger;
3
4#[cfg(target_has_atomic = "ptr")]
5use alloc::sync::Arc;
6
7#[cfg(not(target_has_atomic = "ptr"))]
8use portable_atomic_util::Arc;
9
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12
13/// Trait for runtime configurations potentially loaded from a TOML file.
14///
15/// Implementors provide a global storage slot and the set of file names to search for;
16/// the trait supplies the lookup, lazy-initialization, and serialization logic.
17///
18/// The singleton stored in [`Config::storage`] is initialized on the first call to
19/// [`Config::get`] by walking up the current working directory looking for any of the
20/// names returned by [`Config::file_names`]. If none is found, [`Default`] is used.
21pub trait RuntimeConfig:
22    Default + Clone + Serialize + DeserializeOwned + Send + Sync + 'static
23{
24    /// Global storage for the configuration singleton.
25    ///
26    /// Each implementor must declare its own `static` slot, because Rust traits
27    /// cannot own statics directly.
28    fn storage() -> &'static spin::Mutex<Option<Arc<Self>>>;
29
30    /// File names searched in each directory during [`Config::from_current_dir`].
31    ///
32    /// The first existing file wins.
33    fn file_names() -> &'static [&'static str];
34
35    /// File names searched in each directory, where only a specific TOML section is loaded
36    /// instead of the whole file.
37    ///
38    /// Each entry is `(file_name, section_name)` and the section must deserialize to `Self`.
39    /// Checked after [`Config::file_names`] at each directory level.
40    fn section_file_names() -> &'static [(&'static str, &'static str)] {
41        &[]
42    }
43
44    /// Hook to override fields from environment variables after loading from disk.
45    ///
46    /// The default implementation returns `self` unchanged.
47    #[cfg(std_io)]
48    fn override_from_env(self) -> Self {
49        self
50    }
51
52    /// Retrieves the current configuration, loading it from the current directory if not set.
53    ///
54    /// If no configuration is set, it attempts to load one from any of [`Config::file_names`] in
55    /// the current directory or its parents. If no file is found, a default configuration is used.
56    ///
57    /// # Notes
58    ///
59    /// Calling this function is somewhat expensive, because of a global static lock. The config
60    /// format is optimized for parsing, not for consumption. A good practice is to use a local
61    /// static atomic value that you can populate with the appropriate value from the config
62    /// during initialization.
63    fn get() -> Arc<Self> {
64        let mut state = Self::storage().lock();
65        if state.as_ref().is_none() {
66            cfg_if::cfg_if! {
67                if #[cfg(std_io)] {
68                    let config = Self::from_current_dir();
69                    let config = config.override_from_env();
70                } else {
71                    let config = Self::default();
72                }
73            }
74
75            *state = Some(Arc::new(config));
76        }
77
78        state.as_ref().cloned().unwrap()
79    }
80
81    /// Sets the configuration to the provided value.
82    ///
83    /// # Panics
84    /// Panics if the configuration has already been set or read, as it cannot be overridden.
85    ///
86    /// # Warning
87    /// This method must be called at the start of the program, before any calls to
88    /// [`Config::get`]. Attempting to set the configuration after it has been initialized will
89    /// cause a panic.
90    fn set(config: Self) {
91        let mut state = Self::storage().lock();
92        if state.is_some() {
93            panic!("Cannot set the configuration multiple times.");
94        }
95        *state = Some(Arc::new(config));
96    }
97
98    /// Save the default configuration to the provided file path.
99    #[cfg(std_io)]
100    fn save_default<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<()> {
101        use std::io::Write;
102
103        let config = Self::get();
104        let content =
105            toml::to_string_pretty(config.as_ref()).expect("Default config should be serializable");
106        let mut file = std::fs::File::create(path)?;
107        file.write_all(content.as_bytes())?;
108
109        Ok(())
110    }
111
112    /// Loads configuration from any of [`Config::file_names`] in the current directory or its
113    /// parents.
114    ///
115    /// Traverses up the directory tree until a valid configuration file is found or the root
116    /// is reached. Returns a default configuration if no file is found.
117    #[cfg(std_io)]
118    fn from_current_dir() -> Self {
119        let mut dir = std::env::current_dir().unwrap();
120
121        loop {
122            for name in Self::file_names() {
123                if let Ok(content) = Self::from_file_path(dir.join(name)) {
124                    return content;
125                }
126            }
127
128            for (name, section) in Self::section_file_names() {
129                if let Ok(content) = Self::from_section_file_path(dir.join(name), section) {
130                    return content;
131                }
132            }
133
134            if !dir.pop() {
135                break;
136            }
137        }
138
139        Self::default()
140    }
141
142    /// Loads configuration from a specified file path.
143    #[cfg(std_io)]
144    fn from_file_path<P: AsRef<std::path::Path>>(path: P) -> std::io::Result<Self> {
145        let content = std::fs::read_to_string(path)?;
146        let config: Self = match toml::from_str(&content) {
147            Ok(val) => val,
148            Err(err) => panic!("The file provided doesn't have the right format => {err:?}"),
149        };
150
151        Ok(config)
152    }
153
154    /// Loads configuration from a specific TOML section of the file at the given path.
155    #[cfg(std_io)]
156    fn from_section_file_path<P: AsRef<std::path::Path>>(
157        path: P,
158        section: &str,
159    ) -> std::io::Result<Self> {
160        let content = std::fs::read_to_string(path)?;
161        let mut table: toml::Table = match toml::from_str(&content) {
162            Ok(val) => val,
163            Err(err) => panic!("The file provided doesn't have the right format => {err:?}"),
164        };
165
166        let value = match table.remove(section) {
167            Some(val) => val,
168            None => {
169                return Err(std::io::Error::new(
170                    std::io::ErrorKind::NotFound,
171                    alloc::format!("Section '{section}' not found"),
172                ));
173            }
174        };
175
176        let config: Self = match value.try_into() {
177            Ok(val) => val,
178            Err(err) => {
179                panic!("The section '{section}' doesn't have the right format => {err:?}")
180            }
181        };
182
183        Ok(config)
184    }
185}