1#![cfg(feature = "bedrock")]
20#![cfg_attr(docsrs, doc(cfg(feature = "bedrock")))]
21
22use std::str::FromStr;
23use std::time::SystemTime;
24
25use aws_credential_types::Credentials;
26use aws_sigv4::http_request::{SignableBody, SignableRequest, SigningSettings, sign};
27use aws_sigv4::sign::v4::SigningParams;
28
29use crate::auth::{RequestSigner, SignerResult};
30
31#[derive(Clone)]
36pub struct BedrockCredentials {
37 access_key_id: String,
38 secret_access_key: String,
39 session_token: Option<String>,
40}
41
42impl BedrockCredentials {
43 #[must_use]
46 pub fn new(access_key_id: impl Into<String>, secret_access_key: impl Into<String>) -> Self {
47 Self {
48 access_key_id: access_key_id.into(),
49 secret_access_key: secret_access_key.into(),
50 session_token: None,
51 }
52 }
53
54 #[must_use]
57 pub fn with_session_token(mut self, token: impl Into<String>) -> Self {
58 self.session_token = Some(token.into());
59 self
60 }
61
62 #[must_use]
67 pub fn from_env() -> Option<Self> {
68 let access = std::env::var("AWS_ACCESS_KEY_ID").ok()?;
69 let secret = std::env::var("AWS_SECRET_ACCESS_KEY").ok()?;
70 let mut creds = Self::new(access, secret);
71 if let Ok(token) = std::env::var("AWS_SESSION_TOKEN") {
72 creds = creds.with_session_token(token);
73 }
74 Some(creds)
75 }
76
77 fn to_aws(&self) -> Credentials {
78 Credentials::new(
79 self.access_key_id.clone(),
80 self.secret_access_key.clone(),
81 self.session_token.clone(),
82 None,
83 "claude-api-bedrock-signer",
84 )
85 }
86}
87
88impl std::fmt::Debug for BedrockCredentials {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("BedrockCredentials")
91 .field("access_key_id", &"<redacted>")
92 .field("secret_access_key", &"<redacted>")
93 .field(
94 "session_token",
95 &self.session_token.as_ref().map(|_| "<redacted>"),
96 )
97 .finish()
98 }
99}
100
101#[derive(Debug, Clone)]
106pub struct BedrockSigner {
107 credentials: BedrockCredentials,
108 region: String,
109 service: String,
113}
114
115impl BedrockSigner {
116 #[must_use]
118 pub fn new(credentials: BedrockCredentials, region: impl Into<String>) -> Self {
119 Self {
120 credentials,
121 region: region.into(),
122 service: "bedrock".into(),
123 }
124 }
125
126 #[must_use]
128 pub fn with_service(mut self, service: impl Into<String>) -> Self {
129 self.service = service.into();
130 self
131 }
132}
133
134impl RequestSigner for BedrockSigner {
135 fn sign(&self, request: &mut reqwest::Request) -> SignerResult {
136 let identity = self.credentials.to_aws().into();
137
138 let settings = SigningSettings::default();
139 let params: aws_sigv4::http_request::SigningParams = SigningParams::builder()
140 .identity(&identity)
141 .region(&self.region)
142 .name(&self.service)
143 .time(SystemTime::now())
144 .settings(settings)
145 .build()
146 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?
147 .into();
148
149 let header_strings: Vec<(String, String)> = request
153 .headers()
154 .iter()
155 .filter_map(|(name, value)| {
156 value
157 .to_str()
158 .ok()
159 .map(|v| (name.as_str().to_owned(), v.to_owned()))
160 })
161 .collect();
162 let headers_iter = header_strings.iter().map(|(k, v)| (k.as_str(), v.as_str()));
163
164 let body_bytes = request.body().and_then(|b| b.as_bytes()).unwrap_or(&[]);
165 let signable_body = SignableBody::Bytes(body_bytes);
166
167 let url = request.url().as_str().to_owned();
168 let signable =
169 SignableRequest::new(request.method().as_str(), &url, headers_iter, signable_body)
170 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
171
172 let signing_output = sign(signable, ¶ms)
173 .map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })?;
174 let (instructions, _signature) = signing_output.into_parts();
175
176 for (name, value) in instructions.headers() {
177 let header_name = http::HeaderName::from_str(name)?;
178 let header_value = http::HeaderValue::from_str(value)?;
179 request.headers_mut().insert(header_name, header_value);
180 }
181 Ok(())
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188
189 fn make_request() -> reqwest::Request {
190 let client = reqwest::Client::new();
191 client
192 .post("https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-5-sonnet-20240620-v1:0/invoke")
193 .body(r#"{"messages":[{"role":"user","content":"hi"}]}"#)
194 .build()
195 .unwrap()
196 }
197
198 fn fixed_signer() -> BedrockSigner {
199 BedrockSigner::new(
200 BedrockCredentials::new("AKIDEXAMPLE", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"),
201 "us-east-1",
202 )
203 }
204
205 #[test]
206 fn bedrock_signer_adds_authorization_header() {
207 let signer = fixed_signer();
208 let mut req = make_request();
209 signer.sign(&mut req).expect("sign succeeds");
210
211 let auth = req
212 .headers()
213 .get("authorization")
214 .expect("Authorization header set by signer");
215 let auth_str = auth.to_str().expect("Authorization is ASCII");
216 assert!(
217 auth_str.starts_with("AWS4-HMAC-SHA256 "),
218 "expected sigv4 algorithm prefix: {auth_str}"
219 );
220 assert!(
221 auth_str.contains("Credential=AKIDEXAMPLE/"),
222 "expected access key in credential scope: {auth_str}"
223 );
224 assert!(
225 auth_str.contains("/us-east-1/bedrock/aws4_request"),
226 "expected region+service in credential scope: {auth_str}"
227 );
228 assert!(
229 auth_str.contains("SignedHeaders="),
230 "expected SignedHeaders component: {auth_str}"
231 );
232 assert!(
233 auth_str.contains("Signature="),
234 "expected Signature component: {auth_str}"
235 );
236 }
237
238 #[test]
239 fn bedrock_signer_adds_x_amz_date_header() {
240 let signer = fixed_signer();
241 let mut req = make_request();
242 signer.sign(&mut req).unwrap();
243 let date = req
244 .headers()
245 .get("x-amz-date")
246 .expect("X-Amz-Date header set by signer");
247 let s = date.to_str().unwrap();
248 assert_eq!(s.len(), 16, "date should be 16-char ISO 8601 basic: {s}");
250 assert!(s.ends_with('Z'), "date should be UTC: {s}");
251 }
252
253 #[test]
254 fn bedrock_signer_includes_session_token_when_present() {
255 let creds =
256 BedrockCredentials::new("AKID", "SECRET").with_session_token("session-token-value");
257 let signer = BedrockSigner::new(creds, "us-west-2");
258 let mut req = make_request();
259 signer.sign(&mut req).unwrap();
260 let token = req
261 .headers()
262 .get("x-amz-security-token")
263 .expect("X-Amz-Security-Token forwarded by signer");
264 assert_eq!(token.to_str().unwrap(), "session-token-value");
265 }
266
267 #[test]
268 fn bedrock_credentials_redact_secret_in_debug() {
269 let creds =
270 BedrockCredentials::new("AKID", "VERY-SECRET").with_session_token("ALSO-SECRET");
271 let dbg = format!("{creds:?}");
272 assert!(!dbg.contains("VERY-SECRET"), "{dbg}");
273 assert!(!dbg.contains("ALSO-SECRET"), "{dbg}");
274 assert!(dbg.contains("redacted"), "{dbg}");
275 }
276
277 #[test]
278 fn from_env_returns_none_when_missing() {
279 let _: Option<BedrockCredentials> = BedrockCredentials::from_env();
284 }
285
286 #[test]
287 fn signer_default_service_name_is_bedrock() {
288 let signer = fixed_signer();
289 assert_eq!(signer.service, "bedrock");
290 }
291
292 #[test]
293 fn signer_with_service_override() {
294 let signer = fixed_signer().with_service("bedrock-runtime");
295 assert_eq!(signer.service, "bedrock-runtime");
296 }
297}