axum_proxy/
oneshot.rs

1use std::convert::Infallible;
2use std::task::{Context, Poll};
3
4use client::HttpConnector;
5#[cfg(feature = "__rustls")]
6use client::RustlsConnector;
7use http::uri::{Authority, Scheme};
8use http::{Error as HttpError, Request, Response};
9//use hyper::body::{Body, HttpBody};
10use hyper::body::{Body as HttpBody, Incoming};
11#[cfg(feature = "nativetls")]
12use hyper_tls::HttpsConnector as NativeTlsConnector;
13use hyper_util::client::legacy::Client;
14use hyper_util::client::legacy::connect::Connect;
15use tower_service::Service;
16
17use crate::future::RevProxyFuture;
18use crate::rewrite::PathRewriter;
19use crate::{ProxyError, client};
20
21type BoxErr = Box<dyn std::error::Error + Send + Sync>;
22
23/// A [`Service<Request<B>>`] that sends a request and returns the response, owning a [`Client`].
24///
25/// ```
26/// # async fn run_test() {
27/// # use axum_proxy::OneshotService;
28/// # use axum_proxy::Static;
29/// # use tower_service::Service;
30/// # use http_body_util::Empty;
31/// # use http::Request;
32/// # use hyper::body::Bytes;
33/// let mut svc = OneshotService::http_default("example.com:1234", Static("bar")).unwrap();
34/// let req = Request::builder()
35///     .uri("https://myserver.com/foo")
36///     .body(Empty::<Bytes>::new())
37///     .unwrap();
38/// // http://example.com:1234/bar
39/// let _res = svc.call(req).await.unwrap();
40/// # }
41/// ```
42pub struct OneshotService<Pr, C = HttpConnector, B = Incoming> {
43    client: Client<C, B>,
44    scheme: Scheme,
45    authority: Authority,
46    path: Pr,
47}
48
49impl<Pr: Clone, C: Clone, B> Clone for OneshotService<Pr, C, B> {
50    #[inline]
51    fn clone(&self) -> Self {
52        Self {
53            client: self.client.clone(),
54            scheme: self.scheme.clone(),
55            authority: self.authority.clone(),
56            path: self.path.clone(),
57        }
58    }
59}
60
61impl<Pr, C, B> OneshotService<Pr, C, B> {
62    /// Initializes a service with a general [`Client`].
63    ///
64    /// A client can be built by functions in [`client`].
65    ///
66    /// For the meaning of "scheme" and "authority", refer to the documentation of
67    /// [`Uri`](http::uri::Uri).
68    ///
69    /// The `path` should implement [`PathRewriter`].
70    ///
71    /// # Errors
72    ///
73    /// When `scheme` or `authority` cannot be converted into a [`Scheme`] or [`Authority`].
74    pub fn from<S, A>(
75        client: Client<C, B>,
76        scheme: S,
77        authority: A,
78        path: Pr,
79    ) -> Result<Self, HttpError>
80    where
81        Scheme: TryFrom<S>,
82        <Scheme as TryFrom<S>>::Error: Into<HttpError>,
83        Authority: TryFrom<A>,
84        <Authority as TryFrom<A>>::Error: Into<HttpError>,
85    {
86        let scheme = scheme.try_into().map_err(Into::into)?;
87        let authority = authority.try_into().map_err(Into::into)?;
88        Ok(Self {
89            client,
90            scheme,
91            authority,
92            path,
93        })
94    }
95}
96
97impl<Pr, B> OneshotService<Pr, HttpConnector, B>
98where
99    B: HttpBody + Send,
100    B::Data: Send,
101{
102    /// Use [`client::http_default()`] to build a client.
103    ///
104    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
105    ///
106    /// The `path` should implement [`PathRewriter`].
107    ///
108    /// # Errors
109    ///
110    /// When `authority` cannot be converted into an [`Authority`].
111    pub fn http_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
112    where
113        Authority: TryFrom<A>,
114        <Authority as TryFrom<A>>::Error: Into<HttpError>,
115    {
116        let authority = authority.try_into().map_err(Into::into)?;
117        Ok(Self {
118            client: client::http_default(),
119            scheme: Scheme::HTTP,
120            authority,
121            path,
122        })
123    }
124}
125
126#[cfg(any(feature = "https", feature = "nativetls"))]
127impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
128where
129    B: HttpBody + Send,
130    B::Data: Send,
131{
132    /// Use [`client::https_default()`] to build a client.
133    ///
134    /// This is the same as [`Self::nativetls_default()`].
135    ///
136    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
137    ///
138    /// The `path` should implement [`PathRewriter`].
139    ///
140    /// # Errors
141    ///
142    /// When `authority` cannot be converted into an [`Authority`].
143    #[cfg_attr(docsrs, doc(cfg(any(feature = "https", feature = "nativetls"))))]
144    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
145    where
146        Authority: TryFrom<A>,
147        <Authority as TryFrom<A>>::Error: Into<HttpError>,
148    {
149        let authority = authority.try_into().map_err(Into::into)?;
150        Ok(Self {
151            client: client::https_default(),
152            scheme: Scheme::HTTPS,
153            authority,
154            path,
155        })
156    }
157}
158
159#[cfg(feature = "nativetls")]
160impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
161where
162    B: HttpBody + Send,
163    B::Data: Send,
164{
165    /// Use [`client::nativetls_default()`] to build a client.
166    ///
167    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
168    ///
169    /// The `path` should implement [`PathRewriter`].
170    #[cfg_attr(docsrs, doc(cfg(feature = "nativetls")))]
171    /// # Errors
172    ///
173    /// When `authority` cannot be converted into an [`Authority`].
174    pub fn nativetls_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
175    where
176        Authority: TryFrom<A>,
177        <Authority as TryFrom<A>>::Error: Into<HttpError>,
178    {
179        let authority = authority.try_into().map_err(Into::into)?;
180        Ok(Self {
181            client: client::nativetls_default(),
182            scheme: Scheme::HTTPS,
183            authority,
184            path,
185        })
186    }
187}
188
189#[cfg(feature = "__rustls")]
190impl<Pr, B> OneshotService<Pr, RustlsConnector<HttpConnector>, B>
191where
192    B: HttpBody + Send,
193    B::Data: Send,
194{
195    /// Use [`client::rustls_default()`] to build a client.
196    ///
197    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
198    ///
199    /// The `path` should implement [`PathRewriter`].
200    #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
201    /// # Errors
202    ///
203    /// When `authority` cannot be converted into an [`Authority`].
204    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
205    where
206        Authority: TryFrom<A>,
207        <Authority as TryFrom<A>>::Error: Into<HttpError>,
208    {
209        let authority = authority.try_into().map_err(Into::into)?;
210        Ok(Self {
211            client: client::rustls_default(),
212            scheme: Scheme::HTTPS,
213            authority,
214            path,
215        })
216    }
217}
218
219impl<C, B, Pr> Service<Request<B>> for OneshotService<Pr, C, B>
220where
221    C: Connect + Clone + Send + Sync + 'static,
222    B: HttpBody + Send + 'static + Unpin,
223    B::Data: Send,
224    B::Error: Into<BoxErr>,
225    Pr: PathRewriter,
226{
227    type Response = Result<Response<Incoming>, ProxyError>;
228    type Error = Infallible;
229    type Future = RevProxyFuture;
230
231    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
232        Poll::Ready(Ok(()))
233    }
234
235    fn call(&mut self, req: Request<B>) -> Self::Future {
236        RevProxyFuture::new(
237            &self.client,
238            req,
239            &self.scheme,
240            &self.authority,
241            &mut self.path,
242        )
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use http::uri::{Parts, Uri};
249    use mockito::ServerGuard;
250
251    use super::*;
252    use crate::{ReplaceAll, test_helper};
253
254    async fn make_svc() -> (
255        ServerGuard,
256        OneshotService<ReplaceAll<'static>, HttpConnector, String>,
257    ) {
258        let server = mockito::Server::new_async().await;
259        let uri = Uri::try_from(&server.url());
260        assert!(uri.is_ok());
261        let uri = uri.unwrap();
262
263        let Parts {
264            scheme, authority, ..
265        } = uri.into_parts();
266
267        let svc = OneshotService::from(
268            client::http_default(),
269            scheme.unwrap(),
270            authority.unwrap(),
271            ReplaceAll("foo", "goo"),
272        );
273        assert!(svc.is_ok());
274        (server, svc.unwrap())
275    }
276
277    #[tokio::test]
278    async fn match_path() {
279        let (mut server, mut svc) = make_svc().await;
280        test_helper::match_path(&mut server, &mut svc).await;
281    }
282
283    #[tokio::test]
284    async fn match_query() {
285        let (mut server, mut svc) = make_svc().await;
286        test_helper::match_query(&mut server, &mut svc).await;
287    }
288
289    #[tokio::test]
290    async fn match_post() {
291        let (mut server, mut svc) = make_svc().await;
292        test_helper::match_post(&mut server, &mut svc).await;
293    }
294
295    #[tokio::test]
296    async fn match_header() {
297        let (mut server, mut svc) = make_svc().await;
298        test_helper::match_header(&mut server, &mut svc).await;
299    }
300}