use std::collections::HashMap;
use std::time::Duration;
use hmac::{Hmac, Mac};
use keyhog_core::VerificationResult;
use reqwest::Client;
use sha2::{Digest, Sha256};
use crate::verify::request::execute_request;
use crate::verify::response::read_response_body;
const AWS_VALID_ACCESS_KEY_PREFIXES: &[&str] = &["AKIA", "ASIA", "AROA", "AIDA", "AGPA"];
const AWS_ACCESS_KEY_LEN: usize = 20;
const AWS_MIN_SECRET_KEY_LEN: usize = 40;
pub(crate) async fn build_aws_probe(
access_key: &str,
secret_key: &str,
session_token_template: &Option<String>,
region: &str,
credential: &str,
companions: &HashMap<String, String>,
timeout: Duration,
client: &Client,
) -> super::request::RequestBuildResult {
let access_key = crate::interpolate::resolve_field(access_key, credential, companions);
let secret_key = crate::interpolate::resolve_field(secret_key, credential, companions);
let session_token = session_token_template
.as_ref()
.map(|t| crate::interpolate::resolve_field(t, credential, companions))
.filter(|t| !t.is_empty());
if secret_key.is_empty() {
return super::request::RequestBuildResult::Final {
result: VerificationResult::Unverifiable,
metadata: HashMap::new(),
transient: false,
};
}
if !valid_aws_format(&access_key, &secret_key) {
return super::request::RequestBuildResult::Final {
result: VerificationResult::Dead,
metadata: HashMap::from([("format_valid".into(), "false".into())]),
transient: false,
};
}
if !region
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-')
|| region.is_empty()
|| region.len() > 30
{
return super::request::RequestBuildResult::Final {
result: VerificationResult::Error("invalid AWS region".into()),
metadata: HashMap::new(),
transient: false,
};
}
let host = format!("sts.{region}.amazonaws.com");
let url = format!("https://{host}/");
let body = "Action=GetCallerIdentity&Version=2011-06-15";
match build_sigv4_request(
client,
&url,
&host,
body,
&access_key,
&secret_key,
session_token.as_deref(),
region,
"sts",
timeout,
)
.await
{
Ok((result, metadata, transient)) => super::request::RequestBuildResult::Final {
result,
metadata,
transient,
},
Err(error_msg) => super::request::RequestBuildResult::Final {
result: VerificationResult::Error(error_msg),
metadata: HashMap::from([("format_valid".into(), "true".into())]),
transient: true,
},
}
}
pub(crate) fn valid_aws_format(access_key: &str, secret_key: &str) -> bool {
AWS_VALID_ACCESS_KEY_PREFIXES
.iter()
.any(|p| access_key.starts_with(p))
&& access_key.len() == AWS_ACCESS_KEY_LEN
&& access_key.chars().all(|c| c.is_ascii_alphanumeric())
&& secret_key.len() >= AWS_MIN_SECRET_KEY_LEN
&& secret_key
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '+' || c == '/' || c == '=')
}
async fn build_sigv4_request(
client: &Client,
url: &str,
host: &str,
body: &str,
access_key: &str,
secret_key: &str,
session_token: Option<&str>,
region: &str,
service: &str,
timeout: Duration,
) -> std::result::Result<(VerificationResult, HashMap<String, String>, bool), String> {
use std::time::{SystemTime, UNIX_EPOCH};
let now_secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| e.to_string())?
.as_secs();
let (date_stamp, amz_date) = format_sigv4_timestamps(now_secs);
let date_stamp = &date_stamp;
let amz_date = &amz_date;
let canonical_uri = "/";
let canonical_querystring = "";
let canonical_headers = format!("host:{host}\nx-amz-date:{amz_date}\n");
let signed_headers = "host;x-amz-date";
let payload_hash = hex::encode(Sha256::digest(body.as_bytes()));
let canonical_request = format!(
"POST\n{canonical_uri}\n{canonical_querystring}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!("{date_stamp}/{region}/{service}/aws4_request");
let string_to_sign = format!(
"{algorithm}\n{amz_date}\n{credential_scope}\n{}",
hex::encode(Sha256::digest(canonical_request.as_bytes()))
);
let signing_key = get_signature_key(secret_key, date_stamp, region, service)?;
let signature = hex::encode(hmac_sha256(&signing_key, &string_to_sign)?);
let mut auth_header = format!(
"{algorithm} Credential={access_key}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}"
);
if let Some(token) = session_token {
auth_header.push_str(&format!(", X-Amz-Security-Token={token}"));
}
let mut request = client
.post(url)
.header("Authorization", auth_header)
.header("x-amz-date", amz_date)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body.to_string())
.timeout(timeout);
if let Some(token) = session_token {
request = request.header("x-amz-security-token", token);
}
crate::rate_limit::get_rate_limiter().wait(service).await;
let response = execute_request(request)
.await
.map_err(|e| format!("{:?}", e.result))?;
let status = response.status().as_u16();
let resp_body = read_response_body(response)
.await
.map_err(|e| format!("{:?}", e.result))?;
if resp_body.contains("RequestTimeTooSkewed") || resp_body.contains("SignatureDoesNotMatch") {
tracing::warn!(
status,
"AWS verification failure indicates clock skew or invalid signature. Check system time."
);
}
if status == 200 {
let mut metadata = HashMap::new();
if let Ok(json) = serde_json::from_str::<serde_json::Value>(&resp_body) {
if let Some(arn) =
json.pointer("/GetCallerIdentityResponse/GetCallerIdentityResult/Arn")
{
metadata.insert("arn".into(), arn.as_str().unwrap_or("").into());
}
if let Some(account) =
json.pointer("/GetCallerIdentityResponse/GetCallerIdentityResult/Account")
{
metadata.insert("account_id".into(), account.as_str().unwrap_or("").into());
}
}
Ok((VerificationResult::Live, metadata, false))
} else if status == 403 {
Ok((VerificationResult::Dead, HashMap::new(), false))
} else {
Ok((VerificationResult::RateLimited, HashMap::new(), true))
}
}
fn hmac_sha256(key: &[u8], data: &str) -> std::result::Result<Vec<u8>, String> {
type HmacSha256 = Hmac<sha2::Sha256>;
let mut mac = HmacSha256::new_from_slice(key)
.map_err(|error| format!("failed to create AWS HMAC signer: {error}"))?;
mac.update(data.as_bytes());
Ok(mac.finalize().into_bytes().to_vec())
}
fn get_signature_key(
key: &str,
date_stamp: &str,
region_name: &str,
service_name: &str,
) -> std::result::Result<Vec<u8>, String> {
let k_date = hmac_sha256(format!("AWS4{key}").as_bytes(), date_stamp)?;
let k_region = hmac_sha256(&k_date, region_name)?;
let k_service = hmac_sha256(&k_region, service_name)?;
hmac_sha256(&k_service, "aws4_request")
}
fn format_sigv4_timestamps(unix_secs: u64) -> (String, String) {
let days = (unix_secs / 86_400) as i64;
let secs_of_day = (unix_secs % 86_400) as u32;
let z = days + 719_468;
let era = if z >= 0 { z } else { z - 146_096 } / 146_097;
let doe = (z - era * 146_097) as u32; let yoe = (doe - doe / 1460 + doe / 36_524 - doe / 146_096) / 365;
let y = yoe as i64 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let d = doy - (153 * mp + 2) / 5 + 1; let m = if mp < 10 { mp + 3 } else { mp - 9 }; let year = y + i64::from(m <= 2);
let hour = secs_of_day / 3600;
let minute = (secs_of_day % 3600) / 60;
let second = secs_of_day % 60;
let date_stamp = format!("{year:04}{m:02}{d:02}");
let amz_date = format!("{year:04}{m:02}{d:02}T{hour:02}{minute:02}{second:02}Z");
(date_stamp, amz_date)
}
#[cfg(test)]
mod tests {
use super::format_sigv4_timestamps;
#[test]
fn epoch_zero() {
let (d, a) = format_sigv4_timestamps(0);
assert_eq!(d, "19700101");
assert_eq!(a, "19700101T000000Z");
}
#[test]
fn known_aws_example() {
let (d, a) = format_sigv4_timestamps(1_704_067_200);
assert_eq!(d, "20240101");
assert_eq!(a, "20240101T000000Z");
}
#[test]
fn leap_year_feb_29() {
let (d, a) = format_sigv4_timestamps(1_709_210_096);
assert_eq!(d, "20240229");
assert_eq!(a, "20240229T123456Z");
}
#[test]
fn year_end_to_year_start() {
let (d, a) = format_sigv4_timestamps(1_767_225_599);
assert_eq!(d, "20251231");
assert_eq!(a, "20251231T235959Z");
let (d, a) = format_sigv4_timestamps(1_767_225_600);
assert_eq!(d, "20260101");
assert_eq!(a, "20260101T000000Z");
}
#[test]
fn non_leap_year_feb_28_to_mar_1() {
let (d, a) = format_sigv4_timestamps(1_740_787_199);
assert_eq!(d, "20250228");
assert_eq!(a, "20250228T235959Z");
let (d, a) = format_sigv4_timestamps(1_740_787_200);
assert_eq!(d, "20250301");
assert_eq!(a, "20250301T000000Z");
}
#[test]
fn century_year_2100_is_not_leap() {
let (d, _) = format_sigv4_timestamps(4_107_456_000);
assert_eq!(d, "21000228");
let (d, _) = format_sigv4_timestamps(4_107_456_000 + 86_400);
assert_eq!(d, "21000301");
}
#[test]
fn year_2000_was_leap() {
let (d, _) = format_sigv4_timestamps(951_782_400);
assert_eq!(d, "20000229");
let (d, _) = format_sigv4_timestamps(951_782_400 + 86_400);
assert_eq!(d, "20000301");
}
}