use super::mappers::{map_messages, map_tools};
use super::streaming::process_bedrock_stream;
use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
use crate::{Context, LlmError, Result};
use aws_config::Region;
use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
use aws_sdk_bedrockruntime::{Client, Config};
use futures::StreamExt;
use tracing::{error, info};
const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
const DEFAULT_MAX_TOKENS: i32 = 16_384;
const DEFAULT_REGION: &str = "us-east-1";
#[derive(Clone)]
pub struct AwsCredentials {
pub access_key_id: String,
pub secret_access_key: String,
pub session_token: Option<String>,
}
#[derive(Clone)]
pub struct BedrockProvider {
client: Client,
model: String,
max_tokens: i32,
temperature: Option<f32>,
}
impl BedrockProvider {
pub async fn new() -> Self {
let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
let client = Client::new(&config);
Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
}
pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
let client = build_client(credentials, region);
Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
}
pub fn with_model(mut self, model: &str) -> Self {
self.model = model.to_string();
self
}
pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
async fn send_converse_stream(
&self,
context: &Context,
) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
let (system_blocks, messages) = map_messages(context.messages())?;
let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
if let Some(temp) = self.temperature {
inference_config = inference_config.temperature(temp);
}
let inference_config = inference_config.build();
let mut request = self
.client
.converse_stream()
.model_id(&self.model)
.set_messages(Some(messages))
.inference_config(inference_config);
if !system_blocks.is_empty() {
request = request.set_system(Some(system_blocks));
}
if !context.tools().is_empty() {
let tool_config = map_tools(context.tools())?;
request = request.tool_config(tool_config);
}
info!(model = %self.model, "Sending Bedrock converse_stream request");
let response = request.send().await.map_err(|e| {
error!(model = %self.model, error = ?e, "Bedrock API error");
LlmError::ApiError(format!("Bedrock API error: {e}"))
})?;
Ok(response.stream)
}
}
impl ProviderFactory for BedrockProvider {
async fn from_env() -> Result<Self> {
Ok(Self::new().await)
}
fn with_model(self, model: &str) -> Self {
self.with_model(model)
}
}
impl StreamingModelProvider for BedrockProvider {
fn model(&self) -> Option<crate::LlmModel> {
format!("bedrock:{}", self.model).parse().ok()
}
fn context_window(&self) -> Option<u32> {
get_context_window("bedrock", &self.model)
}
fn stream_response(&self, context: &Context) -> LlmResponseStream {
let provider = self.clone();
let context = context.clone();
Box::pin(async_stream::stream! {
match provider.send_converse_stream(&context).await {
Ok(receiver) => {
let mut stream = Box::pin(process_bedrock_stream(receiver));
while let Some(result) = stream.next().await {
yield result;
}
}
Err(e) => {
yield Err(e);
}
}
})
}
fn display_name(&self) -> String {
format!("Bedrock ({})", self.model)
}
}
fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
if let Some(creds) = credentials {
config = config.credentials_provider(Credentials::new(
creds.access_key_id,
creds.secret_access_key,
creds.session_token,
None,
"aether-bedrock-provider",
));
}
config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
Client::from_conf(config.build())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_provider() -> BedrockProvider {
BedrockProvider::from_config(None, None)
}
#[test]
fn test_display_name() {
assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
}
#[test]
fn test_with_model() {
let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
}
#[test]
fn test_with_max_tokens() {
let provider = test_provider().with_max_tokens(8192);
assert_eq!(provider.max_tokens, 8192);
}
#[test]
fn test_with_temperature() {
let provider = test_provider().with_temperature(0.7);
assert_eq!(provider.temperature, Some(0.7));
}
#[test]
fn test_default_values() {
let provider = test_provider();
assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
assert_eq!(provider.max_tokens, 16_384);
assert!(provider.temperature.is_none());
}
#[test]
fn test_from_config_with_credentials() {
let credentials = AwsCredentials {
access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
session_token: None,
};
let provider = BedrockProvider::from_config(Some(credentials), None);
assert_eq!(provider.model, DEFAULT_MODEL);
}
#[test]
fn test_from_config_with_credentials_and_region() {
let credentials = AwsCredentials {
access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
};
let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
.with_model("anthropic.claude-opus-4-20250514-v1:0")
.with_max_tokens(4096)
.with_temperature(0.5);
assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
assert_eq!(provider.max_tokens, 4096);
assert_eq!(provider.temperature, Some(0.5));
}
#[test]
fn test_from_config_with_region_only() {
let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
assert_eq!(provider.model, DEFAULT_MODEL);
}
}