1use async_trait::async_trait;
2use std::fmt::Debug;
3
4#[doc(no_inline)]
5pub use bytes::Bytes;
6#[doc(no_inline)]
7pub use http::{Request, Response};
8use opentelemetry::propagation::{Extractor, Injector};
9
10pub struct HeaderInjector<'a>(pub &'a mut http::HeaderMap);
15
16impl Injector for HeaderInjector<'_> {
17 fn set(&mut self, key: &str, value: String) {
19 if let Ok(name) = http::header::HeaderName::from_bytes(key.as_bytes()) {
20 if let Ok(val) = http::header::HeaderValue::from_str(&value) {
21 self.0.insert(name, val);
22 }
23 }
24 }
25
26 fn reserve(&mut self, additional: usize) {
28 self.0.reserve(additional);
29 }
30}
31
32pub struct HeaderExtractor<'a>(pub &'a http::HeaderMap);
37
38impl Extractor for HeaderExtractor<'_> {
39 fn get(&self, key: &str) -> Option<&str> {
41 self.0.get(key).and_then(|value| value.to_str().ok())
42 }
43
44 fn keys(&self) -> Vec<&str> {
46 self.0
47 .keys()
48 .map(|value| value.as_str())
49 .collect::<Vec<_>>()
50 }
51
52 fn get_all(&self, key: &str) -> Option<Vec<&str>> {
54 let all_iter = self.0.get_all(key).iter();
55 if let (0, Some(0)) = all_iter.size_hint() {
56 return None;
57 }
58
59 Some(all_iter.filter_map(|value| value.to_str().ok()).collect())
60 }
61}
62
63pub type HttpError = Box<dyn std::error::Error + Send + Sync + 'static>;
64
65#[async_trait]
72pub trait HttpClient: Debug + Send + Sync {
73 #[deprecated(note = "Use `send_bytes` with `Bytes` payload instead.")]
80 async fn send(&self, request: Request<Vec<u8>>) -> Result<Response<Bytes>, HttpError> {
81 self.send_bytes(request.map(Into::into)).await
82 }
83
84 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError>;
91}
92
93#[cfg(feature = "reqwest")]
94mod reqwest {
95 use opentelemetry::otel_debug;
96
97 use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
98
99 #[async_trait]
100 impl HttpClient for reqwest::Client {
101 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
102 otel_debug!(name: "ReqwestClient.Send");
103 let request = request.try_into()?;
104 let mut response = self.execute(request).await?.error_for_status()?;
105 let headers = std::mem::take(response.headers_mut());
106 let mut http_response = Response::builder()
107 .status(response.status())
108 .body(response.bytes().await?)?;
109 *http_response.headers_mut() = headers;
110
111 Ok(http_response)
112 }
113 }
114
115 #[cfg(not(target_arch = "wasm32"))]
116 #[cfg(feature = "reqwest-blocking")]
117 #[async_trait]
118 impl HttpClient for reqwest::blocking::Client {
119 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
120 otel_debug!(name: "ReqwestBlockingClient.Send");
121 let request = request.try_into()?;
122 let mut response = self.execute(request)?.error_for_status()?;
123 let headers = std::mem::take(response.headers_mut());
124 let mut http_response = Response::builder()
125 .status(response.status())
126 .body(response.bytes()?)?;
127 *http_response.headers_mut() = headers;
128
129 Ok(http_response)
130 }
131 }
132}
133
134#[cfg(feature = "hyper")]
135pub mod hyper {
136 use super::{async_trait, Bytes, HttpClient, HttpError, Request, Response};
137 use crate::ResponseExt;
138 use http::HeaderValue;
139 use http_body_util::{BodyExt, Full};
140 use hyper::body::{Body as HttpBody, Frame};
141 use hyper_util::client::legacy::{
142 connect::{Connect, HttpConnector},
143 Client,
144 };
145 use opentelemetry::otel_debug;
146 use std::fmt::Debug;
147 use std::pin::Pin;
148 use std::task::{self, Poll};
149 use std::time::Duration;
150 use tokio::time;
151
152 #[derive(Debug, Clone)]
153 pub struct HyperClient<C = HttpConnector>
154 where
155 C: Connect + Clone + Send + Sync + 'static,
156 {
157 inner: Client<C, Body>,
158 timeout: Duration,
159 authorization: Option<HeaderValue>,
160 }
161
162 impl<C> HyperClient<C>
163 where
164 C: Connect + Clone + Send + Sync + 'static,
165 {
166 pub fn new(connector: C, timeout: Duration, authorization: Option<HeaderValue>) -> Self {
167 let inner = Client::builder(hyper_util::rt::TokioExecutor::new()).build(connector);
169 Self {
170 inner,
171 timeout,
172 authorization,
173 }
174 }
175 }
176
177 impl HyperClient<HttpConnector> {
178 pub fn with_default_connector(
180 timeout: Duration,
181 authorization: Option<HeaderValue>,
182 ) -> Self {
183 Self::new(HttpConnector::new(), timeout, authorization)
184 }
185 }
186
187 #[async_trait]
188 impl<C> HttpClient for HyperClient<C>
189 where
190 C: Connect + Clone + Send + Sync + 'static,
191 HyperClient<C>: Debug,
192 {
193 async fn send_bytes(&self, request: Request<Bytes>) -> Result<Response<Bytes>, HttpError> {
194 otel_debug!(name: "HyperClient.Send");
195 let (parts, body) = request.into_parts();
196 let mut request = Request::from_parts(parts, Body(Full::from(body)));
197 if let Some(ref authorization) = self.authorization {
198 request
199 .headers_mut()
200 .insert(http::header::AUTHORIZATION, authorization.clone());
201 }
202 let mut response = time::timeout(self.timeout, self.inner.request(request)).await??;
203 let headers = std::mem::take(response.headers_mut());
204
205 let mut http_response = Response::builder()
206 .status(response.status())
207 .body(response.into_body().collect().await?.to_bytes())?;
208 *http_response.headers_mut() = headers;
209
210 Ok(http_response.error_for_status()?)
211 }
212 }
213
214 pub struct Body(Full<Bytes>);
215
216 impl HttpBody for Body {
217 type Data = Bytes;
218 type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
219
220 #[inline]
221 fn poll_frame(
222 self: Pin<&mut Self>,
223 cx: &mut task::Context<'_>,
224 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
225 let inner_body = unsafe { self.map_unchecked_mut(|b| &mut b.0) };
226 inner_body.poll_frame(cx).map_err(Into::into)
227 }
228
229 #[inline]
230 fn is_end_stream(&self) -> bool {
231 self.0.is_end_stream()
232 }
233
234 #[inline]
235 fn size_hint(&self) -> hyper::body::SizeHint {
236 self.0.size_hint()
237 }
238 }
239}
240
241pub trait ResponseExt: Sized {
243 fn error_for_status(self) -> Result<Self, HttpError>;
245}
246
247impl<T> ResponseExt for Response<T> {
248 fn error_for_status(self) -> Result<Self, HttpError> {
249 if self.status().is_success() {
250 Ok(self)
251 } else {
252 Err(format!("request failed with status {}", self.status()).into())
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use http::HeaderValue;
261
262 #[test]
263 fn http_headers_get() {
264 let mut carrier = http::HeaderMap::new();
265 HeaderInjector(&mut carrier).set("headerName", "value".to_string());
266
267 assert_eq!(
268 HeaderExtractor(&carrier).get("HEADERNAME"),
269 Some("value"),
270 "case insensitive extraction"
271 )
272 }
273 #[test]
274 fn http_headers_get_all() {
275 let mut carrier = http::HeaderMap::new();
276 carrier.append("headerName", HeaderValue::from_static("value"));
277 carrier.append("headerName", HeaderValue::from_static("value2"));
278 carrier.append("headerName", HeaderValue::from_static("value3"));
279
280 assert_eq!(
281 HeaderExtractor(&carrier).get_all("HEADERNAME"),
282 Some(vec!["value", "value2", "value3"]),
283 "all values from a key extraction"
284 )
285 }
286
287 #[test]
288 fn http_headers_get_all_missing_key() {
289 let mut carrier = http::HeaderMap::new();
290 carrier.append("headerName", HeaderValue::from_static("value"));
291
292 assert_eq!(
293 HeaderExtractor(&carrier).get_all("not_existing"),
294 None,
295 "all values from a missing key extraction"
296 )
297 }
298
299 #[test]
300 fn http_headers_keys() {
301 let mut carrier = http::HeaderMap::new();
302 HeaderInjector(&mut carrier).set("headerName1", "value1".to_string());
303 HeaderInjector(&mut carrier).set("headerName2", "value2".to_string());
304
305 let extractor = HeaderExtractor(&carrier);
306 let got = extractor.keys();
307 assert_eq!(got.len(), 2);
308 assert!(got.contains(&"headername1"));
309 assert!(got.contains(&"headername2"));
310 }
311
312 #[test]
313 fn http_headers_reserve() {
314 let mut carrier = http::HeaderMap::new();
315
316 {
318 let mut injector = HeaderInjector(&mut carrier);
319 injector.reserve(10);
320
321 injector.set("test-header", "test-value".to_string());
323 }
324 assert_eq!(
325 HeaderExtractor(&carrier).get("test-header"),
326 Some("test-value")
327 );
328
329 {
331 let mut injector = HeaderInjector(&mut carrier);
332 injector.reserve(0);
333 injector.set("another-header", "another-value".to_string());
334 }
335 assert_eq!(
336 HeaderExtractor(&carrier).get("another-header"),
337 Some("another-value")
338 );
339
340 let mut new_carrier = http::HeaderMap::new();
342 {
343 let mut new_injector = HeaderInjector(&mut new_carrier);
344 new_injector.reserve(5);
345 }
346 let initial_capacity = new_carrier.capacity();
347
348 {
350 let mut new_injector = HeaderInjector(&mut new_carrier);
351 for i in 0..3 {
352 new_injector.set(&format!("header-{}", i), format!("value-{}", i));
353 }
354 }
355
356 assert!(new_carrier.capacity() >= initial_capacity);
357 assert!(new_carrier.capacity() >= 5);
358 }
359}