use std::borrow::{Borrow, Cow};
use std::collections::HashMap;
use std::time::Duration;
use lazy_static::lazy_static;
use log::{debug, info};
use rusoto_core::credential::AwsCredentials;
use rusoto_core::param::{Params, ServiceParams};
use rusoto_core::signature::{SignedRequest, SignedRequestPayload};
use rusoto_core::Region;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
pub struct AwsAuthIamPayload {
pub iam_http_request_method: String,
pub iam_request_url: String,
pub iam_request_body: String,
pub iam_request_headers: HashMap<String, Vec<String>>,
}
impl AwsAuthIamPayload {
pub fn new<R>(
credentials: &AwsCredentials,
region: Option<R>,
additional_headers: HashMap<&str, &str>,
) -> Self
where
R: Borrow<Region>,
{
info!("Building Login Payload for AWS authentication");
let region = region
.as_ref()
.map(|r| Cow::Borrowed(r.borrow()))
.unwrap_or_default();
let mut request = SignedRequest::new("POST", "sts", ®ion, "/");
let mut params = Params::new();
params.put("Action", "GetCallerIdentity");
params.put("Version", "2011-06-15");
request.set_payload(Some(
serde_urlencoded::to_string(¶ms).unwrap().into_bytes(),
));
request.set_content_type("application/x-www-form-urlencoded".to_owned());
for (header, value) in additional_headers.into_iter() {
request.add_header(header, value)
}
request.sign(credentials);
let uri = format!(
"{}://{}{}",
request.scheme(),
request.hostname(),
request.canonical_path()
);
let payload = match request.payload {
Some(SignedRequestPayload::Buffer(ref buffer)) => base64::encode(buffer),
_ => unreachable!("Payload was set above"),
};
let headers = request
.headers
.iter()
.map(|(k, v)| {
let values = v
.iter()
.map(|v| unsafe { String::from_utf8_unchecked(v.to_vec()) })
.collect();
(k.to_string(), values)
})
.collect();
let result = Self {
iam_http_request_method: "POST".to_string(),
iam_request_url: base64::encode(&uri),
iam_request_body: payload,
iam_request_headers: headers,
};
debug!("AWS Payload: {:#?}", result);
result
}
}
#[allow(clippy::implicit_hasher)]
pub fn presigned_url<R>(
credentials: &AwsCredentials,
region: Option<R>,
additional_headers: HashMap<&str, &str>,
expires_in: Option<&Duration>,
) -> String
where
R: Borrow<Region>,
{
lazy_static! {
static ref DEFAULT_EXPIRES: Duration = Duration::from_secs(60);
}
info!("Building pre-signed URL for AWS authentication");
let region = region
.as_ref()
.map(|r| Cow::Borrowed(r.borrow()))
.unwrap_or_default();
let mut request = SignedRequest::new("GET", "sts", ®ion, "/");
let mut params = Params::new();
params.put("Action", "GetCallerIdentity");
params.put("Version", "2011-06-15");
request.set_params(params);
for (header, value) in additional_headers.into_iter() {
request.add_header(header, value)
}
request.generate_presigned_url(credentials, expires_in.unwrap_or(&DEFAULT_EXPIRES), true)
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use rusoto_core::credential::ProvideAwsCredentials;
pub(crate) async fn credentials() -> Result<AwsCredentials, crate::Error> {
let provider = rusoto_mock::MockCredentialsProvider;
Ok(provider.credentials().await?)
}
pub(crate) async fn post_aws_iam_payload(
region: Option<Region>,
header: HashMap<&str, &str>,
) -> Result<AwsAuthIamPayload, crate::Error> {
let cred = credentials().await?;
Ok(AwsAuthIamPayload::new(&cred, region, header))
}
pub(crate) async fn get_presigned_url(
region: Option<Region>,
header: HashMap<&str, &str>,
) -> Result<String, crate::Error> {
let cred = credentials().await?;
Ok(presigned_url(&cred, region, header, None))
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn post_aws_iam_payload_has_expected_values() -> Result<(), crate::Error> {
let region = Region::UsEast1;
let headers = [("X-Vault-AWS-IAM-Server-ID", "vault.example.com")]
.iter()
.cloned()
.collect();
let payload = post_aws_iam_payload(Some(region.clone()), headers).await?;
assert_eq!(payload.iam_http_request_method, "POST");
assert_eq!(
payload.iam_request_url,
base64::encode(&format!("https://sts.{}.amazonaws.com/", region.name()))
);
assert_eq!(
payload.iam_request_body,
base64::encode("Action=GetCallerIdentity&Version=2011-06-15")
);
assert!(payload.iam_request_headers.contains_key("authorization"));
assert_eq!(
payload
.iam_request_headers
.get(&"X-Vault-AWS-IAM-Server-ID".to_lowercase()),
Some(&vec!["vault.example.com".to_string()])
);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn presigned_url_has_expected_values() -> Result<(), crate::Error> {
let region = Region::UsEast1;
let headers = [("X-K8S-AWS-ID", "example")].iter().cloned().collect();
let url = get_presigned_url(Some(region), headers).await?;
let url = url::Url::parse(&url).unwrap();
assert_eq!(
url.host().unwrap().to_string(),
"sts.us-east-1.amazonaws.com"
);
let params: HashMap<_, _> = url.query_pairs().collect();
assert_eq!(params["Action"], "GetCallerIdentity");
assert_eq!(params["Version"], "2011-06-15");
assert_eq!(params["X-Amz-SignedHeaders"], "host;x-k8s-aws-id");
assert_eq!(params["X-Amz-Expires"], "60");
assert_eq!(params["X-Amz-Algorithm"], "AWS4-HMAC-SHA256");
assert!(params.contains_key("X-Amz-Signature"));
assert!(params.contains_key("X-Amz-Credential"));
assert!(params.contains_key("X-Amz-Date"));
Ok(())
}
}