1use std::{ops::RangeBounds, sync::OnceLock};
16
17use http::{header::HeaderName, Request, Response, StatusCode};
18use tower::Service;
19use tower_http::cors::CorsLayer;
20
21use crate::layers::{
22 body_to_bytes_response::BodyToBytesResponse, bytes_to_body_request::BytesToBodyRequest,
23 catch_http_codes::CatchHttpCodes, form_urlencoded_request::FormUrlencodedRequest,
24 json_request::JsonRequest, json_response::JsonResponse,
25};
26
27static PROPAGATOR_HEADERS: OnceLock<Vec<HeaderName>> = OnceLock::new();
28
29pub fn set_propagator(propagator: &dyn opentelemetry::propagation::TextMapPropagator) {
36 let headers = propagator
37 .fields()
38 .map(|h| HeaderName::try_from(h).unwrap())
39 .collect();
40
41 tracing::debug!(
42 ?headers,
43 "Headers allowed in CORS requests for trace propagators set"
44 );
45 PROPAGATOR_HEADERS
46 .set(headers)
47 .expect(concat!(module_path!(), "::set_propagator was called twice"));
48}
49
50pub trait CorsLayerExt {
51 #[must_use]
52 fn allow_otel_headers<H>(self, headers: H) -> Self
53 where
54 H: IntoIterator<Item = HeaderName>;
55}
56
57impl CorsLayerExt for CorsLayer {
58 fn allow_otel_headers<H>(self, headers: H) -> Self
59 where
60 H: IntoIterator<Item = HeaderName>,
61 {
62 let base = PROPAGATOR_HEADERS.get().cloned().unwrap_or_default();
63 let headers: Vec<_> = headers.into_iter().chain(base).collect();
64 self.allow_headers(headers)
65 }
66}
67
68pub trait ServiceExt<Body>: Sized {
69 fn request_bytes_to_body(self) -> BytesToBodyRequest<Self> {
70 BytesToBodyRequest::new(self)
71 }
72
73 fn response_body_to_bytes(self) -> BodyToBytesResponse<Self> {
77 BodyToBytesResponse::new(self)
78 }
79
80 fn json_response<T>(self) -> JsonResponse<Self, T> {
81 JsonResponse::new(self)
82 }
83
84 fn json_request<T>(self) -> JsonRequest<Self, T> {
85 JsonRequest::new(self)
86 }
87
88 fn form_urlencoded_request<T>(self) -> FormUrlencodedRequest<Self, T> {
89 FormUrlencodedRequest::new(self)
90 }
91
92 fn catch_http_code<M, ResBody, E>(
95 self,
96 status_code: StatusCode,
97 mapper: M,
98 ) -> CatchHttpCodes<Self, M>
99 where
100 M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
101 {
102 self.catch_http_codes(status_code..=status_code, mapper)
103 }
104
105 fn catch_http_codes<B, M, ResBody, E>(self, bounds: B, mapper: M) -> CatchHttpCodes<Self, M>
108 where
109 B: RangeBounds<StatusCode>,
110 M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
111 {
112 CatchHttpCodes::new(self, bounds, mapper)
113 }
114
115 fn catch_http_errors<M, ResBody, E>(self, mapper: M) -> CatchHttpCodes<Self, M>
118 where
119 M: Fn(Response<ResBody>) -> E + Send + Clone + 'static,
120 {
121 self.catch_http_codes(
122 StatusCode::from_u16(400).unwrap()..StatusCode::from_u16(600).unwrap(),
123 mapper,
124 )
125 }
126}
127
128impl<S, B> ServiceExt<B> for S where S: Service<Request<B>> {}