#![cfg(feature = "bedrock")]
#![cfg_attr(docsrs, doc(cfg(feature = "bedrock")))]
use std::str::FromStr;
use std::time::SystemTime;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
use aws_sigv4::sign::v4::SigningParams;
use crate::auth::{RequestSigner, SignerResult};
#[derive(Clone)]
pub struct BedrockCredentials {
access_key_id: String,
secret_access_key: String,
session_token: Option<String>,
}
impl BedrockCredentials {
#[must_use]
pub fn new(access_key_id: impl Into<String>, secret_access_key: impl Into<String>) -> Self {
Self {
access_key_id: access_key_id.into(),
secret_access_key: secret_access_key.into(),
session_token: None,
}
}
#[must_use]
pub fn with_session_token(mut self, token: impl Into<String>) -> Self {
self.session_token = Some(token.into());
self
}
#[must_use]
pub fn from_env() -> Option<Self> {
let access = std::env::var("AWS_ACCESS_KEY_ID").ok()?;
let secret = std::env::var("AWS_SECRET_ACCESS_KEY").ok()?;
let mut creds = Self::new(access, secret);
if let Ok(token) = std::env::var("AWS_SESSION_TOKEN") {
creds = creds.with_session_token(token);
}
Some(creds)
}
fn to_aws(&self) -> Credentials {
Credentials::new(
self.access_key_id.clone(),
self.secret_access_key.clone(),
self.session_token.clone(),
None,
"claude-api-bedrock-signer",
)
}
}
impl std::fmt::Debug for BedrockCredentials {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BedrockCredentials")
.field("access_key_id", &"<redacted>")
.field("secret_access_key", &"<redacted>")
.field(
"session_token",
&self.session_token.as_ref().map(|_| "<redacted>"),
)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct BedrockSigner {
credentials: BedrockCredentials,
region: String,
service: String,
}
impl BedrockSigner {
#[must_use]
pub fn new(credentials: BedrockCredentials, region: impl Into<String>) -> Self {
Self {
credentials,
region: region.into(),
service: "bedrock".into(),
}
}
#[must_use]
pub fn with_service(mut self, service: impl Into<String>) -> Self {
self.service = service.into();
self
}
}
impl RequestSigner for BedrockSigner {
fn sign(&self, request: &mut reqwest::Request) -> SignerResult {
let identity = self.credentials.to_aws().into();
let settings = SigningSettings::default();
let params: aws_sigv4::http_request::SigningParams = SigningParams::builder()
.identity(&identity)
.region(&self.region)
.name(&self.service)
.time(SystemTime::now())
.settings(settings)
.build()
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?
.into();
let header_strings: Vec<(String, String)> = request
.headers()
.iter()
.filter_map(|(name, value)| {
value
.to_str()
.ok()
.map(|v| (name.as_str().to_owned(), v.to_owned()))
})
.collect();
let headers_iter = header_strings.iter().map(|(k, v)| (k.as_str(), v.as_str()));
let body_bytes = request.body().and_then(|b| b.as_bytes()).unwrap_or(&[]);
let signable_body = SignableBody::Bytes(body_bytes);
let url = request.url().as_str().to_owned();
let signable =
SignableRequest::new(request.method().as_str(), &url, headers_iter, signable_body)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
let signing_output = sign(signable, ¶ms)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
let (instructions, _signature) = signing_output.into_parts();
for (name, value) in instructions.headers() {
let header_name = http::HeaderName::from_str(name)?;
let header_value = http::HeaderValue::from_str(value)?;
request.headers_mut().insert(header_name, header_value);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_request() -> reqwest::Request {
let client = reqwest::Client::new();
client
.post("https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20240620-v1:0/invoke")
.body(r#"{"messages":[{"role":"user","content":"hi"}]}"#)
.build()
.unwrap()
}
fn fixed_signer() -> BedrockSigner {
BedrockSigner::new(
BedrockCredentials::new("AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"),
"us-east-1",
)
}
#[test]
fn bedrock_signer_adds_authorization_header() {
let signer = fixed_signer();
let mut req = make_request();
signer.sign(&mut req).expect("sign succeeds");
let auth = req
.headers()
.get("authorization")
.expect("Authorization header set by signer");
let auth_str = auth.to_str().expect("Authorization is ASCII");
assert!(
auth_str.starts_with("AWS4-HMAC-SHA256 "),
"expected sigv4 algorithm prefix: {auth_str}"
);
assert!(
auth_str.contains("Credential=AKIDEXAMPLE/"),
"expected access key in credential scope: {auth_str}"
);
assert!(
auth_str.contains("/us-east-1/bedrock/aws4_request"),
"expected region+service in credential scope: {auth_str}"
);
assert!(
auth_str.contains("SignedHeaders="),
"expected SignedHeaders component: {auth_str}"
);
assert!(
auth_str.contains("Signature="),
"expected Signature component: {auth_str}"
);
}
#[test]
fn bedrock_signer_adds_x_amz_date_header() {
let signer = fixed_signer();
let mut req = make_request();
signer.sign(&mut req).unwrap();
let date = req
.headers()
.get("x-amz-date")
.expect("X-Amz-Date header set by signer");
let s = date.to_str().unwrap();
assert_eq!(s.len(), 16, "date should be 16-char ISO 8601 basic: {s}");
assert!(s.ends_with('Z'), "date should be UTC: {s}");
}
#[test]
fn bedrock_signer_includes_session_token_when_present() {
let creds =
BedrockCredentials::new("AKID", "SECRET").with_session_token("session-token-value");
let signer = BedrockSigner::new(creds, "us-west-2");
let mut req = make_request();
signer.sign(&mut req).unwrap();
let token = req
.headers()
.get("x-amz-security-token")
.expect("X-Amz-Security-Token forwarded by signer");
assert_eq!(token.to_str().unwrap(), "session-token-value");
}
#[test]
fn bedrock_credentials_redact_secret_in_debug() {
let creds =
BedrockCredentials::new("AKID", "VERY-SECRET").with_session_token("ALSO-SECRET");
let dbg = format!("{creds:?}");
assert!(!dbg.contains("VERY-SECRET"), "{dbg}");
assert!(!dbg.contains("ALSO-SECRET"), "{dbg}");
assert!(dbg.contains("redacted"), "{dbg}");
}
#[test]
fn from_env_returns_none_when_missing() {
let _: Option<BedrockCredentials> = BedrockCredentials::from_env();
}
#[test]
fn signer_default_service_name_is_bedrock() {
let signer = fixed_signer();
assert_eq!(signer.service, "bedrock");
}
#[test]
fn signer_with_service_override() {
let signer = fixed_signer().with_service("bedrock-runtime");
assert_eq!(signer.service, "bedrock-runtime");
}
}