Skip to main content

confik/sources/
file_source.rs

1use std::{error::Error, path::PathBuf};
2
3use cfg_if::cfg_if;
4use thiserror::Error;
5
6use crate::{ConfigurationBuilder, Source};
7
8#[derive(Debug, Error)]
9#[error("Could not parse {}", .path.display())]
10struct FileError {
11    path: PathBuf,
12
13    #[source]
14    kind: FileErrorKind,
15}
16
17#[derive(Debug, Error)]
18#[non_exhaustive]
19enum FileErrorKind {
20    #[error(transparent)]
21    CouldNotReadFile(#[from] std::io::Error),
22
23    #[allow(dead_code)]
24    #[error("{0} feature is not enabled")]
25    MissingFeatureForExtension(&'static str),
26
27    #[error("Unknown file extension")]
28    UnknownExtension,
29
30    #[cfg(feature = "toml")]
31    #[error(transparent)]
32    Toml(#[from] toml::de::Error),
33
34    #[cfg(feature = "json")]
35    #[error(transparent)]
36    Json(#[from] serde_json::Error),
37
38    #[cfg(feature = "ron-0_12")]
39    #[error(transparent)]
40    Ron(#[from] ron_0_12::error::SpannedError),
41
42    #[cfg(feature = "yaml_serde-0_10")]
43    #[error(transparent)]
44    Yaml(#[from] yaml_serde_0_10::Error),
45}
46
47/// A [`Source`] referring to a file path.
48#[derive(Debug, Clone)]
49pub struct FileSource {
50    path: PathBuf,
51    allow_secrets: bool,
52}
53
54impl FileSource {
55    /// Create a [`Source`] referring to a file path,
56    ///
57    /// The deserialization method will be determined by the file extension.
58    ///
59    /// Supported extensions:
60    /// - `toml`
61    /// - `json`
62    /// - `ron`
63    /// - `yaml`
64    /// - `yml`
65    pub fn new(path: impl Into<PathBuf>) -> Self {
66        Self {
67            path: path.into(),
68            allow_secrets: false,
69        }
70    }
71
72    /// Allows this source to contain secrets.
73    pub fn allow_secrets(mut self) -> Self {
74        self.allow_secrets = true;
75        self
76    }
77
78    fn deserialize<T: ConfigurationBuilder>(&self) -> Result<T, FileErrorKind> {
79        #[allow(unused_variables)]
80        let contents = std::fs::read_to_string(&self.path)?;
81
82        match self.path.extension().and_then(|ext| ext.to_str()) {
83            Some("toml") => {
84                cfg_if! {
85                    if #[cfg(feature = "toml")] {
86                        Ok(toml::from_str(&contents)?)
87                    } else {
88                        Err(FileErrorKind::MissingFeatureForExtension("toml"))
89                    }
90                }
91            }
92
93            Some("json") => {
94                cfg_if! {
95                    if #[cfg(feature = "json")] {
96                        Ok(serde_json::from_str(&contents)?)
97                    } else {
98                        Err(FileErrorKind::MissingFeatureForExtension("json"))
99                    }
100                }
101            }
102
103            Some("ron") => {
104                cfg_if! {
105                    if #[cfg(feature = "ron-0_12")] {
106                        Ok(ron_0_12::from_str(&contents)?)
107                    } else {
108                        Err(FileErrorKind::MissingFeatureForExtension("ron"))
109                    }
110                }
111            }
112
113            Some("yaml" | "yml") => {
114                cfg_if! {
115                    if #[cfg(feature = "yaml_serde-0_10")] {
116                        Ok(yaml_serde_0_10::from_str(&contents)?)
117                    } else {
118                        Err(FileErrorKind::MissingFeatureForExtension("yaml"))
119                    }
120                }
121            }
122
123            _ => Err(FileErrorKind::UnknownExtension),
124        }
125    }
126}
127
128impl<T: ConfigurationBuilder> Source<T> for FileSource {
129    fn allows_secrets(&self) -> bool {
130        self.allow_secrets
131    }
132
133    fn provide(&self) -> Result<T, Box<dyn Error + Sync + Send>> {
134        self.deserialize().map_err(|err| {
135            Box::new(FileError {
136                path: self.path.clone(),
137                kind: err,
138            }) as _
139        })
140    }
141}
142
143#[cfg(test)]
144mod tests {
145    use std::fs;
146
147    use confik_macros::Configuration;
148
149    use super::*;
150
151    #[derive(Debug, Default, serde::Deserialize, Configuration)]
152    struct NoopConfig {}
153
154    #[derive(Debug, Default, serde::Deserialize, Configuration)]
155    #[allow(dead_code)]
156    struct SimpleConfig {
157        foo: u64,
158    }
159
160    #[test]
161    fn non_existent() {
162        let source = FileSource::new("non-existent-config.toml");
163        let err = source.deserialize::<Option<NoopConfig>>().unwrap_err();
164        assert!(
165            err.to_string().contains("No such file or directory"),
166            "unexpected error message: {err}",
167        );
168    }
169
170    #[test]
171    fn unknown_extension() {
172        let dir = tempfile::TempDir::new().unwrap();
173
174        let cfg_path = dir.path().join("config.cfg");
175        fs::write(&cfg_path, "").unwrap();
176
177        let source = FileSource::new(&cfg_path);
178        let err = source.deserialize::<Option<NoopConfig>>().unwrap_err();
179        assert!(
180            err.to_string().contains("Unknown file extension"),
181            "unexpected error message: {err}",
182        );
183
184        dir.close().unwrap();
185    }
186
187    #[cfg(feature = "json")]
188    #[test]
189    fn json() {
190        let dir = tempfile::TempDir::new().unwrap();
191
192        let json_path = dir.path().join("config.json");
193
194        fs::write(&json_path, "{}").unwrap();
195        let source = FileSource::new(&json_path);
196        let err = source.deserialize::<Option<SimpleConfig>>().unwrap_err();
197        assert!(
198            err.to_string().contains("missing field"),
199            "unexpected error message: {err}",
200        );
201
202        fs::write(&json_path, "{\"foo\":42}").unwrap();
203        let source = FileSource::new(&json_path);
204        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
205        assert_eq!(config.unwrap().foo, 42);
206
207        dir.close().unwrap();
208    }
209
210    #[cfg(feature = "toml")]
211    #[test]
212    fn toml() {
213        let dir = tempfile::TempDir::new().unwrap();
214
215        let toml_path = dir.path().join("config.toml");
216
217        fs::write(&toml_path, "").unwrap();
218        let source = FileSource::new(&toml_path);
219        let err = source.deserialize::<Option<SimpleConfig>>().unwrap_err();
220        assert!(
221            err.to_string().contains("missing field"),
222            "unexpected error message: {err}",
223        );
224
225        fs::write(&toml_path, "foo = 42").unwrap();
226        let source = FileSource::new(&toml_path);
227        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
228        assert_eq!(config.unwrap().foo, 42);
229
230        dir.close().unwrap();
231    }
232
233    #[cfg(feature = "ron-0_12")]
234    #[test]
235    fn ron() {
236        let dir = tempfile::TempDir::new().unwrap();
237
238        let ron_path = dir.path().join("config.ron");
239
240        fs::write(&ron_path, "(bar:42)").unwrap();
241        let source = FileSource::new(&ron_path);
242        let err = source.deserialize::<Option<SimpleConfig>>().unwrap_err();
243        assert!(
244            err.to_string().contains("Expected option"),
245            "unexpected error message: {err}",
246        );
247
248        fs::write(&ron_path, "Some((foo:42))").unwrap();
249        let source = FileSource::new(&ron_path);
250        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
251        assert_eq!(config.unwrap().foo, 42);
252
253        dir.close().unwrap();
254    }
255
256    #[cfg(feature = "yaml_serde-0_10")]
257    #[test]
258    fn yaml() {
259        let dir = tempfile::TempDir::new().unwrap();
260
261        let yaml_path = dir.path().join("config.yaml");
262
263        fs::write(&yaml_path, "{}").unwrap();
264        let source = FileSource::new(&yaml_path);
265        let err = source.deserialize::<Option<SimpleConfig>>().unwrap_err();
266        assert!(
267            err.to_string().contains("missing field"),
268            "unexpected error message: {err}",
269        );
270
271        fs::write(&yaml_path, "foo: 42\n").unwrap();
272        let source = FileSource::new(&yaml_path);
273        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
274        assert_eq!(config.unwrap().foo, 42);
275
276        let yml_path = dir.path().join("config.yml");
277        fs::write(&yml_path, "foo: 43\n").unwrap();
278        let source = FileSource::new(&yml_path);
279        let config = source.deserialize::<Option<SimpleConfig>>().unwrap();
280        assert_eq!(config.unwrap().foo, 43);
281
282        dir.close().unwrap();
283    }
284}