Skip to main content

llm/providers/bedrock/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, ProviderAuthMode, ProviderConnectionConfig, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20/// AWS credentials for explicit authentication with Bedrock.
21#[derive(Clone)]
22pub struct AwsCredentials {
23    pub access_key_id: String,
24    pub secret_access_key: String,
25    pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30    client: Client,
31    model: String,
32    max_tokens: i32,
33    temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37    /// Create a provider using the default AWS credential chain
38    /// (env vars, `~/.aws/credentials`, IAM roles, SSO).
39    pub async fn new() -> Self {
40        Self::new_with_connection(ProviderConnectionConfig::default()).await
41    }
42
43    pub async fn new_with_connection(connection: ProviderConnectionConfig) -> Self {
44        let client = if connection.auth_mode == ProviderAuthMode::None {
45            build_no_auth_client(connection.base_url.as_deref(), region_from_env().as_deref())
46        } else {
47            let mut loader = aws_config::defaults(BehaviorVersion::latest());
48            if let Some(url) = &connection.base_url {
49                loader = loader.endpoint_url(url.clone());
50            }
51            let config = loader.load().await;
52            Client::new(&config)
53        };
54
55        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
56    }
57
58    /// Create a provider from explicit configuration without async credential discovery.
59    pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
60        let client = build_client(credentials, region);
61
62        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
63    }
64
65    pub fn with_model(mut self, model: &str) -> Self {
66        self.model = model.to_string();
67        self
68    }
69
70    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
71        self.max_tokens = max_tokens;
72        self
73    }
74
75    pub fn with_temperature(mut self, temperature: f32) -> Self {
76        self.temperature = Some(temperature);
77        self
78    }
79
80    async fn send_converse_stream(
81        &self,
82        context: &Context,
83    ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
84        let (system_blocks, messages) = map_messages(context.messages())?;
85        let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
86
87        if let Some(temp) = self.temperature {
88            inference_config = inference_config.temperature(temp);
89        }
90
91        let inference_config = inference_config.build();
92
93        let mut request = self
94            .client
95            .converse_stream()
96            .model_id(&self.model)
97            .set_messages(Some(messages))
98            .inference_config(inference_config);
99
100        if !system_blocks.is_empty() {
101            request = request.set_system(Some(system_blocks));
102        }
103
104        if !context.tools().is_empty() {
105            let tool_config = map_tools(context.tools())?;
106            request = request.tool_config(tool_config);
107        }
108
109        info!(model = %self.model, "Sending Bedrock converse_stream request");
110
111        let response = request.send().await.map_err(|e| {
112            error!(model = %self.model, error = ?e, "Bedrock API error");
113            LlmError::from(e)
114        })?;
115
116        Ok(response.stream)
117    }
118}
119
120impl ProviderFactory for BedrockProvider {
121    async fn from_env() -> Result<Self> {
122        Ok(Self::new().await)
123    }
124
125    async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
126        Ok(Self::new_with_connection(connection).await)
127    }
128
129    fn with_model(self, model: &str) -> Self {
130        self.with_model(model)
131    }
132}
133
134impl StreamingModelProvider for BedrockProvider {
135    fn model(&self) -> Option<crate::LlmModel> {
136        format!("bedrock:{}", self.model).parse().ok()
137    }
138
139    fn context_window(&self) -> Option<u32> {
140        get_context_window("bedrock", &self.model)
141    }
142
143    fn stream_response(&self, context: &Context) -> LlmResponseStream {
144        let provider = self.clone();
145        let context = context.clone();
146
147        Box::pin(async_stream::stream! {
148            match provider.send_converse_stream(&context).await {
149                Ok(receiver) => {
150                    let mut stream = Box::pin(process_bedrock_stream(receiver));
151                    while let Some(result) = stream.next().await {
152                        yield result;
153                    }
154                }
155                Err(e) => {
156                    yield Err(e);
157                }
158            }
159        })
160    }
161
162    fn display_name(&self) -> String {
163        format!("Bedrock ({})", self.model)
164    }
165}
166
167impl From<SdkError<ConverseStreamError>> for LlmError {
168    fn from(e: SdkError<ConverseStreamError>) -> Self {
169        let message = format!("Bedrock API error: {e}");
170        match e {
171            SdkError::TimeoutError(_) => LlmError::Timeout(message),
172            SdkError::DispatchFailure(_) => LlmError::Network(message),
173            SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
174            SdkError::ServiceError(svc) => {
175                let inner = svc.err();
176                if inner.is_throttling_exception() {
177                    LlmError::RateLimited(message)
178                } else if inner.is_service_unavailable_exception()
179                    || inner.is_internal_server_exception()
180                    || inner.is_model_stream_error_exception()
181                {
182                    LlmError::ServerError { status: None, message }
183                } else {
184                    LlmError::ApiError(message)
185                }
186            }
187            _ => LlmError::ApiError(message),
188        }
189    }
190}
191
192fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
193    let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
194
195    if let Some(creds) = credentials {
196        config = config.credentials_provider(Credentials::new(
197            creds.access_key_id,
198            creds.secret_access_key,
199            creds.session_token,
200            None,
201            "aether-bedrock-provider",
202        ));
203    }
204
205    config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
206
207    Client::from_conf(config.build())
208}
209
210fn build_no_auth_client(base_url: Option<&str>, region: Option<&str>) -> Client {
211    let mut config = Config::builder()
212        .behavior_version(BehaviorVersion::latest())
213        .allow_no_auth()
214        .region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
215
216    if let Some(url) = base_url {
217        config = config.endpoint_url(url);
218    }
219
220    Client::from_conf(config.build())
221}
222
223fn region_from_env() -> Option<String> {
224    ["AWS_REGION", "AWS_DEFAULT_REGION"].into_iter().find_map(|name| match std::env::var(name) {
225        Ok(value) if !value.is_empty() => Some(value),
226        _ => None,
227    })
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::types::IsoString;
234    use crate::{ChatMessage, ContentBlock};
235    use axum::Router;
236    use axum::body::Body;
237    use axum::extract::State;
238    use axum::http::{HeaderMap, Method, Request, StatusCode};
239    use axum::response::IntoResponse;
240    use axum::routing::any;
241    use std::sync::Arc;
242    use tokio::net::TcpListener;
243    use tokio::sync::{Mutex, oneshot};
244
245    fn test_provider() -> BedrockProvider {
246        BedrockProvider::from_config(None, None)
247    }
248
249    #[test]
250    fn test_display_name() {
251        assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
252    }
253
254    #[test]
255    fn test_with_model() {
256        let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
257        assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
258    }
259
260    #[test]
261    fn test_with_max_tokens() {
262        let provider = test_provider().with_max_tokens(8192);
263        assert_eq!(provider.max_tokens, 8192);
264    }
265
266    #[test]
267    fn test_with_temperature() {
268        let provider = test_provider().with_temperature(0.7);
269        assert_eq!(provider.temperature, Some(0.7));
270    }
271
272    #[test]
273    fn test_default_values() {
274        let provider = test_provider();
275        assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
276        assert_eq!(provider.max_tokens, 16_384);
277        assert!(provider.temperature.is_none());
278    }
279
280    #[tokio::test]
281    async fn auth_none_sends_unsigned_request_to_custom_endpoint() {
282        let endpoint = FakeBedrockEndpoint::start().await;
283        let provider = BedrockProvider::new_with_connection(ProviderConnectionConfig {
284            base_url: Some(endpoint.url.clone()),
285            auth_mode: ProviderAuthMode::None,
286        })
287        .await;
288
289        let context = Context::new(
290            vec![ChatMessage::User { content: vec![ContentBlock::text("hello")], timestamp: IsoString::now() }],
291            vec![],
292        );
293
294        let result = provider.send_converse_stream(&context).await;
295        let request = endpoint.request.await.expect("fake Bedrock endpoint received no request");
296
297        assert!(result.is_err());
298        assert_eq!(request.method, Method::POST);
299        assert!(request.path.starts_with("/model/"), "{}", request.path);
300        assert!(!request.headers.contains_key("authorization"), "request was signed: {:?}", request.headers);
301        assert!(
302            !request.headers.contains_key("x-amz-security-token"),
303            "request included session token: {:?}",
304            request.headers
305        );
306    }
307
308    #[test]
309    fn test_from_config_with_credentials() {
310        let credentials = AwsCredentials {
311            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
312            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
313            session_token: None,
314        };
315
316        let provider = BedrockProvider::from_config(Some(credentials), None);
317        assert_eq!(provider.model, DEFAULT_MODEL);
318    }
319
320    #[test]
321    fn test_from_config_with_credentials_and_region() {
322        let credentials = AwsCredentials {
323            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
324            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
325            session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
326        };
327
328        let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
329            .with_model("anthropic.claude-opus-4-20250514-v1:0")
330            .with_max_tokens(4096)
331            .with_temperature(0.5);
332
333        assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
334        assert_eq!(provider.max_tokens, 4096);
335        assert_eq!(provider.temperature, Some(0.5));
336    }
337
338    #[test]
339    fn test_from_config_with_region_only() {
340        let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
341        assert_eq!(provider.model, DEFAULT_MODEL);
342    }
343
344    #[test]
345    fn catalog_foundation_id_resolves_context_window() {
346        let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
347        assert!(provider.context_window().is_some());
348        assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
349    }
350
351    #[test]
352    fn cross_region_profile_id_in_catalog_resolves() {
353        let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
354        assert!(provider.context_window().is_some());
355    }
356
357    #[test]
358    fn unknown_cross_region_profile_id_falls_through_to_profile() {
359        let id = "us.anthropic.claude-future-model-v99:0";
360        let provider = test_provider().with_model(id);
361        assert_eq!(provider.context_window(), None);
362        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
363        assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
364    }
365
366    #[test]
367    fn inference_profile_arn_is_passed_through_as_profile() {
368        let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
369        let provider = test_provider().with_model(arn);
370        assert_eq!(provider.context_window(), None);
371        assert_eq!(provider.model, arn);
372        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{arn}"));
373    }
374
375    #[test]
376    fn application_inference_profile_arn_is_passed_through_as_profile() {
377        let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
378        let provider = test_provider().with_model(arn);
379        assert_eq!(provider.context_window(), None);
380        assert_eq!(provider.model, arn);
381        assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
382    }
383
384    struct FakeBedrockEndpoint {
385        url: String,
386        request: oneshot::Receiver<CapturedRequest>,
387    }
388
389    struct CapturedRequest {
390        method: Method,
391        path: String,
392        headers: HeaderMap,
393    }
394
395    #[derive(Clone)]
396    struct FakeBedrockState {
397        request_tx: Arc<Mutex<Option<oneshot::Sender<CapturedRequest>>>>,
398        shutdown_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
399    }
400
401    impl FakeBedrockEndpoint {
402        async fn start() -> Self {
403            let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind fake Bedrock endpoint");
404            let url = format!("http://{}", listener.local_addr().expect("fake Bedrock endpoint address"));
405            let (request_tx, request) = oneshot::channel();
406            let (shutdown_tx, shutdown) = oneshot::channel();
407            let state = FakeBedrockState {
408                request_tx: Arc::new(Mutex::new(Some(request_tx))),
409                shutdown_tx: Arc::new(Mutex::new(Some(shutdown_tx))),
410            };
411
412            let app = Router::new().fallback(any(capture_bedrock_request)).with_state(state);
413            tokio::spawn(async move {
414                axum::serve(listener, app)
415                    .with_graceful_shutdown(async {
416                        let _ = shutdown.await;
417                    })
418                    .await
419                    .expect("serve fake Bedrock endpoint");
420            });
421
422            Self { url, request }
423        }
424    }
425
426    async fn capture_bedrock_request(
427        State(state): State<FakeBedrockState>,
428        request: Request<Body>,
429    ) -> impl IntoResponse {
430        let (parts, _) = request.into_parts();
431        if let Some(tx) = state.request_tx.lock().await.take() {
432            let _ = tx.send(CapturedRequest {
433                method: parts.method,
434                path: parts.uri.path().to_string(),
435                headers: parts.headers,
436            });
437        }
438        if let Some(tx) = state.shutdown_tx.lock().await.take() {
439            let _ = tx.send(());
440        }
441        (StatusCode::FORBIDDEN, "{}")
442    }
443}