opentelemetry_http/
lib.rs

1use async_trait::async_trait;
2use std::fmt::Debug;
3
4#[doc(no_inline)]
5pub use bytes::Bytes;
6#[doc(no_inline)]
7pub use http::{Request, Response};
8use opentelemetry::propagation::{Extractor, Injector};
9
10/// Helper for injecting headers into HTTP Requests. This is used for OpenTelemetry context
11/// propagation over HTTP.
12/// See [this](https://github.com/open-telemetry/opentelemetry-rust/blob/main/examples/tracing-http-propagator/README.md)
13/// for example usage.
14pub struct HeaderInjector<'a>(pub &'a mut http::HeaderMap);
15
16impl Injector for HeaderInjector<'_> {
17    /// Set a key and value in the HeaderMap.  Does nothing if the key or value are not valid inputs.
18    fn set(&mut self, key: &str, value: String) {
19        if let Ok(name) = http::header::HeaderName::from_bytes(key.as_bytes()) {
20            if let Ok(val) = http::header::HeaderValue::from_str(&value) {
21                self.0.insert(name, val);
22            }
23        }
24    }
25}
26
27/// Helper for extracting headers from HTTP Requests. This is used for OpenTelemetry context
28/// propagation over HTTP.
29/// See [this](https://github.com/open-telemetry/opentelemetry-rust/blob/main/examples/tracing-http-propagator/README.md)
30/// for example usage.
31pub struct HeaderExtractor<'a>(pub &'a http::HeaderMap);
32
33impl Extractor for HeaderExtractor<'_> {
34    /// Get a value for a key from the HeaderMap.  If the value is not valid ASCII, returns None.
35    fn get(&self, key: &str) -> Option<&str> {
36        self.0.get(key).and_then(|value| value.to_str().ok())
37    }
38
39    /// Collect all the keys from the HeaderMap.
40    fn keys(&self) -> Vec<&str> {
41        self.0
42            .keys()
43            .map(|value| value.as_str())
44            .collect::<Vec<_>>()
45    }
46}
47
48pub type HttpError = Box<dyn std::error::Error + Send + Sync + 'static>;
49
50/// A minimal interface necessary for sending requests over HTTP.
51/// Used primarily for exporting telemetry over HTTP. Also used for fetching
52/// sampling strategies for JaegerRemoteSampler
53///
54/// Users sometime choose HTTP clients that relay on a certain async runtime. This trait allows
55/// users to bring their choice of HTTP client.
56#[async_trait]
57pub trait HttpClient: Debug + Send + Sync {
58    /// Send the specified HTTP request with `Vec<u8>` payload
59    ///
60    /// Returns the HTTP response including the status code and body.
61    ///
62    /// Returns an error if it can't connect to the server or the request could not be completed,
63    /// e.g. because of a timeout, infinite redirects, or a loss of connection.
64    #[deprecated(note = "Use `send_bytes` with `Bytes` payload instead.")]
65    async fn send(&self, request: Request<Vec<u8>>) -> Result<Response<Bytes>, HttpError> {
66        self.send_bytes(request.map(Into::into)).await
67    }
68
69    /// Send the specified HTTP request with `Bytes` payload.
70    ///
71    /// Returns the HTTP response including the status code and body.
72    ///
73    /// Returns an error if it can't connect to the server or the request could not be completed,
74    /// e.g. because of a timeout, infinite redirects, or a loss of connection.
75    async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError>;
76}
77
78#[cfg(feature = "reqwest")]
79mod reqwest {
80    use opentelemetry::otel_debug;
81
82    use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
83
84    #[async_trait]
85    impl HttpClient for reqwest::Client {
86        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
87            otel_debug!(name: "ReqwestClient.Send");
88            let request = request.try_into()?;
89            let mut response = self.execute(request).await?.error_for_status()?;
90            let headers = std::mem::take(response.headers_mut());
91            let mut http_response = Response::builder()
92                .status(response.status())
93                .body(response.bytes().await?)?;
94            *http_response.headers_mut() = headers;
95
96            Ok(http_response)
97        }
98    }
99
100    #[cfg(not(target_arch = "wasm32"))]
101    #[cfg(feature = "reqwest-blocking")]
102    #[async_trait]
103    impl HttpClient for reqwest::blocking::Client {
104        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
105            otel_debug!(name: "ReqwestBlockingClient.Send");
106            let request = request.try_into()?;
107            let mut response = self.execute(request)?.error_for_status()?;
108            let headers = std::mem::take(response.headers_mut());
109            let mut http_response = Response::builder()
110                .status(response.status())
111                .body(response.bytes()?)?;
112            *http_response.headers_mut() = headers;
113
114            Ok(http_response)
115        }
116    }
117}
118
119#[cfg(feature = "hyper")]
120pub mod hyper {
121    use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
122    use crate::ResponseExt;
123    use http::HeaderValue;
124    use http_body_util::{BodyExt, Full};
125    use hyper::body::{Body as HttpBody, Frame};
126    use hyper_util::client::legacy::{
127        connect::{Connect, HttpConnector},
128        Client,
129    };
130    use opentelemetry::otel_debug;
131    use std::fmt::Debug;
132    use std::pin::Pin;
133    use std::task::{self, Poll};
134    use std::time::Duration;
135    use tokio::time;
136
137    #[derive(Debug, Clone)]
138    pub struct HyperClient<C = HttpConnector>
139    where
140        C: Connect + Clone + Send + Sync + 'static,
141    {
142        inner: Client<C, Body>,
143        timeout: Duration,
144        authorization: Option<HeaderValue>,
145    }
146
147    impl<C> HyperClient<C>
148    where
149        C: Connect + Clone + Send + Sync + 'static,
150    {
151        pub fn new(connector: C, timeout: Duration, authorization: Option<HeaderValue>) -> Self {
152            // TODO - support custom executor
153            let inner = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);
154            Self {
155                inner,
156                timeout,
157                authorization,
158            }
159        }
160    }
161
162    impl HyperClient<HttpConnector> {
163        /// Creates a new `HyperClient` with a default `HttpConnector`.
164        pub fn with_default_connector(
165            timeout: Duration,
166            authorization: Option<HeaderValue>,
167        ) -> Self {
168            Self::new(HttpConnector::new(), timeout, authorization)
169        }
170    }
171
172    #[async_trait]
173    impl<C> HttpClient for HyperClient<C>
174    where
175        C: Connect + Clone + Send + Sync + 'static,
176        HyperClient<C>: Debug,
177    {
178        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
179            otel_debug!(name: "HyperClient.Send");
180            let (parts, body) = request.into_parts();
181            let mut request = Request::from_parts(parts, Body(Full::from(body)));
182            if let Some(ref authorization) = self.authorization {
183                request
184                    .headers_mut()
185                    .insert(http::header::AUTHORIZATION, authorization.clone());
186            }
187            let mut response = time::timeout(self.timeout, self.inner.request(request)).await??;
188            let headers = std::mem::take(response.headers_mut());
189
190            let mut http_response = Response::builder()
191                .status(response.status())
192                .body(response.into_body().collect().await?.to_bytes())?;
193            *http_response.headers_mut() = headers;
194
195            Ok(http_response.error_for_status()?)
196        }
197    }
198
199    pub struct Body(Full<Bytes>);
200
201    impl HttpBody for Body {
202        type Data = Bytes;
203        type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
204
205        #[inline]
206        fn poll_frame(
207            self: Pin<&mut Self>,
208            cx: &mut task::Context<'_>,
209        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
210            let inner_body = unsafe { self.map_unchecked_mut(|b| &mut b.0) };
211            inner_body.poll_frame(cx).map_err(Into::into)
212        }
213
214        #[inline]
215        fn is_end_stream(&self) -> bool {
216            self.0.is_end_stream()
217        }
218
219        #[inline]
220        fn size_hint(&self) -> hyper::body::SizeHint {
221            self.0.size_hint()
222        }
223    }
224}
225
226/// Methods to make working with responses from the [`HttpClient`] trait easier.
227pub trait ResponseExt: Sized {
228    /// Turn a response into an error if the HTTP status does not indicate success (200 - 299).
229    fn error_for_status(self) -> Result<Self, HttpError>;
230}
231
232impl<T> ResponseExt for Response<T> {
233    fn error_for_status(self) -> Result<Self, HttpError> {
234        if self.status().is_success() {
235            Ok(self)
236        } else {
237            Err(format!("request failed with status {}", self.status()).into())
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn http_headers_get() {
248        let mut carrier = http::HeaderMap::new();
249        HeaderInjector(&mut carrier).set("headerName", "value".to_string());
250
251        assert_eq!(
252            HeaderExtractor(&carrier).get("HEADERNAME"),
253            Some("value"),
254            "case insensitive extraction"
255        )
256    }
257
258    #[test]
259    fn http_headers_keys() {
260        let mut carrier = http::HeaderMap::new();
261        HeaderInjector(&mut carrier).set("headerName1", "value1".to_string());
262        HeaderInjector(&mut carrier).set("headerName2", "value2".to_string());
263
264        let extractor = HeaderExtractor(&carrier);
265        let got = extractor.keys();
266        assert_eq!(got.len(), 2);
267        assert!(got.contains(&"headername1"));
268        assert!(got.contains(&"headername2"));
269    }
270}