opentelemetry_http/
lib.rs1use async_trait::async_trait;
2use std::fmt::Debug;
3
4#[doc(no_inline)]
5pub use bytes::Bytes;
6#[doc(no_inline)]
7pub use http::{Request, Response};
8use opentelemetry::propagation::{Extractor, Injector};
9
10pub struct HeaderInjector<'a>(pub &'a mut http::HeaderMap);
15
16impl Injector for HeaderInjector<'_> {
17 fn set(&mut self, key: &str, value: String) {
19 if let Ok(name) = http::header::HeaderName::from_bytes(key.as_bytes()) {
20 if let Ok(val) = http::header::HeaderValue::from_str(&value) {
21 self.0.insert(name, val);
22 }
23 }
24 }
25}
26
27pub struct HeaderExtractor<'a>(pub &'a http::HeaderMap);
32
33impl Extractor for HeaderExtractor<'_> {
34 fn get(&self, key: &str) -> Option<&str> {
36 self.0.get(key).and_then(|value| value.to_str().ok())
37 }
38
39 fn keys(&self) -> Vec<&str> {
41 self.0
42 .keys()
43 .map(|value| value.as_str())
44 .collect::<Vec<_>>()
45 }
46
47 fn get_all(&self, key: &str) -> Option<Vec<&str>> {
49 let all_iter = self.0.get_all(key).iter();
50 if let (0, Some(0)) = all_iter.size_hint() {
51 return None;
52 }
53
54 Some(all_iter.filter_map(|value| value.to_str().ok()).collect())
55 }
56}
57
58pub type HttpError = Box<dyn std::error::Error + Send + Sync + 'static>;
59
60#[async_trait]
67pub trait HttpClient: Debug + Send + Sync {
68 #[deprecated(note = "Use `send_bytes` with `Bytes` payload instead.")]
75 async fn send(&self, request: Request<Vec<u8>>) -> Result<Response<Bytes>, HttpError> {
76 self.send_bytes(request.map(Into::into)).await
77 }
78
79 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError>;
86}
87
88#[cfg(feature = "reqwest")]
89mod reqwest {
90 use opentelemetry::otel_debug;
91
92 use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
93
94 #[async_trait]
95 impl HttpClient for reqwest::Client {
96 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
97 otel_debug!(name: "ReqwestClient.Send");
98 let request = request.try_into()?;
99 let mut response = self.execute(request).await?.error_for_status()?;
100 let headers = std::mem::take(response.headers_mut());
101 let mut http_response = Response::builder()
102 .status(response.status())
103 .body(response.bytes().await?)?;
104 *http_response.headers_mut() = headers;
105
106 Ok(http_response)
107 }
108 }
109
110 #[cfg(not(target_arch = "wasm32"))]
111 #[cfg(feature = "reqwest-blocking")]
112 #[async_trait]
113 impl HttpClient for reqwest::blocking::Client {
114 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
115 otel_debug!(name: "ReqwestBlockingClient.Send");
116 let request = request.try_into()?;
117 let mut response = self.execute(request)?.error_for_status()?;
118 let headers = std::mem::take(response.headers_mut());
119 let mut http_response = Response::builder()
120 .status(response.status())
121 .body(response.bytes()?)?;
122 *http_response.headers_mut() = headers;
123
124 Ok(http_response)
125 }
126 }
127}
128
129#[cfg(feature = "hyper")]
130pub mod hyper {
131 use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
132 use crate::ResponseExt;
133 use http::HeaderValue;
134 use http_body_util::{BodyExt, Full};
135 use hyper::body::{Body as HttpBody, Frame};
136 use hyper_util::client::legacy::{
137 connect::{Connect, HttpConnector},
138 Client,
139 };
140 use opentelemetry::otel_debug;
141 use std::fmt::Debug;
142 use std::pin::Pin;
143 use std::task::{self, Poll};
144 use std::time::Duration;
145 use tokio::time;
146
147 #[derive(Debug, Clone)]
148 pub struct HyperClient<C = HttpConnector>
149 where
150 C: Connect + Clone + Send + Sync + 'static,
151 {
152 inner: Client<C, Body>,
153 timeout: Duration,
154 authorization: Option<HeaderValue>,
155 }
156
157 impl<C> HyperClient<C>
158 where
159 C: Connect + Clone + Send + Sync + 'static,
160 {
161 pub fn new(connector: C, timeout: Duration, authorization: Option<HeaderValue>) -> Self {
162 let inner = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);
164 Self {
165 inner,
166 timeout,
167 authorization,
168 }
169 }
170 }
171
172 impl HyperClient<HttpConnector> {
173 pub fn with_default_connector(
175 timeout: Duration,
176 authorization: Option<HeaderValue>,
177 ) -> Self {
178 Self::new(HttpConnector::new(), timeout, authorization)
179 }
180 }
181
182 #[async_trait]
183 impl<C> HttpClient for HyperClient<C>
184 where
185 C: Connect + Clone + Send + Sync + 'static,
186 HyperClient<C>: Debug,
187 {
188 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
189 otel_debug!(name: "HyperClient.Send");
190 let (parts, body) = request.into_parts();
191 let mut request = Request::from_parts(parts, Body(Full::from(body)));
192 if let Some(ref authorization) = self.authorization {
193 request
194 .headers_mut()
195 .insert(http::header::AUTHORIZATION, authorization.clone());
196 }
197 let mut response = time::timeout(self.timeout, self.inner.request(request)).await??;
198 let headers = std::mem::take(response.headers_mut());
199
200 let mut http_response = Response::builder()
201 .status(response.status())
202 .body(response.into_body().collect().await?.to_bytes())?;
203 *http_response.headers_mut() = headers;
204
205 Ok(http_response.error_for_status()?)
206 }
207 }
208
209 pub struct Body(Full<Bytes>);
210
211 impl HttpBody for Body {
212 type Data = Bytes;
213 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
214
215 #[inline]
216 fn poll_frame(
217 self: Pin<&mut Self>,
218 cx: &mut task::Context<'_>,
219 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
220 let inner_body = unsafe { self.map_unchecked_mut(|b| &mut b.0) };
221 inner_body.poll_frame(cx).map_err(Into::into)
222 }
223
224 #[inline]
225 fn is_end_stream(&self) -> bool {
226 self.0.is_end_stream()
227 }
228
229 #[inline]
230 fn size_hint(&self) -> hyper::body::SizeHint {
231 self.0.size_hint()
232 }
233 }
234}
235
236pub trait ResponseExt: Sized {
238 fn error_for_status(self) -> Result<Self, HttpError>;
240}
241
242impl<T> ResponseExt for Response<T> {
243 fn error_for_status(self) -> Result<Self, HttpError> {
244 if self.status().is_success() {
245 Ok(self)
246 } else {
247 Err(format!("request failed with status {}", self.status()).into())
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use http::HeaderValue;
256
257 #[test]
258 fn http_headers_get() {
259 let mut carrier = http::HeaderMap::new();
260 HeaderInjector(&mut carrier).set("headerName", "value".to_string());
261
262 assert_eq!(
263 HeaderExtractor(&carrier).get("HEADERNAME"),
264 Some("value"),
265 "case insensitive extraction"
266 )
267 }
268 #[test]
269 fn http_headers_get_all() {
270 let mut carrier = http::HeaderMap::new();
271 carrier.append("headerName", HeaderValue::from_static("value"));
272 carrier.append("headerName", HeaderValue::from_static("value2"));
273 carrier.append("headerName", HeaderValue::from_static("value3"));
274
275 assert_eq!(
276 HeaderExtractor(&carrier).get_all("HEADERNAME"),
277 Some(vec!["value", "value2", "value3"]),
278 "all values from a key extraction"
279 )
280 }
281
282 #[test]
283 fn http_headers_get_all_missing_key() {
284 let mut carrier = http::HeaderMap::new();
285 carrier.append("headerName", HeaderValue::from_static("value"));
286
287 assert_eq!(
288 HeaderExtractor(&carrier).get_all("not_existing"),
289 None,
290 "all values from a missing key extraction"
291 )
292 }
293
294 #[test]
295 fn http_headers_keys() {
296 let mut carrier = http::HeaderMap::new();
297 HeaderInjector(&mut carrier).set("headerName1", "value1".to_string());
298 HeaderInjector(&mut carrier).set("headerName2", "value2".to_string());
299
300 let extractor = HeaderExtractor(&carrier);
301 let got = extractor.keys();
302 assert_eq!(got.len(), 2);
303 assert!(got.contains(&"headername1"));
304 assert!(got.contains(&"headername2"));
305 }
306}