1mod args;
2mod credentials;
3mod error;
4mod profile;
5mod shell;
6
7pub use args::Args;
8use credentials::*;
9use error::CliError;
10pub use profile::get_mfa_serial_from_profile;
11use shell::Shell;
12
13use std::collections::HashMap;
14use std::env;
15use std::io;
16use std::process::Command;
17
18use aws_config::{BehaviorVersion, Region, meta::credentials::CredentialsProviderChain};
19use aws_sdk_iam::Client;
20use aws_sdk_sts::Client as StsClient;
21
22#[cfg(not(target_os = "windows"))]
23const DEFAULT_SHELL: &str = "/bin/sh";
24
25#[cfg(target_os = "windows")]
26const DEFAULT_SHELL: &str = "cmd.exe";
27
28const AWS_PROFILE: &str = "AWS_PROFILE";
29const AWS_DEFAULT_REGION: &str = "AWS_DEFAULT_REGION";
30
31const AWS_SHARED_CREDENTIALS_FILE: &str = "AWS_SHARED_CREDENTIALS_FILE";
32
33pub async fn run(opts: Args) -> Result<(), CliError> {
34 if let Some(ref profile) = opts.profile {
36 unsafe {
39 env::set_var(AWS_PROFILE, profile);
40 }
41 }
42
43 if let Some(file) = opts.credentials_file {
44 unsafe {
47 env::set_var(AWS_SHARED_CREDENTIALS_FILE, file);
48 }
49 }
50
51 let region_provider =
52 aws_config::meta::region::RegionProviderChain::first_try(opts.region.clone())
53 .or_default_provider()
54 .or_else(env::var(AWS_DEFAULT_REGION).ok().map(Region::new))
55 .or_else(Region::new("us-east-1"));
56
57 let credentials_provider = CredentialsProviderChain::default_provider().await;
58 let shared_config = aws_config::defaults(BehaviorVersion::latest())
59 .region(region_provider)
60 .credentials_provider(credentials_provider)
61 .load()
62 .await;
63
64 let iam_client = Client::new(&shared_config);
65 let serial_number = match opts.arn {
66 None => {
67 if let Some(mfa_serial) = get_mfa_serial_from_profile(opts.profile.as_deref()) {
69 mfa_serial
70 } else {
71 let response = iam_client.list_mfa_devices().max_items(1).send().await?;
73 let mfa_devices = response.mfa_devices();
74 let serial = &mfa_devices.first().ok_or(CliError::NoMFA)?.serial_number();
75 (*serial).to_owned()
76 }
77 }
78 Some(other) => other,
79 };
80
81 let sts_client = StsClient::new(&shared_config);
82
83 let credentials = sts_client
84 .get_session_token()
85 .set_serial_number(Some(serial_number))
86 .token_code(
87 opts.code
88 .expect("MFA code should be available after get_code() call"),
89 )
90 .duration_seconds(opts.duration)
91 .send()
92 .await?
93 .credentials()
94 .map(ToOwned::to_owned)
95 .ok_or(CliError::NoCredentials)?;
96
97 let identity = sts_client.get_caller_identity().send().await?;
98
99 let user = iam_client
100 .get_user()
101 .send()
102 .await?
103 .user()
104 .map(ToOwned::to_owned)
105 .ok_or(CliError::NoAccount)?;
106
107 let account = identity.account.ok_or(CliError::NoAccount)?;
108 let ps = format!("AWS:{}@{} \\$ ", user.user_name(), account);
109 let shell = std::env::var("SHELL").unwrap_or_else(|_| DEFAULT_SHELL.to_owned());
110
111 if let Some(name) = opts.session_profile {
112 let c = credentials.clone();
113 let profile = Profile {
114 name,
115 access_key_id: c.access_key_id().to_owned(),
116 secret_access_key: c.secret_access_key().to_owned(),
117 session_token: Some(c.session_token().to_owned()),
118 region: opts.region.map(|r| r.to_string()),
119 };
120 update_credentials(&profile)?;
121 }
122
123 if opts.shell {
124 let c = credentials.clone();
125 let envs: HashMap<&str, String> = [
126 ("AWS_ACCESS_KEY_ID", c.access_key_id().to_owned()),
127 ("AWS_SECRET_ACCESS_KEY", c.secret_access_key().to_owned()),
128 ("AWS_SESSION_TOKEN", c.session_token().to_owned()),
129 ("PS1", ps.clone()),
130 ]
131 .iter()
132 .cloned()
133 .collect();
134
135 Command::new(shell.clone()).envs(envs).status()?;
136 }
137
138 if opts.export {
139 let mut stdout = io::stdout().lock();
140 Shell::from(shell.as_str()).export(
141 &mut stdout,
142 credentials.access_key_id(),
143 credentials.secret_access_key(),
144 credentials.session_token(),
145 &ps,
146 )?;
147 }
148
149 Ok(())
150}
151
152#[cfg(test)]
153mod tests {
154
155 #[test]
156 fn test_env_var_setting_logic() {
157 let profile = Some("test-profile".to_string());
162 let file = Some("/test/credentials".to_string());
163
164 assert!(profile.is_some()); assert!(file.is_some()); let profile_none: Option<String> = None;
170 let file_none: Option<String> = None;
171
172 assert!(profile_none.is_none()); assert!(file_none.is_none()); }
175}