llm/providers/bedrock/
provider.rs1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{
4 LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window,
5};
6use crate::{Context, LlmError, Result};
7use aws_config::Region;
8use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20#[derive(Clone)]
22pub struct AwsCredentials {
23 pub access_key_id: String,
24 pub secret_access_key: String,
25 pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30 client: Client,
31 model: String,
32 max_tokens: i32,
33 temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37 pub async fn new() -> Self {
40 let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
41 let client = Client::new(&config);
42
43 Self {
44 client,
45 model: DEFAULT_MODEL.to_string(),
46 max_tokens: DEFAULT_MAX_TOKENS,
47 temperature: None,
48 }
49 }
50
51 pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
53 let client = build_client(credentials, region);
54
55 Self {
56 client,
57 model: DEFAULT_MODEL.to_string(),
58 max_tokens: DEFAULT_MAX_TOKENS,
59 temperature: None,
60 }
61 }
62
63 pub fn with_model(mut self, model: &str) -> Self {
64 self.model = model.to_string();
65 self
66 }
67
68 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
69 self.max_tokens = max_tokens;
70 self
71 }
72
73 pub fn with_temperature(mut self, temperature: f32) -> Self {
74 self.temperature = Some(temperature);
75 self
76 }
77
78 async fn send_converse_stream(
79 &self,
80 context: &Context,
81 ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
82 let (system_blocks, messages) = map_messages(context.messages())?;
83 let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
84
85 if let Some(temp) = self.temperature {
86 inference_config = inference_config.temperature(temp);
87 }
88
89 let inference_config = inference_config.build();
90
91 let mut request = self
92 .client
93 .converse_stream()
94 .model_id(&self.model)
95 .set_messages(Some(messages))
96 .inference_config(inference_config);
97
98 if !system_blocks.is_empty() {
99 request = request.set_system(Some(system_blocks));
100 }
101
102 if !context.tools().is_empty() {
103 let tool_config = map_tools(context.tools())?;
104 request = request.tool_config(tool_config);
105 }
106
107 info!(model = %self.model, "Sending Bedrock converse_stream request");
108
109 let response = request.send().await.map_err(|e| {
110 error!(model = %self.model, error = ?e, "Bedrock API error");
111 LlmError::ApiError(format!("Bedrock API error: {e}"))
112 })?;
113
114 Ok(response.stream)
115 }
116}
117
118impl ProviderFactory for BedrockProvider {
119 fn from_env() -> Result<Self> {
120 let handle = tokio::runtime::Handle::try_current()
121 .map_err(|e| LlmError::Other(format!("No Tokio runtime available: {e}")))?;
122
123 Ok(tokio::task::block_in_place(|| handle.block_on(Self::new())))
124 }
125
126 fn with_model(self, model: &str) -> Self {
127 self.with_model(model)
128 }
129}
130
131impl StreamingModelProvider for BedrockProvider {
132 fn model(&self) -> Option<crate::LlmModel> {
133 format!("bedrock:{}", self.model).parse().ok()
134 }
135
136 fn context_window(&self) -> Option<u32> {
137 get_context_window("bedrock", &self.model)
138 }
139
140 fn stream_response(&self, context: &Context) -> LlmResponseStream {
141 let provider = self.clone();
142 let context = context.clone();
143
144 Box::pin(async_stream::stream! {
145 match provider.send_converse_stream(&context).await {
146 Ok(receiver) => {
147 let mut stream = Box::pin(process_bedrock_stream(receiver));
148 while let Some(result) = stream.next().await {
149 yield result;
150 }
151 }
152 Err(e) => {
153 yield Err(e);
154 }
155 }
156 })
157 }
158
159 fn display_name(&self) -> String {
160 format!("Bedrock ({})", self.model)
161 }
162}
163
164fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
165 let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
166
167 if let Some(creds) = credentials {
168 config = config.credentials_provider(Credentials::new(
169 creds.access_key_id,
170 creds.secret_access_key,
171 creds.session_token,
172 None,
173 "aether-bedrock-provider",
174 ));
175 }
176
177 config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
178
179 Client::from_conf(config.build())
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 fn test_provider() -> BedrockProvider {
187 BedrockProvider::from_config(None, None)
188 }
189
190 #[test]
191 fn test_display_name() {
192 assert_eq!(
193 test_provider().display_name(),
194 "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)"
195 );
196 }
197
198 #[test]
199 fn test_with_model() {
200 let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
201 assert_eq!(
202 provider.display_name(),
203 "Bedrock (anthropic.claude-opus-4-20250514-v1:0)"
204 );
205 }
206
207 #[test]
208 fn test_with_max_tokens() {
209 let provider = test_provider().with_max_tokens(8192);
210 assert_eq!(provider.max_tokens, 8192);
211 }
212
213 #[test]
214 fn test_with_temperature() {
215 let provider = test_provider().with_temperature(0.7);
216 assert_eq!(provider.temperature, Some(0.7));
217 }
218
219 #[test]
220 fn test_default_values() {
221 let provider = test_provider();
222 assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
223 assert_eq!(provider.max_tokens, 16_384);
224 assert!(provider.temperature.is_none());
225 }
226
227 #[test]
228 fn test_from_config_with_credentials() {
229 let credentials = AwsCredentials {
230 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
231 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
232 session_token: None,
233 };
234
235 let provider = BedrockProvider::from_config(Some(credentials), None);
236 assert_eq!(provider.model, DEFAULT_MODEL);
237 }
238
239 #[test]
240 fn test_from_config_with_credentials_and_region() {
241 let credentials = AwsCredentials {
242 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
243 secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
244 session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
245 };
246
247 let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
248 .with_model("anthropic.claude-opus-4-20250514-v1:0")
249 .with_max_tokens(4096)
250 .with_temperature(0.5);
251
252 assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
253 assert_eq!(provider.max_tokens, 4096);
254 assert_eq!(provider.temperature, Some(0.5));
255 }
256
257 #[test]
258 fn test_from_config_with_region_only() {
259 let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
260 assert_eq!(provider.model, DEFAULT_MODEL);
261 }
262}