llm/providers/bedrock/
provider.rs1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
8use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
9use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
10use aws_sdk_bedrockruntime::{Client, Config};
11use futures::StreamExt;
12use tracing::{error, info};
13
14const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
15const DEFAULT_MAX_TOKENS: i32 = 16_384;
16const DEFAULT_REGION: &str = "us-east-1";
17
18#[derive(Clone)]
20pub struct AwsCredentials {
21 pub access_key_id: String,
22 pub secret_access_key: String,
23 pub session_token: Option<String>,
24}
25
26#[derive(Clone)]
27pub struct BedrockProvider {
28 client: Client,
29 model: String,
30 max_tokens: i32,
31 temperature: Option<f32>,
32}
33
34impl BedrockProvider {
35 pub async fn new() -> Self {
38 let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
39 let client = Client::new(&config);
40
41 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
42 }
43
44 pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
46 let client = build_client(credentials, region);
47
48 Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
49 }
50
51 pub fn with_model(mut self, model: &str) -> Self {
52 self.model = model.to_string();
53 self
54 }
55
56 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
57 self.max_tokens = max_tokens;
58 self
59 }
60
61 pub fn with_temperature(mut self, temperature: f32) -> Self {
62 self.temperature = Some(temperature);
63 self
64 }
65
66 async fn send_converse_stream(
67 &self,
68 context: &Context,
69 ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
70 let (system_blocks, messages) = map_messages(context.messages())?;
71 let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
72
73 if let Some(temp) = self.temperature {
74 inference_config = inference_config.temperature(temp);
75 }
76
77 let inference_config = inference_config.build();
78
79 let mut request = self
80 .client
81 .converse_stream()
82 .model_id(&self.model)
83 .set_messages(Some(messages))
84 .inference_config(inference_config);
85
86 if !system_blocks.is_empty() {
87 request = request.set_system(Some(system_blocks));
88 }
89
90 if !context.tools().is_empty() {
91 let tool_config = map_tools(context.tools())?;
92 request = request.tool_config(tool_config);
93 }
94
95 info!(model = %self.model, "Sending Bedrock converse_stream request");
96
97 let response = request.send().await.map_err(|e| {
98 error!(model = %self.model, error = ?e, "Bedrock API error");
99 LlmError::ApiError(format!("Bedrock API error: {e}"))
100 })?;
101
102 Ok(response.stream)
103 }
104}
105
106impl ProviderFactory for BedrockProvider {
107 async fn from_env() -> Result<Self> {
108 Ok(Self::new().await)
109 }
110
111 fn with_model(self, model: &str) -> Self {
112 self.with_model(model)
113 }
114}
115
116impl StreamingModelProvider for BedrockProvider {
117 fn model(&self) -> Option<crate::LlmModel> {
118 format!("bedrock:{}", self.model).parse().ok()
119 }
120
121 fn context_window(&self) -> Option<u32> {
122 get_context_window("bedrock", &self.model)
123 }
124
125 fn stream_response(&self, context: &Context) -> LlmResponseStream {
126 let provider = self.clone();
127 let context = context.clone();
128
129 Box::pin(async_stream::stream! {
130 match provider.send_converse_stream(&context).await {
131 Ok(receiver) => {
132 let mut stream = Box::pin(process_bedrock_stream(receiver));
133 while let Some(result) = stream.next().await {
134 yield result;
135 }
136 }
137 Err(e) => {
138 yield Err(e);
139 }
140 }
141 })
142 }
143
144 fn display_name(&self) -> String {
145 format!("Bedrock ({})", self.model)
146 }
147}
148
149fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
150 let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
151
152 if let Some(creds) = credentials {
153 config = config.credentials_provider(Credentials::new(
154 creds.access_key_id,
155 creds.secret_access_key,
156 creds.session_token,
157 None,
158 "aether-bedrock-provider",
159 ));
160 }
161
162 config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
163
164 Client::from_conf(config.build())
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 fn test_provider() -> BedrockProvider {
172 BedrockProvider::from_config(None, None)
173 }
174
175 #[test]
176 fn test_display_name() {
177 assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
178 }
179
180 #[test]
181 fn test_with_model() {
182 let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
183 assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
184 }
185
186 #[test]
187 fn test_with_max_tokens() {
188 let provider = test_provider().with_max_tokens(8192);
189 assert_eq!(provider.max_tokens, 8192);
190 }
191
192 #[test]
193 fn test_with_temperature() {
194 let provider = test_provider().with_temperature(0.7);
195 assert_eq!(provider.temperature, Some(0.7));
196 }
197
198 #[test]
199 fn test_default_values() {
200 let provider = test_provider();
201 assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
202 assert_eq!(provider.max_tokens, 16_384);
203 assert!(provider.temperature.is_none());
204 }
205
206 #[test]
207 fn test_from_config_with_credentials() {
208 let credentials = AwsCredentials {
209 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
210 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
211 session_token: None,
212 };
213
214 let provider = BedrockProvider::from_config(Some(credentials), None);
215 assert_eq!(provider.model, DEFAULT_MODEL);
216 }
217
218 #[test]
219 fn test_from_config_with_credentials_and_region() {
220 let credentials = AwsCredentials {
221 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
222 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
223 session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
224 };
225
226 let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
227 .with_model("anthropic.claude-opus-4-20250514-v1:0")
228 .with_max_tokens(4096)
229 .with_temperature(0.5);
230
231 assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
232 assert_eq!(provider.max_tokens, 4096);
233 assert_eq!(provider.temperature, Some(0.5));
234 }
235
236 #[test]
237 fn test_from_config_with_region_only() {
238 let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
239 assert_eq!(provider.model, DEFAULT_MODEL);
240 }
241}