rops_cli/config/
retrieve.rs1use std::path::Path;
2
3use rops::file::format::*;
4use serde::de::DeserializeOwned;
5
6pub type DefaulConfigFileFormat = TomlFileFormat;
7const ROPS_CONFIG_ENV_VAR_NAME: &str = "ROPS_CONFIG";
8const ROPS_CONFIG_DEFAULT_FILE_NAME: &str = ".rops.toml";
9
10pub(super) fn retrieve_impl<T: DeserializeOwned + Default>(optional_config_path: Option<&Path>) -> anyhow::Result<T> {
12 if let Some(arg_path) = optional_config_path {
13 return read_fs_path_and_deserialize::<T>(arg_path);
14 }
15
16 if let Some(env_path) = std::env::var_os(ROPS_CONFIG_ENV_VAR_NAME) {
17 return read_fs_path_and_deserialize::<T>(env_path);
18 }
19
20 return traverse_fs_or_default::<T>();
21
22 fn traverse_fs_or_default<T: DeserializeOwned + Default>() -> anyhow::Result<T> {
23 let mut traversal_path = std::env::current_dir()?;
24 loop {
25 traversal_path.push(ROPS_CONFIG_DEFAULT_FILE_NAME);
26 if traversal_path.exists() {
27 return read_fs_path_and_deserialize::<T>(traversal_path);
28 }
29 traversal_path.pop();
30
31 if !traversal_path.pop() {
32 return Ok(T::default());
33 }
34 }
35 }
36
37 fn read_fs_path_and_deserialize<T: DeserializeOwned>(config_path: impl AsRef<Path>) -> anyhow::Result<T> {
38 let config_string = std::fs::read_to_string(config_path)?;
39 DefaulConfigFileFormat::deserialize_from_str(&config_string).map_err(Into::into)
40 }
41}
42
43#[cfg(test)]
44mod tests {
45 use serde::{Deserialize, Serialize};
46 use tempfile::NamedTempFile;
47
48 use super::*;
49
50 #[derive(Debug, Default, PartialEq, Serialize, Deserialize)]
51 struct StubConfig {
52 location: Location,
53 }
54
55 impl StubConfig {
56 pub fn serialize(&self, path: &Path) {
57 let config_string = DefaulConfigFileFormat::serialize_to_string(self).unwrap();
58 std::fs::write(path, config_string).unwrap();
59 }
60 }
61
62 #[derive(Debug, Default, Clone, Copy, PartialEq, Serialize, Deserialize)]
63 enum Location {
64 Arg,
65 Env,
66 Traversal,
67 #[default]
68 Fallback,
69 }
70
71 #[test]
72 #[serial_test::serial(config_retrieval)]
73 fn retrieves_config_by_arg() {
74 let expected_config = StubConfig { location: Location::Arg };
75 let temp_file = NamedTempFile::new().unwrap();
76 expected_config.serialize(temp_file.path());
77
78 let retrieved_config = retrieve_impl(Some(temp_file.path())).unwrap();
79 assert_eq!(expected_config, retrieved_config);
80 }
81
82 #[test]
83 #[serial_test::serial(config_retrieval)]
84 fn retrieves_config_by_env() {
85 let expected_config = StubConfig { location: Location::Env };
86 let temp_file = NamedTempFile::new().unwrap();
87 expected_config.serialize(temp_file.path());
88
89 std::env::set_var(ROPS_CONFIG_ENV_VAR_NAME, temp_file.path());
90
91 let retrieved_config = retrieve_impl(None).unwrap();
92 assert_eq!(expected_config, retrieved_config);
93
94 std::env::remove_var(ROPS_CONFIG_ENV_VAR_NAME);
95 }
96
97 #[test]
98 #[serial_test::serial(config_retrieval)]
99 fn retrieves_config_by_traversal_in_current() {
100 test_traversal_impl(&std::env::current_dir().unwrap())
101 }
102
103 #[test]
104 #[serial_test::serial(config_retrieval)]
105 fn retrieves_config_by_traversal_in_ancestor() {
106 test_traversal_impl(std::env::current_dir().unwrap().parent().unwrap())
107 }
108
109 fn test_traversal_impl(directory_path: &Path) {
110 let expected_config = StubConfig {
111 location: Location::Traversal,
112 };
113 let path = directory_path.join(ROPS_CONFIG_DEFAULT_FILE_NAME);
114 expected_config.serialize(&path);
115
116 let retrieved_config = retrieve_impl(None).unwrap();
117 assert_eq!(expected_config, retrieved_config);
118
119 std::fs::remove_file(path).unwrap();
120 }
121
122 #[test]
123 #[serial_test::serial(config_retrieval)]
124 fn retrieves_config_by_default_fallback() {
125 assert_eq!(StubConfig::default(), retrieve_impl(None).unwrap());
126 }
127}