1use std::collections::HashMap;
4use std::fs;
5use std::fs::File;
6use std::io::{BufRead, BufReader};
7use std::path::{Path, PathBuf};
8
9use async_trait::async_trait;
10use dirs_next::home_dir;
11use serde::Deserialize;
12use tokio::process::Command;
13
14use crate::{non_empty_env_var, AwsCredentials, CredentialsError, ProvideAwsCredentials};
15
16const AWS_CONFIG_FILE: &str = "AWS_CONFIG_FILE";
17const AWS_PROFILE: &str = "AWS_PROFILE";
18const AWS_SHARED_CREDENTIALS_FILE: &str = "AWS_SHARED_CREDENTIALS_FILE";
19const DEFAULT: &str = "default";
20const REGION: &str = "region";
21
22#[derive(Clone, Debug)]
34pub struct ProfileProvider {
35    file_path: PathBuf,
37    profile: String,
39}
40
41impl ProfileProvider {
42    pub fn new() -> Result<ProfileProvider, CredentialsError> {
44        let profile_location = ProfileProvider::default_profile_location()?;
45        Ok(ProfileProvider::with_default_configuration(
46            profile_location,
47        ))
48    }
49
50    pub fn with_configuration<F, P>(file_path: F, profile: P) -> ProfileProvider
53    where
54        F: Into<PathBuf>,
55        P: Into<String>,
56    {
57        ProfileProvider {
58            file_path: file_path.into(),
59            profile: profile.into(),
60        }
61    }
62
63    pub fn with_default_configuration<F>(file_path: F) -> ProfileProvider
67    where
68        F: Into<PathBuf>,
69    {
70        ProfileProvider::with_configuration(file_path, ProfileProvider::default_profile_name())
71    }
72
73    pub fn with_default_credentials<P>(profile: P) -> Result<ProfileProvider, CredentialsError>
76        where
77            P: Into<String>,
78    {
79        let profile_location = ProfileProvider::default_profile_location()?;
80        Ok(ProfileProvider {
81            file_path: profile_location.into(),
82            profile: profile.into(),
83        })
84    }
85
86    pub fn region() -> Result<Option<String>, CredentialsError> {
92        let location = ProfileProvider::default_config_location();
93        location.map(|location| {
94            parse_config_file(&location).and_then(|config| {
95                config
96                    .get(&ProfileProvider::default_profile_name())
97                    .and_then(|props| props.get(REGION))
98                    .map(std::borrow::ToOwned::to_owned)
99            })
100        })
101    }
102
103    pub fn region_from_profile(&self) -> Result<Option<String>, CredentialsError> {
107        Ok(
108            parse_config_file(&self.file_path).and_then(|config| {
109                config
110                    .get(&self.profile)
111                    .and_then(|props| props.get(REGION))
112                    .map(std::borrow::ToOwned::to_owned)
113            }))
114    }
115
116    fn default_config_location() -> Result<PathBuf, CredentialsError> {
120        let env = non_empty_env_var(AWS_CONFIG_FILE);
121        match env {
122            Some(path) => Ok(PathBuf::from(path)),
123            None => ProfileProvider::hardcoded_config_location(),
124        }
125    }
126
127    fn hardcoded_config_location() -> Result<PathBuf, CredentialsError> {
128        match home_dir() {
129            Some(mut home_path) => {
130                home_path.push(".aws");
131                home_path.push("config");
132                Ok(home_path)
133            }
134            None => Err(CredentialsError::new("Failed to determine home directory.")),
135        }
136    }
137
138    fn default_profile_location() -> Result<PathBuf, CredentialsError> {
142        let env = non_empty_env_var(AWS_SHARED_CREDENTIALS_FILE);
143        match env {
144            Some(path) => Ok(PathBuf::from(path)),
145            None => ProfileProvider::hardcoded_profile_location(),
146        }
147    }
148
149    fn hardcoded_profile_location() -> Result<PathBuf, CredentialsError> {
150        match home_dir() {
151            Some(mut home_path) => {
152                home_path.push(".aws");
153                home_path.push("credentials");
154                Ok(home_path)
155            }
156            None => Err(CredentialsError::new("Failed to determine home directory.")),
157        }
158    }
159
160    fn default_profile_name() -> String {
165        non_empty_env_var(AWS_PROFILE).unwrap_or_else(|| DEFAULT.to_owned())
166    }
167
168    pub fn file_path(&self) -> &Path {
170        self.file_path.as_ref()
171    }
172
173    pub fn profile(&self) -> &str {
175        &self.profile
176    }
177
178    pub fn set_file_path<F>(&mut self, file_path: F)
180    where
181        F: Into<PathBuf>,
182    {
183        self.file_path = file_path.into();
184    }
185
186    pub fn set_profile<P>(&mut self, profile: P)
188    where
189        P: Into<String>,
190    {
191        self.profile = profile.into();
192    }
193}
194
195#[async_trait]
196impl ProvideAwsCredentials for ProfileProvider {
197    async fn credentials(&self) -> Result<AwsCredentials, CredentialsError> {
198        match ProfileProvider::default_config_location().map(|location| {
199            parse_config_file(&location).and_then(|config| {
200                config
201                    .get(&ProfileProvider::default_profile_name())
202                    .and_then(|props| props.get("credential_process"))
203                    .map(std::borrow::ToOwned::to_owned)
204            })
205        }) {
206            Ok(Some(command)) => {
207                let mut command = parse_command_str(&command)?;
209                let output = command.output().await.map_err(|e| {
210                    CredentialsError::new(format!("Credential process failed: {:?}", e))
211                })?;
212                if output.status.success() {
213                    parse_credential_process_output(&output.stdout)
214                } else {
215                    Err(CredentialsError::new(format!(
216                        "Credential process failed with {}: {}",
217                        output.status,
218                        String::from_utf8_lossy(&output.stderr)
219                    )))
220                }
221            }
222            Ok(None) => {
223                parse_credentials_file(self.file_path()).and_then(|mut profiles| {
225                    profiles
226                        .remove(self.profile())
227                        .ok_or_else(|| CredentialsError::new("profile not found"))
228                })
229            }
230            Err(err) => Err(err),
231        }
232    }
233}
234
235#[derive(Deserialize)]
236struct CredentialProcessOutput {
237    #[serde(flatten)]
238    creds: AwsCredentials,
239    #[serde(rename = "Version")]
240    version: u8,
241}
242
243fn parse_credential_process_output(v: &[u8]) -> Result<AwsCredentials, CredentialsError> {
244    let output: CredentialProcessOutput = serde_json::from_slice(v)?;
245    if output.version == 1 {
246        Ok(output.creds)
247    } else {
248        Err(CredentialsError::new(format!(
249            "Unsupported version '{}' for credential process provider, supported versions: 1",
250            output.version
251        )))
252    }
253}
254
255fn parse_profile_name(line: &str) -> Option<&str> {
257    if let Some(line) = line.trim().strip_suffix("]") {
261        if let Some(profile_name) = line.strip_prefix("[profile ") {
262            return Some(profile_name);
263        }
264        if let Some(profile_name) = line.strip_prefix("[") {
265            return Some(profile_name);
266        }
267    }
268    None
269}
270
271fn parse_config_file(file_path: &Path) -> Option<HashMap<String, HashMap<String, String>>> {
272    match fs::metadata(file_path) {
273        Err(_) => return None,
274        Ok(metadata) => {
275            if !metadata.is_file() {
276                return None;
277            }
278        }
279    };
280    let file = File::open(file_path).expect("expected file");
281    let file_lines = BufReader::new(&file);
282    let result: (HashMap<String, HashMap<String, String>>, Option<String>) = file_lines
283        .lines()
284        .filter_map(|line| {
285            line.ok()
286                .map(|l| l.trim_matches(' ').to_owned())
287                .into_iter()
288                .find(|l| !l.starts_with('#') && !l.is_empty())
289        })
290        .fold(Default::default(), |(mut result, profile), line| {
291            if let Some(next_profile) = parse_profile_name(&line) {
292                (result, Some(next_profile.to_owned()))
293            } else {
294                match &line
295                    .splitn(2, '=')
296                    .map(|value| value.trim_matches(' '))
297                    .collect::<Vec<&str>>()[..]
298                {
299                    [key, value] if !key.is_empty() && !value.is_empty() => {
300                        if let Some(current) = profile.clone() {
301                            let values = result.entry(current).or_insert_with(HashMap::new);
302                            (*values).insert((*key).to_string(), (*value).to_string());
303                        }
304                        (result, profile)
305                    }
306                    _ => (result, profile),
307                }
308            }
309        });
310    Some(result.0)
311}
312
313fn parse_credentials_file(
315    file_path: &Path,
316) -> Result<HashMap<String, AwsCredentials>, CredentialsError> {
317    match fs::metadata(file_path) {
318        Err(_) => {
319            return Err(CredentialsError::new(format!(
320                "Couldn't stat credentials file: [ {:?} ]. Non existant, or no permission.",
321                file_path
322            )));
323        }
324        Ok(metadata) => {
325            if !metadata.is_file() {
326                return Err(CredentialsError::new(format!(
327                    "Credentials file: [ {:?} ] is not a file.",
328                    file_path
329                )));
330            }
331        }
332    };
333
334    let file = File::open(file_path)?;
335
336    let mut profiles: HashMap<String, AwsCredentials> = HashMap::new();
337    let mut access_key: Option<String> = None;
338    let mut secret_key: Option<String> = None;
339    let mut token: Option<String> = None;
340    let mut profile_name: Option<String> = None;
341
342    let file_lines = BufReader::new(&file);
343    for (line_no, line) in file_lines.lines().enumerate() {
344        let unwrapped_line: String =
345            line.unwrap_or_else(|_| panic!("Failed to read credentials file, line: {}", line_no));
346
347        if unwrapped_line.is_empty() {
349            continue;
350        }
351
352        if unwrapped_line.starts_with('#') {
354            continue;
355        }
356
357        if let Some(new_profile_name) = parse_profile_name(&unwrapped_line) {
359            if let (Some(profile), Some(access), Some(secret)) =
360                (profile_name, access_key, secret_key)
361            {
362                let creds = AwsCredentials::new(access, secret, token, None);
363                profiles.insert(profile, creds);
364            }
365
366            access_key = None;
367            secret_key = None;
368            token = None;
369
370            profile_name = Some(new_profile_name.to_owned());
371            continue;
372        }
373
374        let lower_case_line = unwrapped_line.to_ascii_lowercase().to_string();
376
377        if lower_case_line.contains("aws_access_key_id") && access_key.is_none() {
378            let v: Vec<&str> = unwrapped_line.split('=').collect();
379            if !v.is_empty() {
380                access_key = Some(v[1].trim_matches(' ').to_string());
381            }
382        } else if lower_case_line.contains("aws_secret_access_key") && secret_key.is_none() {
383            let v: Vec<&str> = unwrapped_line.split('=').collect();
384            if !v.is_empty() {
385                secret_key = Some(v[1].trim_matches(' ').to_string());
386            }
387        } else if lower_case_line.contains("aws_session_token") && token.is_none() {
388            let v: Vec<&str> = unwrapped_line.split('=').collect();
389            if !v.is_empty() {
390                token = Some(v[1].trim_matches(' ').to_string());
391            }
392        } else if lower_case_line.contains("aws_security_token") {
393            if token.is_none() {
394                let v: Vec<&str> = unwrapped_line.split('=').collect();
395                if !v.is_empty() {
396                    token = Some(v[1].trim_matches(' ').to_string());
397                }
398            }
399        } else {
400            continue;
402        }
403    }
404
405    if let (Some(profile), Some(access), Some(secret)) = (profile_name, access_key, secret_key) {
406        let creds = AwsCredentials::new(access, secret, token, None);
407        profiles.insert(profile, creds);
408    }
409
410    if profiles.is_empty() {
411        return Err(CredentialsError::new("No credentials found."));
412    }
413
414    Ok(profiles)
415}
416
417fn parse_command_str(s: &str) -> Result<Command, CredentialsError> {
418    let args = shlex::split(s)
419        .ok_or_else(|| CredentialsError::new("Unable to parse credential_process value."))?;
420    let mut iter = args.iter();
421    let mut command = Command::new(
422        iter.next()
423            .ok_or_else(|| CredentialsError::new("credential_process value is empty."))?,
424    );
425    command.args(iter);
426    Ok(command)
427}
428
429#[cfg(test)]
430mod tests {
431    use std::env;
432    use std::path::Path;
433
434    use super::*;
435    use crate::test_utils::lock_env;
436    use crate::{CredentialsError, ProvideAwsCredentials};
437
438    #[test]
439    fn parse_config_file_default_profile() {
440        let result = super::parse_config_file(Path::new("tests/sample-data/default_config"));
441        assert!(result.is_some());
442        let profiles = result.unwrap();
443        assert_eq!(profiles.len(), 1);
444        let default_profile = profiles
445            .get(DEFAULT)
446            .expect("No Default profile in default_profile_credentials");
447        assert_eq!(default_profile.get(REGION), Some(&"us-east-2".to_string()));
448        assert_eq!(default_profile.get("output"), Some(&"json".to_string()));
449    }
450
451    #[test]
452    fn parse_config_file_multiple_profiles() {
453        let result =
454            super::parse_config_file(Path::new("tests/sample-data/multiple_profile_config"));
455        assert!(result.is_some());
456
457        let profiles = result.unwrap();
458        assert_eq!(profiles.len(), 3);
459
460        let foo_profile = profiles
461            .get("foo")
462            .expect("No foo profile in multiple_profile_credentials");
463        assert_eq!(foo_profile.get(REGION), Some(&"us-east-3".to_string()));
464        assert_eq!(foo_profile.get("output"), Some(&"json".to_string()));
465
466        let bar_profile = profiles
467            .get("bar")
468            .expect("No bar profile in multiple_profile_credentials");
469        assert_eq!(bar_profile.get(REGION), Some(&"us-east-4".to_string()));
470        assert_eq!(bar_profile.get("output"), Some(&"json".to_string()));
471        assert_eq!(bar_profile.get("# comments"), None);
472    }
473
474    #[test]
475    fn parse_config_file_credential_process() {
476        let result =
477            super::parse_config_file(Path::new("tests/sample-data/credential_process_config"));
478        assert!(result.is_some());
479        let profiles = result.unwrap();
480        assert_eq!(profiles.len(), 2);
481        let default_profile = profiles
482            .get(DEFAULT)
483            .expect("No Default profile in default_profile_credentials");
484        assert_eq!(default_profile.get(REGION), Some(&"us-east-2".to_string()));
485        assert_eq!(
486            default_profile.get("credential_process"),
487            Some(&"cat tests/sample-data/credential_process_sample_response".to_string())
488        );
489    }
490
491    #[test]
492    fn parse_credentials_file_default_profile() {
493        let result = super::parse_credentials_file(Path::new(
494            "tests/sample-data/default_profile_credentials",
495        ));
496        assert!(result.is_ok());
497
498        let profiles = result.ok().unwrap();
499        assert_eq!(profiles.len(), 1);
500
501        let default_profile = profiles
502            .get(DEFAULT)
503            .expect("No Default profile in default_profile_credentials");
504        assert_eq!(default_profile.aws_access_key_id(), "foo");
505        assert_eq!(default_profile.aws_secret_access_key(), "bar");
506    }
507
508    #[test]
509    fn parse_credentials_file_multiple_profiles() {
510        let result = super::parse_credentials_file(Path::new(
511            "tests/sample-data/multiple_profile_credentials",
512        ));
513        assert!(result.is_ok());
514
515        let profiles = result.ok().unwrap();
516        assert_eq!(profiles.len(), 2);
517
518        let foo_profile = profiles
519            .get("foo")
520            .expect("No foo profile in multiple_profile_credentials");
521        assert_eq!(foo_profile.aws_access_key_id(), "foo_access_key");
522        assert_eq!(foo_profile.aws_secret_access_key(), "foo_secret_key");
523
524        let bar_profile = profiles
525            .get("bar")
526            .expect("No bar profile in multiple_profile_credentials");
527        assert_eq!(bar_profile.aws_access_key_id(), "bar_access_key");
528        assert_eq!(bar_profile.aws_secret_access_key(), "bar_secret_key");
529    }
530
531    #[test]
532    fn parse_all_values_credentials_file() {
533        let result =
534            super::parse_credentials_file(Path::new("tests/sample-data/full_profile_credentials"));
535        assert!(result.is_ok());
536
537        let profiles = result.ok().unwrap();
538        assert_eq!(profiles.len(), 1);
539
540        let default_profile = profiles
541            .get(DEFAULT)
542            .expect("No default profile in full_profile_credentials");
543        assert_eq!(default_profile.aws_access_key_id(), "foo");
544        assert_eq!(default_profile.aws_secret_access_key(), "bar");
545    }
546
547    #[tokio::test]
548    async fn profile_provider_happy_path() {
549        let _guard = lock_env();
550        let provider = ProfileProvider::with_configuration(
551            "tests/sample-data/multiple_profile_credentials",
552            "foo",
553        );
554        let result = provider.credentials().await;
555
556        assert!(result.is_ok());
557
558        let creds = result.ok().unwrap();
559        assert_eq!(creds.aws_access_key_id(), "foo_access_key");
560        assert_eq!(creds.aws_secret_access_key(), "foo_secret_key");
561    }
562
563    #[test]
564    fn profile_provider_via_environment_variable() {
565        let _guard = lock_env();
566        let credentials_path = "tests/sample-data/default_profile_credentials";
567        env::set_var(AWS_SHARED_CREDENTIALS_FILE, credentials_path);
568        let result = ProfileProvider::new();
569        assert!(result.is_ok());
570        let provider = result.unwrap();
571        assert_eq!(provider.file_path().to_str().unwrap(), credentials_path);
572        env::remove_var(AWS_SHARED_CREDENTIALS_FILE);
573    }
574
575    #[tokio::test]
576    async fn profile_provider_profile_name_via_environment_variable() {
577        let _guard = lock_env();
578        let credentials_path = "tests/sample-data/multiple_profile_credentials";
579        env::set_var(AWS_SHARED_CREDENTIALS_FILE, credentials_path);
580        env::set_var(AWS_PROFILE, "bar");
581        let result = ProfileProvider::new();
582        assert!(result.is_ok());
583        let provider = result.unwrap();
584        assert_eq!(provider.file_path().to_str().unwrap(), credentials_path);
585        let creds = provider.credentials().await;
586        assert_eq!(creds.unwrap().aws_access_key_id(), "bar_access_key");
587        env::remove_var(AWS_SHARED_CREDENTIALS_FILE);
588        env::remove_var(AWS_PROFILE);
589    }
590
591    #[tokio::test]
592    async fn profile_provider_bad_profile() {
593        let _guard = lock_env();
594        let provider = ProfileProvider::with_configuration(
595            "tests/sample-data/multiple_profile_credentials",
596            "not_a_profile",
597        );
598        let result = provider.credentials().await;
599
600        assert!(result.is_err());
601        assert_eq!(
602            result.err(),
603            Some(CredentialsError::new("profile not found"))
604        );
605    }
606
607    #[tokio::test]
608    async fn profile_provider_credential_process() {
609        let _guard = lock_env();
610        env::set_var(
611            AWS_CONFIG_FILE,
612            "tests/sample-data/credential_process_config",
613        );
614        let provider = ProfileProvider::new().unwrap();
615        let result = provider.credentials().await;
616
617        assert!(result.is_ok());
618
619        let creds = result.ok().unwrap();
620        assert_eq!(creds.aws_access_key_id(), "baz_access_key");
621        assert_eq!(creds.aws_secret_access_key(), "baz_secret_key");
622        assert_eq!(
623            creds.token().as_ref().expect("session token not parsed"),
624            "baz_session_token"
625        );
626        assert!(creds.expires_at().is_some());
627        env::remove_var(AWS_CONFIG_FILE);
628    }
629
630    #[test]
631    fn profile_provider_profile_name() {
632        let _guard = lock_env();
633        let mut provider = ProfileProvider::new().unwrap();
634        assert_eq!(DEFAULT, provider.profile());
635        provider.set_profile("foo");
636        assert_eq!("foo", provider.profile());
637    }
638
639    #[test]
640    fn existing_file_no_credentials() {
641        let result = super::parse_credentials_file(Path::new("tests/sample-data/no_credentials"));
642        assert_eq!(
643            result.err(),
644            Some(CredentialsError::new("No credentials found."))
645        )
646    }
647
648    #[test]
649    fn parse_credentials_bad_path() {
650        let result = super::parse_credentials_file(Path::new("/bad/file/path"));
651        assert_eq!(
652            result.err(),
653            Some(CredentialsError::new(
654                "Couldn\'t stat credentials file: [ \"/bad/file/path\" ]. Non existant, or no permission.",
655            ))
656        );
657    }
658
659    #[test]
660    fn parse_credentials_directory_path() {
661        let result = super::parse_credentials_file(Path::new("tests/"));
662        assert_eq!(
663            result.err(),
664            Some(CredentialsError::new(
665                "Credentials file: [ \"tests/\" ] is not a file.",
666            ))
667        );
668    }
669
670    #[test]
671    fn parse_credentials_unrecognized_field() {
672        let result = super::parse_credentials_file(Path::new(
673            "tests/sample-data/unrecognized_field_profile_credentials",
674        ));
675        assert!(result.is_ok());
676
677        let profiles = result.ok().unwrap();
678        assert_eq!(profiles.len(), 1);
679
680        let default_profile = profiles
681            .get(DEFAULT)
682            .expect("No default profile in full_profile_credentials");
683        assert_eq!(default_profile.aws_access_key_id(), "foo");
684        assert_eq!(default_profile.aws_secret_access_key(), "bar");
685    }
686
687    #[test]
688    fn default_profile_name_from_env_var() {
689        let _guard = lock_env();
690        env::set_var(AWS_PROFILE, "bar");
691        assert_eq!("bar", ProfileProvider::default_profile_name());
692        env::remove_var(AWS_PROFILE);
693    }
694
695    #[test]
696    fn default_profile_name_from_empty_env_var() {
697        let _guard = lock_env();
698        env::set_var(AWS_PROFILE, "");
699        assert_eq!(DEFAULT, ProfileProvider::default_profile_name());
700        env::remove_var(AWS_PROFILE);
701    }
702
703    #[test]
704    fn default_profile_name() {
705        let _guard = lock_env();
706        env::remove_var(AWS_PROFILE);
707        assert_eq!(DEFAULT, ProfileProvider::default_profile_name());
708    }
709
710    #[test]
711    fn default_profile_location_from_env_var() {
712        let _guard = lock_env();
713        env::set_var(AWS_SHARED_CREDENTIALS_FILE, "bar");
714        assert_eq!(
715            Ok(PathBuf::from("bar")),
716            ProfileProvider::default_profile_location()
717        );
718        env::remove_var(AWS_SHARED_CREDENTIALS_FILE);
719    }
720
721    #[test]
722    fn default_profile_location_from_empty_env_var() {
723        let _guard = lock_env();
724        env::set_var(AWS_SHARED_CREDENTIALS_FILE, "");
725        assert_eq!(
726            ProfileProvider::hardcoded_profile_location(),
727            ProfileProvider::default_profile_location()
728        );
729        env::remove_var(AWS_SHARED_CREDENTIALS_FILE);
730    }
731
732    #[test]
733    fn default_profile_location() {
734        let _guard = lock_env();
735        env::remove_var(AWS_SHARED_CREDENTIALS_FILE);
736        assert_eq!(
737            ProfileProvider::hardcoded_profile_location(),
738            ProfileProvider::default_profile_location()
739        );
740    }
741
742    #[test]
743    fn region_from_profile() {
744        let provider = ProfileProvider::with_configuration(
745            "tests/sample-data/multiple_profile_config",
746            "foo",
747        );
748        let maybe_region = provider.region_from_profile().unwrap();
749
750        assert_eq!(
751            maybe_region,
752            Some("us-east-3".to_string())
753        );
754    }
755
756    #[test]
757    fn region_from_profile_missing_profile() {
758        let provider = ProfileProvider::with_configuration(
759            "tests/sample-data/multiple_profile_config",
760            "foobar",
761        );
762        let maybe_region = provider.region_from_profile().unwrap();
763
764        assert_eq!(
765            maybe_region,
766            None
767        );
768    }
769
770}