1use crate::error::CliError;
2use aws_config::Region;
3use clap::Parser;
4use dialoguer::Input;
5
6pub fn region(s: &str) -> Result<Region, CliError> {
7 Ok(Region::new(s.to_owned()))
8}
9
10fn parse_code(s: &str) -> Result<String, CliError> {
11 if s.chars().all(char::is_numeric) && s.len() == 6 {
12 Ok(s.to_string())
13 } else {
14 Err(CliError::ValidationError(
15 "MFA code must be exactly 6 digits".to_string(),
16 ))
17 }
18}
19
20#[derive(Parser, Debug, Clone)]
21#[command(
22 name = "aws-mfa-session",
23 about = "AWS MFA session manager",
24 long_about = None,
25)]
26pub struct Args {
27 #[arg(long = "profile", short = 'p')]
29 pub profile: Option<String>,
30 #[arg(long = "credentials-file", short = 'f')]
32 pub credentials_file: Option<String>,
33 #[arg(long = "region", short = 'r', value_parser = region)]
35 pub region: Option<Region>,
36 #[arg(long = "code", short = 'c', value_parser = parse_code)]
38 pub code: Option<String>,
39 #[arg(long = "arn", short = 'a')]
41 pub arn: Option<String>,
42 #[arg(long = "duration", short = 'd', default_value_t = 3600, value_parser = clap::value_parser!(i32).range(900..129600))]
44 pub duration: i32,
45 #[arg(long = "shell", short = 's')]
47 pub shell: bool,
48 #[arg(long = "export", short = 'e')]
50 pub export: bool,
51 #[arg(long = "update-profile", short = 'u')]
53 pub session_profile: Option<String>,
54}
55
56impl Args {
57 pub fn get_code(&mut self) -> Result<(), CliError> {
58 self.code = match &self.code {
59 None => {
60 if cfg!(test) {
61 Some("123456".to_string())
62 } else {
63 Some(ask_code_interactive()?)
64 }
65 }
66 code => code.to_owned(),
67 };
68 Ok(())
69 }
70}
71
72fn ask_code_interactive() -> Result<String, CliError> {
73 let code: String = Input::new()
74 .with_prompt("Enter MFA code")
75 .interact_text()
76 .map_err(|e| CliError::ValidationError(e.to_string()))?;
77
78 parse_code(&code)
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use clap::CommandFactory;
85
86 #[test]
87 fn test_parse_code_valid() {
88 assert_eq!(parse_code("123456").unwrap(), "123456");
89 assert_eq!(parse_code("000000").unwrap(), "000000");
90 assert_eq!(parse_code("999999").unwrap(), "999999");
91 }
92
93 #[test]
94 fn test_parse_code_invalid_length() {
95 assert!(parse_code("12345").is_err());
96 assert!(parse_code("1234567").is_err());
97 assert!(parse_code("").is_err());
98 }
99
100 #[test]
101 fn test_parse_code_invalid_characters() {
102 assert!(parse_code("12345a").is_err());
103 assert!(parse_code("abcdef").is_err());
104 assert!(parse_code("12-456").is_err());
105 assert!(parse_code("123 56").is_err());
106 }
107
108 #[test]
109 fn test_region_parsing() {
110 let parsed_region = region("us-east-1").unwrap();
111 assert_eq!(parsed_region.to_string(), "us-east-1");
112
113 let parsed_region = region("eu-west-1").unwrap();
114 assert_eq!(parsed_region.to_string(), "eu-west-1");
115 }
116
117 #[test]
118 fn test_args_parsing_minimal() {
119 let args = Args::try_parse_from(["aws-mfa-session", "--code", "123456"]);
120 assert!(args.is_ok());
121 let args = args.unwrap();
122 assert_eq!(args.code, Some("123456".to_string()));
123 assert_eq!(args.duration, 3600); assert!(!args.shell);
125 assert!(!args.export);
126 }
127
128 #[test]
129 fn test_args_parsing_all_options() {
130 let args = Args::try_parse_from([
131 "aws-mfa-session",
132 "--profile",
133 "test-profile",
134 "--credentials-file",
135 "/custom/path/credentials",
136 "--region",
137 "us-west-2",
138 "--code",
139 "654321",
140 "--arn",
141 "arn:aws:iam::123456789012:mfa/test-user",
142 "--duration",
143 "7200",
144 "-s",
145 "-e",
146 "--update-profile",
147 "temp-session",
148 ]);
149
150 assert!(args.is_ok());
151 let args = args.unwrap();
152 assert_eq!(args.profile, Some("test-profile".to_string()));
153 assert_eq!(
154 args.credentials_file,
155 Some("/custom/path/credentials".to_string())
156 );
157 assert_eq!(args.region.unwrap().to_string(), "us-west-2");
158 assert_eq!(args.code, Some("654321".to_string()));
159 assert_eq!(
160 args.arn,
161 Some("arn:aws:iam::123456789012:mfa/test-user".to_string())
162 );
163 assert_eq!(args.duration, 7200);
164 assert!(args.shell);
165 assert!(args.export);
166 assert_eq!(args.session_profile, Some("temp-session".to_string()));
167 }
168
169 #[test]
170 fn test_args_parsing_invalid_code() {
171 let args = Args::try_parse_from(["aws-mfa-session", "--code", "12345"]);
172 assert!(args.is_err());
173
174 let args = Args::try_parse_from(["aws-mfa-session", "--code", "abcdef"]);
175 assert!(args.is_err());
176 }
177
178 #[test]
179 fn test_args_parsing_invalid_duration() {
180 let args_low =
181 Args::try_parse_from(["aws-mfa-session", "--code", "123456", "--duration", "800"]);
182 assert!(args_low.is_err());
183
184 let args_high = Args::try_parse_from([
185 "aws-mfa-session",
186 "--code",
187 "123456",
188 "--duration",
189 "200000",
190 ]);
191 assert!(args_high.is_err());
192 }
193
194 #[test]
195 fn test_args_parsing_missing_code() {
196 let args = Args::try_parse_from(["aws-mfa-session"]);
197 assert!(args.is_ok());
199 }
200
201 #[test]
202 fn test_args_short_flags() {
203 let args = Args::try_parse_from([
204 "aws-mfa-session",
205 "-p",
206 "profile",
207 "-f",
208 "/path/to/file",
209 "-r",
210 "ap-southeast-1",
211 "-c",
212 "123456",
213 "-a",
214 "arn:aws:iam::123:mfa/user",
215 "-d",
216 "1800",
217 "-s",
218 "-e",
219 "-u",
220 "session",
221 ]);
222
223 assert!(args.is_ok());
224 let args = args.unwrap();
225 assert_eq!(args.profile, Some("profile".to_string()));
226 assert_eq!(args.credentials_file, Some("/path/to/file".to_string()));
227 assert_eq!(args.region.unwrap().to_string(), "ap-southeast-1");
228 assert_eq!(args.code, Some("123456".to_string()));
229 assert_eq!(args.arn, Some("arn:aws:iam::123:mfa/user".to_string()));
230 assert_eq!(args.duration, 1800);
231 assert!(args.shell);
232 assert!(args.export);
233 assert_eq!(args.session_profile, Some("session".to_string()));
234 }
235
236 #[test]
237 fn test_command_structure() {
238 let cmd = Args::command();
239 assert_eq!(cmd.get_name(), "aws-mfa-session");
240 assert!(cmd.get_about().is_some());
241 }
242
243 #[test]
244 fn test_args_clone() {
245 let args = Args::try_parse_from(["aws-mfa-session", "--code", "123456"]).unwrap();
246 let cloned = args.clone();
247 assert_eq!(args.code, cloned.code);
248 assert_eq!(args.duration, cloned.duration);
249 }
250
251 #[test]
252 fn test_args_debug() {
253 let args = Args::try_parse_from(["aws-mfa-session", "--code", "123456"]).unwrap();
254 let debug_str = format!("{args:?}");
255 assert!(debug_str.contains("Args"));
256 assert!(debug_str.contains("123456"));
257 }
258
259 #[test]
260 fn test_get_code_with_existing_code() {
261 let mut args = Args::try_parse_from(["aws-mfa-session", "--code", "654321"]).unwrap();
262 assert_eq!(args.code, Some("654321".to_string()));
263
264 args.get_code().unwrap();
266 assert_eq!(args.code, Some("654321".to_string()));
267 }
268
269 #[test]
270 fn test_get_code_without_code_in_test_mode() {
271 let mut args = Args::try_parse_from(["aws-mfa-session"]).unwrap();
272 assert_eq!(args.code, None);
273
274 args.get_code().unwrap();
276 assert_eq!(args.code, Some("123456".to_string()));
277 }
278
279 #[test]
280 fn test_get_code_preserves_existing_valid_code() {
281 let mut args = Args::try_parse_from(["aws-mfa-session", "--code", "999888"]).unwrap();
282 let original_code = args.code.clone();
283
284 args.get_code().unwrap();
285 assert_eq!(args.code, original_code);
286 }
287
288 #[test]
289 fn test_ask_code_interactive_validation() {
290 assert!(parse_code("123456").is_ok());
292 assert!(parse_code("abcdef").is_err());
293 assert!(parse_code("12345").is_err());
294 assert!(parse_code("1234567").is_err());
295 }
296}