aws_mfa_session/
profile.rs

1use ini::Ini;
2
3/// Read MFA serial from AWS profile configuration using INI parsing
4pub fn get_mfa_serial_from_profile(profile_name: Option<&str>) -> Option<String> {
5    let profile_name = profile_name.unwrap_or("default");
6
7    // Use the same environment variable logic as AWS SDK for file paths
8    let config_path = match std::env::var("AWS_CONFIG_FILE") {
9        Ok(path) => path,
10        Err(_) => {
11            let mut home = dirs::home_dir()?;
12            home.push(".aws");
13            home.push("config");
14            home.to_string_lossy().to_string()
15        }
16    };
17
18    let credentials_path = match std::env::var("AWS_SHARED_CREDENTIALS_FILE") {
19        Ok(path) => path,
20        Err(_) => {
21            let mut home = dirs::home_dir()?;
22            home.push(".aws");
23            home.push("credentials");
24            home.to_string_lossy().to_string()
25        }
26    };
27
28    // Try to read mfa_serial from both files, with config taking precedence
29    let mut mfa_serial = None;
30
31    // Check credentials file first (lower precedence)
32    if std::path::Path::new(&credentials_path).exists()
33        && let Some(mfa) = extract_mfa_serial_with_ini(&credentials_path, profile_name)
34    {
35        mfa_serial = Some(mfa);
36    }
37
38    // Check config file second (higher precedence, will override credentials file)
39    if std::path::Path::new(&config_path).exists()
40        && let Some(mfa) = extract_mfa_serial_with_ini(&config_path, profile_name)
41    {
42        mfa_serial = Some(mfa);
43    }
44
45    mfa_serial
46}
47
48/// Extract MFA serial from AWS config file using proper INI parsing
49fn extract_mfa_serial_with_ini(file_path: &str, target_profile: &str) -> Option<String> {
50    let conf = Ini::load_from_file(file_path).ok()?;
51
52    // Try both AWS config file formats:
53    // 1. [profile name] format (used in config files)
54    // 2. [name] format (used in credentials files)
55
56    // First try the config file format [profile name]
57    let profile_section_name = if target_profile == "default" {
58        // In config files, default profile can be either [default] or [profile default]
59        vec!["default".to_string(), format!("profile {}", target_profile)]
60    } else {
61        vec![format!("profile {}", target_profile)]
62    };
63
64    for section_name in &profile_section_name {
65        if let Some(section) = conf.section(Some(section_name))
66            && let Some(mfa_serial) = section.get("mfa_serial")
67        {
68            return Some(mfa_serial.to_string());
69        }
70    }
71
72    // If not found, try the credentials file format [name]
73    if let Some(section) = conf.section(Some(target_profile))
74        && let Some(mfa_serial) = section.get("mfa_serial")
75    {
76        return Some(mfa_serial.to_string());
77    }
78
79    None
80}
81
82#[cfg(test)]
83mod test {
84    use super::*;
85    use std::io::Write;
86    use tempfile::NamedTempFile;
87
88    #[test]
89    fn test_get_mfa_serial_missing_files() {
90        // Test when config files don't exist - should return None
91        let mfa_serial = get_mfa_serial_from_profile(Some("nonexistent"));
92        assert_eq!(mfa_serial, None);
93    }
94
95    #[test]
96    fn test_extract_mfa_serial_with_ini_basic() {
97        // Create a temporary file with INI content for testing
98        let content = r#"
99[profile dev]
100mfa_serial = arn:aws:iam::123456789012:mfa/dev-user
101region = us-west-2
102
103[prod]
104mfa_serial = GAHT12345678
105"#;
106
107        let mut temp_file = NamedTempFile::new().unwrap();
108        temp_file.write_all(content.as_bytes()).unwrap();
109        let temp_path = temp_file.path().to_str().unwrap();
110
111        let dev_mfa = extract_mfa_serial_with_ini(temp_path, "dev");
112        let prod_mfa = extract_mfa_serial_with_ini(temp_path, "prod");
113
114        assert_eq!(
115            dev_mfa,
116            Some("arn:aws:iam::123456789012:mfa/dev-user".to_string())
117        );
118        assert_eq!(prod_mfa, Some("GAHT12345678".to_string()));
119    }
120
121    #[test]
122    fn test_extract_mfa_serial_with_ini_none() {
123        // Create a temporary file with INI content that does NOT contain mfa_serial for the target profile
124        let content = r#"
125[profile dev]
126region = us-west-2
127
128[prod]
129region = us-east-1
130"#;
131        let mut temp_file = NamedTempFile::new().unwrap();
132        temp_file.write_all(content.as_bytes()).unwrap();
133        let temp_path = temp_file.path().to_str().unwrap();
134
135        // Should return None for missing mfa_serial
136        let dev_mfa = extract_mfa_serial_with_ini(temp_path, "dev");
137        let prod_mfa = extract_mfa_serial_with_ini(temp_path, "prod");
138        assert_eq!(dev_mfa, None);
139        assert_eq!(prod_mfa, None);
140    }
141}