1use std::borrow::{Borrow, Cow};
2use std::collections::HashMap;
3use std::time::Duration;
4
5use lazy_static::lazy_static;
6use log::{debug, info};
7use rusoto_core::credential::AwsCredentials;
8use rusoto_core::param::{Params, ServiceParams};
9use rusoto_core::signature::{SignedRequest, SignedRequestPayload};
10use rusoto_core::Region;
11use serde::{Deserialize, Serialize};
12
13#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
21pub struct AwsAuthIamPayload {
22 pub iam_http_request_method: String,
24 pub iam_request_url: String,
26 pub iam_request_body: String,
28 pub iam_request_headers: HashMap<String, Vec<String>>,
30}
31
32impl AwsAuthIamPayload {
33 pub fn new<R>(
42 credentials: &AwsCredentials,
43 region: Option<R>,
44 additional_headers: HashMap<&str, &str>,
45 ) -> Self
46 where
47 R: Borrow<Region>,
48 {
49 info!("Building Login Payload for AWS authentication");
50 let region = region
51 .as_ref()
52 .map(|r| Cow::Borrowed(r.borrow()))
53 .unwrap_or_default();
54 let mut request = SignedRequest::new("POST", "sts", ®ion, "/");
61 let mut params = Params::new();
62
63 params.put("Action", "GetCallerIdentity");
64 params.put("Version", "2011-06-15");
65 request.set_payload(Some(
66 serde_urlencoded::to_string(¶ms).unwrap().into_bytes(),
67 ));
68 request.set_content_type("application/x-www-form-urlencoded".to_owned());
69
70 for (header, value) in additional_headers.into_iter() {
71 request.add_header(header, value)
72 }
73
74 request.sign(credentials);
75
76 let uri = format!(
77 "{}://{}{}",
78 request.scheme(),
79 request.hostname(),
80 request.canonical_path()
81 );
82
83 let payload = match request.payload {
84 Some(SignedRequestPayload::Buffer(ref buffer)) => base64::encode(buffer),
85 _ => unreachable!("Payload was set above"),
86 };
87
88 let headers = request
90 .headers
91 .iter()
92 .map(|(k, v)| {
93 let values = v
94 .iter()
95 .map(|v| unsafe { String::from_utf8_unchecked(v.to_vec()) })
96 .collect();
97
98 (k.to_string(), values)
99 })
100 .collect();
101
102 let result = Self {
103 iam_http_request_method: "POST".to_string(),
104 iam_request_url: base64::encode(&uri),
105 iam_request_body: payload,
106 iam_request_headers: headers,
107 };
108
109 debug!("AWS Payload: {:#?}", result);
110
111 result
112 }
113}
114
115#[allow(clippy::implicit_hasher)]
124pub fn presigned_url<R>(
125 credentials: &AwsCredentials,
126 region: Option<R>,
127 additional_headers: HashMap<&str, &str>,
128 expires_in: Option<&Duration>,
129) -> String
130where
131 R: Borrow<Region>,
132{
133 lazy_static! {
134 static ref DEFAULT_EXPIRES: Duration = Duration::from_secs(60);
135 }
136
137 info!("Building pre-signed URL for AWS authentication");
138 let region = region
139 .as_ref()
140 .map(|r| Cow::Borrowed(r.borrow()))
141 .unwrap_or_default();
142
143 let mut request = SignedRequest::new("GET", "sts", ®ion, "/");
144
145 let mut params = Params::new();
146 params.put("Action", "GetCallerIdentity");
147 params.put("Version", "2011-06-15");
148 request.set_params(params);
149
150 for (header, value) in additional_headers.into_iter() {
151 request.add_header(header, value)
152 }
153
154 request.generate_presigned_url(credentials, expires_in.unwrap_or(&DEFAULT_EXPIRES), true)
155}
156
157#[cfg(test)]
158pub(crate) mod tests {
159 use super::*;
160
161 use rusoto_core::credential::ProvideAwsCredentials;
162
163 pub(crate) async fn credentials() -> Result<AwsCredentials, crate::Error> {
165 let provider = rusoto_mock::MockCredentialsProvider;
166 Ok(provider.credentials().await?)
167 }
168
169 pub(crate) async fn post_aws_iam_payload(
170 region: Option<Region>,
171 header: HashMap<&str, &str>,
172 ) -> Result<AwsAuthIamPayload, crate::Error> {
173 let cred = credentials().await?;
174 Ok(AwsAuthIamPayload::new(&cred, region, header))
175 }
176
177 pub(crate) async fn get_presigned_url(
178 region: Option<Region>,
179 header: HashMap<&str, &str>,
180 ) -> Result<String, crate::Error> {
181 let cred = credentials().await?;
182 Ok(presigned_url(&cred, region, header, None))
183 }
184
185 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
186 async fn post_aws_iam_payload_has_expected_values() -> Result<(), crate::Error> {
187 let region = Region::UsEast1;
188 let headers = [("X-Vault-AWS-IAM-Server-ID", "vault.example.com")]
189 .iter()
190 .cloned()
191 .collect();
192 let payload = post_aws_iam_payload(Some(region.clone()), headers).await?;
193
194 assert_eq!(payload.iam_http_request_method, "POST");
195 assert_eq!(
196 payload.iam_request_url,
197 base64::encode(&format!("https://sts.{}.amazonaws.com/", region.name()))
198 );
199 assert_eq!(
200 payload.iam_request_body,
201 base64::encode("Action=GetCallerIdentity&Version=2011-06-15")
202 );
203 assert!(payload.iam_request_headers.contains_key("authorization"));
204 assert_eq!(
205 payload
206 .iam_request_headers
207 .get(&"X-Vault-AWS-IAM-Server-ID".to_lowercase()),
208 Some(&vec!["vault.example.com".to_string()])
209 );
210 Ok(())
211 }
212
213 #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
214 async fn presigned_url_has_expected_values() -> Result<(), crate::Error> {
215 let region = Region::UsEast1;
216 let headers = [("X-K8S-AWS-ID", "example")].iter().cloned().collect();
217 let url = get_presigned_url(Some(region), headers).await?;
218 let url = url::Url::parse(&url).unwrap();
219
220 assert_eq!(
221 url.host().unwrap().to_string(),
222 "sts.us-east-1.amazonaws.com"
223 );
224
225 let params: HashMap<_, _> = url.query_pairs().collect();
226
227 assert_eq!(params["Action"], "GetCallerIdentity");
228 assert_eq!(params["Version"], "2011-06-15");
229 assert_eq!(params["X-Amz-SignedHeaders"], "host;x-k8s-aws-id");
230 assert_eq!(params["X-Amz-Expires"], "60");
231 assert_eq!(params["X-Amz-Algorithm"], "AWS4-HMAC-SHA256");
232
233 assert!(params.contains_key("X-Amz-Signature"));
234 assert!(params.contains_key("X-Amz-Credential"));
235 assert!(params.contains_key("X-Amz-Date"));
236 Ok(())
237 }
238}