1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use super::credentials::{CredentialsError, CredentialsProvider};

use async_trait::async_trait;
use aws_sdk_sts::Client;
use aws_types::{region::Region, Credentials};
use std::time::SystemTime;

use crate::async_value::AsyncAccessor;
use crate::data_service_credentials::DataServiceCredentials;
use crate::utils::StdError;

/// A [`CredentialsProvider`] created from explicit AWS credentials
pub struct ExplicitCredentialsProvider {
    credentials: Credentials,
}

impl ExplicitCredentialsProvider {
    pub fn new(credentials: Credentials) -> Self {
        Self { credentials }
    }
}

#[async_trait(?Send)]
impl CredentialsProvider for ExplicitCredentialsProvider {
    type Credentials = Credentials;

    fn is_expired(&self, _: &Self::Credentials) -> bool {
        false
    }

    async fn acquire_credentials(&self) -> Result<Credentials, CredentialsError> {
        Ok(self.credentials.clone())
    }
}

pub struct STSCredentialsProvider {
    region: Region,
    arn: String,
    ds_credentials: Box<dyn AsyncAccessor<Value = DataServiceCredentials, Error = StdError>>,
}

impl STSCredentialsProvider {
    pub fn new(
        role_region: &str,
        role_arn: &str,
        ds_credentials: Box<dyn AsyncAccessor<Value = DataServiceCredentials, Error = StdError>>,
    ) -> Self {
        Self {
            region: Region::new(role_region.to_string()),
            arn: role_arn.to_string(),
            ds_credentials,
        }
    }

    fn get_client(&self) -> Client {
        let region = aws_sdk_sts::Region::new(self.region.to_string());
        let config = aws_sdk_sts::Config::builder().region(region).build();
        Client::from_conf(config)
    }
}

#[async_trait(?Send)]
impl CredentialsProvider for STSCredentialsProvider {
    type Credentials = Credentials;

    fn is_expired(&self, credentials: &Self::Credentials) -> bool {
        match credentials.expiry() {
            None => false,
            Some(expiration) => {
                let now = SystemTime::now();
                expiration < now
            }
        }
    }

    async fn acquire_credentials(&self) -> Result<Credentials, CredentialsError> {
        let client = self.get_client();

        let access_token = self
            .ds_credentials
            .get_value()
            .await
            .map_err(|_| CredentialsError::AcquireFailed("No access token".to_string()))?
            .access_token();

        let assume_role_output = client
            .assume_role_with_web_identity()
            .role_session_name("rust_client")
            .role_arn(&self.arn)
            .web_identity_token(access_token)
            .send()
            .await
            .map_err(|e| CredentialsError::AcquireFailed(e.to_string()))?;

        let creds = assume_role_output
            .credentials()
            .ok_or(CredentialsError::MissingCredentials)?
            .to_owned();

        Ok(Credentials::new(
            creds
                .access_key_id()
                .ok_or(CredentialsError::MissingCredential("access_key_id"))?,
            creds
                .secret_access_key()
                .ok_or(CredentialsError::MissingCredential("secret_access_key"))?,
            creds.session_token().map(|s| s.into()),
            None,
            "STS",
        ))
    }
}