1use std::collections::HashMap;
2
3use super::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
4use super::errors::ProviderError;
5use super::retry::{ProviderRetry, RetryConfig};
6use crate::conversation::message::Message;
7use crate::model::ModelConfig;
8use crate::providers::utils::RequestLog;
9use anyhow::Result;
10use async_trait::async_trait;
11use aws_sdk_bedrockruntime::config::ProvideCredentials;
12use aws_sdk_bedrockruntime::operation::converse::ConverseError;
13use aws_sdk_bedrockruntime::{types as bedrock, Client};
14use rmcp::model::Tool;
15use serde_json::Value;
16
17use super::formats::bedrock::{
19 from_bedrock_message, from_bedrock_usage, to_bedrock_message, to_bedrock_tool_config,
20};
21
22pub const BEDROCK_DOC_LINK: &str =
23 "https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html";
24
25pub const BEDROCK_DEFAULT_MODEL: &str = "us.anthropic.claude-sonnet-4-5-20250929-v1:0";
26pub const BEDROCK_KNOWN_MODELS: &[&str] = &["us.anthropic.claude-sonnet-4-5-20250929-v1:0"];
27
28pub const BEDROCK_DEFAULT_MAX_RETRIES: usize = 6;
29pub const BEDROCK_DEFAULT_INITIAL_RETRY_INTERVAL_MS: u64 = 2000;
30pub const BEDROCK_DEFAULT_BACKOFF_MULTIPLIER: f64 = 2.0;
31pub const BEDROCK_DEFAULT_MAX_RETRY_INTERVAL_MS: u64 = 120_000;
32
33#[derive(Debug, serde::Serialize)]
34pub struct BedrockProvider {
35 #[serde(skip)]
36 client: Client,
37 model: ModelConfig,
38 #[serde(skip)]
39 retry_config: RetryConfig,
40 #[serde(skip)]
41 name: String,
42}
43
44impl BedrockProvider {
45 pub async fn from_env(model: ModelConfig) -> Result<Self> {
46 let config = crate::config::Config::global();
47
48 let set_aws_env_vars = |res: Result<HashMap<String, Value>, _>| {
51 if let Ok(map) = res {
52 map.into_iter()
53 .filter(|(key, _)| key.starts_with("AWS_"))
54 .filter_map(|(key, value)| value.as_str().map(|s| (key, s.to_string())))
55 .for_each(|(key, s)| std::env::set_var(key, s));
56 }
57 };
58
59 set_aws_env_vars(config.all_values());
60 set_aws_env_vars(config.all_secrets());
61
62 let mut loader = aws_config::defaults(aws_config::BehaviorVersion::latest());
64
65 if let Ok(profile_name) = config.get_param::<String>("AWS_PROFILE") {
66 if !profile_name.is_empty() {
67 loader = loader.profile_name(&profile_name);
68 }
69 }
70
71 if let Ok(region) = config.get_param::<String>("AWS_REGION") {
73 if !region.is_empty() {
74 loader = loader.region(aws_config::Region::new(region));
75 }
76 }
77
78 let sdk_config = loader.load().await;
79
80 sdk_config
82 .credentials_provider()
83 .ok_or_else(|| anyhow::anyhow!("No AWS credentials provider configured"))?
84 .provide_credentials()
85 .await
86 .map_err(|e| anyhow::anyhow!("Failed to load AWS credentials: {}. Make sure to run 'aws sso login --profile <your-profile>' if using SSO", e))?;
87
88 let client = Client::new(&sdk_config);
89
90 let retry_config = Self::load_retry_config(config);
91
92 Ok(Self {
93 client,
94 model,
95 retry_config,
96 name: Self::metadata().name,
97 })
98 }
99
100 fn load_retry_config(config: &crate::config::Config) -> RetryConfig {
101 let max_retries = config
102 .get_param::<usize>("BEDROCK_MAX_RETRIES")
103 .unwrap_or(BEDROCK_DEFAULT_MAX_RETRIES);
104
105 let initial_interval_ms = config
106 .get_param::<u64>("BEDROCK_INITIAL_RETRY_INTERVAL_MS")
107 .unwrap_or(BEDROCK_DEFAULT_INITIAL_RETRY_INTERVAL_MS);
108
109 let backoff_multiplier = config
110 .get_param::<f64>("BEDROCK_BACKOFF_MULTIPLIER")
111 .unwrap_or(BEDROCK_DEFAULT_BACKOFF_MULTIPLIER);
112
113 let max_interval_ms = config
114 .get_param::<u64>("BEDROCK_MAX_RETRY_INTERVAL_MS")
115 .unwrap_or(BEDROCK_DEFAULT_MAX_RETRY_INTERVAL_MS);
116
117 RetryConfig {
118 max_retries,
119 initial_interval_ms,
120 backoff_multiplier,
121 max_interval_ms,
122 }
123 }
124
125 async fn converse(
126 &self,
127 system: &str,
128 messages: &[Message],
129 tools: &[Tool],
130 ) -> Result<(bedrock::Message, Option<bedrock::TokenUsage>), ProviderError> {
131 let model_name = &self.model.model_name;
132
133 let mut request = self
134 .client
135 .converse()
136 .system(bedrock::SystemContentBlock::Text(system.to_string()))
137 .model_id(model_name.to_string())
138 .set_messages(Some(
139 messages
140 .iter()
141 .filter(|m| m.is_agent_visible())
142 .map(to_bedrock_message)
143 .collect::<Result<_>>()?,
144 ));
145
146 if !tools.is_empty() {
147 request = request.tool_config(to_bedrock_tool_config(tools)?);
148 }
149
150 let response = request
151 .send()
152 .await
153 .map_err(|err| match err.into_service_error() {
154 ConverseError::ThrottlingException(throttle_err) => {
155 ProviderError::RateLimitExceeded {
156 details: format!("Bedrock throttling error: {:?}", throttle_err),
157 retry_delay: None,
158 }
159 }
160 ConverseError::AccessDeniedException(err) => {
161 ProviderError::Authentication(format!("Failed to call Bedrock: {:?}", err))
162 }
163 ConverseError::ValidationException(err)
164 if err
165 .message()
166 .unwrap_or_default()
167 .contains("Input is too long for requested model.") =>
168 {
169 ProviderError::ContextLengthExceeded(format!(
170 "Failed to call Bedrock: {:?}",
171 err
172 ))
173 }
174 ConverseError::ModelErrorException(err) => {
175 ProviderError::ExecutionError(format!("Failed to call Bedrock: {:?}", err))
176 }
177 err => ProviderError::ServerError(format!("Failed to call Bedrock: {:?}", err)),
178 })?;
179
180 match response.output {
181 Some(bedrock::ConverseOutput::Message(message)) => Ok((message, response.usage)),
182 _ => Err(ProviderError::RequestFailed(
183 "No output from Bedrock".to_string(),
184 )),
185 }
186 }
187}
188
189#[async_trait]
190impl Provider for BedrockProvider {
191 fn metadata() -> ProviderMetadata {
192 ProviderMetadata::new(
193 "aws_bedrock",
194 "Amazon Bedrock",
195 "Run models through Amazon Bedrock. Supports AWS SSO profiles - run 'aws sso login --profile <profile-name>' before using. Configure with AWS_PROFILE and AWS_REGION, or use environment variables/credentials.",
196 BEDROCK_DEFAULT_MODEL,
197 BEDROCK_KNOWN_MODELS.to_vec(),
198 BEDROCK_DOC_LINK,
199 vec![
200 ConfigKey::new("AWS_PROFILE", true, false, Some("default")),
201 ConfigKey::new("AWS_REGION", true, false, None),
202 ],
203 )
204 }
205
206 fn get_name(&self) -> &str {
207 &self.name
208 }
209
210 fn retry_config(&self) -> RetryConfig {
211 self.retry_config.clone()
212 }
213
214 fn get_model_config(&self) -> ModelConfig {
215 self.model.clone()
216 }
217
218 #[tracing::instrument(
219 skip(self, model_config, system, messages, tools),
220 fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
221 )]
222 async fn complete_with_model(
223 &self,
224 model_config: &ModelConfig,
225 system: &str,
226 messages: &[Message],
227 tools: &[Tool],
228 ) -> Result<(Message, ProviderUsage), ProviderError> {
229 let model_name = model_config.model_name.clone();
230
231 let (bedrock_message, bedrock_usage) = self
232 .with_retry(|| self.converse(system, messages, tools))
233 .await?;
234
235 let usage = bedrock_usage
236 .as_ref()
237 .map(from_bedrock_usage)
238 .unwrap_or_default();
239
240 let message = from_bedrock_message(&bedrock_message)?;
241
242 let debug_payload = serde_json::json!({
244 "system": system,
245 "messages": messages,
246 "tools": tools
247 });
248 let mut log = RequestLog::start(&self.model, &debug_payload)?;
249 log.write(
250 &serde_json::to_value(&message).unwrap_or_default(),
251 Some(&usage),
252 )?;
253
254 let provider_usage = ProviderUsage::new(model_name.to_string(), usage);
255 Ok((message, provider_usage))
256 }
257}