fregate/middleware/proxy_layer/
shared.rs1use crate::middleware::ProxyError;
2use axum::response::IntoResponse;
3use core::any::type_name;
4use hyper::body::{Bytes, HttpBody};
5use hyper::http::uri::PathAndQuery;
6use hyper::service::Service;
7use hyper::{Request, Response, Uri};
8use std::error::Error;
9use std::fmt::{Debug, Formatter};
10use std::future::Future;
11use std::marker::PhantomData;
12use std::pin::Pin;
13use std::str::FromStr;
14
15pub(crate) struct Shared<
16    TBody,
17    TRespBody,
18    ShouldProxyCallback,
19    OnProxyErrorCallback,
20    OnProxyRequestCallback,
21    OnProxyResponseCallback,
22    TExtension = (),
23> {
24    pub(crate) destination: Uri,
25    pub(crate) should_proxy: ShouldProxyCallback,
26    pub(crate) on_proxy_error: OnProxyErrorCallback,
27    pub(crate) on_proxy_request: OnProxyRequestCallback,
28    pub(crate) on_proxy_response: OnProxyResponseCallback,
29    pub(crate) phantom: PhantomData<(TExtension, TBody, TRespBody)>,
30}
31
32impl<
33        TBody,
34        TRespBody,
35        ShouldProxyCallback,
36        OnProxyErrorCallback,
37        OnProxyRequestCallback,
38        OnProxyResponseCallback,
39        TExtension,
40    > Debug
41    for Shared<
42        TBody,
43        TRespBody,
44        ShouldProxyCallback,
45        OnProxyErrorCallback,
46        OnProxyRequestCallback,
47        OnProxyResponseCallback,
48        TExtension,
49    >
50{
51    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
52        f.debug_struct("Shared")
53            .field("destination", &self.destination)
54            .field(
55                "on_proxy_error",
56                &format_args!("{}", type_name::<OnProxyErrorCallback>()),
57            )
58            .field(
59                "on_proxy_request",
60                &format_args!("{}", type_name::<OnProxyRequestCallback>()),
61            )
62            .field(
63                "on_proxy_response",
64                &format_args!("{}", type_name::<OnProxyResponseCallback>()),
65            )
66            .field("extension", &format_args!("{}", type_name::<TExtension>()))
67            .finish()
68    }
69}
70
71#[allow(clippy::type_complexity)]
72impl<
73        TBody,
74        TRespBody,
75        ShouldProxyCallback,
76        OnProxyErrorCallback,
77        OnProxyRequestCallback,
78        OnProxyResponseCallback,
79    >
80    Shared<
81        TBody,
82        TRespBody,
83        ShouldProxyCallback,
84        OnProxyErrorCallback,
85        OnProxyRequestCallback,
86        OnProxyResponseCallback,
87    >
88{
89    pub(crate) fn new_with_ext<TExtension>(
90        destination: impl Into<String>,
91        should_proxy: ShouldProxyCallback,
92        on_proxy_error: OnProxyErrorCallback,
93        on_proxy_request: OnProxyRequestCallback,
94        on_proxy_response: OnProxyResponseCallback,
95    ) -> Result<
96        Shared<
97            TBody,
98            TRespBody,
99            ShouldProxyCallback,
100            OnProxyErrorCallback,
101            OnProxyRequestCallback,
102            OnProxyResponseCallback,
103            TExtension,
104        >,
105        String,
106    > {
107        let destination = Uri::from_str(&destination.into()).map_err(|err| err.to_string())?;
108
109        let _ = destination
110            .scheme()
111            .ok_or("destination Uri has no scheme!".to_string())?;
112        let _ = destination
113            .authority()
114            .ok_or("destination Uri has no authority!".to_string())?;
115
116        let shared = Shared::<
117            TBody,
118            TRespBody,
119            ShouldProxyCallback,
120            OnProxyErrorCallback,
121            OnProxyRequestCallback,
122            OnProxyResponseCallback,
123            TExtension,
124        > {
125            destination,
126            should_proxy,
127            on_proxy_error,
128            on_proxy_request,
129            on_proxy_response,
130            phantom: PhantomData,
131        };
132
133        Ok(shared)
134    }
135}
136
137impl<
138        TBody,
139        TRespBody,
140        ShouldProxyCallback,
141        OnProxyErrorCallback,
142        OnProxyRequestCallback,
143        OnProxyResponseCallback,
144        TExtension,
145    >
146    Shared<
147        TBody,
148        TRespBody,
149        ShouldProxyCallback,
150        OnProxyErrorCallback,
151        OnProxyRequestCallback,
152        OnProxyResponseCallback,
153        TExtension,
154    >
155where
156    TExtension: Default + Clone + Send + Sync + 'static,
157    ShouldProxyCallback: for<'any> Fn(
158            &'any Request<TBody>,
159            &'any TExtension,
160        ) -> Pin<
161            Box<dyn Future<Output = Result<bool, axum::response::Response>> + Send + 'any>,
162        > + Send
163        + Sync
164        + 'static,
165    OnProxyErrorCallback:
166        Fn(ProxyError, &TExtension) -> axum::response::Response + Send + Sync + 'static,
167    OnProxyRequestCallback: Fn(&mut Request<TBody>, &TExtension) + Send + Sync + 'static,
168    OnProxyResponseCallback: Fn(&mut Response<TRespBody>, &TExtension) + Send + Sync + 'static,
169    TBody: Sync + Send + 'static,
170    TRespBody: HttpBody<Data = Bytes> + Sync + Send + 'static,
171    TRespBody::Error: Into<Box<(dyn Error + Send + Sync + 'static)>>,
172{
173    pub(crate) async fn proxy<TClient>(
174        &self,
175        mut req: Request<TBody>,
176        client: TClient,
177        extension: TExtension,
178        poll_error: Option<Box<(dyn Error + Send + Sync + 'static)>>,
179    ) -> axum::response::Response
180    where
181        TClient: Service<Request<TBody>, Response = Response<TRespBody>>,
182        TClient: Clone + Send + Sync + 'static,
183        <TClient as Service<Request<TBody>>>::Future: Send + 'static,
184        <TClient as Service<Request<TBody>>>::Error:
185            Into<Box<(dyn Error + Send + Sync + 'static)>> + Send,
186    {
187        if let Some(err) = poll_error {
188            return (self.on_proxy_error)(ProxyError::SendRequest(err), &extension);
189        }
190
191        let build_uri = |req: &Request<TBody>| {
192            let p_and_q = req
193                .uri()
194                .path_and_query()
195                .map_or_else(|| req.uri().path(), PathAndQuery::as_str);
196
197            let destination_parts = self.destination.clone().into_parts();
198
199            #[allow(clippy::expect_used)]
200            let authority = destination_parts
202                .authority
203                .expect("Destination uri must have [Authority]");
204
205            #[allow(clippy::expect_used)]
206            let scheme = destination_parts
208                .scheme
209                .expect("Destination uri must have [Scheme]");
210
211            Uri::builder()
212                .authority(authority)
213                .scheme(scheme)
214                .path_and_query(p_and_q)
215                .build()
216                .map_err(ProxyError::UriBuilder)
217        };
218
219        match build_uri(&req) {
220            Ok(new_uri) => {
221                *req.uri_mut() = new_uri;
222
223                (self.on_proxy_request)(&mut req, &extension);
224                let result = send_request(client, req).await;
225
226                match result {
227                    Ok(mut response) => {
228                        (self.on_proxy_response)(&mut response, &extension);
229                        response.into_response()
230                    }
231                    Err(err) => (self.on_proxy_error)(err, &extension),
232                }
233            }
234            Err(err) => (self.on_proxy_error)(err, &extension),
235        }
236    }
237}
238
239pub(crate) fn get_extension<TBody, TExtension>(request: &Request<TBody>) -> TExtension
240where
241    TExtension: Default + Clone + Send + Sync + 'static,
242{
243    request
244        .extensions()
245        .get::<TExtension>()
246        .cloned()
247        .unwrap_or_default()
248}
249
250#[allow(clippy::needless_question_mark)]
251async fn send_request<TClient, TBody, TRespBody>(
252    mut service: TClient,
253    request: Request<TBody>,
254) -> Result<Response<TRespBody>, ProxyError>
255where
256    TClient: Service<Request<TBody>, Response = Response<TRespBody>>,
257    TClient: Clone + Send + Sync + 'static,
258    <TClient as Service<Request<TBody>>>::Future: Send + 'static,
259    <TClient as Service<Request<TBody>>>::Error:
260        Into<Box<(dyn Error + Send + Sync + 'static)>> + Send,
261{
262    Ok(service
263        .call(request)
264        .await
265        .map_err(|err| ProxyError::SendRequest(err.into()))?)
266}