claude_agent/client/adapter/
bedrock.rs

1//! AWS Bedrock adapter using InvokeModel API (Messages API compatible).
2//!
3//! Uses the official Anthropic Messages API format with SigV4 signing.
4//! Supports global and regional endpoints as documented at:
5//! <https://platform.claude.com/docs/en/build-with-claude/claude-on-amazon-bedrock>
6
7use std::sync::Arc;
8use std::time::SystemTime;
9
10use async_trait::async_trait;
11use aws_config::BehaviorVersion;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
14use aws_sigv4::sign::v4::SigningParams;
15use aws_smithy_runtime_api::client::identity::Identity;
16
17use super::base::RequestExecutor;
18use super::config::{BetaFeature, ProviderConfig};
19use super::request::{add_beta_features, build_messages_body};
20use super::token_cache::{AwsCredentialsCache, CachedAwsCredentials, new_aws_credentials_cache};
21use super::traits::ProviderAdapter;
22use crate::client::messages::CreateMessageRequest;
23use crate::types::ApiResponse;
24use crate::{Error, Result};
25
26const ANTHROPIC_VERSION: &str = "bedrock-2023-05-31";
27
28/// Bedrock adapter using InvokeModel API with Messages API format.
29#[derive(Debug)]
30pub struct BedrockAdapter {
31    config: ProviderConfig,
32    region: String,
33    small_model_region: Option<String>,
34    use_global_endpoint: bool,
35    enable_1m_context: bool,
36    auth: BedrockAuth,
37    credentials_cache: AwsCredentialsCache,
38}
39
40#[derive(Debug)]
41enum BedrockAuth {
42    SigV4(Arc<dyn ProvideCredentials>),
43    BearerToken(String),
44}
45
46impl BedrockAdapter {
47    /// Create adapter from environment variables.
48    pub async fn from_env(config: ProviderConfig) -> Result<Self> {
49        let bedrock_config = crate::config::BedrockConfig::from_env();
50        Self::from_config(config, bedrock_config).await
51    }
52
53    /// Create adapter from explicit configuration.
54    pub async fn from_config(
55        config: ProviderConfig,
56        bedrock: crate::config::BedrockConfig,
57    ) -> Result<Self> {
58        let region = bedrock.region.unwrap_or_else(|| "us-east-1".into());
59
60        let auth = if let Some(token) = bedrock.bearer_token {
61            BedrockAuth::BearerToken(token)
62        } else {
63            let aws_config = aws_config::load_defaults(BehaviorVersion::latest()).await;
64            let credentials = aws_config
65                .credentials_provider()
66                .ok_or_else(|| Error::auth("No AWS credentials found"))?;
67            BedrockAuth::SigV4(Arc::from(credentials))
68        };
69
70        Ok(Self {
71            config,
72            region,
73            small_model_region: bedrock.small_model_region,
74            use_global_endpoint: bedrock.use_global_endpoint,
75            enable_1m_context: bedrock.enable_1m_context,
76            auth,
77            credentials_cache: new_aws_credentials_cache(),
78        })
79    }
80
81    /// Set the AWS region.
82    pub fn with_region(mut self, region: impl Into<String>) -> Self {
83        self.region = region.into();
84        self
85    }
86
87    /// Set a separate region for small/fast models (e.g., Haiku).
88    pub fn with_small_model_region(mut self, region: impl Into<String>) -> Self {
89        self.small_model_region = Some(region.into());
90        self
91    }
92
93    /// Enable or disable global endpoint (default: true).
94    pub fn with_global_endpoint(mut self, enable: bool) -> Self {
95        self.use_global_endpoint = enable;
96        self
97    }
98
99    /// Enable 1M context window beta feature.
100    pub fn with_1m_context(mut self, enable: bool) -> Self {
101        self.enable_1m_context = enable;
102        self
103    }
104
105    /// Set bearer token authentication.
106    pub fn with_bearer_token(mut self, token: impl Into<String>) -> Self {
107        self.auth = BedrockAuth::BearerToken(token.into());
108        self
109    }
110
111    /// Get the effective region for a given model.
112    fn region_for_model(&self, model: &str) -> &str {
113        if let Some(ref small_region) = self.small_model_region
114            && model.contains("haiku")
115        {
116            return small_region;
117        }
118        &self.region
119    }
120
121    fn build_invoke_url(&self, model: &str, stream: bool) -> String {
122        let region = self.region_for_model(model);
123        let endpoint = if stream {
124            "invoke-with-response-stream"
125        } else {
126            "invoke"
127        };
128        let encoded_model = urlencoding::encode(model);
129
130        format!(
131            "https://bedrock-runtime.{}.amazonaws.com/model/{}/{}",
132            region, encoded_model, endpoint
133        )
134    }
135
136    /// Build Messages API compatible request body.
137    fn build_request_body(&self, request: &CreateMessageRequest) -> serde_json::Value {
138        let mut body = build_messages_body(
139            request,
140            Some(ANTHROPIC_VERSION),
141            self.config.thinking_budget,
142        );
143
144        // Bedrock doesn't include model in body (it's in the URL)
145        if let Some(obj) = body.as_object_mut() {
146            obj.remove("model");
147        }
148
149        if self.enable_1m_context {
150            add_beta_features(&mut body, &[BetaFeature::Context1M.header_value()]);
151        }
152
153        body
154    }
155
156    /// Get cached or fresh AWS credentials.
157    async fn get_credentials(&self) -> Result<CachedAwsCredentials> {
158        let provider = match &self.auth {
159            BedrockAuth::SigV4(p) => p,
160            BedrockAuth::BearerToken(_) => {
161                return Err(Error::auth("Bearer token mode does not use credentials"));
162            }
163        };
164
165        {
166            let cache = self.credentials_cache.read().await;
167            if let Some(ref creds) = *cache
168                && !creds.is_expired()
169            {
170                return Ok(creds.clone());
171            }
172        }
173
174        let creds = provider
175            .provide_credentials()
176            .await
177            .map_err(|e| Error::auth(e.to_string()))?;
178
179        let cached = CachedAwsCredentials::new(
180            creds.access_key_id().to_string(),
181            creds.secret_access_key().to_string(),
182            creds.session_token().map(|s| s.to_string()),
183            creds.expiry(),
184        );
185
186        *self.credentials_cache.write().await = Some(cached.clone());
187        Ok(cached)
188    }
189
190    /// Get authorization headers for a request.
191    async fn get_auth_headers(
192        &self,
193        method: &str,
194        url: &str,
195        body: &[u8],
196        region: &str,
197    ) -> Result<Vec<(String, String)>> {
198        match &self.auth {
199            BedrockAuth::BearerToken(token) => {
200                Ok(vec![("Authorization".into(), format!("Bearer {}", token))])
201            }
202            BedrockAuth::SigV4(_) => self.sign_request(method, url, body, region).await,
203        }
204    }
205
206    /// Sign request with SigV4.
207    async fn sign_request(
208        &self,
209        method: &str,
210        url: &str,
211        body: &[u8],
212        region: &str,
213    ) -> Result<Vec<(String, String)>> {
214        let creds = self.get_credentials().await?;
215
216        let aws_creds = aws_credential_types::Credentials::new(
217            &creds.access_key_id,
218            &creds.secret_access_key,
219            creds.session_token.clone(),
220            creds.expiry(),
221            "bedrock-adapter",
222        );
223
224        let identity = Identity::new(aws_creds, creds.expiry());
225
226        let signing_params = SigningParams::builder()
227            .identity(&identity)
228            .region(region)
229            .name("bedrock")
230            .time(SystemTime::now())
231            .settings(SigningSettings::default())
232            .build()
233            .map_err(|e| Error::auth(e.to_string()))?;
234
235        let signable_request = SignableRequest::new(
236            method,
237            url,
238            std::iter::empty::<(&str, &str)>(),
239            SignableBody::Bytes(body),
240        )
241        .map_err(|e| Error::auth(e.to_string()))?;
242
243        let (signing_instructions, _) = sign(signable_request, &signing_params.into())
244            .map_err(|e| Error::auth(e.to_string()))?
245            .into_parts();
246
247        Ok(signing_instructions
248            .headers()
249            .map(|(name, value)| (name.to_string(), value.to_string()))
250            .collect())
251    }
252
253    async fn execute_request(
254        &self,
255        http: &reqwest::Client,
256        url: &str,
257        body_bytes: Vec<u8>,
258        region: &str,
259    ) -> Result<reqwest::Response> {
260        let headers = self
261            .get_auth_headers("POST", url, &body_bytes, region)
262            .await?;
263        RequestExecutor::post_bytes(http, url, body_bytes, headers).await
264    }
265}
266
267#[async_trait]
268impl ProviderAdapter for BedrockAdapter {
269    fn config(&self) -> &ProviderConfig {
270        &self.config
271    }
272
273    fn name(&self) -> &'static str {
274        "bedrock"
275    }
276
277    async fn build_url(&self, model: &str, stream: bool) -> String {
278        self.build_invoke_url(model, stream)
279    }
280
281    async fn transform_request(&self, request: CreateMessageRequest) -> Result<serde_json::Value> {
282        Ok(self.build_request_body(&request))
283    }
284
285    fn transform_response(&self, response: serde_json::Value) -> Result<ApiResponse> {
286        // InvokeModel returns Messages API format directly
287        serde_json::from_value(response).map_err(|e| Error::Parse(e.to_string()))
288    }
289
290    async fn send(
291        &self,
292        http: &reqwest::Client,
293        request: CreateMessageRequest,
294    ) -> Result<ApiResponse> {
295        let model = request.model.clone();
296        let region = self.region_for_model(&model);
297        let url = self.build_invoke_url(&model, false);
298        let body = self.build_request_body(&request);
299        let body_bytes = serde_json::to_vec(&body)?;
300
301        let response = self.execute_request(http, &url, body_bytes, region).await?;
302        let json: serde_json::Value = response.json().await?;
303        self.transform_response(json)
304    }
305
306    async fn send_stream(
307        &self,
308        http: &reqwest::Client,
309        mut request: CreateMessageRequest,
310    ) -> Result<reqwest::Response> {
311        request.stream = Some(true);
312        let model = request.model.clone();
313        let region = self.region_for_model(&model);
314        let url = self.build_invoke_url(&model, true);
315        let body = self.build_request_body(&request);
316        let body_bytes = serde_json::to_vec(&body)?;
317
318        self.execute_request(http, &url, body_bytes, region).await
319    }
320
321    async fn refresh_credentials(&self) -> Result<()> {
322        if matches!(self.auth, BedrockAuth::SigV4(_)) {
323            *self.credentials_cache.write().await = None;
324            self.get_credentials().await?;
325        }
326        Ok(())
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::client::adapter::ModelConfig;
334    use serde_json::json;
335
336    #[test]
337    fn test_url_encoding() {
338        let model = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
339        let encoded = urlencoding::encode(model);
340        assert!(encoded.contains("%3A"));
341        assert!(encoded.contains("global.anthropic"));
342    }
343
344    #[test]
345    fn test_invoke_url_format() {
346        let model = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
347        let encoded = urlencoding::encode(model);
348        let url = format!(
349            "https://bedrock-runtime.us-east-1.amazonaws.com/model/{}/invoke",
350            encoded
351        );
352        assert!(url.contains("bedrock-runtime"));
353        assert!(url.contains("/model/"));
354        assert!(url.contains("/invoke"));
355        assert!(url.contains("%3A"));
356    }
357
358    #[test]
359    fn test_stream_url_format() {
360        let model = "global.anthropic.claude-sonnet-4-5-20250929-v1:0";
361        let encoded = urlencoding::encode(model);
362        let url = format!(
363            "https://bedrock-runtime.us-east-1.amazonaws.com/model/{}/invoke-with-response-stream",
364            encoded
365        );
366        assert!(url.contains("/invoke-with-response-stream"));
367    }
368
369    #[test]
370    fn test_model_config() {
371        let config = ModelConfig::bedrock();
372        assert!(config.primary.contains("anthropic"));
373        assert!(config.primary.contains("global"));
374    }
375
376    #[test]
377    fn test_request_body() {
378        let body = json!({
379            "anthropic_version": ANTHROPIC_VERSION,
380            "max_tokens": 1024,
381            "messages": [{"role": "user", "content": "Hello"}],
382        });
383        assert_eq!(body["anthropic_version"], "bedrock-2023-05-31");
384        assert_eq!(body["max_tokens"], 1024);
385    }
386
387    #[test]
388    fn test_beta_header() {
389        let beta_value = BetaFeature::Context1M.header_value();
390        let mut body = json!({
391            "anthropic_version": ANTHROPIC_VERSION,
392            "max_tokens": 1024,
393            "messages": [],
394        });
395        if let Some(obj) = body.as_object_mut() {
396            obj.insert("anthropic_beta".to_string(), json!([beta_value]));
397        }
398        assert_eq!(body["anthropic_beta"][0], beta_value);
399    }
400}