use super::{ChatEvent, ChatProvider, ToolDef};
use crate::ChatMessage;
use anyhow::{Context, Result, anyhow};
use async_trait::async_trait;
use aws_config::BehaviorVersion;
use aws_sdk_bedrockruntime::Client as BedrockClient;
use aws_sdk_bedrockruntime::types::{
ContentBlock, ConversationRole, InferenceConfiguration, Message, SystemContentBlock,
};
use tokio::sync::mpsc::Sender;
pub const DEFAULT_BEDROCK_MODEL: &str = "us.anthropic.claude-sonnet-4-6";
pub const ENV_REGION_TRUSTY: &str = "TRUSTY_AWS_REGION";
pub const ENV_REGION_AWS: &str = "AWS_REGION";
pub const DEFAULT_BEDROCK_REGION: &str = "us-east-1";
pub fn resolve_bedrock_region(explicit: Option<&str>) -> String {
if let Some(r) = explicit.filter(|s| !s.is_empty()) {
return r.to_string();
}
for var in [ENV_REGION_TRUSTY, ENV_REGION_AWS] {
let val = std::env::var(var).unwrap_or_default();
if !val.is_empty() {
return val;
}
}
DEFAULT_BEDROCK_REGION.to_string()
}
pub struct BedrockProvider {
client: BedrockClient,
model: String,
region: String,
}
impl BedrockProvider {
pub async fn new(model: impl Into<String>, region: Option<&str>) -> Result<Self> {
let region_str = resolve_bedrock_region(region);
let region_provider = aws_config::meta::region::RegionProviderChain::first_try(
aws_types::region::Region::new(region_str.clone()),
);
let config = aws_config::defaults(BehaviorVersion::latest())
.region(region_provider)
.load()
.await;
let client = BedrockClient::new(&config);
Ok(Self {
client,
model: model.into(),
region: region_str,
})
}
#[cfg(test)]
pub fn from_client(
client: BedrockClient,
model: impl Into<String>,
region: impl Into<String>,
) -> Self {
Self {
client,
model: model.into(),
region: region.into(),
}
}
pub fn region(&self) -> &str {
&self.region
}
}
#[async_trait]
impl ChatProvider for BedrockProvider {
fn name(&self) -> &str {
"bedrock"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: Vec<ChatMessage>,
_tools: Vec<ToolDef>,
tx: Sender<ChatEvent>,
) -> Result<()> {
let mut system_blocks: Vec<SystemContentBlock> = Vec::new();
let mut converse_messages: Vec<Message> = Vec::new();
for msg in &messages {
if msg.role == "system" {
system_blocks.push(SystemContentBlock::Text(msg.content.clone()));
} else {
let role = if msg.role == "assistant" {
ConversationRole::Assistant
} else {
ConversationRole::User
};
let bedrock_msg = Message::builder()
.role(role)
.content(ContentBlock::Text(msg.content.clone()))
.build()
.context("build Bedrock Message")?;
converse_messages.push(bedrock_msg);
}
}
if converse_messages.is_empty() {
return Err(anyhow!(
"BedrockProvider::chat_stream: no user/assistant messages provided"
));
}
let inference = InferenceConfiguration::builder().max_tokens(4096).build();
let mut req = self
.client
.converse()
.model_id(&self.model)
.inference_config(inference)
.set_messages(Some(converse_messages));
if !system_blocks.is_empty() {
req = req.set_system(Some(system_blocks));
}
let resp = req.send().await.with_context(|| {
format!(
"AWS Bedrock Converse request failed (model={}, region={}). \
Ensure AWS credentials are configured for Bedrock deep analysis \
(AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_PROFILE / IAM role).",
self.model, self.region
)
})?;
let text = extract_converse_text(&resp);
let text = text.unwrap_or_default();
if tx.send(ChatEvent::Delta(text)).await.is_err() {
return Ok(());
}
let _ = tx.send(ChatEvent::Done).await;
Ok(())
}
}
fn extract_converse_text(
resp: &aws_sdk_bedrockruntime::operation::converse::ConverseOutput,
) -> Option<String> {
let msg = resp.output()?.as_message().ok()?;
let mut out = String::new();
for block in msg.content() {
if let ContentBlock::Text(t) = block {
if !out.is_empty() {
out.push('\n');
}
out.push_str(t);
}
}
if out.is_empty() { None } else { Some(out) }
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bedrock_provider_reports_metadata() {
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_types::region::Region::new("us-east-1"))
.no_credentials()
.load()
.await;
let client = BedrockClient::new(&config);
let provider = BedrockProvider::from_client(client, DEFAULT_BEDROCK_MODEL, "us-east-1");
assert_eq!(provider.name(), "bedrock");
assert_eq!(provider.model(), DEFAULT_BEDROCK_MODEL);
assert_eq!(provider.region(), "us-east-1");
}
#[test]
fn bedrock_region_resolution() {
assert_eq!(
resolve_bedrock_region(Some("eu-west-1")),
"eu-west-1",
"explicit should win"
);
assert_eq!(
resolve_bedrock_region(Some("")),
DEFAULT_BEDROCK_REGION,
"empty explicit should fall through to default"
);
assert_eq!(
resolve_bedrock_region(None),
DEFAULT_BEDROCK_REGION,
"None should return default"
);
}
#[tokio::test]
async fn bedrock_no_credentials_returns_clear_error() {
let config = aws_config::defaults(BehaviorVersion::latest())
.region(aws_types::region::Region::new("us-east-1"))
.no_credentials()
.load()
.await;
let client = BedrockClient::new(&config);
let provider = BedrockProvider::from_client(client, DEFAULT_BEDROCK_MODEL, "us-east-1");
let (tx, _rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
let result = provider
.chat_stream(
vec![crate::ChatMessage {
role: "user".into(),
content: "hello".into(),
tool_call_id: None,
tool_calls: None,
}],
vec![],
tx,
)
.await;
let err = result.expect_err("should fail without real credentials");
let msg = format!("{err:#}");
assert!(
msg.to_lowercase().contains("bedrock")
|| msg.to_lowercase().contains("credential")
|| msg.to_lowercase().contains("aws"),
"error message should mention Bedrock/credentials; got: {msg}"
);
}
#[tokio::test]
#[ignore = "requires real AWS credentials with bedrock:InvokeModel permission"]
async fn bedrock_live_converse_smoke_test() {
let provider = BedrockProvider::new(DEFAULT_BEDROCK_MODEL, None)
.await
.expect("BedrockProvider::new failed");
let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
let handle = tokio::spawn(async move {
provider
.chat_stream(
vec![
crate::ChatMessage {
role: "system".into(),
content: "You are a concise assistant. Reply in plain text.".into(),
tool_call_id: None,
tool_calls: None,
},
crate::ChatMessage {
role: "user".into(),
content: "Say hello in exactly 3 words.".into(),
tool_call_id: None,
tool_calls: None,
},
],
vec![],
tx,
)
.await
});
let mut text = String::new();
let mut saw_done = false;
while let Some(ev) = rx.recv().await {
match ev {
ChatEvent::Delta(s) => text.push_str(&s),
ChatEvent::Done => saw_done = true,
ChatEvent::Error(e) => panic!("stream error: {e}"),
ChatEvent::ToolCall(_) => {}
}
}
handle
.await
.expect("task panicked")
.expect("chat_stream failed");
assert!(!text.is_empty(), "expected non-empty response");
assert!(saw_done, "expected ChatEvent::Done");
eprintln!("bedrock_live_converse_smoke_test response: {text:?}");
}
}