Skip to main content

axum_proxy/
oneshot.rs

1use std::convert::Infallible;
2use std::task::{Context, Poll};
3
4use client::HttpConnector;
5#[cfg(feature = "__rustls")]
6use client::RustlsConnector;
7use http::uri::{Authority, Scheme};
8use http::{Error as HttpError, Request, Response};
9//use hyper::body::{Body, HttpBody};
10use hyper::body::{Body as HttpBody, Incoming};
11#[cfg(feature = "nativetls")]
12use hyper_tls::HttpsConnector as NativeTlsConnector;
13use hyper_util::client::legacy::Client;
14use hyper_util::client::legacy::connect::Connect;
15use tower_service::Service;
16
17use crate::future::RevProxyFuture;
18use crate::rewrite::PathRewriter;
19use crate::{ProxyError, client};
20
21type BoxErr = Box<dyn std::error::Error + Send + Sync>;
22
23/// A [`Service<Request<B>>`] that sends a request and returns the response, owning a [`Client`].
24///
25/// ```
26/// # async fn run_test() {
27/// # use axum_proxy::OneshotService;
28/// # use axum_proxy::Static;
29/// # use tower_service::Service;
30/// # use http_body_util::Empty;
31/// # use http::Request;
32/// # use hyper::body::Bytes;
33/// let mut svc = OneshotService::http_default("example.com:1234", Static("bar")).unwrap();
34/// let req = Request::builder()
35///     .uri("https://myserver.com/foo")
36///     .body(Empty::<Bytes>::new())
37///     .unwrap();
38/// // http://example.com:1234/bar
39/// let _res = svc.call(req).await.unwrap();
40/// # }
41/// ```
42pub struct OneshotService<Pr, C = HttpConnector, B = Incoming> {
43    client: Client<C, B>,
44    scheme: Scheme,
45    authority: Authority,
46    path: Pr,
47}
48
49impl<Pr: Clone, C: Clone, B> Clone for OneshotService<Pr, C, B> {
50    #[inline]
51    fn clone(&self) -> Self {
52        Self {
53            client: self.client.clone(),
54            scheme: self.scheme.clone(),
55            authority: self.authority.clone(),
56            path: self.path.clone(),
57        }
58    }
59}
60
61impl<Pr, C, B> OneshotService<Pr, C, B> {
62    /// Initializes a service with a general [`Client`].
63    ///
64    /// A client can be built by functions in [`client`].
65    ///
66    /// For the meaning of "scheme" and "authority", refer to the documentation of
67    /// [`Uri`](http::uri::Uri).
68    ///
69    /// The `path` should implement [`PathRewriter`].
70    ///
71    /// # Errors
72    ///
73    /// When `scheme` or `authority` cannot be converted into a [`Scheme`] or [`Authority`].
74    pub fn from<S, A>(
75        client: Client<C, B>,
76        scheme: S,
77        authority: A,
78        path: Pr,
79    ) -> Result<Self, HttpError>
80    where
81        Scheme: TryFrom<S>,
82        <Scheme as TryFrom<S>>::Error: Into<HttpError>,
83        Authority: TryFrom<A>,
84        <Authority as TryFrom<A>>::Error: Into<HttpError>,
85    {
86        let scheme = scheme.try_into().map_err(Into::into)?;
87        let authority = authority.try_into().map_err(Into::into)?;
88        Ok(Self {
89            client,
90            scheme,
91            authority,
92            path,
93        })
94    }
95}
96
97impl<Pr, B> OneshotService<Pr, HttpConnector, B>
98where
99    B: HttpBody + Send,
100    B::Data: Send,
101{
102    /// Use [`client::http_default()`] to build a client.
103    ///
104    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
105    ///
106    /// The `path` should implement [`PathRewriter`].
107    ///
108    /// # Errors
109    ///
110    /// When `authority` cannot be converted into an [`Authority`].
111    pub fn http_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
112    where
113        Authority: TryFrom<A>,
114        <Authority as TryFrom<A>>::Error: Into<HttpError>,
115    {
116        let authority = authority.try_into().map_err(Into::into)?;
117        Ok(Self {
118            client: client::http_default(),
119            scheme: Scheme::HTTP,
120            authority,
121            path,
122        })
123    }
124}
125
126#[cfg(feature = "nativetls")]
127impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
128where
129    B: HttpBody + Send,
130    B::Data: Send,
131{
132    /// Use [`client::https_default()`] to build a client.
133    ///
134    /// This is the same as [`Self::nativetls_default()`].
135    ///
136    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
137    ///
138    /// The `path` should implement [`PathRewriter`].
139    ///
140    /// # Errors
141    ///
142    /// When `authority` cannot be converted into an [`Authority`].
143    #[cfg(any(feature = "https", feature = "nativetls"))]
144    #[cfg_attr(docsrs, doc(cfg(any(feature = "https", feature = "nativetls"))))]
145    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
146    where
147        Authority: TryFrom<A>,
148        <Authority as TryFrom<A>>::Error: Into<HttpError>,
149    {
150        let authority = authority.try_into().map_err(Into::into)?;
151        Ok(Self {
152            client: client::https_default(),
153            scheme: Scheme::HTTPS,
154            authority,
155            path,
156        })
157    }
158
159    /// Use [`client::nativetls_default()`] to build a client.
160    ///
161    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
162    ///
163    /// The `path` should implement [`PathRewriter`].
164    /// # Errors
165    ///
166    /// When `authority` cannot be converted into an [`Authority`].
167    #[cfg(feature = "nativetls")]
168    #[cfg_attr(docsrs, doc(cfg(feature = "nativetls")))]
169    pub fn nativetls_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
170    where
171        Authority: TryFrom<A>,
172        <Authority as TryFrom<A>>::Error: Into<HttpError>,
173    {
174        let authority = authority.try_into().map_err(Into::into)?;
175        Ok(Self {
176            client: client::nativetls_default(),
177            scheme: Scheme::HTTPS,
178            authority,
179            path,
180        })
181    }
182}
183
184#[cfg(feature = "__rustls")]
185impl<Pr, B> OneshotService<Pr, RustlsConnector<HttpConnector>, B>
186where
187    B: HttpBody + Send,
188    B::Data: Send,
189{
190    /// Use [`client::rustls_default()`] to build a client.
191    ///
192    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
193    ///
194    /// The `path` should implement [`PathRewriter`].
195    #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
196    /// # Errors
197    ///
198    /// When `authority` cannot be converted into an [`Authority`].
199    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
200    where
201        Authority: TryFrom<A>,
202        <Authority as TryFrom<A>>::Error: Into<HttpError>,
203    {
204        let authority = authority.try_into().map_err(Into::into)?;
205        Ok(Self {
206            client: client::rustls_default(),
207            scheme: Scheme::HTTPS,
208            authority,
209            path,
210        })
211    }
212}
213
214impl<C, B, Pr> Service<Request<B>> for OneshotService<Pr, C, B>
215where
216    C: Connect + Clone + Send + Sync + 'static,
217    B: HttpBody + Send + 'static + Unpin,
218    B::Data: Send,
219    B::Error: Into<BoxErr>,
220    Pr: PathRewriter,
221{
222    type Response = Result<Response<Incoming>, ProxyError>;
223    type Error = Infallible;
224    type Future = RevProxyFuture;
225
226    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227        Poll::Ready(Ok(()))
228    }
229
230    fn call(&mut self, req: Request<B>) -> Self::Future {
231        RevProxyFuture::new(
232            &self.client,
233            req,
234            &self.scheme,
235            &self.authority,
236            &mut self.path,
237        )
238    }
239}
240
241#[cfg(test)]
242mod test {
243    use http::uri::{Parts, Uri};
244    use mockito::ServerGuard;
245
246    use super::*;
247    use crate::{ReplaceAll, test_helper};
248
249    async fn make_svc() -> (
250        ServerGuard,
251        OneshotService<ReplaceAll<'static>, HttpConnector, String>,
252    ) {
253        let server = mockito::Server::new_async().await;
254        let uri = Uri::try_from(&server.url());
255        assert!(uri.is_ok());
256        let uri = uri.unwrap();
257
258        let Parts {
259            scheme, authority, ..
260        } = uri.into_parts();
261
262        let svc = OneshotService::from(
263            client::http_default(),
264            scheme.unwrap(),
265            authority.unwrap(),
266            ReplaceAll("foo", "goo"),
267        );
268        assert!(svc.is_ok());
269        (server, svc.unwrap())
270    }
271
272    #[tokio::test]
273    async fn match_path() {
274        let (mut server, mut svc) = make_svc().await;
275        test_helper::match_path(&mut server, &mut svc).await;
276    }
277
278    #[tokio::test]
279    async fn match_query() {
280        let (mut server, mut svc) = make_svc().await;
281        test_helper::match_query(&mut server, &mut svc).await;
282    }
283
284    #[tokio::test]
285    async fn match_post() {
286        let (mut server, mut svc) = make_svc().await;
287        test_helper::match_post(&mut server, &mut svc).await;
288    }
289
290    #[tokio::test]
291    async fn match_header() {
292        let (mut server, mut svc) = make_svc().await;
293        test_helper::match_header(&mut server, &mut svc).await;
294    }
295}