Skip to main content

llm/providers/bedrock/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, ProviderAuthMode, ProviderConnectionConfig, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20/// AWS credentials for explicit authentication with Bedrock.
21#[derive(Clone)]
22pub struct AwsCredentials {
23    pub access_key_id: String,
24    pub secret_access_key: String,
25    pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30    client: Client,
31    model: String,
32    max_tokens: i32,
33    temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37    /// Create a provider using the default AWS credential chain
38    /// (env vars, `~/.aws/credentials`, IAM roles, SSO).
39    pub async fn new() -> Self {
40        Self::new_with_connection(ProviderConnectionConfig::default()).await
41    }
42
43    pub async fn new_with_connection(connection: ProviderConnectionConfig) -> Self {
44        if connection == ProviderConnectionConfig::default() {
45            let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
46            let client = Client::new(&config);
47            return Self {
48                client,
49                model: DEFAULT_MODEL.to_string(),
50                max_tokens: DEFAULT_MAX_TOKENS,
51                temperature: None,
52            };
53        }
54
55        let mut loader = aws_config::defaults(BehaviorVersion::latest());
56        if connection.auth_mode == ProviderAuthMode::None {
57            loader = loader.no_credentials();
58        }
59
60        if let Some(url) = &connection.base_url {
61            loader = loader.endpoint_url(url.clone());
62        }
63
64        let sdk_config = loader.load().await;
65        let mut builder = aws_sdk_bedrockruntime::config::Builder::from(&sdk_config);
66        if connection.auth_mode == ProviderAuthMode::None {
67            builder = builder.allow_no_auth();
68        }
69
70        let client = Client::from_conf(builder.build());
71
72        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
73    }
74
75    /// Create a provider from explicit configuration without async credential discovery.
76    pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
77        let client = build_client(credentials, region);
78
79        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
80    }
81
82    pub fn with_model(mut self, model: &str) -> Self {
83        self.model = model.to_string();
84        self
85    }
86
87    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
88        self.max_tokens = max_tokens;
89        self
90    }
91
92    pub fn with_temperature(mut self, temperature: f32) -> Self {
93        self.temperature = Some(temperature);
94        self
95    }
96
97    async fn send_converse_stream(
98        &self,
99        context: &Context,
100    ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
101        let (system_blocks, messages) = map_messages(context.messages())?;
102        let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
103
104        if let Some(temp) = self.temperature {
105            inference_config = inference_config.temperature(temp);
106        }
107
108        let inference_config = inference_config.build();
109
110        let mut request = self
111            .client
112            .converse_stream()
113            .model_id(&self.model)
114            .set_messages(Some(messages))
115            .inference_config(inference_config);
116
117        if !system_blocks.is_empty() {
118            request = request.set_system(Some(system_blocks));
119        }
120
121        if !context.tools().is_empty() {
122            let tool_config = map_tools(context.tools())?;
123            request = request.tool_config(tool_config);
124        }
125
126        info!(model = %self.model, "Sending Bedrock converse_stream request");
127
128        let response = request.send().await.map_err(|e| {
129            error!(model = %self.model, error = ?e, "Bedrock API error");
130            LlmError::from(e)
131        })?;
132
133        Ok(response.stream)
134    }
135}
136
137impl ProviderFactory for BedrockProvider {
138    async fn from_env() -> Result<Self> {
139        Ok(Self::new().await)
140    }
141
142    async fn from_env_with_connection(connection: ProviderConnectionConfig) -> Result<Self> {
143        Ok(Self::new_with_connection(connection).await)
144    }
145
146    fn with_model(self, model: &str) -> Self {
147        self.with_model(model)
148    }
149}
150
151impl StreamingModelProvider for BedrockProvider {
152    fn model(&self) -> Option<crate::LlmModel> {
153        format!("bedrock:{}", self.model).parse().ok()
154    }
155
156    fn context_window(&self) -> Option<u32> {
157        get_context_window("bedrock", &self.model)
158    }
159
160    fn stream_response(&self, context: &Context) -> LlmResponseStream {
161        let provider = self.clone();
162        let context = context.clone();
163
164        Box::pin(async_stream::stream! {
165            match provider.send_converse_stream(&context).await {
166                Ok(receiver) => {
167                    let mut stream = Box::pin(process_bedrock_stream(receiver));
168                    while let Some(result) = stream.next().await {
169                        yield result;
170                    }
171                }
172                Err(e) => {
173                    yield Err(e);
174                }
175            }
176        })
177    }
178
179    fn display_name(&self) -> String {
180        format!("Bedrock ({})", self.model)
181    }
182}
183
184impl From<SdkError<ConverseStreamError>> for LlmError {
185    fn from(e: SdkError<ConverseStreamError>) -> Self {
186        let message = format!("Bedrock API error: {e}");
187        match e {
188            SdkError::TimeoutError(_) => LlmError::Timeout(message),
189            SdkError::DispatchFailure(_) => LlmError::Network(message),
190            SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
191            SdkError::ServiceError(svc) => {
192                let inner = svc.err();
193                if inner.is_throttling_exception() {
194                    LlmError::RateLimited(message)
195                } else if inner.is_service_unavailable_exception()
196                    || inner.is_internal_server_exception()
197                    || inner.is_model_stream_error_exception()
198                {
199                    LlmError::ServerError { status: None, message }
200                } else {
201                    LlmError::ApiError(message)
202                }
203            }
204            _ => LlmError::ApiError(message),
205        }
206    }
207}
208
209fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
210    let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
211
212    if let Some(creds) = credentials {
213        config = config.credentials_provider(Credentials::new(
214            creds.access_key_id,
215            creds.secret_access_key,
216            creds.session_token,
217            None,
218            "aether-bedrock-provider",
219        ));
220    }
221
222    config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
223
224    Client::from_conf(config.build())
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    fn test_provider() -> BedrockProvider {
232        BedrockProvider::from_config(None, None)
233    }
234
235    #[test]
236    fn test_display_name() {
237        assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
238    }
239
240    #[test]
241    fn test_with_model() {
242        let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
243        assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
244    }
245
246    #[test]
247    fn test_with_max_tokens() {
248        let provider = test_provider().with_max_tokens(8192);
249        assert_eq!(provider.max_tokens, 8192);
250    }
251
252    #[test]
253    fn test_with_temperature() {
254        let provider = test_provider().with_temperature(0.7);
255        assert_eq!(provider.temperature, Some(0.7));
256    }
257
258    #[test]
259    fn test_default_values() {
260        let provider = test_provider();
261        assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
262        assert_eq!(provider.max_tokens, 16_384);
263        assert!(provider.temperature.is_none());
264    }
265
266    #[test]
267    fn test_from_config_with_credentials() {
268        let credentials = AwsCredentials {
269            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
270            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
271            session_token: None,
272        };
273
274        let provider = BedrockProvider::from_config(Some(credentials), None);
275        assert_eq!(provider.model, DEFAULT_MODEL);
276    }
277
278    #[test]
279    fn test_from_config_with_credentials_and_region() {
280        let credentials = AwsCredentials {
281            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
282            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
283            session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
284        };
285
286        let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
287            .with_model("anthropic.claude-opus-4-20250514-v1:0")
288            .with_max_tokens(4096)
289            .with_temperature(0.5);
290
291        assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
292        assert_eq!(provider.max_tokens, 4096);
293        assert_eq!(provider.temperature, Some(0.5));
294    }
295
296    #[test]
297    fn test_from_config_with_region_only() {
298        let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
299        assert_eq!(provider.model, DEFAULT_MODEL);
300    }
301
302    #[test]
303    fn catalog_foundation_id_resolves_context_window() {
304        let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
305        assert!(provider.context_window().is_some());
306        assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
307    }
308
309    #[test]
310    fn cross_region_profile_id_in_catalog_resolves() {
311        let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
312        assert!(provider.context_window().is_some());
313    }
314
315    #[test]
316    fn unknown_cross_region_profile_id_falls_through_to_profile() {
317        let id = "us.anthropic.claude-future-model-v99:0";
318        let provider = test_provider().with_model(id);
319        assert_eq!(provider.context_window(), None);
320        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
321        assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
322    }
323
324    #[test]
325    fn inference_profile_arn_is_passed_through_as_profile() {
326        let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
327        let provider = test_provider().with_model(arn);
328        assert_eq!(provider.context_window(), None);
329        assert_eq!(provider.model, arn);
330        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{arn}"));
331    }
332
333    #[test]
334    fn application_inference_profile_arn_is_passed_through_as_profile() {
335        let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
336        let provider = test_provider().with_model(arn);
337        assert_eq!(provider.context_window(), None);
338        assert_eq!(provider.model, arn);
339        assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
340    }
341}