google-cloud-auth 0.17.2

Google Cloud Platform server application authentication library.
Documentation
use std::env::var;
use std::path::PathBuf;

use async_trait::async_trait;
use hmac::Mac;
use path_clean::PathClean;
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use time::macros::format_description;
use time::OffsetDateTime;
use url::Url;

use crate::credentials::CredentialSource;
use crate::misc::UnwrapOrEmpty;
use crate::token_source::default_http_client;
use crate::token_source::external_account_source::error::Error;
use crate::token_source::external_account_source::subject_token_source::SubjectTokenSource;

const AWS_ALGORITHM: &str = "AWS4-HMAC-SHA256";
const AWS_REQUEST_TYPE: &str = "aws4_request";
const AWS_ACCESS_KEY_ID: &str = "AWS_ACCESS_KEY_ID";
const AWS_DEFAULT_REGION: &str = "AWS_DEFAULT_REGION";
const AWS_REGION: &str = "AWS_REGION";
const AWS_SECRET_ACCESS_KEY: &str = "AWS_SECRET_ACCESS_KEY";
const AWS_SESSION_TOKEN: &str = "AWS_SESSION_TOKEN";
const AWS_IMDS_V2_SESSION_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";

pub struct AWSSubjectTokenSource {
    subject_token_url: Url,
    target_resource: Option<String>,
    credentials: AWSSecurityCredentials,
    region: String,
}

impl AWSSubjectTokenSource {
    pub async fn new(audience: Option<String>, value: CredentialSource) -> Result<Self, Error> {
        if !validate_metadata_server(&value.region_url) {
            return Err(Error::InvalidRegionURL(value.region_url.unwrap_or_empty()));
        }
        // Not value.cred_verification_url but value.url
        if !validate_metadata_server(&value.url) {
            return Err(Error::InvalidSecurityCredentialsURL(value.url.unwrap_or_empty()));
        }
        if !validate_metadata_server(&value.imdsv2_session_token_url) {
            return Err(Error::InvalidIMDSv2SessionTokenURL(
                value.imdsv2_session_token_url.unwrap_or_empty(),
            ));
        }

        let aws_session_token = if should_use_metadata_server() {
            get_aws_session_token(&value.imdsv2_session_token_url).await?
        } else {
            None
        };

        let credentials = get_security_credentials(&aws_session_token, &value.url).await?;
        let region = get_region(&aws_session_token, &value.region_url).await?;

        let url = value
            .regional_cred_verification_url
            .as_ref()
            .ok_or(Error::MissingRegionalCredVerificationURL)?;
        let subject_token_url = Url::parse(&url.replace("{region}", &region))?;

        Ok(Self {
            subject_token_url,
            target_resource: audience,
            credentials,
            region,
        })
    }

    fn create_auth_header(
        &self,
        method: &str,
        now: &OffsetDateTime,
        headers: &[(&str, &str)],
    ) -> Result<String, Error> {
        let date_stamp_short = now.format(&format_description!("[year][month][day]"))?;
        let service_name: Vec<String> = self
            .subject_token_url
            .host_str()
            .unwrap_or_default()
            .split('.')
            .map(|v| v.to_string())
            .collect();
        let service_name = service_name[0].as_str();
        let credential_scope = format!("{}/{}/{}/{}", date_stamp_short, &self.region, service_name, AWS_REQUEST_TYPE);

        let (header_keys, header_values) = canonical_headers(headers);
        let query = self.subject_token_url.query().unwrap_or_default();
        let path = self.subject_token_url.path();
        let path = if path.is_empty() {
            "/".to_string()
        } else {
            PathBuf::from(path).clean().to_string_lossy().to_string()
        };

        // canonicalize request
        let data_hash = hex::encode(Sha256::digest(vec![])); // hash for empty body
        let request_string = format!(
            "{}\n{}\n{}\n{}\n{}\n{}",
            method, path, query, header_values, header_keys, data_hash
        );
        let request_hash = hex::encode(Sha256::digest(request_string.as_bytes()));
        let date_stamp_long = now.format(&format_description!("[year][month][day]T[hour][minute][second]Z"))?;
        let string_to_sign = format!("{}\n{}\n{}\n{}", AWS_ALGORITHM, date_stamp_long, credential_scope, request_hash);

        // sign
        let mut signing_key = format!("AWS4{}", self.credentials.secret_access_key).into_bytes();
        for input in [
            date_stamp_short.as_str(),
            self.region.as_str(),
            service_name,
            AWS_REQUEST_TYPE,
            string_to_sign.as_str(),
        ] {
            let mut mac = hmac::Hmac::<Sha256>::new_from_slice(&signing_key)?;
            mac.update(input.as_bytes());
            let result = mac.finalize();
            signing_key = result.into_bytes().to_vec();
        }

        Ok(format!(
            "{} Credential={}/{}, SignedHeaders={}, Signature={}",
            AWS_ALGORITHM,
            self.credentials.access_key_id,
            credential_scope,
            header_keys,
            hex::encode(signing_key)
        ))
    }

