llm/providers/bedrock/
provider.rs1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::error::SdkError;
8use aws_sdk_bedrockruntime::operation::converse_stream::ConverseStreamError;
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20#[derive(Clone)]
22pub struct AwsCredentials {
23 pub access_key_id: String,
24 pub secret_access_key: String,
25 pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30 client: Client,
31 model: String,
32 max_tokens: i32,
33 temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37 pub async fn new() -> Self {
40 let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
41 let client = Client::new(&config);
42
43 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
44 }
45
46 pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
48 let client = build_client(credentials, region);
49
50 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
51 }
52
53 pub fn with_model(mut self, model: &str) -> Self {
54 self.model = model.to_string();
55 self
56 }
57
58 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
59 self.max_tokens = max_tokens;
60 self
61 }
62
63 pub fn with_temperature(mut self, temperature: f32) -> Self {
64 self.temperature = Some(temperature);
65 self
66 }
67
68 async fn send_converse_stream(
69 &self,
70 context: &Context,
71 ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
72 let (system_blocks, messages) = map_messages(context.messages())?;
73 let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
74
75 if let Some(temp) = self.temperature {
76 inference_config = inference_config.temperature(temp);
77 }
78
79 let inference_config = inference_config.build();
80
81 let mut request = self
82 .client
83 .converse_stream()
84 .model_id(&self.model)
85 .set_messages(Some(messages))
86 .inference_config(inference_config);
87
88 if !system_blocks.is_empty() {
89 request = request.set_system(Some(system_blocks));
90 }
91
92 if !context.tools().is_empty() {
93 let tool_config = map_tools(context.tools())?;
94 request = request.tool_config(tool_config);
95 }
96
97 info!(model = %self.model, "Sending Bedrock converse_stream request");
98
99 let response = request.send().await.map_err(|e| {
100 error!(model = %self.model, error = ?e, "Bedrock API error");
101 LlmError::from(e)
102 })?;
103
104 Ok(response.stream)
105 }
106}
107
108impl ProviderFactory for BedrockProvider {
109 async fn from_env() -> Result<Self> {
110 Ok(Self::new().await)
111 }
112
113 fn with_model(self, model: &str) -> Self {
114 self.with_model(model)
115 }
116}
117
118impl StreamingModelProvider for BedrockProvider {
119 fn model(&self) -> Option<crate::LlmModel> {
120 format!("bedrock:{}", self.model).parse().ok()
121 }
122
123 fn context_window(&self) -> Option<u32> {
124 get_context_window("bedrock", &self.model)
125 }
126
127 fn stream_response(&self, context: &Context) -> LlmResponseStream {
128 let provider = self.clone();
129 let context = context.clone();
130
131 Box::pin(async_stream::stream! {
132 match provider.send_converse_stream(&context).await {
133 Ok(receiver) => {
134 let mut stream = Box::pin(process_bedrock_stream(receiver));
135 while let Some(result) = stream.next().await {
136 yield result;
137 }
138 }
139 Err(e) => {
140 yield Err(e);
141 }
142 }
143 })
144 }
145
146 fn display_name(&self) -> String {
147 format!("Bedrock ({})", self.model)
148 }
149}
150
151impl From<SdkError<ConverseStreamError>> for LlmError {
152 fn from(e: SdkError<ConverseStreamError>) -> Self {
153 let message = format!("Bedrock API error: {e}");
154 match e {
155 SdkError::TimeoutError(_) => LlmError::Timeout(message),
156 SdkError::DispatchFailure(_) => LlmError::Network(message),
157 SdkError::ResponseError(_) => LlmError::ServerError { status: None, message },
158 SdkError::ServiceError(svc) => {
159 let inner = svc.err();
160 if inner.is_throttling_exception() {
161 LlmError::RateLimited(message)
162 } else if inner.is_service_unavailable_exception()
163 || inner.is_internal_server_exception()
164 || inner.is_model_stream_error_exception()
165 {
166 LlmError::ServerError { status: None, message }
167 } else {
168 LlmError::ApiError(message)
169 }
170 }
171 _ => LlmError::ApiError(message),
172 }
173 }
174}
175
176fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
177 let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
178
179 if let Some(creds) = credentials {
180 config = config.credentials_provider(Credentials::new(
181 creds.access_key_id,
182 creds.secret_access_key,
183 creds.session_token,
184 None,
185 "aether-bedrock-provider",
186 ));
187 }
188
189 config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
190
191 Client::from_conf(config.build())
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 fn test_provider() -> BedrockProvider {
199 BedrockProvider::from_config(None, None)
200 }
201
202 #[test]
203 fn test_display_name() {
204 assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
205 }
206
207 #[test]
208 fn test_with_model() {
209 let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
210 assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
211 }
212
213 #[test]
214 fn test_with_max_tokens() {
215 let provider = test_provider().with_max_tokens(8192);
216 assert_eq!(provider.max_tokens, 8192);
217 }
218
219 #[test]
220 fn test_with_temperature() {
221 let provider = test_provider().with_temperature(0.7);
222 assert_eq!(provider.temperature, Some(0.7));
223 }
224
225 #[test]
226 fn test_default_values() {
227 let provider = test_provider();
228 assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
229 assert_eq!(provider.max_tokens, 16_384);
230 assert!(provider.temperature.is_none());
231 }
232
233 #[test]
234 fn test_from_config_with_credentials() {
235 let credentials = AwsCredentials {
236 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
237 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
238 session_token: None,
239 };
240
241 let provider = BedrockProvider::from_config(Some(credentials), None);
242 assert_eq!(provider.model, DEFAULT_MODEL);
243 }
244
245 #[test]
246 fn test_from_config_with_credentials_and_region() {
247 let credentials = AwsCredentials {
248 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
249 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
250 session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
251 };
252
253 let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
254 .with_model("anthropic.claude-opus-4-20250514-v1:0")
255 .with_max_tokens(4096)
256 .with_temperature(0.5);
257
258 assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
259 assert_eq!(provider.max_tokens, 4096);
260 assert_eq!(provider.temperature, Some(0.5));
261 }
262
263 #[test]
264 fn test_from_config_with_region_only() {
265 let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
266 assert_eq!(provider.model, DEFAULT_MODEL);
267 }
268}