aws_auth_payload/
client.rs

1use std::borrow::{Borrow, Cow};
2use std::collections::HashMap;
3use std::time::Duration;
4
5use lazy_static::lazy_static;
6use log::{debug, info};
7use rusoto_core::credential::AwsCredentials;
8use rusoto_core::param::{Params, ServiceParams};
9use rusoto_core::signature::{SignedRequest, SignedRequestPayload};
10use rusoto_core::Region;
11use serde::{Deserialize, Serialize};
12
13/// Payload for use to generate a payload for AWS IAM authentication
14///
15/// This payload is used by HashiCorp's Vault and is generated by making a POST request
16/// to AWS STS `GetCallerIdentity`
17///
18/// See [Vault's Documentation](https://www.vaultproject.io/docs/auth/aws.html#iam-auth-method)
19/// for more information.
20#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
21pub struct AwsAuthIamPayload {
22    /// HTTP method used in the signed request. Currently only `POST` is supported
23    pub iam_http_request_method: String,
24    /// Base64-encoded HTTP URL used in the signed request
25    pub iam_request_url: String,
26    /// Base64-encoded body of the signed request
27    pub iam_request_body: String,
28    /// Headers of the signed request
29    pub iam_request_headers: HashMap<String, Vec<String>>,
30}
31
32impl AwsAuthIamPayload {
33    /// Creates a payload for use to generate a payload for AWS IAM authentication
34    ///
35    /// This payload is used by HashiCorp's Vault and is generated by making a POST request
36    /// to AWS STS `GetCallerIdentity`
37    ///
38    /// See [Vault's Documentation](https://www.vaultproject.io/docs/auth/aws.html#iam-auth-method)
39    /// for more information.
40    /// If you do not provide a `region`, we will use a the "global" AWS STS endpoint.
41    pub fn new<R>(
42        credentials: &AwsCredentials,
43        region: Option<R>,
44        additional_headers: HashMap<&str, &str>,
45    ) -> Self
46    where
47        R: Borrow<Region>,
48    {
49        info!("Building Login Payload for AWS authentication");
50        let region = region
51            .as_ref()
52            .map(|r| Cow::Borrowed(r.borrow()))
53            .unwrap_or_default();
54        // Code below is referenced from the code for
55        // https://rusoto.github.io/rusoto/rusoto_sts/trait.Sts.html#tymethod.get_caller_identity
56
57        // Additional processing for Vault is referenced from Vault CLI's source code:
58        // https://github.com/hashicorp/vault/blob/master/builtin/credential/aws/cli.go
59
60        let mut request = SignedRequest::new("POST", "sts", &region, "/");
61        let mut params = Params::new();
62
63        params.put("Action", "GetCallerIdentity");
64        params.put("Version", "2011-06-15");
65        request.set_payload(Some(
66            serde_urlencoded::to_string(&params).unwrap().into_bytes(),
67        ));
68        request.set_content_type("application/x-www-form-urlencoded".to_owned());
69
70        for (header, value) in additional_headers.into_iter() {
71            request.add_header(header, value)
72        }
73
74        request.sign(credentials);
75
76        let uri = format!(
77            "{}://{}{}",
78            request.scheme(),
79            request.hostname(),
80            request.canonical_path()
81        );
82
83        let payload = match request.payload {
84            Some(SignedRequestPayload::Buffer(ref buffer)) => base64::encode(buffer),
85            _ => unreachable!("Payload was set above"),
86        };
87
88        // We need to convert the headers from bytes back into Strings...
89        let headers = request
90            .headers
91            .iter()
92            .map(|(k, v)| {
93                let values = v
94                    .iter()
95                    .map(|v| unsafe { String::from_utf8_unchecked(v.to_vec()) })
96                    .collect();
97
98                (k.to_string(), values)
99            })
100            .collect();
101
102        let result = Self {
103            iam_http_request_method: "POST".to_string(),
104            iam_request_url: base64::encode(&uri),
105            iam_request_body: payload,
106            iam_request_headers: headers,
107        };
108
109        debug!("AWS Payload: {:#?}", result);
110
111        result
112    }
113}
114
115/// Generates a pre-signed URL using the provided AWS Credentials to
116/// AWS STS `GetCallerIdentity`
117///
118/// This is used by
119/// [Kubernetes AWS IAM Authenticator](https://github.com/kubernetes-sigs/aws-iam-authenticator)
120///
121/// See [Vault's Documentation](https://www.vaultproject.io/docs/auth/aws.html#iam-auth-method)
122/// for more information.
123#[allow(clippy::implicit_hasher)]
124pub fn presigned_url<R>(
125    credentials: &AwsCredentials,
126    region: Option<R>,
127    additional_headers: HashMap<&str, &str>,
128    expires_in: Option<&Duration>,
129) -> String
130where
131    R: Borrow<Region>,
132{
133    lazy_static! {
134        static ref DEFAULT_EXPIRES: Duration = Duration::from_secs(60);
135    }
136
137    info!("Building pre-signed URL for AWS authentication");
138    let region = region
139        .as_ref()
140        .map(|r| Cow::Borrowed(r.borrow()))
141        .unwrap_or_default();
142
143    let mut request = SignedRequest::new("GET", "sts", &region, "/");
144
145    let mut params = Params::new();
146    params.put("Action", "GetCallerIdentity");
147    params.put("Version", "2011-06-15");
148    request.set_params(params);
149
150    for (header, value) in additional_headers.into_iter() {
151        request.add_header(header, value)
152    }
153
154    request.generate_presigned_url(credentials, expires_in.unwrap_or(&DEFAULT_EXPIRES), true)
155}
156
157#[cfg(test)]
158pub(crate) mod tests {
159    use super::*;
160
161    use rusoto_core::credential::ProvideAwsCredentials;
162
163    // mock_key, mock_secret
164    pub(crate) async fn credentials() -> Result<AwsCredentials, crate::Error> {
165        let provider = rusoto_mock::MockCredentialsProvider;
166        Ok(provider.credentials().await?)
167    }
168
169    pub(crate) async fn post_aws_iam_payload(
170        region: Option<Region>,
171        header: HashMap<&str, &str>,
172    ) -> Result<AwsAuthIamPayload, crate::Error> {
173        let cred = credentials().await?;
174        Ok(AwsAuthIamPayload::new(&cred, region, header))
175    }
176
177    pub(crate) async fn get_presigned_url(
178        region: Option<Region>,
179        header: HashMap<&str, &str>,
180    ) -> Result<String, crate::Error> {
181        let cred = credentials().await?;
182        Ok(presigned_url(&cred, region, header, None))
183    }
184
185    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
186    async fn post_aws_iam_payload_has_expected_values() -> Result<(), crate::Error> {
187        let region = Region::UsEast1;
188        let headers = [("X-Vault-AWS-IAM-Server-ID", "vault.example.com")]
189            .iter()
190            .cloned()
191            .collect();
192        let payload = post_aws_iam_payload(Some(region.clone()), headers).await?;
193
194        assert_eq!(payload.iam_http_request_method, "POST");
195        assert_eq!(
196            payload.iam_request_url,
197            base64::encode(&format!("https://sts.{}.amazonaws.com/", region.name()))
198        );
199        assert_eq!(
200            payload.iam_request_body,
201            base64::encode("Action=GetCallerIdentity&Version=2011-06-15")
202        );
203        assert!(payload.iam_request_headers.contains_key("authorization"));
204        assert_eq!(
205            payload
206                .iam_request_headers
207                .get(&"X-Vault-AWS-IAM-Server-ID".to_lowercase()),
208            Some(&vec!["vault.example.com".to_string()])
209        );
210        Ok(())
211    }
212
213    #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
214    async fn presigned_url_has_expected_values() -> Result<(), crate::Error> {
215        let region = Region::UsEast1;
216        let headers = [("X-K8S-AWS-ID", "example")].iter().cloned().collect();
217        let url = get_presigned_url(Some(region), headers).await?;
218        let url = url::Url::parse(&url).unwrap();
219
220        assert_eq!(
221            url.host().unwrap().to_string(),
222            "sts.us-east-1.amazonaws.com"
223        );
224
225        let params: HashMap<_, _> = url.query_pairs().collect();
226
227        assert_eq!(params["Action"], "GetCallerIdentity");
228        assert_eq!(params["Version"], "2011-06-15");
229        assert_eq!(params["X-Amz-SignedHeaders"], "host;x-k8s-aws-id");
230        assert_eq!(params["X-Amz-Expires"], "60");
231        assert_eq!(params["X-Amz-Algorithm"], "AWS4-HMAC-SHA256");
232
233        assert!(params.contains_key("X-Amz-Signature"));
234        assert!(params.contains_key("X-Amz-Credential"));
235        assert!(params.contains_key("X-Amz-Date"));
236        Ok(())
237    }
238}