Skip to main content

aster/providers/
bedrock.rs

1use std::collections::HashMap;
2
3use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
4use super::errors::ProviderError;
5use super::retry::{ProviderRetry, RetryConfig};
6use crate::conversation::message::Message;
7use crate::model::ModelConfig;
8use crate::providers::utils::RequestLog;
9use anyhow::Result;
10use async_trait::async_trait;
11use aws_sdk_bedrockruntime::config::ProvideCredentials;
12use aws_sdk_bedrockruntime::operation::converse::ConverseError;
13use aws_sdk_bedrockruntime::{types as bedrock, Client};
14use rmcp::model::Tool;
15use serde_json::Value;
16
17// Import the migrated helper functions from providers/formats/bedrock.rs
18use super::formats::bedrock::{
19    from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config,
20};
21
22pub const BEDROCK_DOC_LINK: &str =
23    "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html";
24
25pub const BEDROCK_DEFAULT_MODEL: &str = "us.anthropic.claude-sonnet-4-5-20250929-v1:0";
26pub const BEDROCK_KNOWN_MODELS: &[&str] = &["us.anthropic.claude-sonnet-4-5-20250929-v1:0"];
27
28pub const BEDROCK_DEFAULT_MAX_RETRIES: usize = 6;
29pub const BEDROCK_DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 2000;
30pub const BEDROCK_DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
31pub const BEDROCK_DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 120_000;
32
33#[derive(Debug, serde::Serialize)]
34pub struct BedrockProvider {
35    #[serde(skip)]
36    client: Client,
37    model: ModelConfig,
38    #[serde(skip)]
39    retry_config: RetryConfig,
40    #[serde(skip)]
41    name: String,
42}
43
44impl BedrockProvider {
45    pub async fn from_env(model: ModelConfig) -> Result<Self> {
46        let config = crate::config::Config::global();
47
48        // Attempt to load config and secrets to get AWS_ prefixed keys
49        // to re-export them into the environment for aws_config to use as fallback
50        let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
51            if let Ok(map) = res {
52                map.into_iter()
53                    .filter(|(key, _)| key.starts_with("AWS_"))
54                    .filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
55                    .for_each(|(key, s)| std::env::set_var(key, s));
56            }
57        };
58
59        set_aws_env_vars(config.all_values());
60        set_aws_env_vars(config.all_secrets());
61
62        // Use load_defaults() which supports AWS SSO, profiles, and environment variables
63        let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
64
65        if let Ok(profile_name) = config.get_param::<String>("AWS_PROFILE") {
66            if !profile_name.is_empty() {
67                loader = loader.profile_name(&profile_name);
68            }
69        }
70
71        // Check for AWS_REGION configuration
72        if let Ok(region) = config.get_param::<String>("AWS_REGION") {
73            if !region.is_empty() {
74                loader = loader.region(aws_config::Region::new(region));
75            }
76        }
77
78        let sdk_config = loader.load().await;
79
80        // Validate credentials or return error back up
81        sdk_config
82            .credentials_provider()
83            .ok_or_else(|| anyhow::anyhow!("No AWS credentials provider configured"))?
84            .provide_credentials()
85            .await
86            .map_err(|e| anyhow::anyhow!("Failed to load AWS credentials: {}. Make sure to run 'aws sso login --profile <your-profile>' if using SSO", e))?;
87
88        let client = Client::new(&sdk_config);
89
90        let retry_config = Self::load_retry_config(config);
91
92        Ok(Self {
93            client,
94            model,
95            retry_config,
96            name: Self::metadata().name,
97        })
98    }
99
100    fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
101        let max_retries = config
102            .get_param::<usize>("BEDROCK_MAX_RETRIES")
103            .unwrap_or(BEDROCK_DEFAULT_MAX_RETRIES);
104
105        let initial_interval_ms = config
106            .get_param::<u64>("BEDROCK_INITIAL_RETRY_INTERVAL_MS")
107            .unwrap_or(BEDROCK_DEFAULT_INITIAL_RETRY_INTERVAL_MS);
108
109        let backoff_multiplier = config
110            .get_param::<f64>("BEDROCK_BACKOFF_MULTIPLIER")
111            .unwrap_or(BEDROCK_DEFAULT_BACKOFF_MULTIPLIER);
112
113        let max_interval_ms = config
114            .get_param::<u64>("BEDROCK_MAX_RETRY_INTERVAL_MS")
115            .unwrap_or(BEDROCK_DEFAULT_MAX_RETRY_INTERVAL_MS);
116
117        RetryConfig {
118            max_retries,
119            initial_interval_ms,
120            backoff_multiplier,
121            max_interval_ms,
122        }
123    }
124
125    async fn converse(
126        &self,
127        system: &str,
128        messages: &[Message],
129        tools: &[Tool],
130    ) -> Result<(bedrock::Message, Option<bedrock::TokenUsage>), ProviderError> {
131        let model_name = &self.model.model_name;
132
133        let mut request = self
134            .client
135            .converse()
136            .system(bedrock::SystemContentBlock::Text(system.to_string()))
137            .model_id(model_name.to_string())
138            .set_messages(Some(
139                messages
140                    .iter()
141                    .filter(|m| m.is_agent_visible())
142                    .map(to_bedrock_message)
143                    .collect::<Result<_>>()?,
144            ));
145
146        if !tools.is_empty() {
147            request = request.tool_config(to_bedrock_tool_config(tools)?);
148        }
149
150        let response = request
151            .send()
152            .await
153            .map_err(|err| match err.into_service_error() {
154                ConverseError::ThrottlingException(throttle_err) => {
155                    ProviderError::RateLimitExceeded {
156                        details: format!("Bedrock throttling error: {:?}", throttle_err),
157                        retry_delay: None,
158                    }
159                }
160                ConverseError::AccessDeniedException(err) => {
161                    ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
162                }
163                ConverseError::ValidationException(err)
164                    if err
165                        .message()
166                        .unwrap_or_default()
167                        .contains("Input is too long for requested model.") =>
168                {
169                    ProviderError::ContextLengthExceeded(format!(
170                        "Failed to call Bedrock: {:?}",
171                        err
172                    ))
173                }
174                ConverseError::ModelErrorException(err) => {
175                    ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
176                }
177                err => ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err)),
178            })?;
179
180        match response.output {
181            Some(bedrock::ConverseOutput::Message(message)) => Ok((message, response.usage)),
182            _ => Err(ProviderError::RequestFailed(
183                "No output from Bedrock".to_string(),
184            )),
185        }
186    }
187}
188
189#[async_trait]
190impl Provider for BedrockProvider {
191    fn metadata() -> ProviderMetadata {
192        ProviderMetadata::new(
193            "aws_bedrock",
194            "Amazon Bedrock",
195            "Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile <profile-name>' before using. Configure with AWS_PROFILE and AWS_REGION, or use environment variables/credentials.",
196            BEDROCK_DEFAULT_MODEL,
197            BEDROCK_KNOWN_MODELS.to_vec(),
198            BEDROCK_DOC_LINK,
199            vec![
200                ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
201                ConfigKey::new("AWS_REGION", true, false, None),
202            ],
203        )
204    }
205
206    fn get_name(&self) -> &str {
207        &self.name
208    }
209
210    fn retry_config(&self) -> RetryConfig {
211        self.retry_config.clone()
212    }
213
214    fn get_model_config(&self) -> ModelConfig {
215        self.model.clone()
216    }
217
218    #[tracing::instrument(
219        skip(self, model_config, system, messages, tools),
220        fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
221    )]
222    async fn complete_with_model(
223        &self,
224        model_config: &ModelConfig,
225        system: &str,
226        messages: &[Message],
227        tools: &[Tool],
228    ) -> Result<(Message, ProviderUsage), ProviderError> {
229        let model_name = model_config.model_name.clone();
230
231        let (bedrock_message, bedrock_usage) = self
232            .with_retry(|| self.converse(system, messages, tools))
233            .await?;
234
235        let usage = bedrock_usage
236            .as_ref()
237            .map(from_bedrock_usage)
238            .unwrap_or_default();
239
240        let message = from_bedrock_message(&bedrock_message)?;
241
242        // Add debug trace with input context
243        let debug_payload = serde_json::json!({
244            "system": system,
245            "messages": messages,
246            "tools": tools
247        });
248        let mut log = RequestLog::start(&self.model, &debug_payload)?;
249        log.write(
250            &serde_json::to_value(&message).unwrap_or_default(),
251            Some(&usage),
252        )?;
253
254        let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
255        Ok((message, provider_usage))
256    }
257}