orion_sec/
load.rs

1use std::{env, path::PathBuf};
2
3use log::{info, warn};
4use orion_conf::{TomlIO, Yamlable};
5use orion_error::{ErrorOwe, ErrorWith};
6use orion_variate::vars::UpperKey;
7use orion_variate::vars::{EnvDict, ValueDict};
8
9use crate::{
10    error::SecResult,
11    sec::{NoSecConv, SecFrom, SecValueObj, SecValueType},
12};
13
14const SEC_PREFIX: &str = "SEC_";
15const SEC_VALUE_FILE_NAME: &str = "sec_value.yml";
16const GALAXY_DOT_DIR: &str = ".galaxy";
17const DEFAULT_FALLBACK_DIR: &str = "./";
18
19pub fn load_sec_dict() -> SecResult<EnvDict> {
20    let space = load_secfile()?;
21    let mut dict = EnvDict::new();
22    for (k, v) in space.no_sec() {
23        dict.insert(k, v);
24    }
25    Ok(dict)
26}
27
28pub fn load_sec_dict_by(dot_name: &str, file_name: &str, fmt: SecFileFmt) -> SecResult<EnvDict> {
29    let sec_file = dot_path(dot_name).join(file_name);
30    let space = load_secfile_by(sec_file, fmt)?;
31    let mut dict = EnvDict::new();
32    for (k, v) in space.no_sec() {
33        dict.insert(k, v);
34    }
35    Ok(dict)
36}
37
38pub fn load_secfile() -> SecResult<SecValueObj> {
39    let default = sec_value_galaxy_path();
40    load_secfile_by(default, SecFileFmt::Yaml)
41}
42
43pub fn load_galaxy_secfile() -> SecResult<SecValueObj> {
44    let default = sec_value_galaxy_path();
45    load_secfile_by(default, SecFileFmt::Yaml)
46}
47pub enum SecFileFmt {
48    Yaml,
49    Toml,
50}
51
52pub fn load_secfile_by(sec_file: PathBuf, fmt: SecFileFmt) -> SecResult<SecValueObj> {
53    let mut vars_dict = SecValueObj::new();
54    if sec_file.exists() {
55        let dict = match fmt {
56            SecFileFmt::Yaml => ValueDict::load_yaml(&sec_file)
57                .owe_logic()
58                .with(&sec_file)?,
59            SecFileFmt::Toml => ValueDict::load_toml(&sec_file)
60                .owe_logic()
61                .with(&sec_file)?,
62        };
63        info!(target: "exec","  load {}", sec_file.display());
64        for (k, v) in dict.iter() {
65            vars_dict.insert(
66                UpperKey::from(format!("{}{}", SEC_PREFIX, k.as_str().to_uppercase())),
67                SecValueType::sec_from(v.clone()),
68            );
69        }
70    }
71    Ok(vars_dict)
72}
73
74pub fn sec_value_galaxy_path() -> PathBuf {
75    dot_path(GALAXY_DOT_DIR).join(SEC_VALUE_FILE_NAME)
76}
77
78pub fn dot_path(name: &str) -> PathBuf {
79    match resolve_home_dir() {
80        Some(home) => home.join(name),
81        None => {
82            warn!(target: "exec", "  HOME not set; defaulting to current directory for {}", GALAXY_DOT_DIR);
83            PathBuf::from(DEFAULT_FALLBACK_DIR)
84        }
85    }
86}
87
88fn resolve_home_dir() -> Option<PathBuf> {
89    env::var_os("HOME")
90        .map(PathBuf::from)
91        .or_else(|| env::var_os("USERPROFILE").map(PathBuf::from))
92        .or_else(|| {
93            let drive = env::var_os("HOMEDRIVE")?;
94            let path = env::var_os("HOMEPATH")?;
95            let mut buf = PathBuf::from(drive);
96            buf.push(path);
97            Some(buf)
98        })
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use std::ffi::OsString;
105    use std::fs;
106    use std::io::Write;
107    use std::path::Path;
108    use std::sync::{Mutex, MutexGuard, OnceLock};
109    use tempfile::{NamedTempFile, TempDir};
110
111    #[test]
112    fn test_load_secfile_by_nonexistent_file() {
113        let path = PathBuf::from("/nonexistent/path/to/file.yml");
114        let result = load_secfile_by(path, SecFileFmt::Yaml);
115        assert!(result.is_ok());
116        assert!(result.unwrap().is_empty());
117    }
118
119    #[test]
120    fn test_load_secfile_by_yaml() {
121        let mut file = NamedTempFile::with_suffix(".yml").unwrap();
122        writeln!(file, "username: admin").unwrap();
123        writeln!(file, "password: secret123").unwrap();
124        writeln!(file, "port: 8080").unwrap();
125
126        let result = load_secfile_by(file.path().to_path_buf(), SecFileFmt::Yaml);
127        assert!(result.is_ok());
128
129        let obj = result.unwrap();
130        assert_eq!(obj.len(), 3);
131        assert!(obj.contains_key(&UpperKey::from("SEC_USERNAME".to_string())));
132        assert!(obj.contains_key(&UpperKey::from("SEC_PASSWORD".to_string())));
133        assert!(obj.contains_key(&UpperKey::from("SEC_PORT".to_string())));
134    }
135
136    #[test]
137    fn test_load_secfile_by_toml() {
138        let mut file = NamedTempFile::with_suffix(".toml").unwrap();
139        writeln!(file, "api_key = \"abc123\"").unwrap();
140        writeln!(file, "debug = true").unwrap();
141
142        let result = load_secfile_by(file.path().to_path_buf(), SecFileFmt::Toml);
143        assert!(result.is_ok());
144
145        let obj = result.unwrap();
146        assert_eq!(obj.len(), 2);
147        assert!(obj.contains_key(&UpperKey::from("SEC_API_KEY".to_string())));
148        assert!(obj.contains_key(&UpperKey::from("SEC_DEBUG".to_string())));
149    }
150
151    #[test]
152    fn test_load_secfile_by_key_uppercase() {
153        let mut file = NamedTempFile::with_suffix(".yml").unwrap();
154        writeln!(file, "mixedCase: value1").unwrap();
155        writeln!(file, "lower_case: value2").unwrap();
156
157        let result = load_secfile_by(file.path().to_path_buf(), SecFileFmt::Yaml);
158        assert!(result.is_ok());
159
160        let obj = result.unwrap();
161        assert!(obj.contains_key(&UpperKey::from("SEC_MIXEDCASE".to_string())));
162        assert!(obj.contains_key(&UpperKey::from("SEC_LOWER_CASE".to_string())));
163    }
164
165    #[test]
166    fn test_load_secfile_by_values_are_secret() {
167        let mut file = NamedTempFile::with_suffix(".yml").unwrap();
168        writeln!(file, "token: my_secret_token").unwrap();
169
170        let result = load_secfile_by(file.path().to_path_buf(), SecFileFmt::Yaml);
171        assert!(result.is_ok());
172
173        let obj = result.unwrap();
174        let value = obj.get(&UpperKey::from("SEC_TOKEN".to_string())).unwrap();
175        assert!(matches!(value, SecValueType::String(s) if s.is_secret()));
176    }
177
178    #[test]
179    fn test_load_secfile_by_empty_file() {
180        let file = NamedTempFile::with_suffix(".yml").unwrap();
181
182        let result = load_secfile_by(file.path().to_path_buf(), SecFileFmt::Yaml);
183        assert!(result.is_ok());
184        assert!(result.unwrap().is_empty());
185    }
186
187    #[test]
188    fn test_load_sec_dict_by_yaml() {
189        with_temp_home(|home_path| {
190            let dot_dir = home_path.join(".myapp");
191            fs::create_dir_all(&dot_dir).unwrap();
192
193            let sec_file = dot_dir.join("secrets.yml");
194            let mut file = fs::File::create(&sec_file).unwrap();
195            writeln!(file, "db_user: root").unwrap();
196            writeln!(file, "db_pass: password123").unwrap();
197
198            let result = load_sec_dict_by(".myapp", "secrets.yml", SecFileFmt::Yaml);
199            assert!(result.is_ok());
200
201            let dict = result.unwrap();
202            assert_eq!(dict.len(), 2);
203            assert!(dict.contains_key("SEC_DB_USER"));
204            assert!(dict.contains_key("SEC_DB_PASS"));
205        });
206    }
207
208    #[test]
209    fn test_load_sec_dict_by_toml() {
210        with_temp_home(|home_path| {
211            let dot_dir = home_path.join(".config");
212            fs::create_dir_all(&dot_dir).unwrap();
213
214            let sec_file = dot_dir.join("app.toml");
215            let mut file = fs::File::create(&sec_file).unwrap();
216            writeln!(file, "secret_key = \"abc123\"").unwrap();
217            writeln!(file, "enabled = true").unwrap();
218
219            let result = load_sec_dict_by(".config", "app.toml", SecFileFmt::Toml);
220            assert!(result.is_ok());
221
222            let dict = result.unwrap();
223            assert_eq!(dict.len(), 2);
224            assert!(dict.contains_key("SEC_SECRET_KEY"));
225            assert!(dict.contains_key("SEC_ENABLED"));
226        });
227    }
228
229    #[test]
230    fn test_load_sec_dict_by_nonexistent_dir() {
231        with_temp_home(|_| {
232            let result = load_sec_dict_by(".nonexistent", "file.yml", SecFileFmt::Yaml);
233            assert!(result.is_ok());
234            assert!(result.unwrap().is_empty());
235        });
236    }
237
238    #[test]
239    fn test_load_sec_dict_by_values_not_secret() {
240        with_temp_home(|home_path| {
241            let dot_dir = home_path.join(".test");
242            fs::create_dir_all(&dot_dir).unwrap();
243
244            let sec_file = dot_dir.join("data.yml");
245            let mut file = fs::File::create(&sec_file).unwrap();
246            writeln!(file, "value: test_data").unwrap();
247
248            let result = load_sec_dict_by(".test", "data.yml", SecFileFmt::Yaml);
249            assert!(result.is_ok());
250
251            let dict = result.unwrap();
252            // EnvDict 中的值已经通过 no_sec() 转换,不再是 secret
253            assert!(dict.contains_key("SEC_VALUE"));
254        });
255    }
256
257    fn with_temp_home<F>(test: F)
258    where
259        F: FnOnce(&Path),
260    {
261        let temp_dir = TempDir::new().unwrap();
262        let _guard = HomeGuard::set(temp_dir.path());
263        test(temp_dir.path());
264    }
265
266    struct HomeGuard {
267        old_home: Option<OsString>,
268        _lock: MutexGuard<'static, ()>,
269    }
270
271    impl HomeGuard {
272        fn set(path: &Path) -> Self {
273            let lock = home_lock().lock().unwrap();
274            let old_home = env::var_os("HOME");
275            unsafe {
276                env::set_var("HOME", path);
277            }
278
279            Self {
280                old_home,
281                _lock: lock,
282            }
283        }
284    }
285
286    impl Drop for HomeGuard {
287        fn drop(&mut self) {
288            if let Some(ref home) = self.old_home {
289                unsafe {
290                    env::set_var("HOME", home);
291                }
292            } else {
293                unsafe {
294                    env::remove_var("HOME");
295                }
296            }
297        }
298    }
299
300    fn home_lock() -> &'static Mutex<()> {
301        static HOME_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
302        HOME_MUTEX.get_or_init(|| Mutex::new(()))
303    }
304}