use chrono::{DateTime, Utc};
use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
type HmacSha256 = Hmac<Sha256>;
#[derive(Debug, Clone)]
pub struct SigV4Signer {
access_key: String,
secret_key: String,
session_token: Option<String>,
region: String,
service: String,
}
impl SigV4Signer {
pub fn new(
access_key: String,
secret_key: String,
session_token: Option<String>,
region: String,
) -> Self {
Self {
access_key,
secret_key,
session_token,
region,
service: "bedrock".to_string(),
}
}
pub fn sign_request(
&self,
method: &str,
url: &str,
headers: &HashMap<String, String>,
body: &str,
timestamp: DateTime<Utc>,
) -> Result<HashMap<String, String>, String> {
let parsed_url = url::Url::parse(url).map_err(|e| format!("Invalid URL: {}", e))?;
let host = parsed_url.host_str().ok_or("Missing host in URL")?;
let path = parsed_url.path();
let query = parsed_url.query().unwrap_or("");
let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
let date_stamp = timestamp.format("%Y%m%d").to_string();
let mut canonical_headers = headers.clone();
canonical_headers.insert("host".to_string(), host.to_string());
canonical_headers.insert("x-amz-date".to_string(), amz_date.clone());
if let Some(ref token) = self.session_token {
canonical_headers.insert("x-amz-security-token".to_string(), token.clone());
}
let mut sorted_headers: Vec<_> = canonical_headers.iter().collect();
sorted_headers.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase()));
let canonical_headers_str = sorted_headers
.iter()
.map(|(k, v)| format!("{}:{}", k.to_lowercase(), v.trim()))
.collect::<Vec<_>>()
.join("\n");
let signed_headers = sorted_headers
.iter()
.map(|(k, _)| k.to_lowercase())
.collect::<Vec<_>>()
.join(";");
let payload_hash = hex::encode(Sha256::digest(body.as_bytes()));
let canonical_request = format!(
"{}\n{}\n{}\n{}\n\n{}\n{}",
method.to_uppercase(),
path,
query,
canonical_headers_str,
signed_headers,
payload_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!(
"{}/{}/{}/aws4_request",
date_stamp, self.region, self.service
);
let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm, amz_date, credential_scope, canonical_request_hash
);
let signature = self.calculate_signature(&string_to_sign, &date_stamp)?;
let authorization = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, self.access_key, credential_scope, signed_headers, signature
);
let mut final_headers = canonical_headers;
final_headers.insert("Authorization".to_string(), authorization);
Ok(final_headers)
}
fn calculate_signature(
&self,
string_to_sign: &str,
date_stamp: &str,
) -> Result<String, String> {
let k_date = self.hmac_sha256(
format!("AWS4{}", self.secret_key).as_bytes(),
date_stamp.as_bytes(),
)?;
let k_region = self.hmac_sha256(&k_date, self.region.as_bytes())?;
let k_service = self.hmac_sha256(&k_region, self.service.as_bytes())?;
let k_signing = self.hmac_sha256(&k_service, b"aws4_request")?;
let signature = self.hmac_sha256(&k_signing, string_to_sign.as_bytes())?;
Ok(hex::encode(signature))
}
fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Result<Vec<u8>, String> {
let mut mac =
HmacSha256::new_from_slice(key).map_err(|e| format!("HMAC key error: {}", e))?;
mac.update(data);
Ok(mac.finalize().into_bytes().to_vec())
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{TimeZone, Utc};
#[test]
fn test_sigv4_signer_creation() {
let signer = SigV4Signer::new(
"AKIATEST".to_string(),
"testsecret".to_string(),
None,
"us-east-1".to_string(),
);
assert_eq!(signer.access_key, "AKIATEST");
assert_eq!(signer.region, "us-east-1");
assert_eq!(signer.service, "bedrock");
}
#[test]
fn test_hmac_sha256() {
let signer = SigV4Signer::new(
"test".to_string(),
"test".to_string(),
None,
"us-east-1".to_string(),
);
let result = signer.hmac_sha256(b"key", b"message");
assert!(result.is_ok());
let expected = "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a";
assert_eq!(hex::encode(result.unwrap()), expected);
}
#[test]
fn test_sign_request() {
let signer = SigV4Signer::new(
"AKIATEST".to_string(),
"testsecret".to_string(),
None,
"us-east-1".to_string(),
);
let timestamp = Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap();
let headers = HashMap::new();
let result = signer.sign_request(
"POST",
"https://bedrock-runtime.us-east-1.amazonaws.com/model/test/invoke",
&headers,
"{}",
timestamp,
);
assert!(result.is_ok());
let signed_headers = result.unwrap();
assert!(signed_headers.contains_key("Authorization"));
assert!(signed_headers.contains_key("x-amz-date"));
}
}