tsafe-cli 1.0.26

Secrets runtime for developers — inject credentials into processes via exec, never into shell history or .env files
//! AWS runtime config and credential loading.

use super::error::AwsError;

/// Runtime config for the AWS Secrets Manager client.
/// Loaded from environment variables.
#[derive(Clone)]
pub struct AwsConfig {
    pub region: String,
    /// Endpoint URL. Defaults to `https://secretsmanager.{region}.amazonaws.com`.
    /// Can be overridden for testing or custom endpoints (e.g. LocalStack).
    pub endpoint: String,
}

impl AwsConfig {
    /// Load from `AWS_DEFAULT_REGION` or `AWS_REGION`. Fails fast if absent.
    pub fn from_env() -> Result<Self, AwsError> {
        let region = std::env::var("AWS_DEFAULT_REGION")
            .or_else(|_| std::env::var("AWS_REGION"))
            .map_err(|_| AwsError::Config("AWS_DEFAULT_REGION or AWS_REGION is not set".into()))?;
        let endpoint = format!("https://secretsmanager.{region}.amazonaws.com");
        Ok(Self { region, endpoint })
    }

    /// Construct with an explicit endpoint URL (for testing / LocalStack).
    pub fn with_endpoint(region: impl Into<String>, endpoint: impl Into<String>) -> Self {
        let region = region.into();
        let endpoint = endpoint.into();
        Self { region, endpoint }
    }
}

/// AWS credentials loaded from environment variables, ECS task role, or IMDSv2.
#[derive(Clone)]
pub struct AwsCredentials {
    pub access_key_id: String,
    pub secret_access_key: String,
    pub session_token: Option<String>,
}

impl AwsCredentials {
    /// Load credentials. Strategy (in order):
    /// 1. Static env vars: `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY`
    /// 2. ECS task role: `AWS_CONTAINER_CREDENTIALS_RELATIVE_URI`
    /// 3. IMDSv2: EC2 instance role
    pub fn from_env_or_imds() -> Result<Self, AwsError> {
        // 1. Static env vars
        if let (Ok(key), Ok(secret)) = (
            std::env::var("AWS_ACCESS_KEY_ID"),
            std::env::var("AWS_SECRET_ACCESS_KEY"),
        ) {
            return Ok(Self {
                access_key_id: key,
                secret_access_key: secret,
                session_token: std::env::var("AWS_SESSION_TOKEN").ok(),
            });
        }

        // 2. ECS task role credentials
        if let Ok(relative) = std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") {
            return fetch_ecs_credentials(&relative);
        }

        // 3. IMDSv2 (EC2 instance role)
        fetch_imdsv2_credentials()
    }
}

fn http_agent() -> ureq::Agent {
    ureq::AgentBuilder::new()
        .timeout_connect(std::time::Duration::from_secs(5))
        .timeout(std::time::Duration::from_secs(10))
        .build()
}

fn fetch_ecs_credentials(relative_uri: &str) -> Result<AwsCredentials, AwsError> {
    let url = format!("http://169.254.170.2{relative_uri}");
    let agent = http_agent();
    let resp: serde_json::Value = agent
        .get(&url)
        .call()
        .map_err(|e| AwsError::Auth(format!("ECS credentials request failed: {e}")))?
        .into_json()
        .map_err(|e| AwsError::Transport(e.to_string()))?;
    parse_credentials_response(&resp)
}

fn fetch_imdsv2_credentials() -> Result<AwsCredentials, AwsError> {
    let agent = http_agent();

    // Step 1: get IMDSv2 session token
    let imds_token: String = agent
        .put("http://169.254.169.254/latest/api/token")
        .set("X-aws-ec2-metadata-token-ttl-seconds", "21600")
        .call()
        .map_err(|e| {
            AwsError::Auth(format!(
                "IMDSv2 unreachable and no AWS credentials set \
                 (AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY): {e}"
            ))
        })?
        .into_string()
        .map_err(|e| AwsError::Transport(e.to_string()))?;
    let imds_token = imds_token.trim();

    // Step 2: discover the IAM role name
    let role_name: String = agent
        .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
        .set("X-aws-ec2-metadata-token", imds_token)
        .call()
        .map_err(|e| AwsError::Auth(format!("IMDS role name request failed: {e}")))?
        .into_string()
        .map_err(|e| AwsError::Transport(e.to_string()))?;
    let role_name = role_name.trim();

    // Step 3: fetch credentials for the role
    let resp: serde_json::Value = agent
        .get(&format!(
            "http://169.254.169.254/latest/meta-data/iam/security-credentials/{role_name}"
        ))
        .set("X-aws-ec2-metadata-token", imds_token)
        .call()
        .map_err(|e| AwsError::Auth(format!("IMDS credentials request failed: {e}")))?
        .into_json()
        .map_err(|e| AwsError::Transport(e.to_string()))?;

    parse_credentials_response(&resp)
}

