1use super::mappers::{default_cache_point, map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, ProviderAuthMode, ProviderConnectionConfig, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20#[derive(Clone)]
22pub struct AwsCredentials {
23 pub access_key_id: String,
24 pub secret_access_key: String,
25 pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30 client: Client,
31 model: String,
32 inference_profile_arn: Option<String>,
33 max_tokens: i32,
34 temperature: Option<f32>,
35}
36
37impl BedrockProvider {
38 pub async fn new() -> Self {
41 Self::new_with_connection(ProviderConnectionConfig::default()).await
42 }
43
44 pub async fn new_with_connection(connection: ProviderConnectionConfig) -> Self {
45 let client = if connection.auth_mode == ProviderAuthMode::None {
46 build_no_auth_client(connection.base_url.as_deref(), region_from_env().as_deref())
47 } else {
48 let mut loader = aws_config::defaults(BehaviorVersion::latest());
49 if let Some(url) = &connection.base_url {
50 loader = loader.endpoint_url(url.clone());
51 }
52 let config = loader.load().await;
53 Client::new(&config)
54 };
55
56 Self {
57 client,
58 model: DEFAULT_MODEL.to_string(),
59 inference_profile_arn: connection.inference_profile_arn,
60 max_tokens: DEFAULT_MAX_TOKENS,
61 temperature: None,
62 }
63 }
64
65 pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
67 let client = build_client(credentials, region);
68
69 Self {
70 client,
71 model: DEFAULT_MODEL.to_string(),
72 inference_profile_arn: None,
73 max_tokens: DEFAULT_MAX_TOKENS,
74 temperature: None,
75 }
76 }
77
78 pub fn with_model(mut self, model: &str) -> Self {
79 self.model = model.to_string();
80 self
81 }
82
83 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
84 self.max_tokens = max_tokens;
85 self
86 }
87
88 pub fn with_temperature(mut self, temperature: f32) -> Self {
89 self.temperature = Some(temperature);
90 self
91 }
92
93 pub fn with_inference_profile_arn(mut self, arn: impl Into<String>) -> Self {
94 self.inference_profile_arn = Some(arn.into());
95 self
96 }
97
98 fn request_model_id(&self) -> &str {
99 self.inference_profile_arn.as_deref().unwrap_or(&self.model)
100 }
101
102 async fn send_converse_stream(
103 &self,
104 context: &Context,
105 ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
106 let cache_point =
107 self.model().is_some_and(|m| m.supports_prompt_caching()).then(default_cache_point).transpose()?;
108 let (system_blocks, messages) = map_messages(context.messages(), cache_point.as_ref())?;
109 let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
110
111 if let Some(temp) = self.temperature {
112 inference_config = inference_config.temperature(temp);
113 }
114
115 let inference_config = inference_config.build();
116
117 let mut request = self
118 .client
119 .converse_stream()
120 .model_id(self.request_model_id())
121 .set_messages(Some(messages))
122 .inference_config(inference_config);
123
124 if !system_blocks.is_empty() {
125 request = request.set_system(Some(system_blocks));
126 }
127
128 if !context.tools().is_empty() {
129 let tool_config = map_tools(context.tools(), cache_point.as_ref())?;
130 request = request.tool_config(tool_config);
131 }
132
133 if let Some(arn) = self.inference_profile_arn.as_deref() {
134 info!(model = %self.model, inference_profile_arn = %arn, "Sending Bedrock converse_stream request");
135 } else {
136 info!(model = %self.model, "Sending Bedrock converse_stream request");
137 }
138
139 let response = request.send().await.map_err(|e| {
140 error!(model = %self.model, error = ?e, "Bedrock API error");
141 LlmError::from(e)
142 })?;
143
144 Ok(response.stream)
145 }
146}
147
148impl ProviderFactory for BedrockProvider {
149 async fn from_env() -> Result<Self> {
150 Ok(Self::new().await)
151 }
152
153 async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
154 Ok(Self::new_with_connection(connection).await)
155 }
156
157 fn with_model(self, model: &str) -> Self {
158 self.with_model(model)
159 }
160}
161
162impl StreamingModelProvider for BedrockProvider {
163 fn model(&self) -> Option<crate::LlmModel> {
164 format!("bedrock:{}", self.model).parse().ok()
165 }
166
167 fn context_window(&self) -> Option<u32> {
168 get_context_window("bedrock", &self.model)
169 }
170
171 fn stream_response(&self, context: &Context) -> LlmResponseStream {
172 let provider = self.clone();
173 let context = context.clone();
174
175 Box::pin(async_stream::stream! {
176 match provider.send_converse_stream(&context).await {
177 Ok(receiver) => {
178 let mut stream = Box::pin(process_bedrock_stream(receiver));
179 while let Some(result) = stream.next().await {
180 yield result;
181 }
182 }
183 Err(e) => {
184 yield Err(e);
185 }
186 }
187 })
188 }
189
190 fn display_name(&self) -> String {
191 format!("Bedrock ({})", self.model)
192 }
193}
194
195impl From<SdkError<ConverseStreamError>> for LlmError {
196 fn from(e: SdkError<ConverseStreamError>) -> Self {
197 let message = format!("Bedrock API error: {e}");
198 match e {
199 SdkError::TimeoutError(_) => LlmError::Timeout(message),
200 SdkError::DispatchFailure(_) => LlmError::Network(message),
201 SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
202 SdkError::ServiceError(svc) => {
203 let inner = svc.err();
204 if inner.is_throttling_exception() {
205 LlmError::RateLimited(message)
206 } else if inner.is_service_unavailable_exception()
207 || inner.is_internal_server_exception()
208 || inner.is_model_stream_error_exception()
209 {
210 LlmError::ServerError { status: None, message }
211 } else {
212 LlmError::ApiError(message)
213 }
214 }
215 _ => LlmError::ApiError(message),
216 }
217 }
218}
219
220fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
221 let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
222
223 if let Some(creds) = credentials {
224 config = config.credentials_provider(Credentials::new(
225 creds.access_key_id,
226 creds.secret_access_key,
227 creds.session_token,
228 None,
229 "aether-bedrock-provider",
230 ));
231 }
232
233 config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
234
235 Client::from_conf(config.build())
236}
237
238fn build_no_auth_client(base_url: Option<&str>, region: Option<&str>) -> Client {
239 let mut config = Config::builder()
240 .behavior_version(BehaviorVersion::latest())
241 .allow_no_auth()
242 .region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
243
244 if let Some(url) = base_url {
245 config = config.endpoint_url(url);
246 }
247
248 Client::from_conf(config.build())
249}
250
251fn region_from_env() -> Option<String> {
252 ["AWS_REGION", "AWS_DEFAULT_REGION"].into_iter().find_map(|name| match std::env::var(name) {
253 Ok(value) if !value.is_empty() => Some(value),
254 _ => None,
255 })
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::types::IsoString;
262 use crate::{ChatMessage, ContentBlock};
263 use axum::Router;
264 use axum::body::Body;
265 use axum::extract::State;
266 use axum::http::{HeaderMap, Method, Request, StatusCode};
267 use axum::response::IntoResponse;
268 use axum::routing::any;
269 use std::sync::Arc;
270 use tokio::net::TcpListener;
271 use tokio::sync::{Mutex, oneshot};
272
273 fn inference_profile_arn(model: &str) -> String {
274 format!("arn:aws:bedrock:us-west-2:000000000000:inference-profile/{model}")
275 }
276
277 fn application_inference_profile_arn() -> &'static str {
278 "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"
279 }
280
281 fn test_provider() -> BedrockProvider {
282 BedrockProvider::from_config(None, None)
283 }
284
285 #[test]
286 fn test_display_name() {
287 assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
288 }
289
290 #[test]
291 fn test_with_model() {
292 let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
293 assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
294 }
295
296 #[test]
297 fn test_with_max_tokens() {
298 let provider = test_provider().with_max_tokens(8192);
299 assert_eq!(provider.max_tokens, 8192);
300 }
301
302 #[test]
303 fn test_with_temperature() {
304 let provider = test_provider().with_temperature(0.7);
305 assert_eq!(provider.temperature, Some(0.7));
306 }
307
308 #[test]
309 fn test_default_values() {
310 let provider = test_provider();
311 assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
312 assert_eq!(provider.max_tokens, 16_384);
313 assert!(provider.temperature.is_none());
314 }
315
316 #[tokio::test]
317 async fn auth_none_sends_unsigned_request_to_custom_endpoint() {
318 let endpoint = FakeBedrockEndpoint::start().await;
319 let provider = BedrockProvider::new_with_connection(ProviderConnectionConfig {
320 base_url: Some(endpoint.url.clone()),
321 auth_mode: ProviderAuthMode::None,
322 ..Default::default()
323 })
324 .await;
325
326 let context = Context::new(
327 vec![ChatMessage::User { content: vec![ContentBlock::text("hello")], timestamp: IsoString::now() }],
328 vec![],
329 );
330
331 let result = provider.send_converse_stream(&context).await;
332 let request = endpoint.request.await.expect("fake Bedrock endpoint received no request");
333
334 assert!(result.is_err());
335 assert_eq!(request.method, Method::POST);
336 assert!(request.path.starts_with("/model/"), "{}", request.path);
337 assert!(!request.headers.contains_key("authorization"), "request was signed: {:?}", request.headers);
338 assert!(
339 !request.headers.contains_key("x-amz-security-token"),
340 "request included session token: {:?}",
341 request.headers
342 );
343 }
344
345 #[test]
346 fn test_from_config_with_credentials() {
347 let credentials = AwsCredentials {
348 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
349 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
350 session_token: None,
351 };
352
353 let provider = BedrockProvider::from_config(Some(credentials), None);
354 assert_eq!(provider.model, DEFAULT_MODEL);
355 }
356
357 #[test]
358 fn test_from_config_with_credentials_and_region() {
359 let credentials = AwsCredentials {
360 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
361 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
362 session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
363 };
364
365 let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
366 .with_model("anthropic.claude-opus-4-20250514-v1:0")
367 .with_max_tokens(4096)
368 .with_temperature(0.5);
369
370 assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
371 assert_eq!(provider.max_tokens, 4096);
372 assert_eq!(provider.temperature, Some(0.5));
373 }
374
375 #[test]
376 fn test_from_config_with_region_only() {
377 let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
378 assert_eq!(provider.model, DEFAULT_MODEL);
379 }
380
381 #[test]
382 fn catalog_foundation_id_resolves_context_window() {
383 let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
384 assert!(provider.context_window().is_some());
385 assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
386 }
387
388 #[test]
389 fn cross_region_profile_id_in_catalog_resolves() {
390 let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
391 assert!(provider.context_window().is_some());
392 }
393
394 #[test]
395 fn unknown_cross_region_profile_id_falls_through_to_profile() {
396 let id = "us.anthropic.claude-future-model-v99:0";
397 let provider = test_provider().with_model(id);
398 assert_eq!(provider.context_window(), None);
399 assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
400 assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
401 }
402
403 #[tokio::test]
404 async fn separate_inference_profile_arn_is_used_as_request_model_id() {
405 let endpoint = FakeBedrockEndpoint::start().await;
406 let provider = BedrockProvider::new_with_connection(ProviderConnectionConfig {
407 base_url: Some(endpoint.url.clone()),
408 auth_mode: ProviderAuthMode::None,
409 inference_profile_arn: Some(application_inference_profile_arn().to_string()),
410 })
411 .await
412 .with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
413 let context = Context::new(
414 vec![ChatMessage::User { content: vec![ContentBlock::text("hello")], timestamp: IsoString::now() }],
415 vec![],
416 );
417
418 let result = provider.send_converse_stream(&context).await;
419 let request = endpoint.request.await.expect("fake Bedrock endpoint received no request");
420
421 assert!(result.is_err());
422 assert!(
423 request.path.contains("arn%3Aaws%3Abedrock%3Aus-west-2%3A000000000000%3Aapplication-inference-profile"),
424 "{}",
425 request.path
426 );
427 assert_eq!(provider.context_window(), Some(200_000));
428 assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
429 }
430
431 #[test]
432 fn with_inference_profile_arn_keeps_canonical_model_identity() {
433 let arn = inference_profile_arn("us.anthropic.claude-sonnet-4-5-20250929-v1:0");
434 let provider =
435 test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0").with_inference_profile_arn(&arn);
436
437 assert_eq!(provider.request_model_id(), arn);
438 assert_eq!(provider.context_window(), Some(200_000));
439 assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
440 }
441
442 #[test]
443 fn prompt_caching_support_comes_from_canonical_model() {
444 let cached = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
445 assert!(cached.model().unwrap().supports_prompt_caching());
446
447 let unknown_profile = test_provider().with_model("us.anthropic.claude-future-model-v99:0");
448 assert!(!unknown_profile.model().unwrap().supports_prompt_caching());
449 }
450
451 struct FakeBedrockEndpoint {
452 url: String,
453 request: oneshot::Receiver<CapturedRequest>,
454 }
455
456 struct CapturedRequest {
457 method: Method,
458 path: String,
459 headers: HeaderMap,
460 }
461
462 #[derive(Clone)]
463 struct FakeBedrockState {
464 request_tx: Arc<Mutex<Option<oneshot::Sender<CapturedRequest>>>>,
465 shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
466 }
467
468 impl FakeBedrockEndpoint {
469 async fn start() -> Self {
470 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind fake Bedrock endpoint");
471 let url = format!("http://{}", listener.local_addr().expect("fake Bedrock endpoint address"));
472 let (request_tx, request) = oneshot::channel();
473 let (shutdown_tx, shutdown) = oneshot::channel();
474 let state = FakeBedrockState {
475 request_tx: Arc::new(Mutex::new(Some(request_tx))),
476 shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
477 };
478
479 let app = Router::new().fallback(any(capture_bedrock_request)).with_state(state);
480 tokio::spawn(async move {
481 axum::serve(listener, app)
482 .with_graceful_shutdown(async {
483 let _ = shutdown.await;
484 })
485 .await
486 .expect("serve fake Bedrock endpoint");
487 });
488
489 Self { url, request }
490 }
491 }
492
493 async fn capture_bedrock_request(
494 State(state): State<FakeBedrockState>,
495 request: Request<Body>,
496 ) -> impl IntoResponse {
497 let (parts, _) = request.into_parts();
498 if let Some(tx) = state.request_tx.lock().await.take() {
499 let _ = tx.send(CapturedRequest {
500 method: parts.method,
501 path: parts.uri.path().to_string(),
502 headers: parts.headers,
503 });
504 }
505 if let Some(tx) = state.shutdown_tx.lock().await.take() {
506 let _ = tx.send(());
507 }
508 (StatusCode::FORBIDDEN, "{}")
509 }
510}