opentelemetry_http/
lib.rs1use async_trait::async_trait;
2use std::fmt::Debug;
3
4#[doc(no_inline)]
5pub use bytes::Bytes;
6#[doc(no_inline)]
7pub use http::{Request, Response};
8use opentelemetry::propagation::{Extractor, Injector};
9
10pub struct HeaderInjector<'a>(pub &'a mut http::HeaderMap);
15
16impl Injector for HeaderInjector<'_> {
17 fn set(&mut self, key: &str, value: String) {
19 if let Ok(name) = http::header::HeaderName::from_bytes(key.as_bytes()) {
20 if let Ok(val) = http::header::HeaderValue::from_str(&value) {
21 self.0.insert(name, val);
22 }
23 }
24 }
25}
26
27pub struct HeaderExtractor<'a>(pub &'a http::HeaderMap);
32
33impl Extractor for HeaderExtractor<'_> {
34 fn get(&self, key: &str) -> Option<&str> {
36 self.0.get(key).and_then(|value| value.to_str().ok())
37 }
38
39 fn keys(&self) -> Vec<&str> {
41 self.0
42 .keys()
43 .map(|value| value.as_str())
44 .collect::<Vec<_>>()
45 }
46}
47
48pub type HttpError = Box<dyn std::error::Error + Send + Sync + 'static>;
49
50#[async_trait]
57pub trait HttpClient: Debug + Send + Sync {
58 #[deprecated(note = "Use `send_bytes` with `Bytes` payload instead.")]
65 async fn send(&self, request: Request<Vec<u8>>) -> Result<Response<Bytes>, HttpError> {
66 self.send_bytes(request.map(Into::into)).await
67 }
68
69 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError>;
76}
77
78#[cfg(feature = "reqwest")]
79mod reqwest {
80 use opentelemetry::otel_debug;
81
82 use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
83
84 #[async_trait]
85 impl HttpClient for reqwest::Client {
86 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
87 otel_debug!(name: "ReqwestClient.Send");
88 let request = request.try_into()?;
89 let mut response = self.execute(request).await?.error_for_status()?;
90 let headers = std::mem::take(response.headers_mut());
91 let mut http_response = Response::builder()
92 .status(response.status())
93 .body(response.bytes().await?)?;
94 *http_response.headers_mut() = headers;
95
96 Ok(http_response)
97 }
98 }
99
100 #[cfg(not(target_arch = "wasm32"))]
101 #[cfg(feature = "reqwest-blocking")]
102 #[async_trait]
103 impl HttpClient for reqwest::blocking::Client {
104 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
105 otel_debug!(name: "ReqwestBlockingClient.Send");
106 let request = request.try_into()?;
107 let mut response = self.execute(request)?.error_for_status()?;
108 let headers = std::mem::take(response.headers_mut());
109 let mut http_response = Response::builder()
110 .status(response.status())
111 .body(response.bytes()?)?;
112 *http_response.headers_mut() = headers;
113
114 Ok(http_response)
115 }
116 }
117}
118
119#[cfg(feature = "hyper")]
120pub mod hyper {
121 use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
122 use crate::ResponseExt;
123 use http::HeaderValue;
124 use http_body_util::{BodyExt, Full};
125 use hyper::body::{Body as HttpBody, Frame};
126 use hyper_util::client::legacy::{
127 connect::{Connect, HttpConnector},
128 Client,
129 };
130 use opentelemetry::otel_debug;
131 use std::fmt::Debug;
132 use std::pin::Pin;
133 use std::task::{self, Poll};
134 use std::time::Duration;
135 use tokio::time;
136
137 #[derive(Debug, Clone)]
138 pub struct HyperClient<C = HttpConnector>
139 where
140 C: Connect + Clone + Send + Sync + 'static,
141 {
142 inner: Client<C, Body>,
143 timeout: Duration,
144 authorization: Option<HeaderValue>,
145 }
146
147 impl<C> HyperClient<C>
148 where
149 C: Connect + Clone + Send + Sync + 'static,
150 {
151 pub fn new(connector: C, timeout: Duration, authorization: Option<HeaderValue>) -> Self {
152 let inner = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);
154 Self {
155 inner,
156 timeout,
157 authorization,
158 }
159 }
160 }
161
162 impl HyperClient<HttpConnector> {
163 pub fn with_default_connector(
165 timeout: Duration,
166 authorization: Option<HeaderValue>,
167 ) -> Self {
168 Self::new(HttpConnector::new(), timeout, authorization)
169 }
170 }
171
172 #[async_trait]
173 impl<C> HttpClient for HyperClient<C>
174 where
175 C: Connect + Clone + Send + Sync + 'static,
176 HyperClient<C>: Debug,
177 {
178 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
179 otel_debug!(name: "HyperClient.Send");
180 let (parts, body) = request.into_parts();
181 let mut request = Request::from_parts(parts, Body(Full::from(body)));
182 if let Some(ref authorization) = self.authorization {
183 request
184 .headers_mut()
185 .insert(http::header::AUTHORIZATION, authorization.clone());
186 }
187 let mut response = time::timeout(self.timeout, self.inner.request(request)).await??;
188 let headers = std::mem::take(response.headers_mut());
189
190 let mut http_response = Response::builder()
191 .status(response.status())
192 .body(response.into_body().collect().await?.to_bytes())?;
193 *http_response.headers_mut() = headers;
194
195 Ok(http_response.error_for_status()?)
196 }
197 }
198
199 pub struct Body(Full<Bytes>);
200
201 impl HttpBody for Body {
202 type Data = Bytes;
203 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
204
205 #[inline]
206 fn poll_frame(
207 self: Pin<&mut Self>,
208 cx: &mut task::Context<'_>,
209 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
210 let inner_body = unsafe { self.map_unchecked_mut(|b| &mut b.0) };
211 inner_body.poll_frame(cx).map_err(Into::into)
212 }
213
214 #[inline]
215 fn is_end_stream(&self) -> bool {
216 self.0.is_end_stream()
217 }
218
219 #[inline]
220 fn size_hint(&self) -> hyper::body::SizeHint {
221 self.0.size_hint()
222 }
223 }
224}
225
226pub trait ResponseExt: Sized {
228 fn error_for_status(self) -> Result<Self, HttpError>;
230}
231
232impl<T> ResponseExt for Response<T> {
233 fn error_for_status(self) -> Result<Self, HttpError> {
234 if self.status().is_success() {
235 Ok(self)
236 } else {
237 Err(format!("request failed with status {}", self.status()).into())
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn http_headers_get() {
248 let mut carrier = http::HeaderMap::new();
249 HeaderInjector(&mut carrier).set("headerName", "value".to_string());
250
251 assert_eq!(
252 HeaderExtractor(&carrier).get("HEADERNAME"),
253 Some("value"),
254 "case insensitive extraction"
255 )
256 }
257
258 #[test]
259 fn http_headers_keys() {
260 let mut carrier = http::HeaderMap::new();
261 HeaderInjector(&mut carrier).set("headerName1", "value1".to_string());
262 HeaderInjector(&mut carrier).set("headerName2", "value2".to_string());
263
264 let extractor = HeaderExtractor(&carrier);
265 let got = extractor.keys();
266 assert_eq!(got.len(), 2);
267 assert!(got.contains(&"headername1"));
268 assert!(got.contains(&"headername2"));
269 }
270}