reverse_proxy_service/
oneshot.rs

1use crate::client;
2use crate::future::RevProxyFuture;
3use crate::rewrite::PathRewriter;
4use crate::Error;
5
6use client::HttpConnector;
7#[cfg(feature = "__rustls")]
8use client::RustlsConnector;
9#[cfg(feature = "nativetls")]
10use hyper_tls::HttpsConnector as NativeTlsConnector;
11
12use http::uri::{Authority, Scheme};
13use http::Error as HttpError;
14use http::{Request, Response};
15
16use hyper::body::{Body, HttpBody};
17use hyper::client::{connect::Connect, Client};
18
19use tower_service::Service;
20
21use std::convert::Infallible;
22use std::task::{Context, Poll};
23
24type BoxErr = Box<dyn std::error::Error + Send + Sync>;
25
26/// A [`Service<Request<B>>`] that sends a request and returns the response, owning a [`Client`].
27///
28/// ```
29/// # async fn run_test() {
30/// # use reverse_proxy_service::OneshotService;
31/// # use reverse_proxy_service::Static;
32/// # use tower_service::Service;
33/// # use hyper::body::Body;
34/// # use http::Request;
35/// let mut svc = OneshotService::http_default("example.com:1234", Static("bar")).unwrap();
36/// let req = Request::builder()
37///     .uri("https://myserver.com/foo")
38///     .body(Body::empty())
39///     .unwrap();
40/// // http://example.com:1234/bar
41/// let _res = svc.call(req).await.unwrap();
42/// # }
43/// ```
44pub struct OneshotService<Pr, C = HttpConnector, B = Body> {
45    client: Client<C, B>,
46    scheme: Scheme,
47    authority: Authority,
48    path: Pr,
49}
50
51impl<Pr: Clone, C: Clone, B> Clone for OneshotService<Pr, C, B> {
52    #[inline]
53    fn clone(&self) -> Self {
54        Self {
55            client: self.client.clone(),
56            scheme: self.scheme.clone(),
57            authority: self.authority.clone(),
58            path: self.path.clone(),
59        }
60    }
61}
62
63impl<Pr, C, B> OneshotService<Pr, C, B> {
64    /// Initializes a service with a general `Client`.
65    ///
66    /// A client can be built by functions in [`client`].
67    ///
68    /// For the meaning of "scheme" and "authority", refer to the documentation of
69    /// [`Uri`](http::uri::Uri).
70    ///
71    /// The `path` should implement [`PathRewriter`].
72    pub fn from<S, A>(
73        client: Client<C, B>,
74        scheme: S,
75        authority: A,
76        path: Pr,
77    ) -> Result<Self, HttpError>
78    where
79        Scheme: TryFrom<S>,
80        <Scheme as TryFrom<S>>::Error: Into<HttpError>,
81        Authority: TryFrom<A>,
82        <Authority as TryFrom<A>>::Error: Into<HttpError>,
83    {
84        let scheme = scheme.try_into().map_err(Into::into)?;
85        let authority = authority.try_into().map_err(Into::into)?;
86        Ok(Self {
87            client,
88            scheme,
89            authority,
90            path,
91        })
92    }
93}
94
95impl<Pr, B> OneshotService<Pr, HttpConnector, B>
96where
97    B: HttpBody + Send,
98    B::Data: Send,
99{
100    /// Use [`client::http_default()`] to build a client.
101    ///
102    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
103    ///
104    /// The `path` should implement [`PathRewriter`].
105    pub fn http_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
106    where
107        Authority: TryFrom<A>,
108        <Authority as TryFrom<A>>::Error: Into<HttpError>,
109    {
110        let authority = authority.try_into().map_err(Into::into)?;
111        Ok(Self {
112            client: client::http_default(),
113            scheme: Scheme::HTTP,
114            authority,
115            path,
116        })
117    }
118}
119
120#[cfg(any(feature = "https", feature = "nativetls"))]
121impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
122where
123    B: HttpBody + Send,
124    B::Data: Send,
125{
126    /// Use [`client::https_default()`] to build a client.
127    ///
128    /// This is the same as [`Self::nativetls_default()`].
129    ///
130    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
131    ///
132    /// The `path` should implement [`PathRewriter`].
133    #[cfg_attr(docsrs, doc(cfg(any(feature = "https", feature = "nativetls"))))]
134    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
135    where
136        Authority: TryFrom<A>,
137        <Authority as TryFrom<A>>::Error: Into<HttpError>,
138    {
139        let authority = authority.try_into().map_err(Into::into)?;
140        Ok(Self {
141            client: client::https_default(),
142            scheme: Scheme::HTTPS,
143            authority,
144            path,
145        })
146    }
147}
148
149#[cfg(feature = "nativetls")]
150impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
151where
152    B: HttpBody + Send,
153    B::Data: Send,
154{
155    /// Use [`client::nativetls_default()`] to build a client.
156    ///
157    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
158    ///
159    /// The `path` should implement [`PathRewriter`].
160    #[cfg_attr(docsrs, doc(cfg(feature = "nativetls")))]
161    pub fn nativetls_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
162    where
163        Authority: TryFrom<A>,
164        <Authority as TryFrom<A>>::Error: Into<HttpError>,
165    {
166        let authority = authority.try_into().map_err(Into::into)?;
167        Ok(Self {
168            client: client::nativetls_default(),
169            scheme: Scheme::HTTPS,
170            authority,
171            path,
172        })
173    }
174}
175
176#[cfg(feature = "__rustls")]
177impl<Pr, B> OneshotService<Pr, RustlsConnector<HttpConnector>, B>
178where
179    B: HttpBody + Send,
180    B::Data: Send,
181{
182    /// Use [`client::rustls_default()`] to build a client.
183    ///
184    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
185    ///
186    /// The `path` should implement [`PathRewriter`].
187    #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
188    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
189    where
190        Authority: TryFrom<A>,
191        <Authority as TryFrom<A>>::Error: Into<HttpError>,
192    {
193        let authority = authority.try_into().map_err(Into::into)?;
194        Ok(Self {
195            client: client::rustls_default(),
196            scheme: Scheme::HTTPS,
197            authority,
198            path,
199        })
200    }
201}
202
203impl<C, B, Pr> Service<Request<B>> for OneshotService<Pr, C, B>
204where
205    C: Connect + Clone + Send + Sync + 'static,
206    B: HttpBody + Send + 'static,
207    B::Data: Send,
208    B::Error: Into<BoxErr>,
209    Pr: PathRewriter,
210{
211    type Response = Result<Response<Body>, Error>;
212    type Error = Infallible;
213    type Future = RevProxyFuture;
214
215    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
216        Poll::Ready(Ok(()))
217    }
218
219    fn call(&mut self, req: Request<B>) -> Self::Future {
220        RevProxyFuture::new(
221            &self.client,
222            req,
223            &self.scheme,
224            &self.authority,
225            &mut self.path,
226        )
227    }
228}
229
230#[cfg(test)]
231mod test {
232    use super::*;
233    use crate::test_helper;
234    use crate::ReplaceAll;
235
236    use http::uri::{Parts, Uri};
237
238    fn make_svc() -> OneshotService<ReplaceAll<'static>, HttpConnector, String> {
239        let uri = Uri::try_from(&mockito::server_url());
240        assert!(uri.is_ok());
241        let uri = uri.unwrap();
242
243        let Parts {
244            scheme, authority, ..
245        } = uri.into_parts();
246
247        let svc = OneshotService::from(
248            client::http_default(),
249            scheme.unwrap(),
250            authority.unwrap(),
251            ReplaceAll("foo", "goo"),
252        );
253        assert!(svc.is_ok());
254        svc.unwrap()
255    }
256
257    #[tokio::test]
258    async fn match_path() {
259        let mut svc = make_svc();
260        test_helper::match_path(&mut svc).await;
261    }
262
263    #[tokio::test]
264    async fn match_query() {
265        let mut svc = make_svc();
266        test_helper::match_query(&mut svc).await;
267    }
268
269    #[tokio::test]
270    async fn match_post() {
271        let mut svc = make_svc();
272        test_helper::match_post(&mut svc).await;
273    }
274
275    #[tokio::test]
276    async fn match_header() {
277        let mut svc = make_svc();
278        test_helper::match_header(&mut svc).await;
279    }
280}