burn_core/
config.rs

1use alloc::{format, string::String, string::ToString};
2pub use burn_derive::Config;
3
4/// Configuration IO error.
5#[derive(Debug)]
6pub enum ConfigError {
7    /// Invalid format.
8    InvalidFormat(String),
9
10    /// File not found.
11    FileNotFound(String),
12}
13
14impl core::fmt::Display for ConfigError {
15    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16        let mut message = "Config error => ".to_string();
17
18        match self {
19            Self::InvalidFormat(err) => {
20                message += format!("Invalid format: {err}").as_str();
21            }
22            Self::FileNotFound(err) => {
23                message += format!("File not found: {err}").as_str();
24            }
25        };
26
27        f.write_str(message.as_str())
28    }
29}
30
31impl core::error::Error for ConfigError {}
32
33/// Configuration trait.
34pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
35    /// Saves the configuration to a file.
36    ///
37    /// # Arguments
38    ///
39    /// * `file` - File to save the configuration to.
40    ///
41    /// # Returns
42    ///
43    /// The output of the save operation.
44    #[cfg(feature = "std")]
45    fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
46        std::fs::write(file, config_to_json(self))
47    }
48
49    /// Loads the configuration from a file.
50    ///
51    /// # Arguments
52    ///
53    /// * `file` - File to load the configuration from.
54    ///
55    /// # Returns
56    ///
57    /// The loaded configuration.
58    #[cfg(feature = "std")]
59    fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
60        let content = std::fs::read_to_string(file.as_ref())
61            .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
62        config_from_str(&content)
63    }
64
65    /// Loads the configuration from a binary buffer.
66    ///
67    /// # Arguments
68    ///
69    /// * `data` - Binary buffer to load the configuration from.
70    ///
71    /// # Returns
72    ///
73    /// The loaded configuration.
74    fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
75        let content = core::str::from_utf8(data).map_err(|_| {
76            ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
77        })?;
78        config_from_str(content)
79    }
80}
81
82/// Converts a configuration to a JSON string.
83///
84/// # Arguments
85///
86/// * `config` - Configuration to convert.
87///
88/// # Returns
89///
90/// The JSON string.
91pub fn config_to_json<C: Config>(config: &C) -> String {
92    serde_json::to_string_pretty(config).unwrap()
93}
94
95fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
96    serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
97}