Skip to main content

llm/providers/bedrock/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window};
4use crate::{Context, LlmError, Result};
5use aws_config::Region;
6use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
7use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
8use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
9use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
10use aws_sdk_bedrockruntime::{Client, Config};
11use futures::StreamExt;
12use tracing::{error, info};
13
14const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
15const DEFAULT_MAX_TOKENS: i32 = 16_384;
16const DEFAULT_REGION: &str = "us-east-1";
17
18/// AWS credentials for explicit authentication with Bedrock.
19#[derive(Clone)]
20pub struct AwsCredentials {
21    pub access_key_id: String,
22    pub secret_access_key: String,
23    pub session_token: Option<String>,
24}
25
26#[derive(Clone)]
27pub struct BedrockProvider {
28    client: Client,
29    model: String,
30    max_tokens: i32,
31    temperature: Option<f32>,
32}
33
34impl BedrockProvider {
35    /// Create a provider using the default AWS credential chain
36    /// (env vars, `~/.aws/credentials`, IAM roles, SSO).
37    pub async fn new() -> Self {
38        let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
39        let client = Client::new(&config);
40
41        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
42    }
43
44    /// Create a provider from explicit configuration without async credential discovery.
45    pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
46        let client = build_client(credentials, region);
47
48        Self { client, model: DEFAULT_MODEL.to_string(), max_tokens: DEFAULT_MAX_TOKENS, temperature: None }
49    }
50
51    pub fn with_model(mut self, model: &str) -> Self {
52        self.model = model.to_string();
53        self
54    }
55
56    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
57        self.max_tokens = max_tokens;
58        self
59    }
60
61    pub fn with_temperature(mut self, temperature: f32) -> Self {
62        self.temperature = Some(temperature);
63        self
64    }
65
66    async fn send_converse_stream(
67        &self,
68        context: &Context,
69    ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
70        let (system_blocks, messages) = map_messages(context.messages())?;
71        let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
72
73        if let Some(temp) = self.temperature {
74            inference_config = inference_config.temperature(temp);
75        }
76
77        let inference_config = inference_config.build();
78
79        let mut request = self
80            .client
81            .converse_stream()
82            .model_id(&self.model)
83            .set_messages(Some(messages))
84            .inference_config(inference_config);
85
86        if !system_blocks.is_empty() {
87            request = request.set_system(Some(system_blocks));
88        }
89
90        if !context.tools().is_empty() {
91            let tool_config = map_tools(context.tools())?;
92            request = request.tool_config(tool_config);
93        }
94
95        info!(model = %self.model, "Sending Bedrock converse_stream request");
96
97        let response = request.send().await.map_err(|e| {
98            error!(model = %self.model, error = ?e, "Bedrock API error");
99            LlmError::ApiError(format!("Bedrock API error: {e}"))
100        })?;
101
102        Ok(response.stream)
103    }
104}
105
106impl ProviderFactory for BedrockProvider {
107    async fn from_env() -> Result<Self> {
108        Ok(Self::new().await)
109    }
110
111    fn with_model(self, model: &str) -> Self {
112        self.with_model(model)
113    }
114}
115
116impl StreamingModelProvider for BedrockProvider {
117    fn model(&self) -> Option<crate::LlmModel> {
118        format!("bedrock:{}", self.model).parse().ok()
119    }
120
121    fn context_window(&self) -> Option<u32> {
122        get_context_window("bedrock", &self.model)
123    }
124
125    fn stream_response(&self, context: &Context) -> LlmResponseStream {
126        let provider = self.clone();
127        let context = context.clone();
128
129        Box::pin(async_stream::stream! {
130            match provider.send_converse_stream(&context).await {
131                Ok(receiver) => {
132                    let mut stream = Box::pin(process_bedrock_stream(receiver));
133                    while let Some(result) = stream.next().await {
134                        yield result;
135                    }
136                }
137                Err(e) => {
138                    yield Err(e);
139                }
140            }
141        })
142    }
143
144    fn display_name(&self) -> String {
145        format!("Bedrock ({})", self.model)
146    }
147}
148
149fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
150    let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
151
152    if let Some(creds) = credentials {
153        config = config.credentials_provider(Credentials::new(
154            creds.access_key_id,
155            creds.secret_access_key,
156            creds.session_token,
157            None,
158            "aether-bedrock-provider",
159        ));
160    }
161
162    config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
163
164    Client::from_conf(config.build())
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    fn test_provider() -> BedrockProvider {
172        BedrockProvider::from_config(None, None)
173    }
174
175    #[test]
176    fn test_display_name() {
177        assert_eq!(test_provider().display_name(), "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)");
178    }
179
180    #[test]
181    fn test_with_model() {
182        let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
183        assert_eq!(provider.display_name(), "Bedrock (anthropic.claude-opus-4-20250514-v1:0)");
184    }
185
186    #[test]
187    fn test_with_max_tokens() {
188        let provider = test_provider().with_max_tokens(8192);
189        assert_eq!(provider.max_tokens, 8192);
190    }
191
192    #[test]
193    fn test_with_temperature() {
194        let provider = test_provider().with_temperature(0.7);
195        assert_eq!(provider.temperature, Some(0.7));
196    }
197
198    #[test]
199    fn test_default_values() {
200        let provider = test_provider();
201        assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
202        assert_eq!(provider.max_tokens, 16_384);
203        assert!(provider.temperature.is_none());
204    }
205
206    #[test]
207    fn test_from_config_with_credentials() {
208        let credentials = AwsCredentials {
209            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
210            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
211            session_token: None,
212        };
213
214        let provider = BedrockProvider::from_config(Some(credentials), None);
215        assert_eq!(provider.model, DEFAULT_MODEL);
216    }
217
218    #[test]
219    fn test_from_config_with_credentials_and_region() {
220        let credentials = AwsCredentials {
221            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
222            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
223            session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
224        };
225
226        let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
227            .with_model("anthropic.claude-opus-4-20250514-v1:0")
228            .with_max_tokens(4096)
229            .with_temperature(0.5);
230
231        assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
232        assert_eq!(provider.max_tokens, 4096);
233        assert_eq!(provider.temperature, Some(0.5));
234    }
235
236    #[test]
237    fn test_from_config_with_region_only() {
238        let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
239        assert_eq!(provider.model, DEFAULT_MODEL);
240    }
241}