fn parse_credentials_response(resp: &serde_json::Value) -> Result<AwsCredentials, AwsError> {
    let access_key_id = resp["AccessKeyId"]
        .as_str()
        .ok_or_else(|| AwsError::Auth("credentials response missing 'AccessKeyId'".into()))?
        .to_string();
    let secret_access_key = resp["SecretAccessKey"]
        .as_str()
        .ok_or_else(|| AwsError::Auth("credentials response missing 'SecretAccessKey'".into()))?
        .to_string();
    let session_token = resp["Token"].as_str().map(|s| s.to_string());
    Ok(AwsCredentials {
        access_key_id,
        secret_access_key,
        session_token,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn from_env_missing_region_returns_config_error() {
        let result = temp_env::with_vars(
            [
                ("AWS_DEFAULT_REGION", None::<&str>),
                ("AWS_REGION", None::<&str>),
            ],
            AwsConfig::from_env,
        );
        assert!(matches!(result, Err(AwsError::Config(_))));
    }

    #[test]
    fn from_env_uses_aws_default_region() {
        let result =
            temp_env::with_var("AWS_DEFAULT_REGION", Some("us-east-1"), AwsConfig::from_env);
        let cfg = result.unwrap();
        assert_eq!(cfg.region, "us-east-1");
        assert_eq!(
            cfg.endpoint,
            "https://secretsmanager.us-east-1.amazonaws.com"
        );
    }

    #[test]
    fn from_env_falls_back_to_aws_region() {
        let result = temp_env::with_vars(
            [
                ("AWS_DEFAULT_REGION", None::<&str>),
                ("AWS_REGION", Some("eu-west-1")),
            ],
            AwsConfig::from_env,
        );
        let cfg = result.unwrap();
        assert_eq!(cfg.region, "eu-west-1");
    }

    #[test]
    fn static_credentials_from_env() {
        let creds = temp_env::with_vars(
            [
                ("AWS_ACCESS_KEY_ID", Some("AKIAIOSFODNN7EXAMPLE")),
                ("AWS_SECRET_ACCESS_KEY", Some("secret")),
                ("AWS_SESSION_TOKEN", Some("tok")),
            ],
            AwsCredentials::from_env_or_imds,
        )
        .unwrap();
        assert_eq!(creds.access_key_id, "AKIAIOSFODNN7EXAMPLE");
        assert_eq!(creds.secret_access_key, "secret");
        assert_eq!(creds.session_token.as_deref(), Some("tok"));
    }

    #[test]
    fn static_credentials_no_session_token() {
        let creds = temp_env::with_vars(
            [
                ("AWS_ACCESS_KEY_ID", Some("AKIAIOSFODNN7EXAMPLE")),
                ("AWS_SECRET_ACCESS_KEY", Some("secret")),
                ("AWS_SESSION_TOKEN", None::<&str>),
            ],
            AwsCredentials::from_env_or_imds,
        )
        .unwrap();
        assert!(creds.session_token.is_none());
    }

    #[test]
    fn parse_credentials_response_success() {
        let resp = serde_json::json!({
            "AccessKeyId": "ASIA...",
            "SecretAccessKey": "wJalrXUtn",
            "Token": "session-tok"
        });
        let creds = parse_credentials_response(&resp).unwrap();
        assert_eq!(creds.access_key_id, "ASIA...");
        assert_eq!(creds.session_token.as_deref(), Some("session-tok"));
    }

    #[test]
    fn parse_credentials_response_missing_key_returns_auth_error() {
        let resp = serde_json::json!({"SecretAccessKey": "secret"});
        assert!(matches!(
            parse_credentials_response(&resp),
            Err(AwsError::Auth(_))
        ));
    }
}