aws_mfa_session/
args.rs

1use crate::error::CliError;
2use aws_config::Region;
3use clap::Parser;
4use dialoguer::Input;
5
6pub fn region(s: &str) -> Result<Region, CliError> {
7    Ok(Region::new(s.to_owned()))
8}
9
10fn parse_code(s: &str) -> Result<String, CliError> {
11    if s.chars().all(char::is_numeric) && s.len() == 6 {
12        Ok(s.to_string())
13    } else {
14        Err(CliError::ValidationError(
15            "MFA code must be exactly 6 digits".to_string(),
16        ))
17    }
18}
19
20#[derive(Parser, Debug, Clone)]
21#[command(
22    name = "aws-mfa-session",
23    about = "AWS MFA session manager",
24    long_about = None,
25)]
26pub struct Args {
27    /// AWS credential profile to use. AWS_PROFILE is used by default
28    #[arg(long = "profile", short = 'p')]
29    pub profile: Option<String>,
30    /// AWS credentials file location to use. AWS_SHARED_CREDENTIALS_FILE is used if not defined
31    #[arg(long = "credentials-file", short = 'f')]
32    pub credentials_file: Option<String>,
33    /// AWS region. AWS_REGION is used if not defined
34    #[arg(long = "region", short = 'r', value_parser = region)]
35    pub region: Option<Region>,
36    /// MFA code from MFA resource
37    #[arg(long = "code", short = 'c', value_parser = parse_code)]
38    pub code: Option<String>,
39    /// MFA device ARN. If not provided, will try to read mfa_serial from AWS profile configuration, then fall back to automatic detection
40    #[arg(long = "arn", short = 'a')]
41    pub arn: Option<String>,
42    /// Session duration in seconds (900-129600)
43    #[arg(long = "duration", short = 'd', default_value_t = 3600, value_parser = clap::value_parser!(i32).range(900..129600))]
44    pub duration: i32,
45    /// Run shell with AWS credentials as environment variables
46    #[arg(long = "shell", short = 's')]
47    pub shell: bool,
48    /// Print(export) AWS credentials as environment variables
49    #[arg(long = "export", short = 'e')]
50    pub export: bool,
51    /// Update AWS credential profile with temporary session credentials
52    #[arg(long = "update-profile", short = 'u')]
53    pub session_profile: Option<String>,
54}
55
56impl Args {
57    pub fn get_code(&mut self) -> Result<(), CliError> {
58        self.code = match &self.code {
59            None => {
60                if cfg!(test) {
61                    Some("123456".to_string())
62                } else {
63                    Some(ask_code_interactive()?)
64                }
65            }
66            code => code.to_owned(),
67        };
68        Ok(())
69    }
70}
71
72fn ask_code_interactive() -> Result<String, CliError> {
73    let code: String = Input::new()
74        .with_prompt("Enter MFA code")
75        .interact_text()
76        .map_err(|e| CliError::ValidationError(e.to_string()))?;
77
78    parse_code(&code)
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use clap::CommandFactory;
85
86    #[test]
87    fn test_parse_code_valid() {
88        assert_eq!(parse_code("123456").unwrap(), "123456");
89        assert_eq!(parse_code("000000").unwrap(), "000000");
90        assert_eq!(parse_code("999999").unwrap(), "999999");
91    }
92
93    #[test]
94    fn test_parse_code_invalid_length() {
95        assert!(parse_code("12345").is_err());
96        assert!(parse_code("1234567").is_err());
97        assert!(parse_code("").is_err());
98    }
99
100    #[test]
101    fn test_parse_code_invalid_characters() {
102        assert!(parse_code("12345a").is_err());
103        assert!(parse_code("abcdef").is_err());
104        assert!(parse_code("12-456").is_err());
105        assert!(parse_code("123 56").is_err());
106    }
107
108    #[test]
109    fn test_region_parsing() {
110        let parsed_region = region("us-east-1").unwrap();
111        assert_eq!(parsed_region.to_string(), "us-east-1");
112
113        let parsed_region = region("eu-west-1").unwrap();
114        assert_eq!(parsed_region.to_string(), "eu-west-1");
115    }
116
117    #[test]
118    fn test_args_parsing_minimal() {
119        let args = Args::try_parse_from(["aws-mfa-session", "--code", "123456"]);
120        assert!(args.is_ok());
121        let args = args.unwrap();
122        assert_eq!(args.code, Some("123456".to_string()));
123        assert_eq!(args.duration, 3600); // default
124        assert!(!args.shell);
125        assert!(!args.export);
126    }
127
128    #[test]
129    fn test_args_parsing_all_options() {
130        let args = Args::try_parse_from([
131            "aws-mfa-session",
132            "--profile",
133            "test-profile",
134            "--credentials-file",
135            "/custom/path/credentials",
136            "--region",
137            "us-west-2",
138            "--code",
139            "654321",
140            "--arn",
141            "arn:aws:iam::123456789012:mfa/test-user",
142            "--duration",
143            "7200",
144            "-s",
145            "-e",
146            "--update-profile",
147            "temp-session",
148        ]);
149
150        assert!(args.is_ok());
151        let args = args.unwrap();
152        assert_eq!(args.profile, Some("test-profile".to_string()));
153        assert_eq!(
154            args.credentials_file,
155            Some("/custom/path/credentials".to_string())
156        );
157        assert_eq!(args.region.unwrap().to_string(), "us-west-2");
158        assert_eq!(args.code, Some("654321".to_string()));
159        assert_eq!(
160            args.arn,
161            Some("arn:aws:iam::123456789012:mfa/test-user".to_string())
162        );
163        assert_eq!(args.duration, 7200);
164        assert!(args.shell);
165        assert!(args.export);
166        assert_eq!(args.session_profile, Some("temp-session".to_string()));
167    }
168
169    #[test]
170    fn test_args_parsing_invalid_code() {
171        let args = Args::try_parse_from(["aws-mfa-session", "--code", "12345"]);
172        assert!(args.is_err());
173
174        let args = Args::try_parse_from(["aws-mfa-session", "--code", "abcdef"]);
175        assert!(args.is_err());
176    }
177
178    #[test]
179    fn test_args_parsing_invalid_duration() {
180        let args_low =
181            Args::try_parse_from(["aws-mfa-session", "--code", "123456", "--duration", "800"]);
182        assert!(args_low.is_err());
183
184        let args_high = Args::try_parse_from([
185            "aws-mfa-session",
186            "--code",
187            "123456",
188            "--duration",
189            "200000",
190        ]);
191        assert!(args_high.is_err());
192    }
193
194    #[test]
195    fn test_args_parsing_missing_code() {
196        let args = Args::try_parse_from(["aws-mfa-session"]);
197        // This should succeed, as code is optional in the Args struct
198        assert!(args.is_ok());
199    }
200
201    #[test]
202    fn test_args_short_flags() {
203        let args = Args::try_parse_from([
204            "aws-mfa-session",
205            "-p",
206            "profile",
207            "-f",
208            "/path/to/file",
209            "-r",
210            "ap-southeast-1",
211            "-c",
212            "123456",
213            "-a",
214            "arn:aws:iam::123:mfa/user",
215            "-d",
216            "1800",
217            "-s",
218            "-e",
219            "-u",
220            "session",
221        ]);
222
223        assert!(args.is_ok());
224        let args = args.unwrap();
225        assert_eq!(args.profile, Some("profile".to_string()));
226        assert_eq!(args.credentials_file, Some("/path/to/file".to_string()));
227        assert_eq!(args.region.unwrap().to_string(), "ap-southeast-1");
228        assert_eq!(args.code, Some("123456".to_string()));
229        assert_eq!(args.arn, Some("arn:aws:iam::123:mfa/user".to_string()));
230        assert_eq!(args.duration, 1800);
231        assert!(args.shell);
232        assert!(args.export);
233        assert_eq!(args.session_profile, Some("session".to_string()));
234    }
235
236    #[test]
237    fn test_command_structure() {
238        let cmd = Args::command();
239        assert_eq!(cmd.get_name(), "aws-mfa-session");
240        assert!(cmd.get_about().is_some());
241    }
242
243    #[test]
244    fn test_args_clone() {
245        let args = Args::try_parse_from(["aws-mfa-session", "--code", "123456"]).unwrap();
246        let cloned = args.clone();
247        assert_eq!(args.code, cloned.code);
248        assert_eq!(args.duration, cloned.duration);
249    }
250
251    #[test]
252    fn test_args_debug() {
253        let args = Args::try_parse_from(["aws-mfa-session", "--code", "123456"]).unwrap();
254        let debug_str = format!("{args:?}");
255        assert!(debug_str.contains("Args"));
256        assert!(debug_str.contains("123456"));
257    }
258
259    #[test]
260    fn test_get_code_with_existing_code() {
261        let mut args = Args::try_parse_from(["aws-mfa-session", "--code", "654321"]).unwrap();
262        assert_eq!(args.code, Some("654321".to_string()));
263
264        // get_code should not change existing code
265        args.get_code().unwrap();
266        assert_eq!(args.code, Some("654321".to_string()));
267    }
268
269    #[test]
270    fn test_get_code_without_code_in_test_mode() {
271        let mut args = Args::try_parse_from(["aws-mfa-session"]).unwrap();
272        assert_eq!(args.code, None);
273
274        // In test mode, get_code should set code to "123456"
275        args.get_code().unwrap();
276        assert_eq!(args.code, Some("123456".to_string()));
277    }
278
279    #[test]
280    fn test_get_code_preserves_existing_valid_code() {
281        let mut args = Args::try_parse_from(["aws-mfa-session", "--code", "999888"]).unwrap();
282        let original_code = args.code.clone();
283
284        args.get_code().unwrap();
285        assert_eq!(args.code, original_code);
286    }
287
288    #[test]
289    fn test_ask_code_interactive_validation() {
290        // Test that the interactive code asking validates input
291        assert!(parse_code("123456").is_ok());
292        assert!(parse_code("abcdef").is_err());
293        assert!(parse_code("12345").is_err());
294        assert!(parse_code("1234567").is_err());
295    }
296}