universal_config/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod error;
4
5use crate::error::{
6    DeserializationError, Result, SerializationError, UniversalConfigError as Error,
7    UniversalConfigError,
8};
9use dirs::{config_dir, home_dir};
10use serde::de::DeserializeOwned;
11use serde::Serialize;
12use std::fs;
13use std::path::{Path, PathBuf};
14use tracing::debug;
15
16/// Supported config formats.
17pub enum Format {
18    /// `.json` file
19    #[cfg(feature = "json")]
20    Json,
21    /// `.yaml` or `.yml` files.
22    #[cfg(feature = "yaml")]
23    Yaml,
24    /// `.toml` files.
25    #[cfg(feature = "toml")]
26    Toml,
27    /// `.corn` files.
28    #[cfg(feature = "corn")]
29    Corn,
30    /// `.xml` files.
31    #[cfg(feature = "xml")]
32    Xml,
33    /// `.ron` files.
34    #[cfg(feature = "ron")]
35    Ron,
36}
37
38impl Format {
39    const fn extension(&self) -> &str {
40        match self {
41            #[cfg(feature = "json")]
42            Self::Json => "json",
43            #[cfg(feature = "yaml")]
44            Self::Yaml => "yaml",
45            #[cfg(feature = "toml")]
46            Self::Toml => "toml",
47            #[cfg(feature = "corn")]
48            Self::Corn => "corn",
49            #[cfg(feature = "xml")]
50            Self::Xml => "xml",
51            #[cfg(feature = "ron")]
52            Self::Ron => "ron",
53        }
54    }
55}
56
57/// The main loader struct.
58///
59/// Create a new loader and configure as appropriate
60/// to load your config file.
61pub struct ConfigLoader<'a> {
62    /// The name of your program, used when determining the directory path.
63    app_name: &'a str,
64    /// The name of the file (*excluding* extension) to search for.
65    /// Defaults to `config`.
66    file_name: &'a str,
67    /// Allowed file formats.
68    /// Defaults to all formats.
69    /// Set to disable formats you do not wish to allow.
70    formats: &'a [Format],
71    /// The directory to load the config file from.
72    /// Defaults to your system config dir (`$XDG_CONFIG_DIR` on Linux),
73    /// or your home dir if that does not exist.
74    config_dir: Option<&'a str>,
75}
76
77impl<'a> ConfigLoader<'a> {
78    /// Creates a new config loader for the provided app name.
79    /// Uses a default file name of "config" and all formats.
80    #[must_use]
81    pub const fn new(app_name: &'a str) -> ConfigLoader<'a> {
82        Self {
83            app_name,
84            file_name: "config",
85            formats: &[
86                #[cfg(feature = "json")]
87                Format::Json,
88                #[cfg(feature = "yaml")]
89                Format::Yaml,
90                #[cfg(feature = "toml")]
91                Format::Toml,
92                #[cfg(feature = "corn")]
93                Format::Corn,
94                #[cfg(feature = "xml")]
95                Format::Xml,
96                #[cfg(feature = "ron")]
97                Format::Ron,
98            ],
99            config_dir: None,
100        }
101    }
102
103    /// Specifies the file name to look for, excluding the extension.
104    ///
105    /// If not specified, defaults to "config".
106    #[must_use]
107    pub const fn with_file_name(mut self, file_name: &'a str) -> Self {
108        self.file_name = file_name;
109        self
110    }
111
112    /// Specifies which file formats to search for, and in which order.
113    ///
114    /// If not specified, all formats are checked for
115    /// in the order JSON, YAML, TOML, Corn.
116    #[must_use]
117    pub const fn with_formats(mut self, formats: &'a [Format]) -> Self {
118        self.formats = formats;
119        self
120    }
121
122    /// Specifies which directory the config should be loaded from.
123    ///
124    /// If not specified, loads from `$XDG_CONFIG_DIR/<app_name>`
125    /// or `$HOME/.<app_name>` if the config dir does not exist.
126    #[must_use]
127    pub const fn with_config_dir(mut self, dir: &'a str) -> Self {
128        self.config_dir = Some(dir);
129        self
130    }
131
132    /// Attempts to locate a config file on disk and load it.
133    ///
134    /// # Errors
135    ///
136    /// Will return a `UniversalConfigError` if any error occurs
137    /// when looking for, reading, or deserializing a config file.
138    pub fn find_and_load<T: DeserializeOwned>(&self) -> Result<T> {
139        let file = self.try_find_file()?;
140        debug!("Found file at: '{}", file.display());
141        Self::load(&file)
142    }
143
144    /// Attempts to find the directory in which the config file is stored.
145    ///
146    /// # Errors
147    ///
148    /// Will error if the user's home directory cannot be located.
149    pub fn config_dir(&self) -> std::result::Result<PathBuf, UniversalConfigError> {
150        self.config_dir
151            .map(Into::into)
152            .or_else(|| config_dir().map(|dir| dir.join(self.app_name)))
153            .or_else(|| home_dir().map(|dir| dir.join(format!(".{}", self.app_name))))
154            .ok_or(Error::MissingUserDir)
155    }
156
157    /// Attempts to find a config file for the given app name
158    /// in the app's config directory
159    /// that matches any of the allowed formats.
160    fn try_find_file(&self) -> Result<PathBuf> {
161        let config_dir = self.config_dir()?;
162
163        let extensions = self.get_extensions();
164
165        debug!("Using config dir: {}", config_dir.display());
166
167        let file = extensions.into_iter().find_map(|extension| {
168            let full_path = config_dir.join(format!("{}.{extension}", self.file_name));
169
170            if Path::exists(&full_path) {
171                Some(full_path)
172            } else {
173                None
174            }
175        });
176
177        file.ok_or(Error::FileNotFound)
178    }
179
180    /// Loads the file at the given path,
181    /// deserializing it into a new `T`.
182    ///
183    /// The type is automatically determined from the file extension.
184    ///
185    /// # Errors
186    ///
187    /// Will return a `UniversalConfigError` if unable to read or deserialize the file.
188    pub fn load<T: DeserializeOwned, P: AsRef<Path>>(path: P) -> Result<T> {
189        let str = fs::read_to_string(&path)?;
190
191        let extension = path
192            .as_ref()
193            .extension()
194            .unwrap_or_default()
195            .to_str()
196            .unwrap_or_default();
197
198        let config = Self::deserialize(&str, extension)?;
199        Ok(config)
200    }
201
202    /// Gets a list of supported and enabled file extensions.
203    fn get_extensions(&self) -> Vec<&'static str> {
204        let mut extensions = vec![];
205
206        for format in self.formats {
207            match format {
208                #[cfg(feature = "json")]
209                Format::Json => extensions.push("json"),
210                #[cfg(feature = "yaml")]
211                Format::Yaml => {
212                    extensions.push("yaml");
213                    extensions.push("yml");
214                }
215                #[cfg(feature = "toml")]
216                Format::Toml => extensions.push("toml"),
217                #[cfg(feature = "corn")]
218                Format::Corn => extensions.push("corn"),
219                #[cfg(feature = "xml")]
220                Format::Xml => extensions.push("xml"),
221                #[cfg(feature = "ron")]
222                Format::Ron => extensions.push("ron"),
223            }
224        }
225
226        extensions
227    }
228
229    /// Attempts to deserialize the provided input into `T`,
230    /// based on the provided file extension.
231    fn deserialize<T: DeserializeOwned>(
232        str: &str,
233        extension: &str,
234    ) -> std::result::Result<T, DeserializationError> {
235        let res = match extension {
236            #[cfg(feature = "json")]
237            "json" => serde_json::from_str(str).map_err(DeserializationError::from),
238            #[cfg(feature = "toml")]
239            "toml" => toml::from_str(str).map_err(DeserializationError::from),
240            #[cfg(feature = "yaml")]
241            "yaml" | "yml" => serde_yaml::from_str(str).map_err(DeserializationError::from),
242            #[cfg(feature = "corn")]
243            "corn" => corn::from_str(str).map_err(DeserializationError::from),
244            #[cfg(feature = "xml")]
245            "xml" => serde_xml_rs::from_str(str).map_err(DeserializationError::from),
246            #[cfg(feature = "ron")]
247            "ron" => ron::from_str(str).map_err(DeserializationError::from),
248            _ => Err(DeserializationError::UnsupportedExtension(
249                extension.to_string(),
250            )),
251        }?;
252
253        Ok(res)
254    }
255
256    /// Saves the provided configuration into a file of the specified format.
257    ///
258    /// The file is stored in the app's configuration directory.
259    /// Directories are automatically created if required.
260    ///
261    /// # Errors
262    ///
263    /// If the provided config cannot be serialised into the format, an error will be returned.
264    /// The `.corn` format is not supported, and the function will error if specified.
265    ///
266    /// If a valid config dir cannot be found, an error will be returned.
267    ///
268    /// If the file cannot be written to the specified path, an error will be returned.
269    pub fn save<T: Serialize>(&self, config: &T, format: &Format) -> Result<()> {
270        let str = match format {
271            #[cfg(feature = "json")]
272            Format::Json => serde_json::to_string_pretty(config).map_err(SerializationError::from),
273            #[cfg(feature = "yaml")]
274            Format::Yaml => serde_yaml::to_string(config).map_err(SerializationError::from),
275            #[cfg(feature = "toml")]
276            Format::Toml => toml::to_string_pretty(config).map_err(SerializationError::from),
277            #[cfg(feature = "corn")]
278            Format::Corn => Err(SerializationError::UnsupportedExtension("corn".to_string())),
279            #[cfg(feature = "xml")]
280            Format::Xml => serde_xml_rs::to_string(config).map_err(SerializationError::from),
281            #[cfg(feature = "ron")]
282            Format::Ron => ron::to_string(config).map_err(SerializationError::from),
283        }?;
284
285        let config_dir = self.config_dir()?;
286        let file_name = format!("{}.{}", self.file_name, format.extension());
287        let full_path = config_dir.join(file_name);
288
289        fs::create_dir_all(config_dir)?;
290        fs::write(full_path, str)?;
291
292        Ok(())
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use serde::Deserialize;
300
301    #[derive(Deserialize)]
302    struct ConfigContents {
303        test: String,
304    }
305
306    #[test]
307    fn test_json() {
308        let res: ConfigContents = ConfigLoader::load("test_configs/config.json").unwrap();
309        assert_eq!(res.test, "hello world")
310    }
311
312    #[test]
313    fn test_yaml() {
314        let res: ConfigContents = ConfigLoader::load("test_configs/config.yaml").unwrap();
315        assert_eq!(res.test, "hello world")
316    }
317
318    #[test]
319    fn test_toml() {
320        let res: ConfigContents = ConfigLoader::load("test_configs/config.toml").unwrap();
321        assert_eq!(res.test, "hello world")
322    }
323
324    #[test]
325    fn test_corn() {
326        let res: ConfigContents = ConfigLoader::load("test_configs/config.corn").unwrap();
327        assert_eq!(res.test, "hello world")
328    }
329
330    #[test]
331    fn test_xml() {
332        let res: ConfigContents = ConfigLoader::load("test_configs/config.xml").unwrap();
333        assert_eq!(res.test, "hello world")
334    }
335
336    #[test]
337    fn test_ron() {
338        let res: ConfigContents = ConfigLoader::load("test_configs/config.ron").unwrap();
339        assert_eq!(res.test, "hello world")
340    }
341
342    #[test]
343    fn test_find_load() {
344        let config = ConfigLoader::new("universal-config");
345        let res: ConfigContents = config
346            .with_config_dir("test_configs")
347            .find_and_load()
348            .unwrap();
349        assert_eq!(res.test, "hello world")
350    }
351}