aws-mfa 1.0.25

Authenticate to AWS with MFA 🔐
Documentation
use anyhow::{anyhow, Result};
use aws_config::SdkConfig;
use aws_credential_types::provider::ProvideCredentials;
use aws_sdk_sts::config::Builder;
use aws_sdk_sts::Client;

use crate::error::Error;
use crate::error::Error::{
    GetCallerIdentityError, GetSessionTokenError, InvalidIdentity, InvalidSession, Other,
};
use crate::Credentials;

pub fn get_client(config: &SdkConfig, provider: impl ProvideCredentials + 'static) -> Client {
    let builder = Builder::from(config).credentials_provider(provider);

    Client::from_conf(builder.build())
}

pub async fn get_mfa_device_arn(
    client: &Client,
    identifier: Option<String>,
) -> Result<String, Error> {
    let identity = client
        .get_caller_identity()
        .send()
        .await
        .map_err(|e| GetCallerIdentityError(Box::new(e)))?;

    let account = identity
        .account()
        .ok_or_else(|| InvalidIdentity(String::from("account")))?;

    let arn = identity
        .arn()
        .ok_or_else(|| InvalidIdentity(String::from("arn")))?;

    let user = arn
        .split('/')
        .next_back()
        .ok_or_else(|| Other(anyhow!("could not extract user in arn `{}`", arn)))?;

    let identifier = match identifier {
        Some(i) => i,
        None => String::from(user),
    };

    let arn = format!("arn:aws:iam::{account}:mfa/{identifier}");

    Ok(arn)
}

pub async fn get_auth_credentials(
    client: &Client,
    arn: &str,
    code: &str,
    duration: i32,
) -> Result<Credentials, Error> {
    let session = client
        .get_session_token()
        .serial_number(arn)
        .token_code(code)
        .duration_seconds(duration)
        .send()
        .await
        .map_err(|e| GetSessionTokenError(Box::new(e)))?;

    let credentials = session
        .credentials()
        .ok_or_else(|| InvalidSession(String::from("credentials")))?;

    Ok(Credentials::new(
        credentials.access_key_id(),
        credentials.secret_access_key(),
        credentials.session_token(),
        credentials.expiration().secs(),
    ))
}

#[cfg(test)]
mod tests {
    use anyhow::Result;
    use aws_credential_types::Credentials;
    use aws_sdk_sts::config::Region;
    use aws_sdk_sts::{Client, Config};
    use aws_smithy_http_client::test_util::{ReplayEvent, StaticReplayClient};
    use aws_smithy_types::body::SdkBody;
    use aws_smithy_types::date_time::Format;
    use aws_smithy_types::DateTime;
    use http::{Method, Request, Response};

    use crate::sts::{get_auth_credentials, get_mfa_device_arn};

    #[tokio::test]
    async fn test_get_mfa_device_arn_without_identifier() -> Result<()> {
        let credentials = Credentials::new("", "", None, None, "");
        let request = Request::builder()
            .method(Method::POST)
            .uri("https://sts.eu-west-1.amazonaws.com/")
            .body(SdkBody::from("Action=GetCallerIdentity&Version=2011-06-15"))?;
        let response = Response::builder().status(200).body(SdkBody::from(
            "
        <GetCallerIdentityResponse>
            <GetCallerIdentityResult>
                <UserId>user_id</UserId>
                <Account>account</Account>
                <Arn>arn:aws:iam::account:user/user_name</Arn>
            </GetCallerIdentityResult>
        </GetCallerIdentityResponse>",
        ))?;
        let replay_client = StaticReplayClient::new(vec![ReplayEvent::new(request, response)]);
        let conf = Config::builder()
            .behavior_version_latest()
            .region(Region::new("eu-west-1"))
            .credentials_provider(credentials)
            .http_client(replay_client.clone())
            .build();
        let client = Client::from_conf(conf);
        let arn = get_mfa_device_arn(&client, None).await?;

        replay_client.assert_requests_match(&[]);
        assert_eq!(arn, "arn:aws:iam::account:mfa/user_name");

        Ok(())
    }

    #[tokio::test]
    async fn test_get_mfa_device_arn_with_identifier() -> Result<()> {
        let credentials = Credentials::new("", "", None, None, "");
        let request = Request::builder()
            .method(Method::POST)
            .uri("https://sts.eu-west-1.amazonaws.com/")
            .body(SdkBody::from("Action=GetCallerIdentity&Version=2011-06-15"))?;
        let response = Response::builder().status(200).body(SdkBody::from(
            "
        <GetCallerIdentityResponse>
            <GetCallerIdentityResult>
                <UserId>user_id</UserId>
                <Account>account</Account>
                <Arn>arn:aws:iam::account:user/user_name</Arn>
            </GetCallerIdentityResult>
        </GetCallerIdentityResponse>",
        ))?;
        let replay_client = StaticReplayClient::new(vec![ReplayEvent::new(request, response)]);
        let conf = Config::builder()
            .behavior_version_latest()
            .region(Region::new("eu-west-1"))
            .credentials_provider(credentials)
            .http_client(replay_client.clone())
            .build();
        let client = Client::from_conf(conf);
        let arn = get_mfa_device_arn(&client, Some(String::from("device_id"))).await?;

        replay_client.assert_requests_match(&[]);
        assert_eq!(arn, "arn:aws:iam::account:mfa/device_id");

        Ok(())
    }

    #[tokio::test]
    async fn test_get_auth_credentials() -> Result<()> {
        let credentials = Credentials::new("", "", None, None, "");
        let duration = 0;
        let arn = "arn";
        let code = "code";
        let request = Request::builder()
            .method(Method::POST)
            .uri("https://sts.eu-west-1.amazonaws.com/")
            .body(SdkBody::from(format!("Action=GetSessionToken&Version=2011-06-15&DurationSeconds={}&SerialNumber={}&TokenCode={}", duration, arn, code)))?;
        let response = Response::builder().status(200).body(SdkBody::from(
            "
        <GetSessionTokenResponse>
            <GetSessionTokenResult>
                <Credentials>
                    <AccessKeyId>access_key_id</AccessKeyId>
                    <SecretAccessKey>secret_access_key</SecretAccessKey>
                    <SessionToken>session_token</SessionToken>
                    <Expiration>2022-08-31T19:55:58Z</Expiration>
                </Credentials>
            </GetSessionTokenResult>
        </GetSessionTokenResponse>",
        ))?;
        let replay_client = StaticReplayClient::new(vec![ReplayEvent::new(request, response)]);
        let conf = Config::builder()
            .behavior_version_latest()
            .region(Region::new("eu-west-1"))
            .credentials_provider(credentials)
            .http_client(replay_client.clone())
            .build();
        let client = Client::from_conf(conf);
        let credentials = get_auth_credentials(&client, arn, code, duration).await?;

        replay_client.assert_requests_match(&[]);
        assert_eq!(credentials.access_key_id(), "access_key_id");
        assert_eq!(credentials.secret_access_key(), "secret_access_key");
        assert_eq!(credentials.session_token(), "session_token");
        assert_eq!(
            credentials.session_expiration_timestamp(),
            DateTime::from_str("2022-08-31T19:55:58Z", Format::DateTime)?.secs()
        );

        Ok(())
    }
}