use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
type HmacSha256 = Hmac<Sha256>;
pub struct AwsCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
}
impl AwsCredentials {
pub fn from_env() -> Option<Self> {
let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
.ok()
.filter(|s| !s.is_empty())?;
let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
.ok()
.filter(|s| !s.is_empty())?;
let session_token = std::env::var("AWS_SESSION_TOKEN")
.ok()
.filter(|s| !s.is_empty());
Some(Self {
access_key_id,
secret_access_key,
session_token,
})
}
}
fn hmac(key: &[u8], data: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
fn sha256_hex(data: &[u8]) -> String {
let mut h = Sha256::new();
h.update(data);
hex::encode(h.finalize())
}
pub fn uri_encode_segment(s: &str) -> String {
let mut out = String::with_capacity(s.len());
for b in s.bytes() {
match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(b as char)
}
_ => out.push_str(&format!("%{:02X}", b)),
}
}
out
}
#[allow(clippy::too_many_arguments)]
pub fn signed_headers(
creds: &AwsCredentials,
region: &str,
service: &str,
method: &str,
canonical_path: &str,
canonical_query: &str,
base_headers: &[(String, String)],
body: &[u8],
amz_date: &str,
date_stamp: &str,
) -> Vec<(String, String)> {
let mut headers: Vec<(String, String)> = base_headers.to_vec();
headers.push(("x-amz-date".into(), amz_date.to_string()));
if let Some(tok) = &creds.session_token {
headers.push(("x-amz-security-token".into(), tok.clone()));
}
let mut canon: Vec<(String, String)> = headers
.iter()
.map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
.collect();
canon.sort_by(|a, b| a.0.cmp(&b.0));
let canonical_headers: String = canon.iter().map(|(k, v)| format!("{k}:{v}\n")).collect();
let signed_header_names = canon
.iter()
.map(|(k, _)| k.clone())
.collect::<Vec<_>>()
.join(";");
let payload_hash = sha256_hex(body);
let canonical_request = format!(
"{method}\n{canonical_path}\n{canonical_query}\n{canonical_headers}\n{signed_header_names}\n{payload_hash}"
);
let scope = format!("{date_stamp}/{region}/{service}/aws4_request");
let string_to_sign = format!(
"AWS4-HMAC-SHA256\n{amz_date}\n{scope}\n{}",
sha256_hex(canonical_request.as_bytes())
);
let k_date = hmac(
format!("AWS4{}", creds.secret_access_key).as_bytes(),
date_stamp.as_bytes(),
);
let k_region = hmac(&k_date, region.as_bytes());
let k_service = hmac(&k_region, service.as_bytes());
let k_signing = hmac(&k_service, b"aws4_request");
let signature = hex::encode(hmac(&k_signing, string_to_sign.as_bytes()));
let authorization = format!(
"AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
creds.access_key_id, scope, signed_header_names, signature
);
headers.push(("authorization".into(), authorization));
headers
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn matches_aws_get_vanilla_vector() {
let creds = AwsCredentials {
access_key_id: "AKIDEXAMPLE".into(),
secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".into(),
session_token: None,
};
let base = vec![("host".to_string(), "example.amazonaws.com".to_string())];
let headers = signed_headers(
&creds,
"us-east-1",
"service",
"GET",
"/",
"",
&base,
b"",
"20150830T123600Z",
"20150830",
);
let auth = headers
.iter()
.find(|(k, _)| k == "authorization")
.map(|(_, v)| v.as_str())
.unwrap();
assert_eq!(
auth,
"AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \
SignedHeaders=host;x-amz-date, \
Signature=5fa00fa31553b73ebf1942676e86291e8372ff2a2260956d9b8aae1d763fbf31"
);
}
#[test]
fn uri_encode_handles_bedrock_model_ids() {
assert_eq!(
uri_encode_segment("anthropic.claude-3-5-sonnet-20241022-v2:0"),
"anthropic.claude-3-5-sonnet-20241022-v2%3A0"
);
}
#[test]
fn session_token_is_added_to_signed_headers() {
let creds = AwsCredentials {
access_key_id: "AKID".into(),
secret_access_key: "secret".into(),
session_token: Some("token123".into()),
};
let base = vec![("host".to_string(), "h".to_string())];
let headers = signed_headers(
&creds, "us-east-1", "bedrock", "POST", "/", "", &base, b"{}", "20150830T123600Z",
"20150830",
);
assert!(headers
.iter()
.any(|(k, v)| k == "x-amz-security-token" && v == "token123"));
let auth = headers.iter().find(|(k, _)| k == "authorization").unwrap();
assert!(auth.1.contains("x-amz-security-token"));
}
}