opentelemetry_http/
lib.rs

1use async_trait::async_trait;
2use std::fmt::Debug;
3
4#[doc(no_inline)]
5pub use bytes::Bytes;
6#[doc(no_inline)]
7pub use http::{Request, Response};
8use opentelemetry::propagation::{Extractor, Injector};
9
10/// Helper for injecting headers into HTTP Requests. This is used for OpenTelemetry context
11/// propagation over HTTP.
12/// See [this](https://github.com/open-telemetry/opentelemetry-rust/blob/main/examples/tracing-http-propagator/README.md)
13/// for example usage.
14pub struct HeaderInjector<'a>(pub &'a mut http::HeaderMap);
15
16impl Injector for HeaderInjector<'_> {
17    /// Set a key and value in the HeaderMap.  Does nothing if the key or value are not valid inputs.
18    fn set(&mut self, key: &str, value: String) {
19        if let Ok(name) = http::header::HeaderName::from_bytes(key.as_bytes()) {
20            if let Ok(val) = http::header::HeaderValue::from_str(&value) {
21                self.0.insert(name, val);
22            }
23        }
24    }
25}
26
27/// Helper for extracting headers from HTTP Requests. This is used for OpenTelemetry context
28/// propagation over HTTP.
29/// See [this](https://github.com/open-telemetry/opentelemetry-rust/blob/main/examples/tracing-http-propagator/README.md)
30/// for example usage.
31pub struct HeaderExtractor<'a>(pub &'a http::HeaderMap);
32
33impl Extractor for HeaderExtractor<'_> {
34    /// Get a value for a key from the HeaderMap.  If the value is not valid ASCII, returns None.
35    fn get(&self, key: &str) -> Option<&str> {
36        self.0.get(key).and_then(|value| value.to_str().ok())
37    }
38
39    /// Collect all the keys from the HeaderMap.
40    fn keys(&self) -> Vec<&str> {
41        self.0
42            .keys()
43            .map(|value| value.as_str())
44            .collect::<Vec<_>>()
45    }
46
47    /// Get all the values for a key from the HeaderMap
48    fn get_all(&self, key: &str) -> Option<Vec<&str>> {
49        let all_iter = self.0.get_all(key).iter();
50        if let (0, Some(0)) = all_iter.size_hint() {
51            return None;
52        }
53
54        Some(all_iter.filter_map(|value| value.to_str().ok()).collect())
55    }
56}
57
58pub type HttpError = Box<dyn std::error::Error + Send + Sync + 'static>;
59
60/// A minimal interface necessary for sending requests over HTTP.
61/// Used primarily for exporting telemetry over HTTP. Also used for fetching
62/// sampling strategies for JaegerRemoteSampler
63///
64/// Users sometime choose HTTP clients that relay on a certain async runtime. This trait allows
65/// users to bring their choice of HTTP client.
66#[async_trait]
67pub trait HttpClient: Debug + Send + Sync {
68    /// Send the specified HTTP request with `Vec<u8>` payload
69    ///
70    /// Returns the HTTP response including the status code and body.
71    ///
72    /// Returns an error if it can't connect to the server or the request could not be completed,
73    /// e.g. because of a timeout, infinite redirects, or a loss of connection.
74    #[deprecated(note = "Use `send_bytes` with `Bytes` payload instead.")]
75    async fn send(&self, request: Request<Vec<u8>>) -> Result<Response<Bytes>, HttpError> {
76        self.send_bytes(request.map(Into::into)).await
77    }
78
79    /// Send the specified HTTP request with `Bytes` payload.
80    ///
81    /// Returns the HTTP response including the status code and body.
82    ///
83    /// Returns an error if it can't connect to the server or the request could not be completed,
84    /// e.g. because of a timeout, infinite redirects, or a loss of connection.
85    async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError>;
86}
87
88#[cfg(feature = "reqwest")]
89mod reqwest {
90    use opentelemetry::otel_debug;
91
92    use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
93
94    #[async_trait]
95    impl HttpClient for reqwest::Client {
96        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
97            otel_debug!(name: "ReqwestClient.Send");
98            let request = request.try_into()?;
99            let mut response = self.execute(request).await?.error_for_status()?;
100            let headers = std::mem::take(response.headers_mut());
101            let mut http_response = Response::builder()
102                .status(response.status())
103                .body(response.bytes().await?)?;
104            *http_response.headers_mut() = headers;
105
106            Ok(http_response)
107        }
108    }
109
110    #[cfg(not(target_arch = "wasm32"))]
111    #[cfg(feature = "reqwest-blocking")]
112    #[async_trait]
113    impl HttpClient for reqwest::blocking::Client {
114        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
115            otel_debug!(name: "ReqwestBlockingClient.Send");
116            let request = request.try_into()?;
117            let mut response = self.execute(request)?.error_for_status()?;
118            let headers = std::mem::take(response.headers_mut());
119            let mut http_response = Response::builder()
120                .status(response.status())
121                .body(response.bytes()?)?;
122            *http_response.headers_mut() = headers;
123
124            Ok(http_response)
125        }
126    }
127}
128
129#[cfg(feature = "hyper")]
130pub mod hyper {
131    use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
132    use crate::ResponseExt;
133    use http::HeaderValue;
134    use http_body_util::{BodyExt, Full};
135    use hyper::body::{Body as HttpBody, Frame};
136    use hyper_util::client::legacy::{
137        connect::{Connect, HttpConnector},
138        Client,
139    };
140    use opentelemetry::otel_debug;
141    use std::fmt::Debug;
142    use std::pin::Pin;
143    use std::task::{self, Poll};
144    use std::time::Duration;
145    use tokio::time;
146
147    #[derive(Debug, Clone)]
148    pub struct HyperClient<C = HttpConnector>
149    where
150        C: Connect + Clone + Send + Sync + 'static,
151    {
152        inner: Client<C, Body>,
153        timeout: Duration,
154        authorization: Option<HeaderValue>,
155    }
156
157    impl<C> HyperClient<C>
158    where
159        C: Connect + Clone + Send + Sync + 'static,
160    {
161        pub fn new(connector: C, timeout: Duration, authorization: Option<HeaderValue>) -> Self {
162            // TODO - support custom executor
163            let inner = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);
164            Self {
165                inner,
166                timeout,
167                authorization,
168            }
169        }
170    }
171
172    impl HyperClient<HttpConnector> {
173        /// Creates a new `HyperClient` with a default `HttpConnector`.
174        pub fn with_default_connector(
175            timeout: Duration,
176            authorization: Option<HeaderValue>,
177        ) -> Self {
178            Self::new(HttpConnector::new(), timeout, authorization)
179        }
180    }
181
182    #[async_trait]
183    impl<C> HttpClient for HyperClient<C>
184    where
185        C: Connect + Clone + Send + Sync + 'static,
186        HyperClient<C>: Debug,
187    {
188        async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
189            otel_debug!(name: "HyperClient.Send");
190            let (parts, body) = request.into_parts();
191            let mut request = Request::from_parts(parts, Body(Full::from(body)));
192            if let Some(ref authorization) = self.authorization {
193                request
194                    .headers_mut()
195                    .insert(http::header::AUTHORIZATION, authorization.clone());
196            }
197            let mut response = time::timeout(self.timeout, self.inner.request(request)).await??;
198            let headers = std::mem::take(response.headers_mut());
199
200            let mut http_response = Response::builder()
201                .status(response.status())
202                .body(response.into_body().collect().await?.to_bytes())?;
203            *http_response.headers_mut() = headers;
204
205            Ok(http_response.error_for_status()?)
206        }
207    }
208
209    pub struct Body(Full<Bytes>);
210
211    impl HttpBody for Body {
212        type Data = Bytes;
213        type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
214
215        #[inline]
216        fn poll_frame(
217            self: Pin<&mut Self>,
218            cx: &mut task::Context<'_>,
219        ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
220            let inner_body = unsafe { self.map_unchecked_mut(|b| &mut b.0) };
221            inner_body.poll_frame(cx).map_err(Into::into)
222        }
223
224        #[inline]
225        fn is_end_stream(&self) -> bool {
226            self.0.is_end_stream()
227        }
228
229        #[inline]
230        fn size_hint(&self) -> hyper::body::SizeHint {
231            self.0.size_hint()
232        }
233    }
234}
235
236/// Methods to make working with responses from the [`HttpClient`] trait easier.
237pub trait ResponseExt: Sized {
238    /// Turn a response into an error if the HTTP status does not indicate success (200 - 299).
239    fn error_for_status(self) -> Result<Self, HttpError>;
240}
241
242impl<T> ResponseExt for Response<T> {
243    fn error_for_status(self) -> Result<Self, HttpError> {
244        if self.status().is_success() {
245            Ok(self)
246        } else {
247            Err(format!("request failed with status {}", self.status()).into())
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use http::HeaderValue;
256
257    #[test]
258    fn http_headers_get() {
259        let mut carrier = http::HeaderMap::new();
260        HeaderInjector(&mut carrier).set("headerName", "value".to_string());
261
262        assert_eq!(
263            HeaderExtractor(&carrier).get("HEADERNAME"),
264            Some("value"),
265            "case insensitive extraction"
266        )
267    }
268    #[test]
269    fn http_headers_get_all() {
270        let mut carrier = http::HeaderMap::new();
271        carrier.append("headerName", HeaderValue::from_static("value"));
272        carrier.append("headerName", HeaderValue::from_static("value2"));
273        carrier.append("headerName", HeaderValue::from_static("value3"));
274
275        assert_eq!(
276            HeaderExtractor(&carrier).get_all("HEADERNAME"),
277            Some(vec!["value", "value2", "value3"]),
278            "all values from a key extraction"
279        )
280    }
281
282    #[test]
283    fn http_headers_get_all_missing_key() {
284        let mut carrier = http::HeaderMap::new();
285        carrier.append("headerName", HeaderValue::from_static("value"));
286
287        assert_eq!(
288            HeaderExtractor(&carrier).get_all("not_existing"),
289            None,
290            "all values from a missing key extraction"
291        )
292    }
293
294    #[test]
295    fn http_headers_keys() {
296        let mut carrier = http::HeaderMap::new();
297        HeaderInjector(&mut carrier).set("headerName1", "value1".to_string());
298        HeaderInjector(&mut carrier).set("headerName2", "value2".to_string());
299
300        let extractor = HeaderExtractor(&carrier);
301        let got = extractor.keys();
302        assert_eq!(got.len(), 2);
303        assert!(got.contains(&"headername1"));
304        assert!(got.contains(&"headername2"));
305    }
306}