1#![doc = include_str!("../README.md")]
2
3mod error;
4
5use crate::error::{
6 DeserializationError, Result, SerializationError, UniversalConfigError as Error,
7 UniversalConfigError,
8};
9use dirs::{config_dir, home_dir};
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::fs;
13use std::path::{Path, PathBuf};
14use tracing::debug;
15
16pub enum Format {
18 #[cfg(feature = "json")]
20 Json,
21 #[cfg(feature = "yaml")]
23 Yaml,
24 #[cfg(feature = "toml")]
26 Toml,
27 #[cfg(feature = "corn")]
29 Corn,
30 #[cfg(feature = "xml")]
32 Xml,
33 #[cfg(feature = "ron")]
35 Ron,
36}
37
38impl Format {
39 const fn extension(&self) -> &str {
40 match self {
41 #[cfg(feature = "json")]
42 Self::Json => "json",
43 #[cfg(feature = "yaml")]
44 Self::Yaml => "yaml",
45 #[cfg(feature = "toml")]
46 Self::Toml => "toml",
47 #[cfg(feature = "corn")]
48 Self::Corn => "corn",
49 #[cfg(feature = "xml")]
50 Self::Xml => "xml",
51 #[cfg(feature = "ron")]
52 Self::Ron => "ron",
53 }
54 }
55}
56
57pub struct ConfigLoader<'a> {
62 app_name: &'a str,
64 file_name: &'a str,
67 formats: &'a [Format],
71 config_dir: Option<&'a str>,
75}
76
77impl<'a> ConfigLoader<'a> {
78 #[must_use]
81 pub const fn new(app_name: &'a str) -> ConfigLoader<'a> {
82 Self {
83 app_name,
84 file_name: "config",
85 formats: &[
86 #[cfg(feature = "json")]
87 Format::Json,
88 #[cfg(feature = "yaml")]
89 Format::Yaml,
90 #[cfg(feature = "toml")]
91 Format::Toml,
92 #[cfg(feature = "corn")]
93 Format::Corn,
94 #[cfg(feature = "xml")]
95 Format::Xml,
96 #[cfg(feature = "ron")]
97 Format::Ron,
98 ],
99 config_dir: None,
100 }
101 }
102
103 #[must_use]
107 pub const fn with_file_name(mut self, file_name: &'a str) -> Self {
108 self.file_name = file_name;
109 self
110 }
111
112 #[must_use]
117 pub const fn with_formats(mut self, formats: &'a [Format]) -> Self {
118 self.formats = formats;
119 self
120 }
121
122 #[must_use]
127 pub const fn with_config_dir(mut self, dir: &'a str) -> Self {
128 self.config_dir = Some(dir);
129 self
130 }
131
132 pub fn find_and_load<T: DeserializeOwned>(&self) -> Result<T> {
139 let file = self.try_find_file()?;
140 debug!("Found file at: '{}", file.display());
141 Self::load(&file)
142 }
143
144 pub fn config_dir(&self) -> std::result::Result<PathBuf, UniversalConfigError> {
150 self.config_dir
151 .map(Into::into)
152 .or_else(|| config_dir().map(|dir| dir.join(self.app_name)))
153 .or_else(|| home_dir().map(|dir| dir.join(format!(".{}", self.app_name))))
154 .ok_or(Error::MissingUserDir)
155 }
156
157 fn try_find_file(&self) -> Result<PathBuf> {
161 let config_dir = self.config_dir()?;
162
163 let extensions = self.get_extensions();
164
165 debug!("Using config dir: {}", config_dir.display());
166
167 let file = extensions.into_iter().find_map(|extension| {
168 let full_path = config_dir.join(format!("{}.{extension}", self.file_name));
169
170 if Path::exists(&full_path) {
171 Some(full_path)
172 } else {
173 None
174 }
175 });
176
177 file.ok_or(Error::FileNotFound)
178 }
179
180 pub fn load<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T> {
189 let str = fs::read_to_string(&path)?;
190
191 let extension = path
192 .as_ref()
193 .extension()
194 .unwrap_or_default()
195 .to_str()
196 .unwrap_or_default();
197
198 let config = Self::deserialize(&str, extension)?;
199 Ok(config)
200 }
201
202 fn get_extensions(&self) -> Vec<&'static str> {
204 let mut extensions = vec![];
205
206 for format in self.formats {
207 match format {
208 #[cfg(feature = "json")]
209 Format::Json => extensions.push("json"),
210 #[cfg(feature = "yaml")]
211 Format::Yaml => {
212 extensions.push("yaml");
213 extensions.push("yml");
214 }
215 #[cfg(feature = "toml")]
216 Format::Toml => extensions.push("toml"),
217 #[cfg(feature = "corn")]
218 Format::Corn => extensions.push("corn"),
219 #[cfg(feature = "xml")]
220 Format::Xml => extensions.push("xml"),
221 #[cfg(feature = "ron")]
222 Format::Ron => extensions.push("ron"),
223 }
224 }
225
226 extensions
227 }
228
229 fn deserialize<T: DeserializeOwned>(
232 str: &str,
233 extension: &str,
234 ) -> std::result::Result<T, DeserializationError> {
235 let res = match extension {
236 #[cfg(feature = "json")]
237 "json" => serde_json::from_str(str).map_err(DeserializationError::from),
238 #[cfg(feature = "toml")]
239 "toml" => toml::from_str(str).map_err(DeserializationError::from),
240 #[cfg(feature = "yaml")]
241 "yaml" | "yml" => serde_yaml::from_str(str).map_err(DeserializationError::from),
242 #[cfg(feature = "corn")]
243 "corn" => corn::from_str(str).map_err(DeserializationError::from),
244 #[cfg(feature = "xml")]
245 "xml" => serde_xml_rs::from_str(str).map_err(DeserializationError::from),
246 #[cfg(feature = "ron")]
247 "ron" => ron::from_str(str).map_err(DeserializationError::from),
248 _ => Err(DeserializationError::UnsupportedExtension(
249 extension.to_string(),
250 )),
251 }?;
252
253 Ok(res)
254 }
255
256 pub fn save<T: Serialize>(&self, config: &T, format: &Format) -> Result<()> {
270 let str = match format {
271 #[cfg(feature = "json")]
272 Format::Json => serde_json::to_string_pretty(config).map_err(SerializationError::from),
273 #[cfg(feature = "yaml")]
274 Format::Yaml => serde_yaml::to_string(config).map_err(SerializationError::from),
275 #[cfg(feature = "toml")]
276 Format::Toml => toml::to_string_pretty(config).map_err(SerializationError::from),
277 #[cfg(feature = "corn")]
278 Format::Corn => Err(SerializationError::UnsupportedExtension("corn".to_string())),
279 #[cfg(feature = "xml")]
280 Format::Xml => serde_xml_rs::to_string(config).map_err(SerializationError::from),
281 #[cfg(feature = "ron")]
282 Format::Ron => ron::to_string(config).map_err(SerializationError::from),
283 }?;
284
285 let config_dir = self.config_dir()?;
286 let file_name = format!("{}.{}", self.file_name, format.extension());
287 let full_path = config_dir.join(file_name);
288
289 fs::create_dir_all(config_dir)?;
290 fs::write(full_path, str)?;
291
292 Ok(())
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299 use serde::Deserialize;
300
301 #[derive(Deserialize)]
302 struct ConfigContents {
303 test: String,
304 }
305
306 #[test]
307 fn test_json() {
308 let res: ConfigContents = ConfigLoader::load("test_configs/config.json").unwrap();
309 assert_eq!(res.test, "hello world")
310 }
311
312 #[test]
313 fn test_yaml() {
314 let res: ConfigContents = ConfigLoader::load("test_configs/config.yaml").unwrap();
315 assert_eq!(res.test, "hello world")
316 }
317
318 #[test]
319 fn test_toml() {
320 let res: ConfigContents = ConfigLoader::load("test_configs/config.toml").unwrap();
321 assert_eq!(res.test, "hello world")
322 }
323
324 #[test]
325 fn test_corn() {
326 let res: ConfigContents = ConfigLoader::load("test_configs/config.corn").unwrap();
327 assert_eq!(res.test, "hello world")
328 }
329
330 #[test]
331 fn test_xml() {
332 let res: ConfigContents = ConfigLoader::load("test_configs/config.xml").unwrap();
333 assert_eq!(res.test, "hello world")
334 }
335
336 #[test]
337 fn test_ron() {
338 let res: ConfigContents = ConfigLoader::load("test_configs/config.ron").unwrap();
339 assert_eq!(res.test, "hello world")
340 }
341
342 #[test]
343 fn test_find_load() {
344 let config = ConfigLoader::new("universal-config");
345 let res: ConfigContents = config
346 .with_config_dir("test_configs")
347 .find_and_load()
348 .unwrap();
349 assert_eq!(res.test, "hello world")
350 }
351}