1use async_trait::async_trait;
49use aws_credential_types::Credentials;
50use bytes::Bytes;
51use http::{HeaderMap, HeaderName, HeaderValue};
52use http::{Request, Response};
53use opentelemetry_http::{HttpClient, HttpError};
54use std::fmt::Debug;
55use thiserror::Error;
56
57pub type SigningPredicate = Box<dyn Fn(&Request<Bytes>) -> bool + Send + Sync>;
59
60mod builder;
61pub mod signing;
62
63pub use builder::SigV4ClientBuilder;
64pub use signing::sign_request;
65
66#[derive(Error, Debug)]
68pub enum SigV4Error {
69 #[error("AWS credentials not provided")]
70 MissingCredentials,
71
72 #[error("HTTP client not provided")]
73 MissingClient,
74
75 #[error("Failed to sign request: {0}")]
76 SigningError(#[from] Box<dyn std::error::Error + Send + Sync>),
77
78 #[error("HTTP error: {0}")]
79 HttpError(#[from] http::Error),
80}
81
82pub struct SigV4Client<T: HttpClient> {
84 inner: T,
85 credentials: Credentials,
86 region: String,
87 service: String,
88 should_sign_predicate: Option<SigningPredicate>,
89}
90
91impl<T: HttpClient + Debug> Debug for SigV4Client<T> {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("SigV4Client")
94 .field("inner", &self.inner)
95 .field("credentials", &self.credentials)
96 .field("region", &self.region)
97 .field("service", &self.service)
98 .field("should_sign_predicate", &format_args!("<function>"))
99 .finish()
100 }
101}
102
103impl<T: HttpClient> SigV4Client<T> {
104 pub(crate) fn new(
106 inner: T,
107 credentials: Credentials,
108 region: impl Into<String>,
109 service: impl Into<String>,
110 should_sign_predicate: Option<SigningPredicate>,
111 ) -> Self {
112 Self {
113 inner,
114 credentials,
115 region: region.into(),
116 service: service.into(),
117 should_sign_predicate,
118 }
119 }
120
121 async fn sign_request(&self, request: Request<Bytes>) -> Result<Request<Bytes>, SigV4Error> {
122 if let Some(predicate) = &self.should_sign_predicate {
124 if !predicate(&request) {
125 return Ok(request);
126 }
127 }
128
129 let (parts, body) = request.into_parts();
130
131 let endpoint = format!(
132 "{}://{}{}",
133 parts.uri.scheme_str().unwrap_or("https"),
134 parts.uri.authority().map(|a| a.as_str()).unwrap_or(""),
135 parts.uri.path()
136 );
137
138 let mut reqwest_headers = HeaderMap::new();
140 for (name, value) in parts.headers.iter() {
141 if let (Ok(header_name), Ok(header_value)) = (
142 HeaderName::from_bytes(name.as_ref()),
143 HeaderValue::from_bytes(value.as_bytes()),
144 ) {
145 reqwest_headers.insert(header_name, header_value);
146 }
147 }
148
149 let signed_headers = signing::sign_request(
150 &self.credentials,
151 &endpoint,
152 parts.method.as_str(),
153 &reqwest_headers,
154 body.as_ref(),
155 &self.region,
156 &self.service,
157 )
158 .map_err(SigV4Error::SigningError)?;
159
160 let mut builder = Request::builder().method(parts.method).uri(parts.uri);
162
163 let mut http_headers = parts.headers;
165 for (name, value) in signed_headers.iter() {
166 if let Ok(header_name) = HeaderName::from_bytes(name.as_ref()) {
167 if let Ok(header_value) = HeaderValue::from_bytes(value.as_bytes()) {
168 http_headers.insert(header_name, header_value);
169 }
170 }
171 }
172
173 match builder.headers_mut() {
175 Some(headers_mut) => {
176 *headers_mut = http_headers;
177 }
178 None => {
179 let err = http::Response::builder().status(500).body(()).unwrap_err();
182 return Err(SigV4Error::HttpError(err));
183 }
184 }
185
186 Ok(builder.body(body)?)
187 }
188}
189
190#[async_trait]
191impl<T: HttpClient> HttpClient for SigV4Client<T> {
192 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
193 let signed_request = self
194 .sign_request(request)
195 .await
196 .map_err(|e| Box::new(e) as HttpError)?;
197
198 self.inner.send_bytes(signed_request).await
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use aws_credential_types::Credentials;
206 use http::{Request, Response, StatusCode};
207 use std::sync::Arc;
208
209 #[derive(Debug)]
210 struct MockHttpClient {
211 response: Arc<Response<Bytes>>,
212 }
213
214 impl MockHttpClient {
215 fn new(response: Response<Bytes>) -> Self {
216 Self {
217 response: Arc::new(response),
218 }
219 }
220 }
221
222 #[async_trait]
223 impl HttpClient for MockHttpClient {
224 async fn send_bytes(&self, _request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
225 Ok(Response::builder()
226 .status(self.response.status())
227 .body(self.response.body().clone())
228 .unwrap())
229 }
230 }
231
232 #[tokio::test]
233 async fn test_sigv4_client_signs_request() {
234 let mock_response = Response::builder()
236 .status(StatusCode::OK)
237 .body(Bytes::from("test response"))
238 .unwrap();
239
240 let mock_client = MockHttpClient::new(mock_response);
242
243 let credentials = Credentials::new("test_key", "test_secret", None, None, "test");
245
246 let sigv4_client = SigV4Client::new(mock_client, credentials, "us-east-1", "xray", None);
248
249 let request = Request::builder()
251 .method("POST")
252 .uri("https://xray.us-east-1.amazonaws.com/")
253 .body(Bytes::new())
254 .unwrap();
255
256 let response = sigv4_client.send_bytes(request).await.unwrap();
258
259 assert_eq!(response.status(), StatusCode::OK);
261 assert_eq!(response.body(), &Bytes::from("test response"));
262 }
263
264 #[tokio::test]
265 async fn test_sigv4_client_preserves_headers() {
266 let mock_response = Response::builder()
268 .status(StatusCode::OK)
269 .body(Bytes::from("test response"))
270 .unwrap();
271
272 let mock_client = MockHttpClient::new(mock_response);
274
275 let credentials = Credentials::new("test_key", "test_secret", None, None, "test");
277
278 let sigv4_client = SigV4Client::new(mock_client, credentials, "us-east-1", "xray", None);
280
281 let request = Request::builder()
283 .method("POST")
284 .uri("https://xray.us-east-1.amazonaws.com/")
285 .header("X-Custom-Header", "test-value")
286 .body(Bytes::new())
287 .unwrap();
288
289 let signed_request = sigv4_client.sign_request(request).await.unwrap();
291
292 assert!(signed_request.headers().contains_key("X-Custom-Header"));
294 assert_eq!(
295 signed_request.headers().get("X-Custom-Header").unwrap(),
296 "test-value"
297 );
298
299 assert!(signed_request.headers().contains_key("x-amz-date"));
301 assert!(signed_request.headers().contains_key("authorization"));
302 }
303}