Skip to main content

llm/providers/bedrock/
provider.rs

1use super::mappers::{map_messages, map_tools};
2use super::streaming::process_bedrock_stream;
3use crate::provider::{
4    LlmResponseStream, ProviderFactory, StreamingModelProvider, get_context_window,
5};
6use crate::{Context, LlmError, Result};
7use aws_config::Region;
8use aws_sdk_bedrockruntime::config::{BehaviorVersion, Credentials};
9use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
10use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
11use aws_sdk_bedrockruntime::types::{ConverseStreamOutput, InferenceConfiguration};
12use aws_sdk_bedrockruntime::{Client, Config};
13use futures::StreamExt;
14use tracing::{error, info};
15
16const DEFAULT_MODEL: &str = "anthropic.claude-sonnet-4-5-20250929-v1:0";
17const DEFAULT_MAX_TOKENS: i32 = 16_384;
18const DEFAULT_REGION: &str = "us-east-1";
19
20/// AWS credentials for explicit authentication with Bedrock.
21#[derive(Clone)]
22pub struct AwsCredentials {
23    pub access_key_id: String,
24    pub secret_access_key: String,
25    pub session_token: Option<String>,
26}
27
28#[derive(Clone)]
29pub struct BedrockProvider {
30    client: Client,
31    model: String,
32    max_tokens: i32,
33    temperature: Option<f32>,
34}
35
36impl BedrockProvider {
37    /// Create a provider using the default AWS credential chain
38    /// (env vars, `~/.aws/credentials`, IAM roles, SSO).
39    pub async fn new() -> Self {
40        let config = aws_config::defaults(BehaviorVersion::latest()).load().await;
41        let client = Client::new(&config);
42
43        Self {
44            client,
45            model: DEFAULT_MODEL.to_string(),
46            max_tokens: DEFAULT_MAX_TOKENS,
47            temperature: None,
48        }
49    }
50
51    /// Create a provider from explicit configuration without async credential discovery.
52    pub fn from_config(credentials: Option<AwsCredentials>, region: Option<&str>) -> Self {
53        let client = build_client(credentials, region);
54
55        Self {
56            client,
57            model: DEFAULT_MODEL.to_string(),
58            max_tokens: DEFAULT_MAX_TOKENS,
59            temperature: None,
60        }
61    }
62
63    pub fn with_model(mut self, model: &str) -> Self {
64        self.model = model.to_string();
65        self
66    }
67
68    pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
69        self.max_tokens = max_tokens;
70        self
71    }
72
73    pub fn with_temperature(mut self, temperature: f32) -> Self {
74        self.temperature = Some(temperature);
75        self
76    }
77
78    async fn send_converse_stream(
79        &self,
80        context: &Context,
81    ) -> Result<EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>> {
82        let (system_blocks, messages) = map_messages(context.messages())?;
83        let mut inference_config = InferenceConfiguration::builder().max_tokens(self.max_tokens);
84
85        if let Some(temp) = self.temperature {
86            inference_config = inference_config.temperature(temp);
87        }
88
89        let inference_config = inference_config.build();
90
91        let mut request = self
92            .client
93            .converse_stream()
94            .model_id(&self.model)
95            .set_messages(Some(messages))
96            .inference_config(inference_config);
97
98        if !system_blocks.is_empty() {
99            request = request.set_system(Some(system_blocks));
100        }
101
102        if !context.tools().is_empty() {
103            let tool_config = map_tools(context.tools())?;
104            request = request.tool_config(tool_config);
105        }
106
107        info!(model = %self.model, "Sending Bedrock converse_stream request");
108
109        let response = request.send().await.map_err(|e| {
110            error!(model = %self.model, error = ?e, "Bedrock API error");
111            LlmError::ApiError(format!("Bedrock API error: {e}"))
112        })?;
113
114        Ok(response.stream)
115    }
116}
117
118impl ProviderFactory for BedrockProvider {
119    fn from_env() -> Result<Self> {
120        let handle = tokio::runtime::Handle::try_current()
121            .map_err(|e| LlmError::Other(format!("No Tokio runtime available: {e}")))?;
122
123        Ok(tokio::task::block_in_place(|| handle.block_on(Self::new())))
124    }
125
126    fn with_model(self, model: &str) -> Self {
127        self.with_model(model)
128    }
129}
130
131impl StreamingModelProvider for BedrockProvider {
132    fn model(&self) -> Option<crate::LlmModel> {
133        format!("bedrock:{}", self.model).parse().ok()
134    }
135
136    fn context_window(&self) -> Option<u32> {
137        get_context_window("bedrock", &self.model)
138    }
139
140    fn stream_response(&self, context: &Context) -> LlmResponseStream {
141        let provider = self.clone();
142        let context = context.clone();
143
144        Box::pin(async_stream::stream! {
145            match provider.send_converse_stream(&context).await {
146                Ok(receiver) => {
147                    let mut stream = Box::pin(process_bedrock_stream(receiver));
148                    while let Some(result) = stream.next().await {
149                        yield result;
150                    }
151                }
152                Err(e) => {
153                    yield Err(e);
154                }
155            }
156        })
157    }
158
159    fn display_name(&self) -> String {
160        format!("Bedrock ({})", self.model)
161    }
162}
163
164fn build_client(credentials: Option<AwsCredentials>, region: Option<&str>) -> Client {
165    let mut config = Config::builder().behavior_version(BehaviorVersion::latest());
166
167    if let Some(creds) = credentials {
168        config = config.credentials_provider(Credentials::new(
169            creds.access_key_id,
170            creds.secret_access_key,
171            creds.session_token,
172            None,
173            "aether-bedrock-provider",
174        ));
175    }
176
177    config = config.region(Region::new(region.unwrap_or(DEFAULT_REGION).to_string()));
178
179    Client::from_conf(config.build())
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    fn test_provider() -> BedrockProvider {
187        BedrockProvider::from_config(None, None)
188    }
189
190    #[test]
191    fn test_display_name() {
192        assert_eq!(
193            test_provider().display_name(),
194            "Bedrock (anthropic.claude-sonnet-4-5-20250929-v1:0)"
195        );
196    }
197
198    #[test]
199    fn test_with_model() {
200        let provider = test_provider().with_model("anthropic.claude-opus-4-20250514-v1:0");
201        assert_eq!(
202            provider.display_name(),
203            "Bedrock (anthropic.claude-opus-4-20250514-v1:0)"
204        );
205    }
206
207    #[test]
208    fn test_with_max_tokens() {
209        let provider = test_provider().with_max_tokens(8192);
210        assert_eq!(provider.max_tokens, 8192);
211    }
212
213    #[test]
214    fn test_with_temperature() {
215        let provider = test_provider().with_temperature(0.7);
216        assert_eq!(provider.temperature, Some(0.7));
217    }
218
219    #[test]
220    fn test_default_values() {
221        let provider = test_provider();
222        assert_eq!(provider.model, "anthropic.claude-sonnet-4-5-20250929-v1:0");
223        assert_eq!(provider.max_tokens, 16_384);
224        assert!(provider.temperature.is_none());
225    }
226
227    #[test]
228    fn test_from_config_with_credentials() {
229        let credentials = AwsCredentials {
230            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
231            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
232            session_token: None,
233        };
234
235        let provider = BedrockProvider::from_config(Some(credentials), None);
236        assert_eq!(provider.model, DEFAULT_MODEL);
237    }
238
239    #[test]
240    fn test_from_config_with_credentials_and_region() {
241        let credentials = AwsCredentials {
242            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
243            secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
244            session_token: Some("FwoGZXIvYXdzEBYaD...".to_string()),
245        };
246
247        let provider = BedrockProvider::from_config(Some(credentials), Some("us-west-2"))
248            .with_model("anthropic.claude-opus-4-20250514-v1:0")
249            .with_max_tokens(4096)
250            .with_temperature(0.5);
251
252        assert_eq!(provider.model, "anthropic.claude-opus-4-20250514-v1:0");
253        assert_eq!(provider.max_tokens, 4096);
254        assert_eq!(provider.temperature, Some(0.5));
255    }
256
257    #[test]
258    fn test_from_config_with_region_only() {
259        let provider = BedrockProvider::from_config(None, Some("eu-west-1"));
260        assert_eq!(provider.model, DEFAULT_MODEL);
261    }
262}