use crate::Headers;
use crate::{Error, Result};
use aws_credential_types::Credentials;
use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
use aws_sigv4::sign::v4;
use std::time::SystemTime;
use tokio::sync::OnceCell;
pub(super) const BEDROCK_SERVICE: &str = "bedrock";
static CREDS_CACHE: OnceCell<CachedCreds> = OnceCell::const_new();
#[derive(Clone)]
pub(super) struct CachedCreds {
pub creds: Credentials,
pub region: String,
}
pub(super) async fn get_credentials() -> Result<CachedCreds> {
let cached = CREDS_CACHE.get_or_try_init(load_credentials_uncached).await?;
Ok(cached.clone())
}
async fn load_credentials_uncached() -> Result<CachedCreds> {
use aws_config::BehaviorVersion;
use aws_credential_types::provider::ProvideCredentials;
let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
let region = config
.region()
.map(|r| r.as_ref().to_string())
.or_else(|| std::env::var("AWS_REGION").ok())
.or_else(|| std::env::var("AWS_DEFAULT_REGION").ok())
.unwrap_or_else(|| "us-east-1".to_string());
let provider = config.credentials_provider().ok_or_else(|| Error::AdapterNotSupported {
adapter_kind: crate::adapter::AdapterKind::BedrockSigv4,
feature: "AWS credentials (no provider found in default chain)".to_string(),
})?;
let creds = provider.provide_credentials().await.map_err(|err| Error::AdapterNotSupported {
adapter_kind: crate::adapter::AdapterKind::BedrockSigv4,
feature: format!("AWS credential resolution failed: {err}"),
})?;
Ok(CachedCreds { creds, region })
}
pub(super) fn cached_region(cached: &CachedCreds) -> &str {
&cached.region
}
pub(super) fn sign_request(creds: &Credentials, region: &str, url: &str, body: &[u8]) -> Result<Headers> {
let identity = creds.clone().into();
let signing_params = v4::SigningParams::builder()
.identity(&identity)
.region(region)
.name(BEDROCK_SERVICE)
.time(SystemTime::now())
.settings(SigningSettings::default())
.build()
.map_err(|err| sign_err(format!("signing params: {err}")))?
.into();
let host = url_host(url).ok_or_else(|| sign_err(format!("could not extract host from url: {url}")))?;
let signing_headers: Vec<(&str, &str)> = vec![("host", host), ("content-type", "application/json")];
let signable = SignableRequest::new("POST", url, signing_headers.into_iter(), SignableBody::Bytes(body))
.map_err(|err| sign_err(format!("signable request: {err}")))?;
let (signing_instructions, _sig) = sign(signable, &signing_params)
.map_err(|err| sign_err(format!("sign: {err}")))?
.into_parts();
let mut genai_headers_vec: Vec<(String, String)> =
vec![("content-type".to_string(), "application/json".to_string())];
for (name, value) in signing_instructions.headers() {
genai_headers_vec.push((name.to_string(), value.to_string()));
}
Ok(Headers::from(genai_headers_vec))
}
fn sign_err(msg: String) -> Error {
Error::AdapterNotSupported {
adapter_kind: crate::adapter::AdapterKind::BedrockSigv4,
feature: format!("SigV4 signing failed: {msg}"),
}
}
fn url_host(url: &str) -> Option<&str> {
let without_scheme = url.split_once("://").map(|(_, rest)| rest).unwrap_or(url);
let end = without_scheme
.find(|c: char| c == '/' || c == ':' || c == '?')
.unwrap_or(without_scheme.len());
let host = &without_scheme[..end];
if host.is_empty() { None } else { Some(host) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_host_from_url() {
assert_eq!(
url_host("https://bedrock-runtime.us-east-1.amazonaws.com/model/foo/converse"),
Some("bedrock-runtime.us-east-1.amazonaws.com")
);
assert_eq!(
url_host("https://bedrock-runtime.us-east-1.amazonaws.com"),
Some("bedrock-runtime.us-east-1.amazonaws.com")
);
assert_eq!(url_host("http://localhost:4566/model/x/converse"), Some("localhost"));
}
}