1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::path::Path;

use rops::file::format::*;
use serde::de::DeserializeOwned;

pub type DefaulConfigFileFormat = TomlFileFormat;
const ROPS_CONFIG_ENV_VAR_NAME: &str = "ROPS_CONFIG";
const ROPS_CONFIG_DEFAULT_FILE_NAME: &str = ".rops.toml";

// separated with generic parameter to simplify unit testing of strategy
pub(super) fn retrieve_impl<T: DeserializeOwned + Default>(optional_config_path: Option<&Path>) -> anyhow::Result<T> {
    if let Some(arg_path) = optional_config_path {
        return read_fs_path_and_deserialize::<T>(arg_path);
    }

    if let Some(env_path) = std::env::var_os(ROPS_CONFIG_ENV_VAR_NAME) {
        return read_fs_path_and_deserialize::<T>(env_path);
    }

    return traverse_fs_or_default::<T>();

    fn traverse_fs_or_default<T: DeserializeOwned + Default>() -> anyhow::Result<T> {
        let mut traversal_path = std::env::current_dir()?;
        loop {
            traversal_path.push(ROPS_CONFIG_DEFAULT_FILE_NAME);
            if traversal_path.exists() {
                return read_fs_path_and_deserialize::<T>(traversal_path);
            }
            traversal_path.pop();

            if !traversal_path.pop() {
                return Ok(T::default());
            }
        }
    }

    fn read_fs_path_and_deserialize<T: DeserializeOwned>(config_path: impl AsRef<Path>) -> anyhow::Result<T> {
        let config_string = std::fs::read_to_string(config_path)?;
        DefaulConfigFileFormat::deserialize_from_str(&config_string).map_err(Into::into)
    }
}

#[cfg(test)]
mod tests {
    use serde::{Deserialize, Serialize};
    use tempfile::NamedTempFile;

    use super::*;

    #[derive(Debug, Default, PartialEq, Serialize, Deserialize)]
    struct StubConfig {
        location: Location,
    }

    impl StubConfig {
        pub fn serialize(&self, path: &Path) {
            let config_string = DefaulConfigFileFormat::serialize_to_string(self).unwrap();
            std::fs::write(path, config_string).unwrap();
        }
    }

    #[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
    enum Location {
        Arg,
        Env,
        Traversal,
        #[default]
        Fallback,
    }

    #[test]
    #[serial_test::serial(config_retrieval)]
    fn retrieves_config_by_arg() {
        let expected_config = StubConfig { location: Location::Arg };
        let temp_file = NamedTempFile::new().unwrap();
        expected_config.serialize(temp_file.path());

        let retrieved_config = retrieve_impl(Some(temp_file.path())).unwrap();
        assert_eq!(expected_config, retrieved_config);
    }

    #[test]
    #[serial_test::serial(config_retrieval)]
    fn retrieves_config_by_env() {
        let expected_config = StubConfig { location: Location::Env };
        let temp_file = NamedTempFile::new().unwrap();
        expected_config.serialize(temp_file.path());

        std::env::set_var(ROPS_CONFIG_ENV_VAR_NAME, temp_file.path());

        let retrieved_config = retrieve_impl(None).unwrap();
        assert_eq!(expected_config, retrieved_config);

        std::env::remove_var(ROPS_CONFIG_ENV_VAR_NAME);
    }

    #[test]
    #[serial_test::serial(config_retrieval)]
    fn retrieves_config_by_traversal_in_current() {
        test_traversal_impl(&std::env::current_dir().unwrap())
    }

    #[test]
    #[serial_test::serial(config_retrieval)]
    fn retrieves_config_by_traversal_in_ancestor() {
        test_traversal_impl(std::env::current_dir().unwrap().parent().unwrap())
    }

    fn test_traversal_impl(directory_path: &Path) {
        let expected_config = StubConfig {
            location: Location::Traversal,
        };
        let path = directory_path.join(ROPS_CONFIG_DEFAULT_FILE_NAME);
        expected_config.serialize(&path);

        let retrieved_config = retrieve_impl(None).unwrap();
        assert_eq!(expected_config, retrieved_config);

        std::fs::remove_file(path).unwrap();
    }

    #[test]
    #[serial_test::serial(config_retrieval)]
    fn retrieves_config_by_default_fallback() {
        assert_eq!(StubConfig::default(), retrieve_impl(None).unwrap());
    }
}