use std::collections::BTreeMap;
use chrono::{DateTime, Utc};
use sha2::{Digest, Sha256};
use url::{Host, Url};
use zeroize::Zeroizing;
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct AwsSigV4Credentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
}
impl std::fmt::Debug for AwsSigV4Credentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsSigV4Credentials")
.field("access_key_id", &"<redacted>")
.field("secret_access_key", &"<redacted>")
.field(
"session_token",
&self.session_token.as_ref().map(|_| "<redacted>"),
)
.finish()
}
}
pub(crate) struct AwsSigV4Input<'a> {
pub credentials: &'a AwsSigV4Credentials,
pub method: &'a str,
pub url: &'a str,
pub service: &'a str,
pub region: &'a str,
pub headers: &'a BTreeMap<String, String>,
pub body: &'a [u8],
pub timestamp: DateTime<Utc>,
}
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct AwsSigV4SignedRequest {
pub headers: BTreeMap<String, String>,
pub authorization: String,
pub amz_date: String,
pub content_sha256: String,
pub security_token: Option<String>,
pub signed_headers: String,
pub canonical_request: String,
pub string_to_sign: String,
pub credential_scope: String,
}
impl std::fmt::Debug for AwsSigV4SignedRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut headers = self.headers.clone();
for (name, value) in headers.iter_mut() {
let lower = name.to_ascii_lowercase();
if lower == "authorization" || lower.contains("token") {
*value = "<redacted>".to_string();
}
}
f.debug_struct("AwsSigV4SignedRequest")
.field("headers", &headers)
.field("authorization", &"<redacted>")
.field("amz_date", &self.amz_date)
.field("content_sha256", &self.content_sha256)
.field(
"security_token",
&self.security_token.as_ref().map(|_| "<redacted>"),
)
.field("signed_headers", &self.signed_headers)
.field(
"canonical_request",
&redact_debug_text(&self.canonical_request, self.security_token.as_deref()),
)
.field("string_to_sign", &self.string_to_sign)
.field("credential_scope", &self.credential_scope)
.finish()
}
}
pub(crate) fn sign(input: AwsSigV4Input<'_>) -> Result<AwsSigV4SignedRequest, String> {
validate_required("access_key_id", &input.credentials.access_key_id)?;
validate_required("secret_access_key", &input.credentials.secret_access_key)?;
let method = input.method.trim().to_ascii_uppercase();
validate_required("method", &method)?;
validate_required("service", input.service)?;
validate_required("region", input.region)?;
let service = input.service.trim();
let region = input.region.trim();
let parsed = Url::parse(input.url)
.map_err(|_| "url must be an absolute URL with a scheme and host".to_string())?;
let host = host_header(&parsed)?;
let canonical_uri = canonical_uri(parsed.path());
let canonical_query = canonical_query(parsed.query().unwrap_or_default());
let amz_date = input.timestamp.format("%Y%m%dT%H%M%SZ").to_string();
let date = input.timestamp.format("%Y%m%d").to_string();
let content_sha256 = sha256_hex(input.body);
let mut canonical_headers = CanonicalHeaders::default();
let mut output_headers = BTreeMap::new();
for (name, value) in input.headers {
let lower = normalize_header_name(name)?;
match lower.as_str() {
"authorization" => {
return Err("headers.Authorization is generated by aws_sigv4_headers".to_string());
}
"x-amz-date" | "x-amz-content-sha256" | "x-amz-security-token" => {
return Err(format!("headers.{name} is generated by aws_sigv4_headers"));
}
"host" => {
let value = canonical_header_value(value);
if value != host {
return Err("headers.Host must match the URL host".to_string());
}
}
_ => {
let value = canonical_header_value(value);
canonical_headers.insert(&lower, &value);
output_headers.insert(output_header_name(&lower, name), value);
}
}
}
canonical_headers.insert("host", &host);
canonical_headers.insert("x-amz-content-sha256", &content_sha256);
canonical_headers.insert("x-amz-date", &amz_date);
output_headers.insert("Host".to_string(), host);
output_headers.insert("X-Amz-Content-Sha256".to_string(), content_sha256.clone());
output_headers.insert("X-Amz-Date".to_string(), amz_date.clone());
let mut security_token = None;
if let Some(token) = input
.credentials
.session_token
.as_deref()
.filter(|token| !token.trim().is_empty())
{
let value = canonical_header_value(token);
canonical_headers.insert("x-amz-security-token", &value);
output_headers.insert("X-Amz-Security-Token".to_string(), value.clone());
security_token = Some(value);
}
let signed_headers = canonical_headers.signed_headers();
let canonical_headers_text = canonical_headers.canonical_headers();
let canonical_request = format!(
"{method}\n{canonical_uri}\n{canonical_query}\n{canonical_headers_text}\n{signed_headers}\n{content_sha256}"
);
let credential_scope = format!("{date}/{region}/{service}/aws4_request");
let string_to_sign = format!(
"AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
sha256_hex(canonical_request.as_bytes())
);
let signature = signature_hex(
&input.credentials.secret_access_key,
&date,
region,
service,
&string_to_sign,
);
let authorization = format!(
"AWS4-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
input.credentials.access_key_id
);
output_headers.insert("Authorization".to_string(), authorization.clone());
Ok(AwsSigV4SignedRequest {
headers: output_headers,
authorization,
amz_date,
content_sha256,
security_token,
signed_headers,
canonical_request,
string_to_sign,
credential_scope,
})
}
fn redact_debug_text(text: &str, token: Option<&str>) -> String {
match token {
Some(token) if !token.is_empty() => text.replace(token, "<redacted>"),
_ => text.to_string(),
}
}
fn validate_required(label: &str, value: &str) -> Result<(), String> {
if value.trim().is_empty() {
Err(format!("{label} is required"))
} else {
Ok(())
}
}
fn host_header(url: &Url) -> Result<String, String> {
let mut host = match url.host() {
Some(Host::Domain(domain)) => domain.to_string(),
Some(Host::Ipv4(addr)) => addr.to_string(),
Some(Host::Ipv6(addr)) => format!("[{addr}]"),
None => return Err("url must include a host".to_string()),
};
if let Some(port) = url.port() {
host.push(':');
host.push_str(&port.to_string());
}
Ok(host)
}
#[derive(Default)]
struct CanonicalHeaders {
headers: BTreeMap<String, String>,
}
impl CanonicalHeaders {
fn insert(&mut self, name: &str, value: &str) {
self.headers
.entry(name.to_string())
.and_modify(|existing| {
existing.push(',');
existing.push_str(value);
})
.or_insert_with(|| value.to_string());
}
fn signed_headers(&self) -> String {
self.headers.keys().cloned().collect::<Vec<_>>().join(";")
}
fn canonical_headers(&self) -> String {
self.headers
.iter()
.map(|(key, value)| format!("{key}:{value}\n"))
.collect()
}
}
fn normalize_header_name(name: &str) -> Result<String, String> {
let trimmed = name.trim();
if trimmed.is_empty() {
return Err("header names cannot be empty".to_string());
}
if !trimmed.bytes().all(is_http_token_byte) {
return Err(format!("invalid header name `{trimmed}`"));
}
Ok(trimmed.to_ascii_lowercase())
}
fn is_http_token_byte(byte: u8) -> bool {
matches!(
byte,
b'!' | b'#'
| b'$'
| b'%'
| b'&'
| b'\''
| b'*'
| b'+'
| b'-'
| b'.'
| b'0'..=b'9'
| b'A'..=b'Z'
| b'^'
| b'_'
| b'`'
| b'a'..=b'z'
| b'|'
| b'~'
)
}
fn canonical_header_value(value: &str) -> String {
value.split_whitespace().collect::<Vec<_>>().join(" ")
}
fn output_header_name(lower: &str, original: &str) -> String {
match lower {
"content-type" => "Content-Type".to_string(),
"accept" => "Accept".to_string(),
"x-amz-target" => "X-Amz-Target".to_string(),
_ => original.trim().to_string(),
}
}
fn canonical_uri(path: &str) -> String {
if path.is_empty() {
return "/".to_string();
}
path.split('/')
.map(|segment| uri_encode(&percent_decode(segment)))
.collect::<Vec<_>>()
.join("/")
}
fn canonical_query(raw: &str) -> String {
if raw.is_empty() {
return String::new();
}
let mut pairs = raw
.split('&')
.map(|part| {
let (key, value) = part.split_once('=').unwrap_or((part, ""));
(
uri_encode(&percent_decode(key)),
uri_encode(&percent_decode(value)),
)
})
.collect::<Vec<_>>();
pairs.sort();
pairs
.into_iter()
.map(|(key, value)| format!("{key}={value}"))
.collect::<Vec<_>>()
.join("&")
}
fn percent_decode(input: &str) -> Vec<u8> {
let bytes = input.as_bytes();
let mut decoded = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'%' && i + 2 < bytes.len() {
if let Some(value) = hex_pair(bytes[i + 1], bytes[i + 2]) {
decoded.push(value);
i += 3;
continue;
}
}
decoded.push(bytes[i]);
i += 1;
}
decoded
}
fn hex_pair(high: u8, low: u8) -> Option<u8> {
Some(hex_digit(high)? << 4 | hex_digit(low)?)
}
fn hex_digit(byte: u8) -> Option<u8> {
match byte {
b'0'..=b'9' => Some(byte - b'0'),
b'a'..=b'f' => Some(byte - b'a' + 10),
b'A'..=b'F' => Some(byte - b'A' + 10),
_ => None,
}
}
fn uri_encode(bytes: &[u8]) -> String {
let mut out = String::new();
for byte in bytes {
match *byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
out.push(*byte as char);
}
_ => out.push_str(&format!("%{byte:02X}")),
}
}
out
}
fn signature_hex(
secret_access_key: &str,
date: &str,
region: &str,
service: &str,
string_to_sign: &str,
) -> String {
let secret = Zeroizing::new(format!("AWS4{secret_access_key}"));
let date_key = Zeroizing::new(crate::connectors::hmac::hmac_sha256(
secret.as_bytes(),
date.as_bytes(),
));
let region_key = Zeroizing::new(crate::connectors::hmac::hmac_sha256(
&date_key,
region.as_bytes(),
));
let service_key = Zeroizing::new(crate::connectors::hmac::hmac_sha256(
®ion_key,
service.as_bytes(),
));
let signing_key = Zeroizing::new(crate::connectors::hmac::hmac_sha256(
&service_key,
b"aws4_request",
));
hex::encode(crate::connectors::hmac::hmac_sha256(
&signing_key,
string_to_sign.as_bytes(),
))
}
fn sha256_hex(data: &[u8]) -> String {
hex::encode(Sha256::digest(data))
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::TimeZone;
fn credentials() -> AwsSigV4Credentials {
AwsSigV4Credentials {
access_key_id: "AKIDEXAMPLE".to_string(),
secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
session_token: None,
}
}
fn fixed_time() -> DateTime<Utc> {
Utc.with_ymd_and_hms(2015, 8, 30, 12, 36, 0).unwrap()
}
#[test]
fn signs_standard_headers() {
let headers = BTreeMap::from([(
"Content-Type".to_string(),
"application/x-amz-json-1.1".to_string(),
)]);
let signed = sign(AwsSigV4Input {
credentials: &credentials(),
method: "POST",
url: "https://service.us-east-1.amazonaws.com/",
service: "service",
region: "us-east-1",
headers: &headers,
body: br#"{"hello":"world"}"#,
timestamp: fixed_time(),
})
.expect("signed request");
assert_eq!(signed.amz_date, "20150830T123600Z");
assert_eq!(
signed.signed_headers,
"content-type;host;x-amz-content-sha256;x-amz-date"
);
assert_eq!(
signed.content_sha256,
"93a23971a914e5eacbf0a8d25154cda309c3c1c72fbb9914d47c60f3cb681588"
);
assert!(signed
.authorization
.contains("Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request"));
assert!(signed
.authorization
.contains("SignedHeaders=content-type;host;x-amz-content-sha256;x-amz-date"));
assert!(signed.authorization.contains(
"Signature=bbd206a978e4215b047ffcb9df0601d8173a68a1d314fbb5ce382e7556b147be"
));
}
#[test]
fn signs_session_token() {
let mut creds = credentials();
creds.session_token = Some("session-token".to_string());
let headers = BTreeMap::new();
let signed = sign(AwsSigV4Input {
credentials: &creds,
method: "GET",
url: "https://service.us-east-1.amazonaws.com/",
service: "service",
region: "us-east-1",
headers: &headers,
body: b"",
timestamp: fixed_time(),
})
.expect("signed request");
assert_eq!(
signed
.headers
.get("X-Amz-Security-Token")
.map(String::as_str),
Some("session-token")
);
assert_eq!(
signed.signed_headers,
"host;x-amz-content-sha256;x-amz-date;x-amz-security-token"
);
assert!(signed
.canonical_request
.contains("x-amz-security-token:session-token\n"));
}
#[test]
fn canonicalizes_query_and_path_edges() {
let headers = BTreeMap::from([("X-Amz-Target".to_string(), "Example.Action".to_string())]);
let signed = sign(AwsSigV4Input {
credentials: &credentials(),
method: "GET",
url: "https://example.amazonaws.com/a%2Fb/space%20here/tilde%7e?Param=value%2Fwith%20space&plus=a+b&empty&Param=second",
service: "execute-api",
region: "us-east-1",
headers: &headers,
body: b"",
timestamp: fixed_time(),
})
.expect("signed request");
let lines = signed.canonical_request.lines().collect::<Vec<_>>();
assert_eq!(lines[1], "/a%2Fb/space%20here/tilde~");
assert_eq!(
lines[2],
"Param=second&Param=value%2Fwith%20space&empty=&plus=a%2Bb"
);
assert!(signed
.canonical_request
.contains("x-amz-target:Example.Action\n"));
}
#[test]
fn validation_errors_do_not_include_credentials() {
let creds = AwsSigV4Credentials {
access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
secret_access_key: "secret-that-must-not-leak".to_string(),
session_token: Some("session-that-must-not-leak".to_string()),
};
let headers = BTreeMap::new();
let error = sign(AwsSigV4Input {
credentials: &creds,
method: "POST",
url: "not a url",
service: "service",
region: "us-east-1",
headers: &headers,
body: b"",
timestamp: fixed_time(),
})
.expect_err("invalid url should fail");
assert!(!error.contains("AKIAIOSFODNN7EXAMPLE"));
assert!(!error.contains("secret-that-must-not-leak"));
assert!(!error.contains("session-that-must-not-leak"));
}
#[test]
fn debug_output_redacts_credentials() {
let mut creds = credentials();
creds.session_token = Some("session-token".to_string());
let headers = BTreeMap::new();
let signed = sign(AwsSigV4Input {
credentials: &creds,
method: "GET",
url: "https://service.us-east-1.amazonaws.com/",
service: "service",
region: "us-east-1",
headers: &headers,
body: b"",
timestamp: fixed_time(),
})
.expect("signed request");
let debug = format!("{signed:?}");
assert!(!debug.contains("AKIDEXAMPLE"));
assert!(!debug.contains("session-token"));
assert!(!debug.contains("Signature="));
}
}