use std::time::{Duration, Instant};
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::Client as BedrockClient;
use aws_sdk_bedrockruntime::types::{
ContentBlock, ConversationRole, InferenceConfiguration, Message, SystemContentBlock,
};
use tracing::{debug, warn};
use super::{LlmProvider, LlmRequest, LlmResponse, error::SmLlmError, pricing};
const ENV_REGION_TRUSTY: &str = "TRUSTY_AWS_REGION";
const ENV_REGION_AWS: &str = "AWS_REGION";
const DEFAULT_REGION: &str = "us-east-1";
const INFERENCE_PROFILE_PREFIXES: &[&str] = &["us.", "eu.", "ap.", "jp.", "global."];
const MAX_RETRIES: u32 = 3;
pub fn resolve_bedrock_region(explicit: Option<&str>) -> String {
if let Some(r) = explicit.filter(|s| !s.is_empty()) {
return r.to_string();
}
for var in [ENV_REGION_TRUSTY, ENV_REGION_AWS] {
if let Ok(val) = std::env::var(var) {
let val = val.trim().to_string();
if !val.is_empty() {
return val;
}
}
}
DEFAULT_REGION.to_string()
}
fn validate_model_id(model_id: &str) -> Result<(), SmLlmError> {
if INFERENCE_PROFILE_PREFIXES
.iter()
.any(|p| model_id.starts_with(p))
{
return Ok(());
}
Err(SmLlmError::Validation(format!(
"Bedrock model id {model_id:?} must start with a cross-region inference-profile \
prefix (us., eu., ap., jp., or global.). Example: \"us.anthropic.claude-sonnet-4-6\"."
)))
}
pub struct BedrockProvider {
client: BedrockClient,
pub model: String,
region: String,
}
impl BedrockProvider {
pub async fn new(model: impl Into<String>, region: Option<&str>) -> Result<Self, SmLlmError> {
let model = model.into();
validate_model_id(&model)?;
let region_str = resolve_bedrock_region(region);
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_config::meta::region::RegionProviderChain::first_try(
aws_types::region::Region::new(region_str.clone()),
))
.load()
.await;
let client = BedrockClient::new(&config);
Ok(Self {
client,
model,
region: region_str,
})
}
#[cfg(test)]
pub fn from_client(
client: BedrockClient,
model: impl Into<String>,
region: impl Into<String>,
) -> Self {
Self {
client,
model: model.into(),
region: region.into(),
}
}
pub fn region(&self) -> &str {
&self.region
}
async fn call_once(&self, req: &LlmRequest) -> Result<LlmResponse, SmLlmError> {
let start = Instant::now();
let mut system_blocks: Vec<SystemContentBlock> = Vec::new();
if !req.system.is_empty() {
system_blocks.push(SystemContentBlock::Text(req.system.clone()));
}
let mut messages: Vec<Message> = Vec::new();
for m in &req.messages {
let role = if m.role == "assistant" {
ConversationRole::Assistant
} else {
ConversationRole::User
};
let msg = Message::builder()
.role(role)
.content(ContentBlock::Text(m.content.clone()))
.build()
.map_err(|e| SmLlmError::Validation(format!("build Bedrock Message: {e}")))?;
messages.push(msg);
}
if messages.is_empty() {
return Err(SmLlmError::Validation(
"LlmRequest contains no user/assistant messages".to_string(),
));
}
let inference = InferenceConfiguration::builder()
.max_tokens(i32::try_from(req.max_tokens).unwrap_or(i32::MAX))
.temperature(req.temperature)
.build();
let mut sdk_req = self
.client
.converse()
.model_id(&req.model)
.inference_config(inference)
.set_messages(Some(messages));
if !system_blocks.is_empty() {
sdk_req = sdk_req.set_system(Some(system_blocks));
}
let resp = sdk_req
.send()
.await
.map_err(|e| map_sdk_error(e.to_string(), &req.model, &self.region))?;
let latency_ms = start.elapsed().as_millis() as u64;
let text = extract_converse_text(&resp).unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&resp);
let cost_usd = pricing::estimate_cost_usd(&req.model, input_tokens, output_tokens);
Ok(LlmResponse {
text,
model: req.model.clone(),
input_tokens,
output_tokens,
latency_ms,
cost_usd,
})
}
}
fn map_sdk_error(msg: String, model: &str, region: &str) -> SmLlmError {
let lower = msg.to_lowercase();
if lower.contains("resourcenotfound") || lower.contains("no such model") {
SmLlmError::ModelNotFound(format!("model={model}: {msg}"))
} else if lower.contains("accessdenied")
|| lower.contains("unauthorized")
|| lower.contains("credential")
|| lower.contains("not authorized")
|| lower.contains("no credentials")
{
SmLlmError::AccessDenied(format!(
"AWS Bedrock access denied (model={model}, region={region}): {msg}"
))
} else if lower.contains("validationexception") || lower.contains("validation") {
SmLlmError::Validation(msg)
} else if lower.contains("throttling") || lower.contains("throttled") || lower.contains("rate")
{
SmLlmError::RateLimited
} else if lower.contains("serviceunavailable") || lower.contains("internalserver") {
SmLlmError::Upstream {
status: 503,
body: msg,
}
} else if lower.contains("modelnotready") || lower.contains("not in active") {
SmLlmError::ModelNotReady(msg)
} else {
SmLlmError::Transport(format!(
"Bedrock Converse SDK error (model={model}, region={region}): {msg}"
))
}
}
#[async_trait]
impl LlmProvider for BedrockProvider {
fn name(&self) -> &str {
"bedrock"
}
async fn complete(&self, req: LlmRequest) -> Result<LlmResponse, SmLlmError> {
debug!(
provider = "bedrock",
model = %req.model,
region = %self.region,
"sm bedrock complete request"
);
let mut attempt = 0u32;
loop {
match self.call_once(&req).await {
Ok(resp) => {
debug!(
provider = "bedrock",
model = %resp.model,
input_tokens = resp.input_tokens,
output_tokens = resp.output_tokens,
latency_ms = resp.latency_ms,
cost_usd = resp.cost_usd,
"sm bedrock complete response"
);
return Ok(resp);
}
Err(err) if err.is_retryable() && attempt < MAX_RETRIES => {
attempt += 1;
let backoff_ms = 500u64 * (1u64 << attempt.min(6));
warn!(attempt, backoff_ms, model = %req.model, "sm bedrock retry: {err}");
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
Err(err) => return Err(err),
}
}
}
}
fn extract_converse_text(
resp: &aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
) -> Option<String> {
let msg = resp.output()?.as_message().ok()?;
let mut out = String::new();
for block in msg.content() {
if let ContentBlock::Text(t) = block {
if !out.is_empty() {
out.push('\n');
}
out.push_str(t);
}
}
if out.is_empty() { None } else { Some(out) }
}
fn extract_token_usage(
resp: &aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
) -> (u32, u32) {
resp.usage()
.map(|u| {
(
u.input_tokens().max(0) as u32,
u.output_tokens().max(0) as u32,
)
})
.unwrap_or((0, 0))
}
#[cfg(test)]
#[path = "bedrock_tests.rs"]
mod tests;