Skip to main content

llm/providers/bedrock/
provider.rs

1use super::mappers::{default_cache_point, map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, ProviderAuthMode, ProviderConnectionConfig, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20/// AWS credentials for explicit authentication with Bedrock.
21#[derive(Clone)]
22pub struct AwsCredentials {
23    pub access_key_id: String,
24    pub secret_access_key: String,
25    pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30    client: Client,
31    model: String,
32    inference_profile_arn: Option<String>,
33    max_tokens: i32,
34    temperature: Option<f32>,
35}
36
37impl BedrockProvider {
38    /// Create a provider using the default AWS credential chain
39    /// (env vars, `~/.aws/credentials`, IAM roles, SSO).
40    pub async fn new() -> Self {
41        Self::new_with_connection(ProviderConnectionConfig::default()).await
42    }
43
44    pub async fn new_with_connection(connection: ProviderConnectionConfig) -> Self {
45        let client = if connection.auth_mode == ProviderAuthMode::None {
46            build_no_auth_client(connection.base_url.as_deref(), region_from_env().as_deref())
47        } else {
48            let mut loader = aws_config::defaults(BehaviorVersion::latest());
49            if let Some(url) = &connection.base_url {
50                loader = loader.endpoint_url(url.clone());
51            }
52            let config = loader.load().await;
53            Client::new(&config)
54        };
55
56        Self {
57            client,
58            model: DEFAULT_MODEL.to_string(),
59            inference_profile_arn: connection.inference_profile_arn,
60            max_tokens: DEFAULT_MAX_TOKENS,
61            temperature: None,
62        }
63    }
64
65    /// Create a provider from explicit configuration without async credential discovery.
66    pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
67        let client = build_client(credentials, region);
68
69        Self {
70            client,
71            model: DEFAULT_MODEL.to_string(),
72            inference_profile_arn: None,
73            max_tokens: DEFAULT_MAX_TOKENS,
74            temperature: None,
75        }
76    }
77
78    pub fn with_model(mut self, model: &str) -> Self {
79        self.model = model.to_string();
80        self
81    }
82
83    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
84        self.max_tokens = max_tokens;
85        self
86    }
87
88    pub fn with_temperature(mut self, temperature: f32) -> Self {
89        self.temperature = Some(temperature);
90        self
91    }
92
93    pub fn with_inference_profile_arn(mut self, arn: impl Into<String>) -> Self {
94        self.inference_profile_arn = Some(arn.into());
95        self
96    }
97
98    fn request_model_id(&self) -> &str {
99        self.inference_profile_arn.as_deref().unwrap_or(&self.model)
100    }
101
102    async fn send_converse_stream(
103        &self,
104        context: &Context,
105    ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
106        let cache_point =
107            self.model().is_some_and(|m| m.supports_prompt_caching()).then(default_cache_point).transpose()?;
108        let (system_blocks, messages) = map_messages(context.messages(), cache_point.as_ref())?;
109        let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
110
111        if let Some(temp) = self.temperature {
112            inference_config = inference_config.temperature(temp);
113        }
114
115        let inference_config = inference_config.build();
116
117        let mut request = self
118            .client
119            .converse_stream()
120            .model_id(self.request_model_id())
121            .set_messages(Some(messages))
122            .inference_config(inference_config);
123
124        if !system_blocks.is_empty() {
125            request = request.set_system(Some(system_blocks));
126        }
127
128        if !context.tools().is_empty() {
129            let tool_config = map_tools(context.tools(), cache_point.as_ref())?;
130            request = request.tool_config(tool_config);
131        }
132
133        if let Some(arn) = self.inference_profile_arn.as_deref() {
134            info!(model = %self.model, inference_profile_arn = %arn, "Sending Bedrock converse_stream request");
135        } else {
136            info!(model = %self.model, "Sending Bedrock converse_stream request");
137        }
138
139        let response = request.send().await.map_err(|e| {
140            error!(model = %self.model, error = ?e, "Bedrock API error");
141            LlmError::from(e)
142        })?;
143
144        Ok(response.stream)
145    }
146}
147
148impl ProviderFactory for BedrockProvider {
149    async fn from_env() -> Result<Self> {
150        Ok(Self::new().await)
151    }
152
153    async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
154        Ok(Self::new_with_connection(connection).await)
155    }
156
157    fn with_model(self, model: &str) -> Self {
158        self.with_model(model)
159    }
160}
161
162impl StreamingModelProvider for BedrockProvider {
163    fn model(&self) -> Option<crate::LlmModel> {
164        format!("bedrock:{}", self.model).parse().ok()
165    }
166
167    fn context_window(&self) -> Option<u32> {
168        get_context_window("bedrock", &self.model)
169    }
170
171    fn stream_response(&self, context: &Context) -> LlmResponseStream {
172        let provider = self.clone();
173        let context = context.clone();
174
175        Box::pin(async_stream::stream! {
176            match provider.send_converse_stream(&context).await {
177                Ok(receiver) => {
178                    let mut stream = Box::pin(process_bedrock_stream(receiver));
179                    while let Some(result) = stream.next().await {
180                        yield result;
181                    }
182                }
183                Err(e) => {
184                    yield Err(e);
185                }
186            }
187        })
188    }
189
190    fn display_name(&self) -> String {
191        format!("Bedrock ({})", self.model)
192    }
193}
194
195impl From<SdkError<ConverseStreamError>> for LlmError {
196    fn from(e: SdkError<ConverseStreamError>) -> Self {
197        let message = format!("Bedrock API error: {e}");
198        match e {
199            SdkError::TimeoutError(_) => LlmError::Timeout(message),
200            SdkError::DispatchFailure(_) => LlmError::Network(message),
201            SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
202            SdkError::ServiceError(svc) => {
203                let inner = svc.err();
204                if inner.is_throttling_exception() {
205                    LlmError::RateLimited(message)
206                } else if inner.is_service_unavailable_exception()
207                    || inner.is_internal_server_exception()
208                    || inner.is_model_stream_error_exception()
209                {
210                    LlmError::ServerError { status: None, message }
211                } else {
212                    LlmError::ApiError(message)
213                }
214            }
215            _ => LlmError::ApiError(message),
216        }
217    }
218}
219
220fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
221    let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
222
223    if let Some(creds) = credentials {
224        config = config.credentials_provider(Credentials::new(
225            creds.access_key_id,
226            creds.secret_access_key,
227            creds.session_token,
228            None,
229            "aether-bedrock-provider",
230        ));
231    }
232
233    config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
234
235    Client::from_conf(config.build())
236}
237
238fn build_no_auth_client(base_url: Option<&str>, region: Option<&str>) -> Client {
239    let mut config = Config::builder()
240        .behavior_version(BehaviorVersion::latest())
241        .allow_no_auth()
242        .region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
243
244    if let Some(url) = base_url {
245        config = config.endpoint_url(url);
246    }
247
248    Client::from_conf(config.build())
249}
250
251fn region_from_env() -> Option<String> {
252    ["AWS_REGION", "AWS_DEFAULT_REGION"].into_iter().find_map(|name| match std::env::var(name) {
253        Ok(value) if !value.is_empty() => Some(value),
254        _ => None,
255    })
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::types::IsoString;
262    use crate::{ChatMessage, ContentBlock};
263    use axum::Router;
264    use axum::body::Body;
265    use axum::extract::State;
266    use axum::http::{HeaderMap, Method, Request, StatusCode};
267    use axum::response::IntoResponse;
268    use axum::routing::any;
269    use std::sync::Arc;
270    use tokio::net::TcpListener;
271    use tokio::sync::{Mutex, oneshot};
272
273    fn inference_profile_arn(model: &str) -> String {
274        format!("arn:aws:bedrock:us-west-2:000000000000:inference-profile/{model}")
275    }
276
277    fn application_inference_profile_arn() -> &'static str {
278        "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000"
279    }
280
281    fn test_provider() -> BedrockProvider {
282        BedrockProvider::from_config(None, None)
283    }
284
285    #[test]
286    fn test_display_name() {
287        assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
288    }
289
290    #[test]
291    fn test_with_model() {
292        let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
293        assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
294    }
295
296    #[test]
297    fn test_with_max_tokens() {
298        let provider = test_provider().with_max_tokens(8192);
299        assert_eq!(provider.max_tokens, 8192);
300    }
301
302    #[test]
303    fn test_with_temperature() {
304        let provider = test_provider().with_temperature(0.7);
305        assert_eq!(provider.temperature, Some(0.7));
306    }
307
308    #[test]
309    fn test_default_values() {
310        let provider = test_provider();
311        assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
312        assert_eq!(provider.max_tokens, 16_384);
313        assert!(provider.temperature.is_none());
314    }
315
316    #[tokio::test]
317    async fn auth_none_sends_unsigned_request_to_custom_endpoint() {
318        let endpoint = FakeBedrockEndpoint::start().await;
319        let provider = BedrockProvider::new_with_connection(ProviderConnectionConfig {
320            base_url: Some(endpoint.url.clone()),
321            auth_mode: ProviderAuthMode::None,
322            ..Default::default()
323        })
324        .await;
325
326        let context = Context::new(
327            vec![ChatMessage::User { content: vec![ContentBlock::text("hello")], timestamp: IsoString::now() }],
328            vec![],
329        );
330
331        let result = provider.send_converse_stream(&context).await;
332        let request = endpoint.request.await.expect("fake Bedrock endpoint received no request");
333
334        assert!(result.is_err());
335        assert_eq!(request.method, Method::POST);
336        assert!(request.path.starts_with("/model/"), "{}", request.path);
337        assert!(!request.headers.contains_key("authorization"), "request was signed: {:?}", request.headers);
338        assert!(
339            !request.headers.contains_key("x-amz-security-token"),
340            "request included session token: {:?}",
341            request.headers
342        );
343    }
344
345    #[test]
346    fn test_from_config_with_credentials() {
347        let credentials = AwsCredentials {
348            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
349            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
350            session_token: None,
351        };
352
353        let provider = BedrockProvider::from_config(Some(credentials), None);
354        assert_eq!(provider.model, DEFAULT_MODEL);
355    }
356
357    #[test]
358    fn test_from_config_with_credentials_and_region() {
359        let credentials = AwsCredentials {
360            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
361            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
362            session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
363        };
364
365        let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
366            .with_model("anthropic.claude-opus-4-20250514-v1:0")
367            .with_max_tokens(4096)
368            .with_temperature(0.5);
369
370        assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
371        assert_eq!(provider.max_tokens, 4096);
372        assert_eq!(provider.temperature, Some(0.5));
373    }
374
375    #[test]
376    fn test_from_config_with_region_only() {
377        let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
378        assert_eq!(provider.model, DEFAULT_MODEL);
379    }
380
381    #[test]
382    fn catalog_foundation_id_resolves_context_window() {
383        let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
384        assert!(provider.context_window().is_some());
385        assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
386    }
387
388    #[test]
389    fn cross_region_profile_id_in_catalog_resolves() {
390        let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
391        assert!(provider.context_window().is_some());
392    }
393
394    #[test]
395    fn unknown_cross_region_profile_id_falls_through_to_profile() {
396        let id = "us.anthropic.claude-future-model-v99:0";
397        let provider = test_provider().with_model(id);
398        assert_eq!(provider.context_window(), None);
399        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
400        assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
401    }
402
403    #[tokio::test]
404    async fn separate_inference_profile_arn_is_used_as_request_model_id() {
405        let endpoint = FakeBedrockEndpoint::start().await;
406        let provider = BedrockProvider::new_with_connection(ProviderConnectionConfig {
407            base_url: Some(endpoint.url.clone()),
408            auth_mode: ProviderAuthMode::None,
409            inference_profile_arn: Some(application_inference_profile_arn().to_string()),
410        })
411        .await
412        .with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
413        let context = Context::new(
414            vec![ChatMessage::User { content: vec![ContentBlock::text("hello")], timestamp: IsoString::now() }],
415            vec![],
416        );
417
418        let result = provider.send_converse_stream(&context).await;
419        let request = endpoint.request.await.expect("fake Bedrock endpoint received no request");
420
421        assert!(result.is_err());
422        assert!(
423            request.path.contains("arn%3Aaws%3Abedrock%3Aus-west-2%3A000000000000%3Aapplication-inference-profile"),
424            "{}",
425            request.path
426        );
427        assert_eq!(provider.context_window(), Some(200_000));
428        assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
429    }
430
431    #[test]
432    fn with_inference_profile_arn_keeps_canonical_model_identity() {
433        let arn = inference_profile_arn("us.anthropic.claude-sonnet-4-5-20250929-v1:0");
434        let provider =
435            test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0").with_inference_profile_arn(&arn);
436
437        assert_eq!(provider.request_model_id(), arn);
438        assert_eq!(provider.context_window(), Some(200_000));
439        assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
440    }
441
442    #[test]
443    fn prompt_caching_support_comes_from_canonical_model() {
444        let cached = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
445        assert!(cached.model().unwrap().supports_prompt_caching());
446
447        let unknown_profile = test_provider().with_model("us.anthropic.claude-future-model-v99:0");
448        assert!(!unknown_profile.model().unwrap().supports_prompt_caching());
449    }
450
451    struct FakeBedrockEndpoint {
452        url: String,
453        request: oneshot::Receiver<CapturedRequest>,
454    }
455
456    struct CapturedRequest {
457        method: Method,
458        path: String,
459        headers: HeaderMap,
460    }
461
462    #[derive(Clone)]
463    struct FakeBedrockState {
464        request_tx: Arc<Mutex<Option<oneshot::Sender<CapturedRequest>>>>,
465        shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
466    }
467
468    impl FakeBedrockEndpoint {
469        async fn start() -> Self {
470            let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind fake Bedrock endpoint");
471            let url = format!("http://{}", listener.local_addr().expect("fake Bedrock endpoint address"));
472            let (request_tx, request) = oneshot::channel();
473            let (shutdown_tx, shutdown) = oneshot::channel();
474            let state = FakeBedrockState {
475                request_tx: Arc::new(Mutex::new(Some(request_tx))),
476                shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
477            };
478
479            let app = Router::new().fallback(any(capture_bedrock_request)).with_state(state);
480            tokio::spawn(async move {
481                axum::serve(listener, app)
482                    .with_graceful_shutdown(async {
483                        let _ = shutdown.await;
484                    })
485                    .await
486                    .expect("serve fake Bedrock endpoint");
487            });
488
489            Self { url, request }
490        }
491    }
492
493    async fn capture_bedrock_request(
494        State(state): State<FakeBedrockState>,
495        request: Request<Body>,
496    ) -> impl IntoResponse {
497        let (parts, _) = request.into_parts();
498        if let Some(tx) = state.request_tx.lock().await.take() {
499            let _ = tx.send(CapturedRequest {
500                method: parts.method,
501                path: parts.uri.path().to_string(),
502                headers: parts.headers,
503            });
504        }
505        if let Some(tx) = state.shutdown_tx.lock().await.take() {
506            let _ = tx.send(());
507        }
508        (StatusCode::FORBIDDEN, "{}")
509    }
510}