1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, ProviderAuthMode, ProviderConnectionConfig, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20#[derive(Clone)]
22pub struct AwsCredentials {
23 pub access_key_id: String,
24 pub secret_access_key: String,
25 pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30 client: Client,
31 model: String,
32 max_tokens: i32,
33 temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37 pub async fn new() -> Self {
40 Self::new_with_connection(ProviderConnectionConfig::default()).await
41 }
42
43 pub async fn new_with_connection(connection: ProviderConnectionConfig) -> Self {
44 let client = if connection.auth_mode == ProviderAuthMode::None {
45 build_no_auth_client(connection.base_url.as_deref(), region_from_env().as_deref())
46 } else {
47 let mut loader = aws_config::defaults(BehaviorVersion::latest());
48 if let Some(url) = &connection.base_url {
49 loader = loader.endpoint_url(url.clone());
50 }
51 let config = loader.load().await;
52 Client::new(&config)
53 };
54
55 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
56 }
57
58 pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
60 let client = build_client(credentials, region);
61
62 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
63 }
64
65 pub fn with_model(mut self, model: &str) -> Self {
66 self.model = model.to_string();
67 self
68 }
69
70 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
71 self.max_tokens = max_tokens;
72 self
73 }
74
75 pub fn with_temperature(mut self, temperature: f32) -> Self {
76 self.temperature = Some(temperature);
77 self
78 }
79
80 async fn send_converse_stream(
81 &self,
82 context: &Context,
83 ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
84 let (system_blocks, messages) = map_messages(context.messages())?;
85 let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
86
87 if let Some(temp) = self.temperature {
88 inference_config = inference_config.temperature(temp);
89 }
90
91 let inference_config = inference_config.build();
92
93 let mut request = self
94 .client
95 .converse_stream()
96 .model_id(&self.model)
97 .set_messages(Some(messages))
98 .inference_config(inference_config);
99
100 if !system_blocks.is_empty() {
101 request = request.set_system(Some(system_blocks));
102 }
103
104 if !context.tools().is_empty() {
105 let tool_config = map_tools(context.tools())?;
106 request = request.tool_config(tool_config);
107 }
108
109 info!(model = %self.model, "Sending Bedrock converse_stream request");
110
111 let response = request.send().await.map_err(|e| {
112 error!(model = %self.model, error = ?e, "Bedrock API error");
113 LlmError::from(e)
114 })?;
115
116 Ok(response.stream)
117 }
118}
119
120impl ProviderFactory for BedrockProvider {
121 async fn from_env() -> Result<Self> {
122 Ok(Self::new().await)
123 }
124
125 async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
126 Ok(Self::new_with_connection(connection).await)
127 }
128
129 fn with_model(self, model: &str) -> Self {
130 self.with_model(model)
131 }
132}
133
134impl StreamingModelProvider for BedrockProvider {
135 fn model(&self) -> Option<crate::LlmModel> {
136 format!("bedrock:{}", self.model).parse().ok()
137 }
138
139 fn context_window(&self) -> Option<u32> {
140 get_context_window("bedrock", &self.model)
141 }
142
143 fn stream_response(&self, context: &Context) -> LlmResponseStream {
144 let provider = self.clone();
145 let context = context.clone();
146
147 Box::pin(async_stream::stream! {
148 match provider.send_converse_stream(&context).await {
149 Ok(receiver) => {
150 let mut stream = Box::pin(process_bedrock_stream(receiver));
151 while let Some(result) = stream.next().await {
152 yield result;
153 }
154 }
155 Err(e) => {
156 yield Err(e);
157 }
158 }
159 })
160 }
161
162 fn display_name(&self) -> String {
163 format!("Bedrock ({})", self.model)
164 }
165}
166
167impl From<SdkError<ConverseStreamError>> for LlmError {
168 fn from(e: SdkError<ConverseStreamError>) -> Self {
169 let message = format!("Bedrock API error: {e}");
170 match e {
171 SdkError::TimeoutError(_) => LlmError::Timeout(message),
172 SdkError::DispatchFailure(_) => LlmError::Network(message),
173 SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
174 SdkError::ServiceError(svc) => {
175 let inner = svc.err();
176 if inner.is_throttling_exception() {
177 LlmError::RateLimited(message)
178 } else if inner.is_service_unavailable_exception()
179 || inner.is_internal_server_exception()
180 || inner.is_model_stream_error_exception()
181 {
182 LlmError::ServerError { status: None, message }
183 } else {
184 LlmError::ApiError(message)
185 }
186 }
187 _ => LlmError::ApiError(message),
188 }
189 }
190}
191
192fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
193 let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
194
195 if let Some(creds) = credentials {
196 config = config.credentials_provider(Credentials::new(
197 creds.access_key_id,
198 creds.secret_access_key,
199 creds.session_token,
200 None,
201 "aether-bedrock-provider",
202 ));
203 }
204
205 config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
206
207 Client::from_conf(config.build())
208}
209
210fn build_no_auth_client(base_url: Option<&str>, region: Option<&str>) -> Client {
211 let mut config = Config::builder()
212 .behavior_version(BehaviorVersion::latest())
213 .allow_no_auth()
214 .region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
215
216 if let Some(url) = base_url {
217 config = config.endpoint_url(url);
218 }
219
220 Client::from_conf(config.build())
221}
222
223fn region_from_env() -> Option<String> {
224 ["AWS_REGION", "AWS_DEFAULT_REGION"].into_iter().find_map(|name| match std::env::var(name) {
225 Ok(value) if !value.is_empty() => Some(value),
226 _ => None,
227 })
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::types::IsoString;
234 use crate::{ChatMessage, ContentBlock};
235 use axum::Router;
236 use axum::body::Body;
237 use axum::extract::State;
238 use axum::http::{HeaderMap, Method, Request, StatusCode};
239 use axum::response::IntoResponse;
240 use axum::routing::any;
241 use std::sync::Arc;
242 use tokio::net::TcpListener;
243 use tokio::sync::{Mutex, oneshot};
244
245 fn test_provider() -> BedrockProvider {
246 BedrockProvider::from_config(None, None)
247 }
248
249 #[test]
250 fn test_display_name() {
251 assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
252 }
253
254 #[test]
255 fn test_with_model() {
256 let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
257 assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
258 }
259
260 #[test]
261 fn test_with_max_tokens() {
262 let provider = test_provider().with_max_tokens(8192);
263 assert_eq!(provider.max_tokens, 8192);
264 }
265
266 #[test]
267 fn test_with_temperature() {
268 let provider = test_provider().with_temperature(0.7);
269 assert_eq!(provider.temperature, Some(0.7));
270 }
271
272 #[test]
273 fn test_default_values() {
274 let provider = test_provider();
275 assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
276 assert_eq!(provider.max_tokens, 16_384);
277 assert!(provider.temperature.is_none());
278 }
279
280 #[tokio::test]
281 async fn auth_none_sends_unsigned_request_to_custom_endpoint() {
282 let endpoint = FakeBedrockEndpoint::start().await;
283 let provider = BedrockProvider::new_with_connection(ProviderConnectionConfig {
284 base_url: Some(endpoint.url.clone()),
285 auth_mode: ProviderAuthMode::None,
286 })
287 .await;
288
289 let context = Context::new(
290 vec![ChatMessage::User { content: vec![ContentBlock::text("hello")], timestamp: IsoString::now() }],
291 vec![],
292 );
293
294 let result = provider.send_converse_stream(&context).await;
295 let request = endpoint.request.await.expect("fake Bedrock endpoint received no request");
296
297 assert!(result.is_err());
298 assert_eq!(request.method, Method::POST);
299 assert!(request.path.starts_with("/model/"), "{}", request.path);
300 assert!(!request.headers.contains_key("authorization"), "request was signed: {:?}", request.headers);
301 assert!(
302 !request.headers.contains_key("x-amz-security-token"),
303 "request included session token: {:?}",
304 request.headers
305 );
306 }
307
308 #[test]
309 fn test_from_config_with_credentials() {
310 let credentials = AwsCredentials {
311 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
312 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
313 session_token: None,
314 };
315
316 let provider = BedrockProvider::from_config(Some(credentials), None);
317 assert_eq!(provider.model, DEFAULT_MODEL);
318 }
319
320 #[test]
321 fn test_from_config_with_credentials_and_region() {
322 let credentials = AwsCredentials {
323 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
324 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
325 session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
326 };
327
328 let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
329 .with_model("anthropic.claude-opus-4-20250514-v1:0")
330 .with_max_tokens(4096)
331 .with_temperature(0.5);
332
333 assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
334 assert_eq!(provider.max_tokens, 4096);
335 assert_eq!(provider.temperature, Some(0.5));
336 }
337
338 #[test]
339 fn test_from_config_with_region_only() {
340 let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
341 assert_eq!(provider.model, DEFAULT_MODEL);
342 }
343
344 #[test]
345 fn catalog_foundation_id_resolves_context_window() {
346 let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
347 assert!(provider.context_window().is_some());
348 assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
349 }
350
351 #[test]
352 fn cross_region_profile_id_in_catalog_resolves() {
353 let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
354 assert!(provider.context_window().is_some());
355 }
356
357 #[test]
358 fn unknown_cross_region_profile_id_falls_through_to_profile() {
359 let id = "us.anthropic.claude-future-model-v99:0";
360 let provider = test_provider().with_model(id);
361 assert_eq!(provider.context_window(), None);
362 assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
363 assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
364 }
365
366 #[test]
367 fn inference_profile_arn_is_passed_through_as_profile() {
368 let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
369 let provider = test_provider().with_model(arn);
370 assert_eq!(provider.context_window(), None);
371 assert_eq!(provider.model, arn);
372 assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{arn}"));
373 }
374
375 #[test]
376 fn application_inference_profile_arn_is_passed_through_as_profile() {
377 let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
378 let provider = test_provider().with_model(arn);
379 assert_eq!(provider.context_window(), None);
380 assert_eq!(provider.model, arn);
381 assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
382 }
383
384 struct FakeBedrockEndpoint {
385 url: String,
386 request: oneshot::Receiver<CapturedRequest>,
387 }
388
389 struct CapturedRequest {
390 method: Method,
391 path: String,
392 headers: HeaderMap,
393 }
394
395 #[derive(Clone)]
396 struct FakeBedrockState {
397 request_tx: Arc<Mutex<Option<oneshot::Sender<CapturedRequest>>>>,
398 shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
399 }
400
401 impl FakeBedrockEndpoint {
402 async fn start() -> Self {
403 let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind fake Bedrock endpoint");
404 let url = format!("http://{}", listener.local_addr().expect("fake Bedrock endpoint address"));
405 let (request_tx, request) = oneshot::channel();
406 let (shutdown_tx, shutdown) = oneshot::channel();
407 let state = FakeBedrockState {
408 request_tx: Arc::new(Mutex::new(Some(request_tx))),
409 shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
410 };
411
412 let app = Router::new().fallback(any(capture_bedrock_request)).with_state(state);
413 tokio::spawn(async move {
414 axum::serve(listener, app)
415 .with_graceful_shutdown(async {
416 let _ = shutdown.await;
417 })
418 .await
419 .expect("serve fake Bedrock endpoint");
420 });
421
422 Self { url, request }
423 }
424 }
425
426 async fn capture_bedrock_request(
427 State(state): State<FakeBedrockState>,
428 request: Request<Body>,
429 ) -> impl IntoResponse {
430 let (parts, _) = request.into_parts();
431 if let Some(tx) = state.request_tx.lock().await.take() {
432 let _ = tx.send(CapturedRequest {
433 method: parts.method,
434 path: parts.uri.path().to_string(),
435 headers: parts.headers,
436 });
437 }
438 if let Some(tx) = state.shutdown_tx.lock().await.take() {
439 let _ = tx.send(());
440 }
441 (StatusCode::FORBIDDEN, "{}")
442 }
443}