1use alloc::{format, string::String, string::ToString};
2pub use burn_derive::Config;
3
4#[derive(Debug)]
6pub enum ConfigError {
7 InvalidFormat(String),
9
10 FileNotFound(String),
12}
13
14impl core::fmt::Display for ConfigError {
15 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
16 let mut message = "Config error => ".to_string();
17
18 match self {
19 Self::InvalidFormat(err) => {
20 message += format!("Invalid format: {err}").as_str();
21 }
22 Self::FileNotFound(err) => {
23 message += format!("File not found: {err}").as_str();
24 }
25 };
26
27 f.write_str(message.as_str())
28 }
29}
30
31impl core::error::Error for ConfigError {}
32
33pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
35 #[cfg(feature = "std")]
45 fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
46 std::fs::write(file, config_to_json(self))
47 }
48
49 #[cfg(feature = "std")]
59 fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
60 let content = std::fs::read_to_string(file.as_ref())
61 .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
62 config_from_str(&content)
63 }
64
65 fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
75 let content = core::str::from_utf8(data).map_err(|_| {
76 ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
77 })?;
78 config_from_str(content)
79 }
80}
81
82pub fn config_to_json<C: Config>(config: &C) -> String {
92 serde_json::to_string_pretty(config).unwrap()
93}
94
95fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
96 serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
97}