Skip to main content

llm/providers/bedrock/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20/// AWS credentials for explicit authentication with Bedrock.
21#[derive(Clone)]
22pub struct AwsCredentials {
23    pub access_key_id: String,
24    pub secret_access_key: String,
25    pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30    client: Client,
31    model: String,
32    max_tokens: i32,
33    temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37    /// Create a provider using the default AWS credential chain
38    /// (env vars, `~/.aws/credentials`, IAM roles, SSO).
39    pub async fn new() -> Self {
40        let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
41        let client = Client::new(&config);
42
43        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
44    }
45
46    /// Create a provider from explicit configuration without async credential discovery.
47    pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
48        let client = build_client(credentials, region);
49
50        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
51    }
52
53    pub fn with_model(mut self, model: &str) -> Self {
54        self.model = model.to_string();
55        self
56    }
57
58    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
59        self.max_tokens = max_tokens;
60        self
61    }
62
63    pub fn with_temperature(mut self, temperature: f32) -> Self {
64        self.temperature = Some(temperature);
65        self
66    }
67
68    async fn send_converse_stream(
69        &self,
70        context: &Context,
71    ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
72        let (system_blocks, messages) = map_messages(context.messages())?;
73        let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
74
75        if let Some(temp) = self.temperature {
76            inference_config = inference_config.temperature(temp);
77        }
78
79        let inference_config = inference_config.build();
80
81        let mut request = self
82            .client
83            .converse_stream()
84            .model_id(&self.model)
85            .set_messages(Some(messages))
86            .inference_config(inference_config);
87
88        if !system_blocks.is_empty() {
89            request = request.set_system(Some(system_blocks));
90        }
91
92        if !context.tools().is_empty() {
93            let tool_config = map_tools(context.tools())?;
94            request = request.tool_config(tool_config);
95        }
96
97        info!(model = %self.model, "Sending Bedrock converse_stream request");
98
99        let response = request.send().await.map_err(|e| {
100            error!(model = %self.model, error = ?e, "Bedrock API error");
101            LlmError::from(e)
102        })?;
103
104        Ok(response.stream)
105    }
106}
107
108impl ProviderFactory for BedrockProvider {
109    async fn from_env() -> Result<Self> {
110        Ok(Self::new().await)
111    }
112
113    fn with_model(self, model: &str) -> Self {
114        self.with_model(model)
115    }
116}
117
118impl StreamingModelProvider for BedrockProvider {
119    fn model(&self) -> Option<crate::LlmModel> {
120        format!("bedrock:{}", self.model).parse().ok()
121    }
122
123    fn context_window(&self) -> Option<u32> {
124        get_context_window("bedrock", &self.model)
125    }
126
127    fn stream_response(&self, context: &Context) -> LlmResponseStream {
128        let provider = self.clone();
129        let context = context.clone();
130
131        Box::pin(async_stream::stream! {
132            match provider.send_converse_stream(&context).await {
133                Ok(receiver) => {
134                    let mut stream = Box::pin(process_bedrock_stream(receiver));
135                    while let Some(result) = stream.next().await {
136                        yield result;
137                    }
138                }
139                Err(e) => {
140                    yield Err(e);
141                }
142            }
143        })
144    }
145
146    fn display_name(&self) -> String {
147        format!("Bedrock ({})", self.model)
148    }
149}
150
151impl From<SdkError<ConverseStreamError>> for LlmError {
152    fn from(e: SdkError<ConverseStreamError>) -> Self {
153        let message = format!("Bedrock API error: {e}");
154        match e {
155            SdkError::TimeoutError(_) => LlmError::Timeout(message),
156            SdkError::DispatchFailure(_) => LlmError::Network(message),
157            SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
158            SdkError::ServiceError(svc) => {
159                let inner = svc.err();
160                if inner.is_throttling_exception() {
161                    LlmError::RateLimited(message)
162                } else if inner.is_service_unavailable_exception()
163                    || inner.is_internal_server_exception()
164                    || inner.is_model_stream_error_exception()
165                {
166                    LlmError::ServerError { status: None, message }
167                } else {
168                    LlmError::ApiError(message)
169                }
170            }
171            _ => LlmError::ApiError(message),
172        }
173    }
174}
175
176fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
177    let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
178
179    if let Some(creds) = credentials {
180        config = config.credentials_provider(Credentials::new(
181            creds.access_key_id,
182            creds.secret_access_key,
183            creds.session_token,
184            None,
185            "aether-bedrock-provider",
186        ));
187    }
188
189    config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
190
191    Client::from_conf(config.build())
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    fn test_provider() -> BedrockProvider {
199        BedrockProvider::from_config(None, None)
200    }
201
202    #[test]
203    fn test_display_name() {
204        assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
205    }
206
207    #[test]
208    fn test_with_model() {
209        let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
210        assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
211    }
212
213    #[test]
214    fn test_with_max_tokens() {
215        let provider = test_provider().with_max_tokens(8192);
216        assert_eq!(provider.max_tokens, 8192);
217    }
218
219    #[test]
220    fn test_with_temperature() {
221        let provider = test_provider().with_temperature(0.7);
222        assert_eq!(provider.temperature, Some(0.7));
223    }
224
225    #[test]
226    fn test_default_values() {
227        let provider = test_provider();
228        assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
229        assert_eq!(provider.max_tokens, 16_384);
230        assert!(provider.temperature.is_none());
231    }
232
233    #[test]
234    fn test_from_config_with_credentials() {
235        let credentials = AwsCredentials {
236            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
237            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
238            session_token: None,
239        };
240
241        let provider = BedrockProvider::from_config(Some(credentials), None);
242        assert_eq!(provider.model, DEFAULT_MODEL);
243    }
244
245    #[test]
246    fn test_from_config_with_credentials_and_region() {
247        let credentials = AwsCredentials {
248            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
249            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
250            session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
251        };
252
253        let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
254            .with_model("anthropic.claude-opus-4-20250514-v1:0")
255            .with_max_tokens(4096)
256            .with_temperature(0.5);
257
258        assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
259        assert_eq!(provider.max_tokens, 4096);
260        assert_eq!(provider.temperature, Some(0.5));
261    }
262
263    #[test]
264    fn test_from_config_with_region_only() {
265        let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
266        assert_eq!(provider.model, DEFAULT_MODEL);
267    }
268
269    #[test]
270    fn catalog_foundation_id_resolves_context_window() {
271        let provider = test_provider().with_model("anthropic.claude-sonnet-4-5-20250929-v1:0");
272        assert!(provider.context_window().is_some());
273        assert_eq!(provider.model().unwrap().to_string(), "bedrock:anthropic.claude-sonnet-4-5-20250929-v1:0");
274    }
275
276    #[test]
277    fn cross_region_profile_id_in_catalog_resolves() {
278        let provider = test_provider().with_model("us.anthropic.claude-opus-4-6-v1");
279        assert!(provider.context_window().is_some());
280    }
281
282    #[test]
283    fn unknown_cross_region_profile_id_falls_through_to_profile() {
284        let id = "us.anthropic.claude-future-model-v99:0";
285        let provider = test_provider().with_model(id);
286        assert_eq!(provider.context_window(), None);
287        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{id}"));
288        assert_eq!(provider.display_name(), format!("Bedrock ({id})"));
289    }
290
291    #[test]
292    fn inference_profile_arn_is_passed_through_as_profile() {
293        let arn = "arn:aws:bedrock:us-west-2:000000000000:inference-profile/us.anthropic.claude-opus-4-7";
294        let provider = test_provider().with_model(arn);
295        assert_eq!(provider.context_window(), None);
296        assert_eq!(provider.model, arn);
297        assert_eq!(provider.model().unwrap().to_string(), format!("bedrock:{arn}"));
298    }
299
300    #[test]
301    fn application_inference_profile_arn_is_passed_through_as_profile() {
302        let arn = "arn:aws:bedrock:us-west-2:000000000000:application-inference-profile/000000000000";
303        let provider = test_provider().with_model(arn);
304        assert_eq!(provider.context_window(), None);
305        assert_eq!(provider.model, arn);
306        assert_eq!(provider.display_name(), format!("Bedrock ({arn})"));
307    }
308}