1#![cfg(feature = "bedrock")]
37#![cfg_attr(docsrs, doc(cfg(feature = "bedrock")))]
38
39use std::str::FromStr;
40use std::time::SystemTime;
41
42use aws_credential_types::Credentials;
43use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
44use aws_sigv4::sign::v4::SigningParams;
45
46use crate::auth::{RequestSigner, SignerResult};
47
48#[derive(Clone)]
53pub struct BedrockCredentials {
54 access_key_id: String,
55 secret_access_key: String,
56 session_token: Option<String>,
57}
58
59impl BedrockCredentials {
60 #[must_use]
63 pub fn new(access_key_id: impl Into<String>, secret_access_key: impl Into<String>) -> Self {
64 Self {
65 access_key_id: access_key_id.into(),
66 secret_access_key: secret_access_key.into(),
67 session_token: None,
68 }
69 }
70
71 #[must_use]
74 pub fn with_session_token(mut self, token: impl Into<String>) -> Self {
75 self.session_token = Some(token.into());
76 self
77 }
78
79 #[must_use]
84 pub fn from_env() -> Option<Self> {
85 let access = std::env::var("AWS_ACCESS_KEY_ID").ok()?;
86 let secret = std::env::var("AWS_SECRET_ACCESS_KEY").ok()?;
87 let mut creds = Self::new(access, secret);
88 if let Ok(token) = std::env::var("AWS_SESSION_TOKEN") {
89 creds = creds.with_session_token(token);
90 }
91 Some(creds)
92 }
93
94 fn to_aws(&self) -> Credentials {
95 Credentials::new(
96 self.access_key_id.clone(),
97 self.secret_access_key.clone(),
98 self.session_token.clone(),
99 None,
100 "claude-api-bedrock-signer",
101 )
102 }
103}
104
105impl std::fmt::Debug for BedrockCredentials {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 f.debug_struct("BedrockCredentials")
108 .field("access_key_id", &"<redacted>")
109 .field("secret_access_key", &"<redacted>")
110 .field(
111 "session_token",
112 &self.session_token.as_ref().map(|_| "<redacted>"),
113 )
114 .finish()
115 }
116}
117
118#[derive(Debug, Clone)]
123pub struct BedrockSigner {
124 credentials: BedrockCredentials,
125 region: String,
126 service: String,
130}
131
132impl BedrockSigner {
133 #[must_use]
135 pub fn new(credentials: BedrockCredentials, region: impl Into<String>) -> Self {
136 Self {
137 credentials,
138 region: region.into(),
139 service: "bedrock".into(),
140 }
141 }
142
143 #[must_use]
145 pub fn with_service(mut self, service: impl Into<String>) -> Self {
146 self.service = service.into();
147 self
148 }
149}
150
151impl RequestSigner for BedrockSigner {
152 fn sign(&self, request: &mut reqwest::Request) -> SignerResult {
153 let identity = self.credentials.to_aws().into();
154
155 let settings = SigningSettings::default();
156 let params: aws_sigv4::http_request::SigningParams = SigningParams::builder()
157 .identity(&identity)
158 .region(&self.region)
159 .name(&self.service)
160 .time(SystemTime::now())
161 .settings(settings)
162 .build()
163 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?
164 .into();
165
166 let header_strings: Vec<(String, String)> = request
170 .headers()
171 .iter()
172 .filter_map(|(name, value)| {
173 value
174 .to_str()
175 .ok()
176 .map(|v| (name.as_str().to_owned(), v.to_owned()))
177 })
178 .collect();
179 let headers_iter = header_strings.iter().map(|(k, v)| (k.as_str(), v.as_str()));
180
181 let body_bytes = request.body().and_then(|b| b.as_bytes()).unwrap_or(&[]);
182 let signable_body = SignableBody::Bytes(body_bytes);
183
184 let url = request.url().as_str().to_owned();
185 let signable =
186 SignableRequest::new(request.method().as_str(), &url, headers_iter, signable_body)
187 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
188
189 let signing_output = sign(signable, ¶ms)
190 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
191 let (instructions, _signature) = signing_output.into_parts();
192
193 for (name, value) in instructions.headers() {
194 let header_name = http::HeaderName::from_str(name)?;
195 let header_value = http::HeaderValue::from_str(value)?;
196 request.headers_mut().insert(header_name, header_value);
197 }
198 Ok(())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 fn make_request() -> reqwest::Request {
207 let client = reqwest::Client::new();
208 client
209 .post("https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20240620-v1:0/invoke")
210 .body(r#"{"messages":[{"role":"user","content":"hi"}]}"#)
211 .build()
212 .unwrap()
213 }
214
215 fn fixed_signer() -> BedrockSigner {
216 BedrockSigner::new(
217 BedrockCredentials::new("AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"),
218 "us-east-1",
219 )
220 }
221
222 #[test]
223 fn bedrock_signer_adds_authorization_header() {
224 let signer = fixed_signer();
225 let mut req = make_request();
226 signer.sign(&mut req).expect("sign succeeds");
227
228 let auth = req
229 .headers()
230 .get("authorization")
231 .expect("Authorization header set by signer");
232 let auth_str = auth.to_str().expect("Authorization is ASCII");
233 assert!(
234 auth_str.starts_with("AWS4-HMAC-SHA256 "),
235 "expected sigv4 algorithm prefix: {auth_str}"
236 );
237 assert!(
238 auth_str.contains("Credential=AKIDEXAMPLE/"),
239 "expected access key in credential scope: {auth_str}"
240 );
241 assert!(
242 auth_str.contains("/us-east-1/bedrock/aws4_request"),
243 "expected region+service in credential scope: {auth_str}"
244 );
245 assert!(
246 auth_str.contains("SignedHeaders="),
247 "expected SignedHeaders component: {auth_str}"
248 );
249 assert!(
250 auth_str.contains("Signature="),
251 "expected Signature component: {auth_str}"
252 );
253 }
254
255 #[test]
256 fn bedrock_signer_adds_x_amz_date_header() {
257 let signer = fixed_signer();
258 let mut req = make_request();
259 signer.sign(&mut req).unwrap();
260 let date = req
261 .headers()
262 .get("x-amz-date")
263 .expect("X-Amz-Date header set by signer");
264 let s = date.to_str().unwrap();
265 assert_eq!(s.len(), 16, "date should be 16-char ISO 8601 basic: {s}");
267 assert!(s.ends_with('Z'), "date should be UTC: {s}");
268 }
269
270 #[test]
271 fn bedrock_signer_includes_session_token_when_present() {
272 let creds =
273 BedrockCredentials::new("AKID", "SECRET").with_session_token("session-token-value");
274 let signer = BedrockSigner::new(creds, "us-west-2");
275 let mut req = make_request();
276 signer.sign(&mut req).unwrap();
277 let token = req
278 .headers()
279 .get("x-amz-security-token")
280 .expect("X-Amz-Security-Token forwarded by signer");
281 assert_eq!(token.to_str().unwrap(), "session-token-value");
282 }
283
284 #[test]
285 fn bedrock_credentials_redact_secret_in_debug() {
286 let creds =
287 BedrockCredentials::new("AKID", "VERY-SECRET").with_session_token("ALSO-SECRET");
288 let dbg = format!("{creds:?}");
289 assert!(!dbg.contains("VERY-SECRET"), "{dbg}");
290 assert!(!dbg.contains("ALSO-SECRET"), "{dbg}");
291 assert!(dbg.contains("redacted"), "{dbg}");
292 }
293
294 #[test]
295 fn from_env_returns_none_when_missing() {
296 let _: Option<BedrockCredentials> = BedrockCredentials::from_env();
301 }
302
303 #[test]
304 fn signer_default_service_name_is_bedrock() {
305 let signer = fixed_signer();
306 assert_eq!(signer.service, "bedrock");
307 }
308
309 #[test]
310 fn signer_with_service_override() {
311 let signer = fixed_signer().with_service("bedrock-runtime");
312 assert_eq!(signer.service, "bedrock-runtime");
313 }
314}