aws_mfa_session/
profile.rs1use ini::Ini;
2
3pub fn get_mfa_serial_from_profile(profile_name: Option<&str>) -> Option<String> {
5 let profile_name = profile_name.unwrap_or("default");
6
7 let config_path = match std::env::var("AWS_CONFIG_FILE") {
9 Ok(path) => path,
10 Err(_) => {
11 let mut home = dirs::home_dir()?;
12 home.push(".aws");
13 home.push("config");
14 home.to_string_lossy().to_string()
15 }
16 };
17
18 let credentials_path = match std::env::var("AWS_SHARED_CREDENTIALS_FILE") {
19 Ok(path) => path,
20 Err(_) => {
21 let mut home = dirs::home_dir()?;
22 home.push(".aws");
23 home.push("credentials");
24 home.to_string_lossy().to_string()
25 }
26 };
27
28 let mut mfa_serial = None;
30
31 if std::path::Path::new(&credentials_path).exists()
33 && let Some(mfa) = extract_mfa_serial_with_ini(&credentials_path, profile_name)
34 {
35 mfa_serial = Some(mfa);
36 }
37
38 if std::path::Path::new(&config_path).exists()
40 && let Some(mfa) = extract_mfa_serial_with_ini(&config_path, profile_name)
41 {
42 mfa_serial = Some(mfa);
43 }
44
45 mfa_serial
46}
47
48fn extract_mfa_serial_with_ini(file_path: &str, target_profile: &str) -> Option<String> {
50 let conf = Ini::load_from_file(file_path).ok()?;
51
52 let profile_section_name = if target_profile == "default" {
58 vec!["default".to_string(), format!("profile {}", target_profile)]
60 } else {
61 vec![format!("profile {}", target_profile)]
62 };
63
64 for section_name in &profile_section_name {
65 if let Some(section) = conf.section(Some(section_name))
66 && let Some(mfa_serial) = section.get("mfa_serial")
67 {
68 return Some(mfa_serial.to_string());
69 }
70 }
71
72 if let Some(section) = conf.section(Some(target_profile))
74 && let Some(mfa_serial) = section.get("mfa_serial")
75 {
76 return Some(mfa_serial.to_string());
77 }
78
79 None
80}
81
82#[cfg(test)]
83mod test {
84 use super::*;
85 use std::io::Write;
86 use tempfile::NamedTempFile;
87
88 #[test]
89 fn test_get_mfa_serial_missing_files() {
90 let mfa_serial = get_mfa_serial_from_profile(Some("nonexistent"));
92 assert_eq!(mfa_serial, None);
93 }
94
95 #[test]
96 fn test_extract_mfa_serial_with_ini_basic() {
97 let content = r#"
99[profile dev]
100mfa_serial = arn:aws:iam::123456789012:mfa/dev-user
101region = us-west-2
102
103[prod]
104mfa_serial = GAHT12345678
105"#;
106
107 let mut temp_file = NamedTempFile::new().unwrap();
108 temp_file.write_all(content.as_bytes()).unwrap();
109 let temp_path = temp_file.path().to_str().unwrap();
110
111 let dev_mfa = extract_mfa_serial_with_ini(temp_path, "dev");
112 let prod_mfa = extract_mfa_serial_with_ini(temp_path, "prod");
113
114 assert_eq!(
115 dev_mfa,
116 Some("arn:aws:iam::123456789012:mfa/dev-user".to_string())
117 );
118 assert_eq!(prod_mfa, Some("GAHT12345678".to_string()));
119 }
120
121 #[test]
122 fn test_extract_mfa_serial_with_ini_none() {
123 let content = r#"
125[profile dev]
126region = us-west-2
127
128[prod]
129region = us-east-1
130"#;
131 let mut temp_file = NamedTempFile::new().unwrap();
132 temp_file.write_all(content.as_bytes()).unwrap();
133 let temp_path = temp_file.path().to_str().unwrap();
134
135 let dev_mfa = extract_mfa_serial_with_ini(temp_path, "dev");
137 let prod_mfa = extract_mfa_serial_with_ini(temp_path, "prod");
138 assert_eq!(dev_mfa, None);
139 assert_eq!(prod_mfa, None);
140 }
141}