confik/sources/
file_source.rs

1use std::{error::Error, path::PathBuf};
2
3use cfg_if::cfg_if;
4use thiserror::Error;
5
6use crate::{ConfigurationBuilder, Source};
7
8#[derive(Debug, Error)]
9#[error("Could not parse {}", .path.display())]
10struct FileError {
11    path: PathBuf,
12
13    #[source]
14    kind: FileErrorKind,
15}
16
17#[derive(Debug, Error)]
18enum FileErrorKind {
19    #[error(transparent)]
20    CouldNotReadFile(#[from] std::io::Error),
21
22    #[allow(dead_code)]
23    #[error("{0} feature is not enabled")]
24    MissingFeatureForExtension(&'static str),
25
26    #[error("Unknown file extension")]
27    UnknownExtension,
28
29    #[cfg(feature = "toml")]
30    #[error(transparent)]
31    Toml(#[from] toml::de::Error),
32
33    #[cfg(feature = "json")]
34    #[error(transparent)]
35    Json(#[from] serde_json::Error),
36}
37
38/// A [`Source`] referring to a file path.
39#[derive(Debug, Clone)]
40pub struct FileSource {
41    path: PathBuf,
42    allow_secrets: bool,
43}
44
45impl FileSource {
46    /// Create a [`Source`] referring to a file path,
47    ///
48    /// The deserialization method will be determined by the file extension.
49    ///
50    /// Supported extensions:
51    /// - `toml`
52    /// - `json`
53    pub fn new(path: impl Into<PathBuf>) -> Self {
54        Self {
55            path: path.into(),
56            allow_secrets: false,
57        }
58    }
59
60    /// Allows this source to contain secrets.
61    pub fn allow_secrets(mut self) -> Self {
62        self.allow_secrets = true;
63        self
64    }
65
66    fn deserialize<T: ConfigurationBuilder>(&self) -> Result<T, FileErrorKind> {
67        #[allow(unused_variables)]
68        let contents = std::fs::read_to_string(&self.path)?;
69
70        match self.path.extension().and_then(|ext| ext.to_str()) {
71            Some("toml") => {
72                cfg_if! {
73                    if #[cfg(feature = "toml")] {
74                        Ok(toml::from_str(&contents)?)
75                    } else {
76                        Err(FileErrorKind::MissingFeatureForExtension("toml"))
77                    }
78                }
79            }
80
81            Some("json") => {
82                cfg_if! {
83                    if #[cfg(feature = "json")] {
84                        Ok(serde_json::from_str(&contents)?)
85                    } else {
86                        Err(FileErrorKind::MissingFeatureForExtension("json"))
87                    }
88                }
89            }
90
91            _ => Err(FileErrorKind::UnknownExtension),
92        }
93    }
94}
95
96impl<T: ConfigurationBuilder> Source<T> for FileSource {
97    fn allows_secrets(&self) -> bool {
98        self.allow_secrets
99    }
100
101    fn provide(&self) -> Result<T, Box<dyn Error + Sync + Send>> {
102        self.deserialize().map_err(|err| {
103            Box::new(FileError {
104                path: self.path.clone(),
105                kind: err,
106            }) as _
107        })
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use std::fs;
114
115    use confik_macros::Configuration;
116
117    use super::*;
118
119    #[derive(Debug, Default, serde::Deserialize, Configuration)]
120    struct NoopConfig {}
121
122    #[derive(Debug, Default, serde::Deserialize, Configuration)]
123    #[allow(dead_code)]
124    struct SimpleConfig {
125        foo: u64,
126    }
127
128    #[test]
129    fn non_existent() {
130        let source = FileSource::new("non-existent-config.toml");
131        let err = source.deserialize::<Option<NoopConfig>>().unwrap_err();
132        assert!(
133            err.to_string().contains("No such file or directory"),
134            "unexpected error message: {err}",
135        );
136    }
137
138    #[test]
139    fn unknown_extension() {
140        let dir = tempfile::TempDir::new().unwrap();
141
142        let cfg_path = dir.path().join("config.cfg");
143        fs::write(&cfg_path, "").unwrap();
144
145        let source = FileSource::new(&cfg_path);
146        let err = source.deserialize::<Option<NoopConfig>>().unwrap_err();
147        assert!(
148            err.to_string().contains("Unknown file extension"),
149            "unexpected error message: {err}",
150        );
151
152        dir.close().unwrap();
153    }
154
155    #[cfg(feature = "json")]
156    #[test]
157    fn json() {
158        let dir = tempfile::TempDir::new().unwrap();
159
160        let json_path = dir.path().join("config.json");
161
162        fs::write(&json_path, "{}").unwrap();
163        let source = FileSource::new(&json_path);
164        let err = source.deserialize::<Option<SimpleConfig>>().unwrap_err();
165        assert!(
166            err.to_string().contains("missing field"),
167            "unexpected error message: {err}",
168        );
169
170        fs::write(&json_path, "{\"foo\":42}").unwrap();
171        let source = FileSource::new(&json_path);
172        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
173        assert_eq!(config.unwrap().foo, 42);
174
175        dir.close().unwrap();
176    }
177
178    #[cfg(feature = "toml")]
179    #[test]
180    fn toml() {
181        let dir = tempfile::TempDir::new().unwrap();
182
183        let toml_path = dir.path().join("config.toml");
184
185        fs::write(&toml_path, "").unwrap();
186        let source = FileSource::new(&toml_path);
187        let err = source.deserialize::<Option<SimpleConfig>>().unwrap_err();
188        assert!(
189            err.to_string().contains("missing field"),
190            "unexpected error message: {err}",
191        );
192
193        fs::write(&toml_path, "foo = 42").unwrap();
194        let source = FileSource::new(&toml_path);
195        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
196        assert_eq!(config.unwrap().foo, 42);
197
198        dir.close().unwrap();
199    }
200}