claude_agent/client/adapter/
bedrock.rs1use std::sync::Arc;
8use std::time::SystemTime;
9
10use async_trait::async_trait;
11use aws_config::BehaviorVersion;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
14use aws_sigv4::sign::v4::SigningParams;
15use aws_smithy_runtime_api::client::identity::Identity;
16
17use super::base::RequestExecutor;
18use super::config::{BetaFeature, ProviderConfig};
19use super::request::{add_beta_features, build_messages_body};
20use super::token_cache::{AwsCredentialsCache, CachedAwsCredentials, new_aws_credentials_cache};
21use super::traits::ProviderAdapter;
22use crate::client::messages::CreateMessageRequest;
23use crate::types::ApiResponse;
24use crate::{Error, Result};
25
26const ANTHROPIC_VERSION: &str = "bedrock-2023-05-31";
27
28#[derive(Debug)]
30pub struct BedrockAdapter {
31 config: ProviderConfig,
32 region: String,
33 small_model_region: Option<String>,
34 use_global_endpoint: bool,
35 enable_1m_context: bool,
36 auth: BedrockAuth,
37 credentials_cache: AwsCredentialsCache,
38}
39
40#[derive(Debug)]
41enum BedrockAuth {
42 SigV4(Arc<dyn ProvideCredentials>),
43 BearerToken(String),
44}
45
46impl BedrockAdapter {
47 pub async fn from_env(config: ProviderConfig) -> Result<Self> {
49 let bedrock_config = crate::config::BedrockConfig::from_env();
50 Self::from_config(config, bedrock_config).await
51 }
52
53 pub async fn from_config(
55 config: ProviderConfig,
56 bedrock: crate::config::BedrockConfig,
57 ) -> Result<Self> {
58 let region = bedrock.region.unwrap_or_else(|| "us-east-1".into());
59
60 let auth = if let Some(token) = bedrock.bearer_token {
61 BedrockAuth::BearerToken(token)
62 } else {
63 let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
64 let credentials = aws_config
65 .credentials_provider()
66 .ok_or_else(|| Error::auth("No AWS credentials found"))?;
67 BedrockAuth::SigV4(Arc::from(credentials))
68 };
69
70 Ok(Self {
71 config,
72 region,
73 small_model_region: bedrock.small_model_region,
74 use_global_endpoint: bedrock.use_global_endpoint,
75 enable_1m_context: bedrock.enable_1m_context,
76 auth,
77 credentials_cache: new_aws_credentials_cache(),
78 })
79 }
80
81 pub fn with_region(mut self, region: impl Into<String>) -> Self {
83 self.region = region.into();
84 self
85 }
86
87 pub fn with_small_model_region(mut self, region: impl Into<String>) -> Self {
89 self.small_model_region = Some(region.into());
90 self
91 }
92
93 pub fn with_global_endpoint(mut self, enable: bool) -> Self {
95 self.use_global_endpoint = enable;
96 self
97 }
98
99 pub fn with_1m_context(mut self, enable: bool) -> Self {
101 self.enable_1m_context = enable;
102 self
103 }
104
105 pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
107 self.auth = BedrockAuth::BearerToken(token.into());
108 self
109 }
110
111 fn region_for_model(&self, model: &str) -> &str {
113 if let Some(ref small_region) = self.small_model_region
114 && model.contains("haiku")
115 {
116 return small_region;
117 }
118 &self.region
119 }
120
121 fn build_invoke_url(&self, model: &str, stream: bool) -> String {
122 let region = self.region_for_model(model);
123 let endpoint = if stream {
124 "invoke-with-response-stream"
125 } else {
126 "invoke"
127 };
128 let encoded_model = urlencoding::encode(model);
129
130 format!(
131 "https://bedrock-runtime.{}.amazonaws.com/model/{}/{}",
132 region, encoded_model, endpoint
133 )
134 }
135
136 fn build_request_body(&self, request: &CreateMessageRequest) -> serde_json::Value {
138 let mut body = build_messages_body(
139 request,
140 Some(ANTHROPIC_VERSION),
141 self.config.thinking_budget,
142 );
143
144 if let Some(obj) = body.as_object_mut() {
146 obj.remove("model");
147 }
148
149 if self.enable_1m_context {
150 add_beta_features(&mut body, &[BetaFeature::Context1M.header_value()]);
151 }
152
153 body
154 }
155
156 async fn get_credentials(&self) -> Result<CachedAwsCredentials> {
158 let provider = match &self.auth {
159 BedrockAuth::SigV4(p) => p,
160 BedrockAuth::BearerToken(_) => {
161 return Err(Error::auth("Bearer token mode does not use credentials"));
162 }
163 };
164
165 {
166 let cache = self.credentials_cache.read().await;
167 if let Some(ref creds) = *cache
168 && !creds.is_expired()
169 {
170 return Ok(creds.clone());
171 }
172 }
173
174 let creds = provider
175 .provide_credentials()
176 .await
177 .map_err(|e| Error::auth(e.to_string()))?;
178
179 let cached = CachedAwsCredentials::new(
180 creds.access_key_id().to_string(),
181 creds.secret_access_key().to_string(),
182 creds.session_token().map(|s| s.to_string()),
183 creds.expiry(),
184 );
185
186 *self.credentials_cache.write().await = Some(cached.clone());
187 Ok(cached)
188 }
189
190 async fn get_auth_headers(
192 &self,
193 method: &str,
194 url: &str,
195 body: &[u8],
196 region: &str,
197 ) -> Result<Vec<(String, String)>> {
198 match &self.auth {
199 BedrockAuth::BearerToken(token) => {
200 Ok(vec![("Authorization".into(), format!("Bearer {}", token))])
201 }
202 BedrockAuth::SigV4(_) => self.sign_request(method, url, body, region).await,
203 }
204 }
205
206 async fn sign_request(
208 &self,
209 method: &str,
210 url: &str,
211 body: &[u8],
212 region: &str,
213 ) -> Result<Vec<(String, String)>> {
214 let creds = self.get_credentials().await?;
215
216 let aws_creds = aws_credential_types::Credentials::new(
217 &creds.access_key_id,
218 &creds.secret_access_key,
219 creds.session_token.clone(),
220 creds.expiry(),
221 "bedrock-adapter",
222 );
223
224 let identity = Identity::new(aws_creds, creds.expiry());
225
226 let signing_params = SigningParams::builder()
227 .identity(&identity)
228 .region(region)
229 .name("bedrock")
230 .time(SystemTime::now())
231 .settings(SigningSettings::default())
232 .build()
233 .map_err(|e| Error::auth(e.to_string()))?;
234
235 let signable_request = SignableRequest::new(
236 method,
237 url,
238 std::iter::empty::<(&str, &str)>(),
239 SignableBody::Bytes(body),
240 )
241 .map_err(|e| Error::auth(e.to_string()))?;
242
243 let (signing_instructions, _) = sign(signable_request, &signing_params.into())
244 .map_err(|e| Error::auth(e.to_string()))?
245 .into_parts();
246
247 Ok(signing_instructions
248 .headers()
249 .map(|(name, value)| (name.to_string(), value.to_string()))
250 .collect())
251 }
252
253 async fn execute_request(
254 &self,
255 http: &reqwest::Client,
256 url: &str,
257 body_bytes: Vec<u8>,
258 region: &str,
259 ) -> Result<reqwest::Response> {
260 let headers = self
261 .get_auth_headers("POST", url, &body_bytes, region)
262 .await?;
263 RequestExecutor::post_bytes(http, url, body_bytes, headers).await
264 }
265}
266
267#[async_trait]
268impl ProviderAdapter for BedrockAdapter {
269 fn config(&self) -> &ProviderConfig {
270 &self.config
271 }
272
273 fn name(&self) -> &'static str {
274 "bedrock"
275 }
276
277 async fn build_url(&self, model: &str, stream: bool) -> String {
278 self.build_invoke_url(model, stream)
279 }
280
281 async fn transform_request(&self, request: CreateMessageRequest) -> Result<serde_json::Value> {
282 Ok(self.build_request_body(&request))
283 }
284
285 fn transform_response(&self, response: serde_json::Value) -> Result<ApiResponse> {
286 serde_json::from_value(response).map_err(|e| Error::Parse(e.to_string()))
288 }
289
290 async fn send(
291 &self,
292 http: &reqwest::Client,
293 request: CreateMessageRequest,
294 ) -> Result<ApiResponse> {
295 let model = request.model.clone();
296 let region = self.region_for_model(&model);
297 let url = self.build_invoke_url(&model, false);
298 let body = self.build_request_body(&request);
299 let body_bytes = serde_json::to_vec(&body)?;
300
301 let response = self.execute_request(http, &url, body_bytes, region).await?;
302 let json: serde_json::Value = response.json().await?;
303 self.transform_response(json)
304 }
305
306 async fn send_stream(
307 &self,
308 http: &reqwest::Client,
309 mut request: CreateMessageRequest,
310 ) -> Result<reqwest::Response> {
311 request.stream = Some(true);
312 let model = request.model.clone();
313 let region = self.region_for_model(&model);
314 let url = self.build_invoke_url(&model, true);
315 let body = self.build_request_body(&request);
316 let body_bytes = serde_json::to_vec(&body)?;
317
318 self.execute_request(http, &url, body_bytes, region).await
319 }
320
321 async fn refresh_credentials(&self) -> Result<()> {
322 if matches!(self.auth, BedrockAuth::SigV4(_)) {
323 *self.credentials_cache.write().await = None;
324 self.get_credentials().await?;
325 }
326 Ok(())
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::client::adapter::ModelConfig;
334 use serde_json::json;
335
336 #[test]
337 fn test_url_encoding() {
338 let model = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
339 let encoded = urlencoding::encode(model);
340 assert!(encoded.contains("%3A"));
341 assert!(encoded.contains("global.anthropic"));
342 }
343
344 #[test]
345 fn test_invoke_url_format() {
346 let model = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
347 let encoded = urlencoding::encode(model);
348 let url = format!(
349 "https://bedrock-runtime.us-east-1.amazonaws.com/model/{}/invoke",
350 encoded
351 );
352 assert!(url.contains("bedrock-runtime"));
353 assert!(url.contains("/model/"));
354 assert!(url.contains("/invoke"));
355 assert!(url.contains("%3A"));
356 }
357
358 #[test]
359 fn test_stream_url_format() {
360 let model = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
361 let encoded = urlencoding::encode(model);
362 let url = format!(
363 "https://bedrock-runtime.us-east-1.amazonaws.com/model/{}/invoke-with-response-stream",
364 encoded
365 );
366 assert!(url.contains("/invoke-with-response-stream"));
367 }
368
369 #[test]
370 fn test_model_config() {
371 let config = ModelConfig::bedrock();
372 assert!(config.primary.contains("anthropic"));
373 assert!(config.primary.contains("global"));
374 }
375
376 #[test]
377 fn test_request_body() {
378 let body = json!({
379 "anthropic_version": ANTHROPIC_VERSION,
380 "max_tokens": 1024,
381 "messages": [{"role": "user", "content": "Hello"}],
382 });
383 assert_eq!(body["anthropic_version"], "bedrock-2023-05-31");
384 assert_eq!(body["max_tokens"], 1024);
385 }
386
387 #[test]
388 fn test_beta_header() {
389 let beta_value = BetaFeature::Context1M.header_value();
390 let mut body = json!({
391 "anthropic_version": ANTHROPIC_VERSION,
392 "max_tokens": 1024,
393 "messages": [],
394 });
395 if let Some(obj) = body.as_object_mut() {
396 obj.insert("anthropic_beta".to_string(), json!([beta_value]));
397 }
398 assert_eq!(body["anthropic_beta"][0], beta_value);
399 }
400}