    fn create_subject_token(&self, now: OffsetDateTime) -> Result<String, Error> {
        let format_date = now.format(&format_description!("[year][month][day]T[hour][minute][second]Z"))?;
        let mut sorted_headers: Vec<(&str, &str)> = vec![
            ("host", self.subject_token_url.host_str().unwrap_or("")),
            ("x-amz-date", &format_date),
        ];
        if let Some(security_token) = &self.credentials.token {
            sorted_headers.push(("x-amz-security-token", security_token));
        }
        // The full, canonical resource name of the workload identity pool
        // provider, with or without the HTTPS prefix.
        // Including this header as part of the signature is recommended to
        // ensure data integrity.
        if let Some(target_resource) = &self.target_resource {
            sorted_headers.push(("x-goog-cloud-target-resource", target_resource));
        }
        let method = "POST";
        let authorization = self.create_auth_header(method, &now, &sorted_headers)?;

        let mut aws_headers = Vec::with_capacity(sorted_headers.len() + 1);
        aws_headers.push(AWSRequestHeader {
            key: "Authorization".to_string(),
            value: authorization,
        });
        for header in sorted_headers {
            aws_headers.push(AWSRequestHeader {
                key: header.0.to_string(),
                value: header.1.to_string(),
            })
        }
        let aws_request = AWSRequest {
            url: self.subject_token_url.to_string(),
            method,
            headers: aws_headers,
        };
        let result = serde_json::to_string(&aws_request)?;
        Ok(utf8_percent_encode(&result, NON_ALPHANUMERIC).to_string())
    }
}

#[async_trait]
impl SubjectTokenSource for AWSSubjectTokenSource {
    async fn subject_token(&self) -> Result<String, Error> {
        self.create_subject_token(OffsetDateTime::now_utc())
    }
}

#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct AWSSecurityCredentials {
    access_key_id: String,
    secret_access_key: String,
    token: Option<String>,
}

#[derive(Serialize)]
struct AWSRequestHeader {
    key: String,
    value: String,
}

#[derive(Serialize)]
struct AWSRequest {
    url: String,
    method: &'static str,
    headers: Vec<AWSRequestHeader>,
}

const VALID_HOST_NAMES: [&str; 2] = ["169.254.169.254", "fd00:ec2::254"];

fn validate_metadata_server(metadata_url: &Option<String>) -> bool {
    let metadata_url = metadata_url.unwrap_or_empty();
    if metadata_url.is_empty() {
        return true;
    }
    let host = match Url::parse(&metadata_url) {
        Err(_) => return false,
        Ok(v) => v,
    };

    VALID_HOST_NAMES.contains(&host.host_str().unwrap_or(""))
}

fn should_use_metadata_server() -> bool {
    !can_retrieve_region_from_environment() || !can_retrieve_security_credential_from_environment()
}

fn can_retrieve_region_from_environment() -> bool {
    var(AWS_REGION).is_ok() || var(AWS_DEFAULT_REGION).is_ok()
}

fn can_retrieve_security_credential_from_environment() -> bool {
    var(AWS_ACCESS_KEY_ID).is_ok() && var(AWS_SECRET_ACCESS_KEY).is_ok()
}

async fn get_aws_session_token(imds_v2_session_token_url: &Option<String>) -> Result<Option<String>, Error> {
    let url = match imds_v2_session_token_url {
        Some(url) => url,
        None => return Ok(None),
    };

    let client = default_http_client();
    let response = client
        .put(url)
        .header("X-aws-ec2-metadata-token-ttl-seconds", "300")
        .send()
        .await?;
    if !response.status().is_success() {
        return Err(Error::UnexpectedStatusOnGetSessionToken(response.status().as_u16()));
    }
    Ok(response.text().await.map(Some)?)
}

