1use alloc::{format, string::String, string::ToString};
2pub use burn_derive::Config;
3use core::fmt::Debug;
4
5#[derive(Debug)]
7pub enum ConfigError {
8 InvalidFormat(String),
10
11 FileNotFound(String),
13}
14
15impl core::fmt::Display for ConfigError {
16 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
17 let mut message = "Config error => ".to_string();
18
19 match self {
20 Self::InvalidFormat(err) => {
21 message += format!("Invalid format: {err}").as_str();
22 }
23 Self::FileNotFound(err) => {
24 message += format!("File not found: {err}").as_str();
25 }
26 };
27
28 f.write_str(message.as_str())
29 }
30}
31
32impl core::error::Error for ConfigError {}
33
34pub trait Config: Debug + serde::Serialize + serde::de::DeserializeOwned {
36 #[cfg(feature = "std")]
46 fn save<P: AsRef<std::path::Path>>(&self, file: P) -> std::io::Result<()> {
47 std::fs::write(file, config_to_json(self))
48 }
49
50 #[cfg(feature = "std")]
60 fn load<P: AsRef<std::path::Path>>(file: P) -> Result<Self, ConfigError> {
61 let content = std::fs::read_to_string(file.as_ref())
62 .map_err(|_| ConfigError::FileNotFound(file.as_ref().to_string_lossy().to_string()))?;
63 config_from_str(&content)
64 }
65
66 fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
76 let content = core::str::from_utf8(data).map_err(|_| {
77 ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
78 })?;
79 config_from_str(content)
80 }
81}
82
83pub fn config_to_json<C: Config>(config: &C) -> String {
93 serde_json::to_string_pretty(config).unwrap()
94}
95
96fn config_from_str<C: Config>(content: &str) -> Result<C, ConfigError> {
97 serde_json::from_str(content).map_err(|err| ConfigError::InvalidFormat(format!("{err}")))
98}