1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, ProviderAuthMode, ProviderConnectionConfig, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20#[derive(Clone)]
22pub struct AwsCredentials {
23 pub access_key_id: String,
24 pub secret_access_key: String,
25 pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30 client: Client,
31 model: String,
32 max_tokens: i32,
33 temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37 pub async fn new() -> Self {
40 Self::new_with_connection(ProviderConnectionConfig::default()).await
41 }
42
43 pub async fn new_with_connection(connection: ProviderConnectionConfig) -> Self {
44 if connection == ProviderConnectionConfig::default() {
45 let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
46 let client = Client::new(&config);
47 return Self {
48 client,
49 model: DEFAULT_MODEL.to_string(),
50 max_tokens: DEFAULT_MAX_TOKENS,
51 temperature: None,
52 };
53 }
54
55 let mut loader = aws_config::defaults(BehaviorVersion::latest());
56 if connection.auth_mode == ProviderAuthMode::None {
57 loader = loader.no_credentials();
58 }
59
60 if let Some(url) = &connection.base_url {
61 loader = loader.endpoint_url(url.clone());
62 }
63
64 let sdk_config = loader.load().await;
65 let mut builder = aws_sdk_bedrockruntime::config::Builder::from(&sdk_config);
66 if connection.auth_mode == ProviderAuthMode::None {
67 builder = builder.allow_no_auth();
68 }
69
70 let client = Client::from_conf(builder.build());
71
72 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
73 }
74
75 pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
77 let client = build_client(credentials, region);
78
79 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
80 }
81
82 pub fn with_model(mut self, model: &str) -> Self {
83 self.model = model.to_string();
84 self
85 }
86
87 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
88 self.max_tokens = max_tokens;
89 self
90 }
91
92 pub fn with_temperature(mut self, temperature: f32) -> Self {
93 self.temperature = Some(temperature);
94 self
95 }
96
97 async fn send_converse_stream(
98 &self,
99 context: &Context,
100 ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
101 let (system_blocks, messages) = map_messages(context.messages())?;
102 let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
103
104 if let Some(temp) = self.temperature {
105 inference_config = inference_config.temperature(temp);
106 }
107
108 let inference_config = inference_config.build();
109
110 let mut request = self
111 .client
112 .converse_stream()
113 .model_id(&self.model)
114 .set_messages(Some(messages))
115 .inference_config(inference_config);
116
117 if !system_blocks.is_empty() {
118 request = request.set_system(Some(system_blocks));
119 }
120
121 if !context.tools().is_empty() {
122 let tool_config = map_tools(context.tools())?;
123 request = request.tool_config(tool_config);
124 }
125
126 info!(model = %self.model, "Sending Bedrock converse_stream request");
127
128 let response = request.send().await.map_err(|e| {
129 error!(model = %self.model, error = ?e, "Bedrock API error");
130 LlmError::from(e)
131 })?;
132
133 Ok(response.stream)
134 }
135}
136
137impl ProviderFactory for BedrockProvider {
138 async fn from_env() -> Result<Self> {
139 Ok(Self::new().await)
140 }
141
142 async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
143 Ok(Self::new_with_connection(connection).await)
144 }
145
146 fn with_model(self, model: &str) -> Self {
147 self.with_model(model)
148 }
149}
150
151impl StreamingModelProvider for BedrockProvider {
152 fn model(&self) -> Option<crate::LlmModel> {
153 format!("bedrock:{}", self.model).parse().ok()
154 }
155
156 fn context_window(&self) -> Option<u32> {
157 get_context_window("bedrock", &self.model)
158 }
159
160 fn stream_response(&self, context: &Context) -> LlmResponseStream {
161 let provider = self.clone();
162 let context = context.clone();
163
164 Box::pin(async_stream::stream! {
165 match provider.send_converse_stream(&context).await {
166 Ok(receiver) => {
167 let mut stream = Box::pin(process_bedrock_stream(receiver));
168 while let Some(result) = stream.next().await {
169 yield result;
170 }
171 }
172 Err(e) => {
173 yield Err(e);
174 }
175 }
176 })
177 }
178
179 fn display_name(&self) -> String {
180 format!("Bedrock ({})", self.model)
181 }
182}
183
184impl From<SdkError<ConverseStreamError>> for LlmError {
185 fn from(e: SdkError<ConverseStreamError>) -> Self {
186 let message = format!("Bedrock API error: {e}");
187 match e {
188 SdkError::TimeoutError(_) => LlmError::Timeout(message),
189 SdkError::DispatchFailure(_) => LlmError::Network(message),
190 SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
191 SdkError::ServiceError(svc) => {
192 let inner = svc.err();
193 if inner.is_throttling_exception() {
194 LlmError::RateLimited(message)
195 } else if inner.is_service_unavailable_exception()
196 || inner.is_internal_server_exception()
197 || inner.is_model_stream_error_exception()
198 {
199 LlmError::ServerError { status: None, message }
200 } else {
201 LlmError::ApiError(message)
202 }
203 }
204 _ => LlmError::ApiError(message),
205 }
206 }
207}
208
209fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
210 let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
211
212 if let Some(creds) = credentials {
213 config = config.credentials_provider(Credentials::new(
214 creds.access_key_id,
215 creds.secret_access_key,
216 creds.session_token,
217 None,
218 "aether-bedrock-provider",
219 ));
220 }
221
222 config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
223
224 Client::from_conf(config.build())
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 fn test_provider() -> BedrockProvider {
232 BedrockProvider::from_config(None, None)
233 }
234
235 #[test]
236 fn test_display_name() {
237 assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
238 }
239
240 #[test]
241 fn test_with_model() {
242 let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
243 assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
244 }
245
246 #[test]
247 fn test_with_max_tokens() {
248 let provider = test_provider().with_max_tokens(8192);
249 assert_eq!(provider.max_tokens, 8192);
250 }
251
252 #[test]
253 fn test_with_temperature() {
254 let provider = test_provider().with_temperature(0.7);
255 assert_eq!(provider.temperature, Some(0.7));
256 }
257
258 #[test]
259 fn test_default_values() {
260 let provider = test_provider();
261 assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
262 assert_eq!(provider.max_tokens, 16_384);
263 assert!(provider.temperature.is_none());
264 }
265
266 #[test]
267 fn test_from_config_with_credentials() {
268 let credentials = AwsCredentials {
269 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
270 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
271 session_token: None,
272 };
273
274 let provider = BedrockProvider::from_config(Some(credentials), None);
275 assert_eq!(provider.model, DEFAULT_MODEL);
276 }
277
278 #[test]
279 fn test_from_config_with_credentials_and_region() {
280 let credentials = AwsCredentials {
281 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
282 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
283 session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
284 };
285
286 let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
287 .with_model("anthropic.claude-opus-4-20250514-v1:0")
288 .with_max_tokens(4096)
289 .with_temperature(0.5);
290
291 assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
292 assert_eq!(provider.max_tokens, 4096);
293 assert_eq!(provider.temperature, Some(0.5));
294 }
295
296 #[test]
297 fn test_from_config_with_region_only() {
298 let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
299 assert_eq!(provider.model, DEFAULT_MODEL);
300 }
301
302 #[test]
303 fn catalog_foundation_id_resolves_context_window() {
304 let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
305 assert!(provider.context_window().is_some());
306 assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
307 }
308
309 #[test]
310 fn cross_region_profile_id_in_catalog_resolves() {
311 let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
312 assert!(provider.context_window().is_some());
313 }
314
315 #[test]
316 fn unknown_cross_region_profile_id_falls_through_to_profile() {
317 let id = "us.anthropic.claude-future-model-v99:0";
318 let provider = test_provider().with_model(id);
319 assert_eq!(provider.context_window(), None);
320 assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
321 assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
322 }
323
324 #[test]
325 fn inference_profile_arn_is_passed_through_as_profile() {
326 let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
327 let provider = test_provider().with_model(arn);
328 assert_eq!(provider.context_window(), None);
329 assert_eq!(provider.model, arn);
330 assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{arn}"));
331 }
332
333 #[test]
334 fn application_inference_profile_arn_is_passed_through_as_profile() {
335 let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
336 let provider = test_provider().with_model(arn);
337 assert_eq!(provider.context_window(), None);
338 assert_eq!(provider.model, arn);
339 assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
340 }
341}