async fn get_security_credentials(
    temporary_session_token: &Option<String>,
    url: &Option<String>,
) -> Result<AWSSecurityCredentials, Error> {
    if can_retrieve_security_credential_from_environment() {
        return Ok(AWSSecurityCredentials {
            access_key_id: var(AWS_ACCESS_KEY_ID).unwrap(),
            secret_access_key: var(AWS_SECRET_ACCESS_KEY).unwrap(),
            token: var(AWS_SESSION_TOKEN).ok(),
        });
    }

    // get metadata role name
    let url = url.as_ref().ok_or(Error::MissingSecurityCredentialsURL)?;
    let client = default_http_client();
    let mut builder = client.get(url);
    if let Some(token) = temporary_session_token {
        builder = builder.header(AWS_IMDS_V2_SESSION_TOKEN_HEADER, token);
    }
    let response = builder.send().await?;
    if !response.status().is_success() {
        return Err(Error::UnexpectedStatusOnGetRoleName(response.status().as_u16()));
    }
    let role_name = response.text().await?;

    // get metadata security credentials
    let url = format!("{}/{}", url, role_name);
    let mut builder = client.get(url);
    if let Some(token) = temporary_session_token {
        builder = builder.header(AWS_IMDS_V2_SESSION_TOKEN_HEADER, token);
    }
    let response = builder.send().await?;
    if !response.status().is_success() {
        return Err(Error::UnexpectedStatusOnGetCredentials(response.status().as_u16()));
    }
    let cred: AWSSecurityCredentials = response.json().await?;
    Ok(cred)
}

async fn get_region(temporary_session_token: &Option<String>, url: &Option<String>) -> Result<String, Error> {
    if can_retrieve_region_from_environment() {
        if let Ok(region) = var(AWS_REGION) {
            return Ok(region);
        }
        return Ok(var(AWS_DEFAULT_REGION).unwrap());
    }
    let url = url.as_ref().ok_or(Error::MissingRegionURL)?;
    let client = default_http_client();
    let mut builder = client.get(url);
    if let Some(token) = temporary_session_token {
        builder = builder.header(AWS_IMDS_V2_SESSION_TOKEN_HEADER, token);
    }
    let response = builder.send().await?;
    if !response.status().is_success() {
        return Err(Error::UnexpectedStatusOnGetRegion(response.status().as_u16()));
    }
    let body = response.bytes().await?;

    // This endpoint will return the region in format: us-east-2b.
    // Only the us-east-2 part should be used.
    let resp_body_end = if !body.is_empty() { body.len() - 1 } else { 0 };
    Ok(String::from_utf8_lossy(&body[0..resp_body_end]).to_string())
}

fn canonical_headers<'a>(sorted_headers: &[(&'a str, &'a str)]) -> (String, String) {
    let mut full_headers: Vec<String> = Vec::with_capacity(sorted_headers.len());
    let mut keys = Vec::with_capacity(sorted_headers.len());
    for header in sorted_headers {
        keys.push(header.0);
        full_headers.push(format!("{}:{}\n", header.0, header.1));
    }
    (keys.join(";"), full_headers.join(""))
}

#[cfg(test)]
mod tests {
    use time::macros::{datetime, format_description};
    use url::Url;

    use crate::credentials::CredentialsFile;
    use crate::token_source::external_account_source::aws_subject_token_source::{
        AWSSecurityCredentials, AWSSubjectTokenSource,
    };

