pub mod pricing;
pub mod tool_use;
pub use pricing::{estimate_bedrock_cost_usd, normalize_model_family};
pub use tool_use::{build_tool_config, document_to_json_string, json_to_document};
use std::time::Instant;
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::Client as BedrockClient;
use aws_sdk_bedrockruntime::types::{
ContentBlock, ConversationRole, InferenceConfiguration, Message, SystemContentBlock,
};
use tracing::{debug, warn};
use super::{LlmProvider, LlmRequest, LlmResponse, error::LlmError};
const ENV_REGION_TRUSTY: &str = "TRUSTY_AWS_REGION";
const ENV_REGION_AWS: &str = "AWS_REGION";
const DEFAULT_REGION: &str = "us-east-1";
pub(crate) const INFERENCE_PROFILE_PREFIXES: &[&str] = &["us.", "eu.", "ap.", "jp.", "global."];
const MAX_RETRIES: u32 = 3;
pub fn resolve_bedrock_region(explicit: Option<&str>) -> String {
if let Some(r) = explicit.filter(|s| !s.is_empty()) {
return r.to_string();
}
for var in [ENV_REGION_TRUSTY, ENV_REGION_AWS] {
if let Ok(val) = std::env::var(var) {
let val = val.trim().to_string();
if !val.is_empty() {
return val;
}
}
}
DEFAULT_REGION.to_string()
}
fn validate_model_id(model_id: &str) -> Result<(), LlmError> {
let has_profile_prefix = INFERENCE_PROFILE_PREFIXES
.iter()
.any(|pfx| model_id.starts_with(pfx));
if has_profile_prefix {
return Ok(());
}
Err(LlmError::Validation(format!(
"Bedrock model id {model_id:?} must start with a cross-region inference-profile \
prefix (us., eu., ap., jp., or global.). \
Example: \"us.anthropic.claude-sonnet-4-6\". \
Bare foundation-model ids are not supported."
)))
}
pub struct BedrockProvider {
client: BedrockClient,
pub model: String,
region: String,
}
impl BedrockProvider {
pub async fn new(model: impl Into<String>, region: Option<&str>) -> Result<Self, LlmError> {
let model = model.into();
validate_model_id(&model)?;
let region_str = resolve_bedrock_region(region);
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::meta::region::RegionProviderChain::first_try(
aws_types::region::Region::new(region_str.clone()),
))
.load()
.await;
let client = BedrockClient::new(&config);
Ok(Self {
client,
model,
region: region_str,
})
}
#[cfg(test)]
pub fn from_client(
client: BedrockClient,
model: impl Into<String>,
region: impl Into<String>,
) -> Self {
Self {
client,
model: model.into(),
region: region.into(),
}
}
pub fn region(&self) -> &str {
&self.region
}
async fn call_once(&self, req: &LlmRequest) -> Result<LlmResponse, LlmError> {
let start = Instant::now();
let mut system_blocks: Vec<SystemContentBlock> = Vec::new();
if !req.system.is_empty() {
system_blocks.push(SystemContentBlock::Text(req.system.clone()));
}
let mut converse_messages: Vec<Message> = Vec::new();
for msg in &req.messages {
let role = if msg.role == "assistant" {
ConversationRole::Assistant
} else {
ConversationRole::User
};
let bedrock_msg = Message::builder()
.role(role)
.content(ContentBlock::Text(msg.content.clone()))
.build()
.map_err(|e| LlmError::Validation(format!("build Bedrock Message: {e}")))?;
converse_messages.push(bedrock_msg);
}
if converse_messages.is_empty() {
return Err(LlmError::Validation(
"LlmRequest contains no user/assistant messages".to_string(),
));
}
let inference = InferenceConfiguration::builder()
.max_tokens(req.max_tokens as i32)
.temperature(req.temperature)
.build();
let mut sdk_req = self
.client
.converse()
.model_id(&req.model)
.inference_config(inference)
.set_messages(Some(converse_messages));
if !system_blocks.is_empty() {
sdk_req = sdk_req.set_system(Some(system_blocks));
}
if let Some(ref schema) = req.response_schema {
let tool_config = build_tool_config(&schema.name, &schema.schema)?;
sdk_req = sdk_req.tool_config(tool_config);
}
let resp = sdk_req.send().await.map_err(|sdk_err| {
let msg = sdk_err.to_string();
let lower = msg.to_lowercase();
if lower.contains("resourcenotfound") || lower.contains("no such model") {
LlmError::ModelNotFound(format!("model={}: {msg}", req.model))
} else if lower.contains("accessdenied")
|| lower.contains("unauthorized")
|| lower.contains("credential")
|| lower.contains("not authorized")
{
LlmError::AccessDenied(format!(
"AWS Bedrock access denied (model={}, region={}): {msg}. \
Ensure AWS credentials are configured and the account has \
bedrock:InvokeModel permission.",
req.model, self.region
))
} else if lower.contains("validationexception") || lower.contains("validation") {
LlmError::Validation(msg)
} else if lower.contains("throttlingexception")
|| lower.contains("throttled")
|| lower.contains("rate")
{
LlmError::RateLimited
} else if lower.contains("serviceunavailable")
|| lower.contains("internalserver")
|| lower.contains("modelnotready")
&& (lower.contains("creating") || lower.contains("failed"))
{
LlmError::Upstream {
status: 503,
body: msg,
}
} else if lower.contains("modelnotready") || lower.contains("not in active") {
LlmError::ModelNotReady(msg)
} else {
LlmError::Transport(format!(
"Bedrock Converse SDK error (model={}, region={}): {msg}",
req.model, self.region
))
}
})?;
let latency_ms = start.elapsed().as_millis() as u64;
let text = if req.response_schema.is_some() {
tool_use::extract_tool_use_json(&resp)
.or_else(|| extract_converse_text(&resp))
.unwrap_or_default()
} else {
extract_converse_text(&resp).unwrap_or_default()
};
let (input_tokens, output_tokens) = extract_token_usage(&resp);
let cost_usd = estimate_bedrock_cost_usd(&req.model, input_tokens, output_tokens);
let finish_reason = Some(resp.stop_reason().as_str().trim().to_ascii_lowercase());
Ok(LlmResponse {
text,
model: req.model.clone(),
input_tokens,
output_tokens,
latency_ms,
cost_usd,
finish_reason,
})
}
}
#[async_trait]
impl LlmProvider for BedrockProvider {
fn name(&self) -> &str {
"bedrock"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, LlmError> {
debug!(
model = %req.model,
provider = "bedrock",
region = %self.region,
structured = req.response_schema.is_some(),
"bedrock complete request"
);
let mut attempt = 0u32;
loop {
match self.call_once(&req).await {
Ok(resp) => {
debug!(
model = %resp.model,
input_tokens = resp.input_tokens,
output_tokens = resp.output_tokens,
latency_ms = resp.latency_ms,
cost_usd = resp.cost_usd,
"bedrock complete response"
);
return Ok(resp);
}
Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
attempt += 1;
let backoff_ms = 500u64 * (1u64 << attempt.min(6));
warn!(
attempt,
backoff_ms,
model = %req.model,
"bedrock transient error — retrying: {err}"
);
tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)).await;
}
Err(err) => return Err(err),
}
}
}
}
fn extract_converse_text(
resp: &aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
) -> Option<String> {
let msg = resp.output()?.as_message().ok()?;
let mut out = String::new();
for block in msg.content() {
if let ContentBlock::Text(t) = block {
if !out.is_empty() {
out.push('\n');
}
out.push_str(t);
}
}
if out.is_empty() { None } else { Some(out) }
}
fn extract_token_usage(
resp: &aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
) -> (u32, u32) {
resp.usage()
.map(|u| {
(
u.input_tokens().max(0) as u32,
u.output_tokens().max(0) as u32,
)
})
.unwrap_or((0, 0))
}
#[cfg(test)]
#[path = "tests.rs"]
mod tests;