aws_mfa_session/
lib.rs

1mod args;
2mod credentials;
3mod error;
4mod profile;
5mod shell;
6
7pub use args::Args;
8use credentials::*;
9use error::CliError;
10pub use profile::get_mfa_serial_from_profile;
11use shell::Shell;
12
13use std::collections::HashMap;
14use std::env;
15use std::io;
16use std::process::Command;
17
18use aws_config::{BehaviorVersion, Region, meta::credentials::CredentialsProviderChain};
19use aws_sdk_iam::Client;
20use aws_sdk_sts::Client as StsClient;
21
22#[cfg(not(target_os = "windows"))]
23const DEFAULT_SHELL: &str = "/bin/sh";
24
25#[cfg(target_os = "windows")]
26const DEFAULT_SHELL: &str = "cmd.exe";
27
28const AWS_PROFILE: &str = "AWS_PROFILE";
29const AWS_DEFAULT_REGION: &str = "AWS_DEFAULT_REGION";
30
31const AWS_SHARED_CREDENTIALS_FILE: &str = "AWS_SHARED_CREDENTIALS_FILE";
32
33pub async fn run(opts: Args) -> Result<(), CliError> {
34    // ProfileProvider is limited, but AWS_PROFILE is used elsewhere
35    if let Some(ref profile) = opts.profile {
36        // SAFETY: Setting AWS_PROFILE environment variable is safe in this single-threaded context
37        // and doesn't interfere with other parts of the application
38        unsafe {
39            env::set_var(AWS_PROFILE, profile);
40        }
41    }
42
43    if let Some(file) = opts.credentials_file {
44        // SAFETY: Setting AWS_SHARED_CREDENTIALS_FILE environment variable is safe in this
45        // single-threaded context and doesn't interfere with other parts of the application
46        unsafe {
47            env::set_var(AWS_SHARED_CREDENTIALS_FILE, file);
48        }
49    }
50
51    let region_provider =
52        aws_config::meta::region::RegionProviderChain::first_try(opts.region.clone())
53            .or_default_provider()
54            .or_else(env::var(AWS_DEFAULT_REGION).ok().map(Region::new))
55            .or_else(Region::new("us-east-1"));
56
57    let credentials_provider = CredentialsProviderChain::default_provider().await;
58    let shared_config = aws_config::defaults(BehaviorVersion::latest())
59        .region(region_provider)
60        .credentials_provider(credentials_provider)
61        .load()
62        .await;
63
64    let iam_client = Client::new(&shared_config);
65    let serial_number = match opts.arn {
66        None => {
67            // First, try to get mfa_serial from profile configuration
68            if let Some(mfa_serial) = get_mfa_serial_from_profile(opts.profile.as_deref()) {
69                mfa_serial
70            } else {
71                // Fallback to automatic MFA device detection
72                let response = iam_client.list_mfa_devices().max_items(1).send().await?;
73                let mfa_devices = response.mfa_devices();
74                let serial = &mfa_devices.first().ok_or(CliError::NoMFA)?.serial_number();
75                (*serial).to_owned()
76            }
77        }
78        Some(other) => other,
79    };
80
81    let sts_client = StsClient::new(&shared_config);
82
83    let credentials = sts_client
84        .get_session_token()
85        .set_serial_number(Some(serial_number))
86        .token_code(
87            opts.code
88                .expect("MFA code should be available after get_code() call"),
89        )
90        .duration_seconds(opts.duration)
91        .send()
92        .await?
93        .credentials()
94        .map(ToOwned::to_owned)
95        .ok_or(CliError::NoCredentials)?;
96
97    let identity = sts_client.get_caller_identity().send().await?;
98
99    let user = iam_client
100        .get_user()
101        .send()
102        .await?
103        .user()
104        .map(ToOwned::to_owned)
105        .ok_or(CliError::NoAccount)?;
106
107    let account = identity.account.ok_or(CliError::NoAccount)?;
108    let ps = format!("AWS:{}@{} \\$ ", user.user_name(), account);
109    let shell = std::env::var("SHELL").unwrap_or_else(|_| DEFAULT_SHELL.to_owned());
110
111    if let Some(name) = opts.session_profile {
112        let c = credentials.clone();
113        let profile = Profile {
114            name,
115            access_key_id: c.access_key_id().to_owned(),
116            secret_access_key: c.secret_access_key().to_owned(),
117            session_token: Some(c.session_token().to_owned()),
118            region: opts.region.map(|r| r.to_string()),
119        };
120        update_credentials(&profile)?;
121    }
122
123    if opts.shell {
124        let c = credentials.clone();
125        let envs: HashMap<&str, String> = [
126            ("AWS_ACCESS_KEY_ID", c.access_key_id().to_owned()),
127            ("AWS_SECRET_ACCESS_KEY", c.secret_access_key().to_owned()),
128            ("AWS_SESSION_TOKEN", c.session_token().to_owned()),
129            ("PS1", ps.clone()),
130        ]
131        .iter()
132        .cloned()
133        .collect();
134
135        Command::new(shell.clone()).envs(envs).status()?;
136    }
137
138    if opts.export {
139        let mut stdout = io::stdout().lock();
140        Shell::from(shell.as_str()).export(
141            &mut stdout,
142            credentials.access_key_id(),
143            credentials.secret_access_key(),
144            credentials.session_token(),
145            &ps,
146        )?;
147    }
148
149    Ok(())
150}
151
152#[cfg(test)]
153mod tests {
154
155    #[test]
156    fn test_env_var_setting_logic() {
157        // Test the logic for setting environment variables based on Args
158        // This test verifies the conditional logic without mocking env vars
159
160        // Test Some() values result in setting environment variables
161        let profile = Some("test-profile".to_string());
162        let file = Some("/test/credentials".to_string());
163
164        // This is the pattern from run() function - verify the conditions work
165        assert!(profile.is_some()); // Would trigger env::set_var in run()
166        assert!(file.is_some()); // Would trigger env::set_var in run()
167
168        // Test None values don't trigger environment variable setting
169        let profile_none: Option<String> = None;
170        let file_none: Option<String> = None;
171
172        assert!(profile_none.is_none()); // Would NOT trigger env::set_var in run()
173        assert!(file_none.is_none()); // Would NOT trigger env::set_var in run()
174    }
175}