otlp_sigv4_client/
lib.rs

1//! A SigV4-compatible HTTP client wrapper for OpenTelemetry OTLP exporters.
2//!
3//! This crate provides a wrapper that adds AWS SigV4 signing capabilities to any OpenTelemetry
4//! HTTP client implementation. It's particularly useful when sending telemetry data to AWS services
5//! that require SigV4 authentication. This crate is part of the
6//! [serverless-otlp-forwarder](https://github.com/dev7a/serverless-otlp-forwarder/) project, which provides
7//! a comprehensive solution for OpenTelemetry telemetry collection in AWS Lambda environments.
8//!
9//! # Features
10//!
11//! - `reqwest` - Includes reqwest as a dependency (enabled by default)
12//!
13//! The client works with any HTTP client that implements the `opentelemetry_http::HttpClient` trait.
14//!
15//! # Example
16//!
17//! ```no_run
18//! use aws_credential_types::Credentials;
19//! use otlp_sigv4_client::SigV4ClientBuilder;
20//! use opentelemetry_otlp::{HttpExporterBuilder, WithHttpConfig};
21//! use reqwest::Client as ReqwestClient;
22//!
23//! #[tokio::main]
24//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
25//!     let credentials = Credentials::new(
26//!         "access_key",
27//!         "secret_key",
28//!         None,
29//!         None,
30//!         "example",
31//!     );
32//!
33//!     let sigv4_client = SigV4ClientBuilder::new()
34//!         .with_client(ReqwestClient::new())
35//!         .with_credentials(credentials)
36//!         .with_region("us-west-2")
37//!         .with_service("xray")
38//!         .build()?;
39//!
40//!     let _exporter = HttpExporterBuilder::default()
41//!         .with_http_client(sigv4_client)
42//!         .build_span_exporter()?;
43//!
44//!     Ok(())
45//! }
46//! ```
47
48use async_trait::async_trait;
49use aws_credential_types::Credentials;
50use bytes::Bytes;
51use http::{HeaderMap, HeaderName, HeaderValue};
52use http::{Request, Response};
53use opentelemetry_http::{HttpClient, HttpError};
54use std::fmt::Debug;
55use thiserror::Error;
56
57/// Type alias for the request signing predicate
58pub type SigningPredicate = Box<dyn Fn(&Request<Bytes>) -> bool + Send + Sync>;
59
60mod builder;
61pub mod signing;
62
63pub use builder::SigV4ClientBuilder;
64pub use signing::sign_request;
65
66/// Errors that can occur during SigV4 client operations
67#[derive(Error, Debug)]
68pub enum SigV4Error {
69    #[error("AWS credentials not provided")]
70    MissingCredentials,
71
72    #[error("HTTP client not provided")]
73    MissingClient,
74
75    #[error("Failed to sign request: {0}")]
76    SigningError(#[from] Box<dyn std::error::Error + Send + Sync>),
77
78    #[error("HTTP error: {0}")]
79    HttpError(#[from] http::Error),
80}
81
82/// A decorator that adds SigV4 signing capabilities to any HttpClient implementation
83pub struct SigV4Client<T: HttpClient> {
84    inner: T,
85    credentials: Credentials,
86    region: String,
87    service: String,
88    should_sign_predicate: Option<SigningPredicate>,
89}
90
91impl<T: HttpClient + Debug> Debug for SigV4Client<T> {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("SigV4Client")
94            .field("inner", &self.inner)
95            .field("credentials", &self.credentials)
96            .field("region", &self.region)
97            .field("service", &self.service)
98            .field("should_sign_predicate", &format_args!("<function>"))
99            .finish()
100    }
101}
102
103impl<T: HttpClient> SigV4Client<T> {
104    /// Creates a new SigV4Client with the given parameters
105    pub(crate) fn new(
106        inner: T,
107        credentials: Credentials,
108        region: impl Into<String>,
109        service: impl Into<String>,
110        should_sign_predicate: Option<SigningPredicate>,
111    ) -> Self {
112        Self {
113            inner,
114            credentials,
115            region: region.into(),
116            service: service.into(),
117            should_sign_predicate,
118        }
119    }
120
121    async fn sign_request(&self, request: Request<Bytes>) -> Result<Request<Bytes>, SigV4Error> {
122        // Check if we should sign this request
123        if let Some(predicate) = &self.should_sign_predicate {
124            if !predicate(&request) {
125                return Ok(request);
126            }
127        }
128
129        let (parts, body) = request.into_parts();
130
131        let endpoint = format!(
132            "{}://{}{}",
133            parts.uri.scheme_str().unwrap_or("https"),
134            parts.uri.authority().map(|a| a.as_str()).unwrap_or(""),
135            parts.uri.path()
136        );
137
138        // Convert http::HeaderMap to reqwest::HeaderMap
139        let mut reqwest_headers = HeaderMap::new();
140        for (name, value) in parts.headers.iter() {
141            if let (Ok(header_name), Ok(header_value)) = (
142                HeaderName::from_bytes(name.as_ref()),
143                HeaderValue::from_bytes(value.as_bytes()),
144            ) {
145                reqwest_headers.insert(header_name, header_value);
146            }
147        }
148
149        let signed_headers = signing::sign_request(
150            &self.credentials,
151            &endpoint,
152            parts.method.as_str(),
153            &reqwest_headers,
154            body.as_ref(),
155            &self.region,
156            &self.service,
157        )
158        .map_err(SigV4Error::SigningError)?;
159
160        // Rebuild request with signed headers
161        let mut builder = Request::builder().method(parts.method).uri(parts.uri);
162
163        // Convert reqwest::HeaderMap to http::HeaderMap and preserve original headers
164        let mut http_headers = parts.headers;
165        for (name, value) in signed_headers.iter() {
166            if let Ok(header_name) = HeaderName::from_bytes(name.as_ref()) {
167                if let Ok(header_value) = HeaderValue::from_bytes(value.as_bytes()) {
168                    http_headers.insert(header_name, header_value);
169                }
170            }
171        }
172
173        // Set the headers on the builder
174        match builder.headers_mut() {
175            Some(headers_mut) => {
176                *headers_mut = http_headers;
177            }
178            None => {
179                // If we can't get the headers, we create a minimal HTTP error
180                // Note: http::Error doesn't have public constructors in recent versions
181                let err = http::Response::builder().status(500).body(()).unwrap_err();
182                return Err(SigV4Error::HttpError(err));
183            }
184        }
185
186        Ok(builder.body(body)?)
187    }
188}
189
190#[async_trait]
191impl<T: HttpClient> HttpClient for SigV4Client<T> {
192    async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
193        let signed_request = self
194            .sign_request(request)
195            .await
196            .map_err(|e| Box::new(e) as HttpError)?;
197
198        self.inner.send_bytes(signed_request).await
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use aws_credential_types::Credentials;
206    use http::{Request, Response, StatusCode};
207    use std::sync::Arc;
208
209    #[derive(Debug)]
210    struct MockHttpClient {
211        response: Arc<Response<Bytes>>,
212    }
213
214    impl MockHttpClient {
215        fn new(response: Response<Bytes>) -> Self {
216            Self {
217                response: Arc::new(response),
218            }
219        }
220    }
221
222    #[async_trait]
223    impl HttpClient for MockHttpClient {
224        async fn send_bytes(&self, _request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
225            Ok(Response::builder()
226                .status(self.response.status())
227                .body(self.response.body().clone())
228                .unwrap())
229        }
230    }
231
232    #[tokio::test]
233    async fn test_sigv4_client_signs_request() {
234        // Create a mock response
235        let mock_response = Response::builder()
236            .status(StatusCode::OK)
237            .body(Bytes::from("test response"))
238            .unwrap();
239
240        // Create the mock client
241        let mock_client = MockHttpClient::new(mock_response);
242
243        // Create test credentials
244        let credentials = Credentials::new("test_key", "test_secret", None, None, "test");
245
246        // Create the SigV4 client
247        let sigv4_client = SigV4Client::new(mock_client, credentials, "us-east-1", "xray", None);
248
249        // Create a test request
250        let request = Request::builder()
251            .method("POST")
252            .uri("https://xray.us-east-1.amazonaws.com/")
253            .body(Bytes::new())
254            .unwrap();
255
256        // Send the request
257        let response = sigv4_client.send_bytes(request).await.unwrap();
258
259        // Verify the response
260        assert_eq!(response.status(), StatusCode::OK);
261        assert_eq!(response.body(), &Bytes::from("test response"));
262    }
263
264    #[tokio::test]
265    async fn test_sigv4_client_preserves_headers() {
266        // Create a mock response
267        let mock_response = Response::builder()
268            .status(StatusCode::OK)
269            .body(Bytes::from("test response"))
270            .unwrap();
271
272        // Create the mock client
273        let mock_client = MockHttpClient::new(mock_response);
274
275        // Create test credentials
276        let credentials = Credentials::new("test_key", "test_secret", None, None, "test");
277
278        // Create the SigV4 client
279        let sigv4_client = SigV4Client::new(mock_client, credentials, "us-east-1", "xray", None);
280
281        // Create a test request with custom headers
282        let request = Request::builder()
283            .method("POST")
284            .uri("https://xray.us-east-1.amazonaws.com/")
285            .header("X-Custom-Header", "test-value")
286            .body(Bytes::new())
287            .unwrap();
288
289        // Sign the request
290        let signed_request = sigv4_client.sign_request(request).await.unwrap();
291
292        // Verify that the original header is preserved
293        assert!(signed_request.headers().contains_key("X-Custom-Header"));
294        assert_eq!(
295            signed_request.headers().get("X-Custom-Header").unwrap(),
296            "test-value"
297        );
298
299        // Verify that AWS SigV4 headers are added
300        assert!(signed_request.headers().contains_key("x-amz-date"));
301        assert!(signed_request.headers().contains_key("authorization"));
302    }
303}