    fn create_token_source() -> AWSSubjectTokenSource {
        let cred = r#"{
            "type": "external_account",
            "audience": "//iam.googleapis.com/projects/myprojectnumber/locations/global/workloadIdentityPools/aws-test/providers/aws-test",
            "subject_token_type": "urn:ietf:params:aws:token-type:aws4_request",
            "service_account_impersonation_url": "https://iamcredentials.googleapis.com/test",
            "token_url": "https://sts.googleapis.com/v1/token",
            "credential_source": {
                "environment_id": "aws1",
                "region_url": "http://169.254.169.254/latest/meta-data/placement/availability-zone",
                "url": "http://169.254.169.254/latest/meta-data/iam/security-credentials",
                "regional_cred_verification_url": "https://sts.{region}.amazonaws.com?Action=GetCallerIdentity&Version=2011-06-15"
            }
        }"#;
        let region = "ap-northeast-1b".to_string();
        let cred: CredentialsFile = serde_json::from_str(cred).unwrap();
        let url = cred.credential_source.unwrap().regional_cred_verification_url.unwrap();
        let subject_token_url = Url::parse(&url.replace("{region}", &region)).unwrap();

        AWSSubjectTokenSource {
            subject_token_url,
            target_resource: cred.audience,
            credentials: AWSSecurityCredentials {
                access_key_id: "AccessKeyId".to_string(),
                secret_access_key: "SecretAccessKey".to_string(),
                token: Some("SecurityToken".to_string()),
            },
            region,
        }
    }
    #[test]
    fn test_create_auth_header() {
        let source = create_token_source();
        let now = datetime!(2022-12-31 00:00:00).assume_utc();
        let format_date = now
            .format(&format_description!("[year][month][day]T[hour][minute][second]Z"))
            .unwrap();
        let sorted_headers: Vec<(&str, &str)> = vec![
            ("host", source.subject_token_url.host_str().unwrap_or("")),
            ("x-amz-date", &format_date),
            ("x-amz-security-token", source.credentials.token.as_ref().unwrap()),
            ("x-goog-cloud-target-resource", source.target_resource.as_ref().unwrap()),
        ];
        let actual = source.create_auth_header("POST", &now, &sorted_headers).unwrap();
        let expected = "AWS4-HMAC-SHA256 Credential=AccessKeyId/20221231/ap-northeast-1b/sts/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token;x-goog-cloud-target-resource, Signature=168a40df8b7c11fb0588a13cada1443e31e4736de702232f9a2177b26edda21c";
        assert_eq!(actual, expected);
    }

    #[test]
    fn test_create_subject_token() {
        let source = create_token_source();
        let now = datetime!(2022-12-31 00:00:00).assume_utc();
        match source.create_subject_token(now) {
            Ok(token) => {
                let expected = "%7B%22url%22%3A%22https%3A%2F%2Fsts%2Eap%2Dnortheast%2D1b%2Eamazonaws%2Ecom%2F%3FAction%3DGetCallerIdentity%26Version%3D2011%2D06%2D15%22%2C%22method%22%3A%22POST%22%2C%22headers%22%3A%5B%7B%22key%22%3A%22Authorization%22%2C%22value%22%3A%22AWS4%2DHMAC%2DSHA256%20Credential%3DAccessKeyId%2F20221231%2Fap%2Dnortheast%2D1b%2Fsts%2Faws4%5Frequest%2C%20SignedHeaders%3Dhost%3Bx%2Damz%2Ddate%3Bx%2Damz%2Dsecurity%2Dtoken%3Bx%2Dgoog%2Dcloud%2Dtarget%2Dresource%2C%20Signature%3D168a40df8b7c11fb0588a13cada1443e31e4736de702232f9a2177b26edda21c%22%7D%2C%7B%22key%22%3A%22host%22%2C%22value%22%3A%22sts%2Eap%2Dnortheast%2D1b%2Eamazonaws%2Ecom%22%7D%2C%7B%22key%22%3A%22x%2Damz%2Ddate%22%2C%22value%22%3A%2220221231T000000Z%22%7D%2C%7B%22key%22%3A%22x%2Damz%2Dsecurity%2Dtoken%22%2C%22value%22%3A%22SecurityToken%22%7D%2C%7B%22key%22%3A%22x%2Dgoog%2Dcloud%2Dtarget%2Dresource%22%2C%22value%22%3A%22%2F%2Fiam%2Egoogleapis%2Ecom%2Fprojects%2Fmyprojectnumber%2Flocations%2Fglobal%2FworkloadIdentityPools%2Faws%2Dtest%2Fproviders%2Faws%2Dtest%22%7D%5D%7D";
                assert_eq!(token, expected);
            }
            Err(err) => {
                tracing::error!("error={},{:?}", err, err);
                unreachable!();
            }
        }
    }
}