burn_core/
config.rs

1use alloc::{format, string::String, string::ToString};
2pub use burn_derive::Config;
3use core::fmt::Debug;
4
5/// Configuration IO error.
6#[derive(Debug)]
7pub enum ConfigError {
8    /// Invalid format.
9    InvalidFormat(String),
10
11    /// File not found.
12    FileNotFound(String),
13}
14
15impl core::fmt::Display for ConfigError {
16    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
17        let mut message = "Config error => ".to_string();
18
19        match self {
20            Self::InvalidFormat(err) => {
21                message += format!("Invalid format: {err}").as_str();
22            }
23            Self::FileNotFound(err) => {
24                message += format!("File not found: {err}").as_str();
25            }
26        };
27
28        f.write_str(message.as_str())
29    }
30}
31
32impl core::error::Error for ConfigError {}
33
34/// Configuration trait.
35pub trait Config: Debug + serde::Serialize + serde::de::DeserializeOwned {
36    /// Saves the configuration to a file.
37    ///
38    /// # Arguments
39    ///
40    /// * `file` - File to save the configuration to.
41    ///
42    /// # Returns
43    ///
44    /// The output of the save operation.
45    #[cfg(feature = "std")]
46    fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
47        std::fs::write(file, config_to_json(self))
48    }
49
50    /// Loads the configuration from a file.
51    ///
52    /// # Arguments
53    ///
54    /// * `file` - File to load the configuration from.
55    ///
56    /// # Returns
57    ///
58    /// The loaded configuration.
59    #[cfg(feature = "std")]
60    fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
61        let content = std::fs::read_to_string(file.as_ref())
62            .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
63        config_from_str(&content)
64    }
65
66    /// Loads the configuration from a binary buffer.
67    ///
68    /// # Arguments
69    ///
70    /// * `data` - Binary buffer to load the configuration from.
71    ///
72    /// # Returns
73    ///
74    /// The loaded configuration.
75    fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
76        let content = core::str::from_utf8(data).map_err(|_| {
77            ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
78        })?;
79        config_from_str(content)
80    }
81}
82
83/// Converts a configuration to a JSON string.
84///
85/// # Arguments
86///
87/// * `config` - Configuration to convert.
88///
89/// # Returns
90///
91/// The JSON string.
92pub fn config_to_json<C: Config>(config: &C) -> String {
93    serde_json::to_string_pretty(config).unwrap()
94}
95
96fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
97    serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
98}