rops_cli/config/
retrieve.rs

1use std::path::Path;
2
3use rops::file::format::*;
4use serde::de::DeserializeOwned;
5
6pub type DefaulConfigFileFormat = TomlFileFormat;
7const ROPS_CONFIG_ENV_VAR_NAME: &str = "ROPS_CONFIG";
8const ROPS_CONFIG_DEFAULT_FILE_NAME: &str = ".rops.toml";
9
10// separated with generic parameter to simplify unit testing of strategy
11pub(super) fn retrieve_impl<T: DeserializeOwned + Default>(optional_config_path: Option<&Path>) -> anyhow::Result<T> {
12    if let Some(arg_path) = optional_config_path {
13        return read_fs_path_and_deserialize::<T>(arg_path);
14    }
15
16    if let Some(env_path) = std::env::var_os(ROPS_CONFIG_ENV_VAR_NAME) {
17        return read_fs_path_and_deserialize::<T>(env_path);
18    }
19
20    return traverse_fs_or_default::<T>();
21
22    fn traverse_fs_or_default<T: DeserializeOwned + Default>() -> anyhow::Result<T> {
23        let mut traversal_path = std::env::current_dir()?;
24        loop {
25            traversal_path.push(ROPS_CONFIG_DEFAULT_FILE_NAME);
26            if traversal_path.exists() {
27                return read_fs_path_and_deserialize::<T>(traversal_path);
28            }
29            traversal_path.pop();
30
31            if !traversal_path.pop() {
32                return Ok(T::default());
33            }
34        }
35    }
36
37    fn read_fs_path_and_deserialize<T: DeserializeOwned>(config_path: impl AsRef<Path>) -> anyhow::Result<T> {
38        let config_string = std::fs::read_to_string(config_path)?;
39        DefaulConfigFileFormat::deserialize_from_str(&config_string).map_err(Into::into)
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use serde::{Deserialize, Serialize};
46    use tempfile::NamedTempFile;
47
48    use super::*;
49
50    #[derive(Debug, Default, PartialEq, Serialize, Deserialize)]
51    struct StubConfig {
52        location: Location,
53    }
54
55    impl StubConfig {
56        pub fn serialize(&self, path: &Path) {
57            let config_string = DefaulConfigFileFormat::serialize_to_string(self).unwrap();
58            std::fs::write(path, config_string).unwrap();
59        }
60    }
61
62    #[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
63    enum Location {
64        Arg,
65        Env,
66        Traversal,
67        #[default]
68        Fallback,
69    }
70
71    #[test]
72    #[serial_test::serial(config_retrieval)]
73    fn retrieves_config_by_arg() {
74        let expected_config = StubConfig { location: Location::Arg };
75        let temp_file = NamedTempFile::new().unwrap();
76        expected_config.serialize(temp_file.path());
77
78        let retrieved_config = retrieve_impl(Some(temp_file.path())).unwrap();
79        assert_eq!(expected_config, retrieved_config);
80    }
81
82    #[test]
83    #[serial_test::serial(config_retrieval)]
84    fn retrieves_config_by_env() {
85        let expected_config = StubConfig { location: Location::Env };
86        let temp_file = NamedTempFile::new().unwrap();
87        expected_config.serialize(temp_file.path());
88
89        std::env::set_var(ROPS_CONFIG_ENV_VAR_NAME, temp_file.path());
90
91        let retrieved_config = retrieve_impl(None).unwrap();
92        assert_eq!(expected_config, retrieved_config);
93
94        std::env::remove_var(ROPS_CONFIG_ENV_VAR_NAME);
95    }
96
97    #[test]
98    #[serial_test::serial(config_retrieval)]
99    fn retrieves_config_by_traversal_in_current() {
100        test_traversal_impl(&std::env::current_dir().unwrap())
101    }
102
103    #[test]
104    #[serial_test::serial(config_retrieval)]
105    fn retrieves_config_by_traversal_in_ancestor() {
106        test_traversal_impl(std::env::current_dir().unwrap().parent().unwrap())
107    }
108
109    fn test_traversal_impl(directory_path: &Path) {
110        let expected_config = StubConfig {
111            location: Location::Traversal,
112        };
113        let path = directory_path.join(ROPS_CONFIG_DEFAULT_FILE_NAME);
114        expected_config.serialize(&path);
115
116        let retrieved_config = retrieve_impl(None).unwrap();
117        assert_eq!(expected_config, retrieved_config);
118
119        std::fs::remove_file(path).unwrap();
120    }
121
122    #[test]
123    #[serial_test::serial(config_retrieval)]
124    fn retrieves_config_by_default_fallback() {
125        assert_eq!(StubConfig::default(), retrieve_impl(None).unwrap());
126    }
127}