axum_proxy/
oneshot.rs

1use std::convert::Infallible;
2use std::task::{Context, Poll};
3
4use client::HttpConnector;
5#[cfg(feature = "__rustls")]
6use client::RustlsConnector;
7use http::uri::{Authority, Scheme};
8use http::{Error as HttpError, Request, Response};
9//use hyper::body::{Body, HttpBody};
10use hyper::body::{Body as HttpBody, Incoming};
11#[cfg(feature = "nativetls")]
12use hyper_tls::HttpsConnector as NativeTlsConnector;
13use hyper_util::client::legacy::connect::Connect;
14use hyper_util::client::legacy::Client;
15use tower_service::Service;
16
17use crate::future::RevProxyFuture;
18use crate::rewrite::PathRewriter;
19use crate::{client, Error};
20
21type BoxErr = Box<dyn std::error::Error + Send + Sync>;
22
23/// A [`Service<Request<B>>`] that sends a request and returns the response, owning a [`Client`].
24///
25/// ```
26/// # async fn run_test() {
27/// # use axum_proxy::OneshotService;
28/// # use axum_proxy::Static;
29/// # use tower_service::Service;
30/// # use http_body_util::Empty;
31/// # use http::Request;
32/// # use hyper::body::Bytes;
33/// let mut svc = OneshotService::http_default("example.com:1234", Static("bar")).unwrap();
34/// let req = Request::builder()
35///     .uri("https://myserver.com/foo")
36///     .body(Empty::<Bytes>::new())
37///     .unwrap();
38/// // http://example.com:1234/bar
39/// let _res = svc.call(req).await.unwrap();
40/// # }
41/// ```
42#[expect(clippy::module_name_repetitions)]
43pub struct OneshotService<Pr, C = HttpConnector, B = Incoming> {
44    client: Client<C, B>,
45    scheme: Scheme,
46    authority: Authority,
47    path: Pr,
48}
49
50impl<Pr: Clone, C: Clone, B> Clone for OneshotService<Pr, C, B> {
51    #[inline]
52    fn clone(&self) -> Self {
53        Self {
54            client: self.client.clone(),
55            scheme: self.scheme.clone(),
56            authority: self.authority.clone(),
57            path: self.path.clone(),
58        }
59    }
60}
61
62impl<Pr, C, B> OneshotService<Pr, C, B> {
63    /// Initializes a service with a general [`Client`].
64    ///
65    /// A client can be built by functions in [`client`].
66    ///
67    /// For the meaning of "scheme" and "authority", refer to the documentation of
68    /// [`Uri`](http::uri::Uri).
69    ///
70    /// The `path` should implement [`PathRewriter`].
71    ///
72    /// # Errors
73    ///
74    /// When `scheme` or `authority` cannot be converted into a [`Scheme`] or [`Authority`].
75    pub fn from<S, A>(
76        client: Client<C, B>,
77        scheme: S,
78        authority: A,
79        path: Pr,
80    ) -> Result<Self, HttpError>
81    where
82        Scheme: TryFrom<S>,
83        <Scheme as TryFrom<S>>::Error: Into<HttpError>,
84        Authority: TryFrom<A>,
85        <Authority as TryFrom<A>>::Error: Into<HttpError>,
86    {
87        let scheme = scheme.try_into().map_err(Into::into)?;
88        let authority = authority.try_into().map_err(Into::into)?;
89        Ok(Self {
90            client,
91            scheme,
92            authority,
93            path,
94        })
95    }
96}
97
98impl<Pr, B> OneshotService<Pr, HttpConnector, B>
99where
100    B: HttpBody + Send,
101    B::Data: Send,
102{
103    /// Use [`client::http_default()`] to build a client.
104    ///
105    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
106    ///
107    /// The `path` should implement [`PathRewriter`].
108    ///
109    /// # Errors
110    ///
111    /// When `authority` cannot be converted into an [`Authority`].
112    pub fn http_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
113    where
114        Authority: TryFrom<A>,
115        <Authority as TryFrom<A>>::Error: Into<HttpError>,
116    {
117        let authority = authority.try_into().map_err(Into::into)?;
118        Ok(Self {
119            client: client::http_default(),
120            scheme: Scheme::HTTP,
121            authority,
122            path,
123        })
124    }
125}
126
127#[cfg(any(feature = "https", feature = "nativetls"))]
128impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
129where
130    B: HttpBody + Send,
131    B::Data: Send,
132{
133    /// Use [`client::https_default()`] to build a client.
134    ///
135    /// This is the same as [`Self::nativetls_default()`].
136    ///
137    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
138    ///
139    /// The `path` should implement [`PathRewriter`].
140    ///
141    /// # Errors
142    ///
143    /// When `authority` cannot be converted into an [`Authority`].
144    #[cfg_attr(docsrs, doc(cfg(any(feature = "https", feature = "nativetls"))))]
145    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
146    where
147        Authority: TryFrom<A>,
148        <Authority as TryFrom<A>>::Error: Into<HttpError>,
149    {
150        let authority = authority.try_into().map_err(Into::into)?;
151        Ok(Self {
152            client: client::https_default(),
153            scheme: Scheme::HTTPS,
154            authority,
155            path,
156        })
157    }
158}
159
160#[cfg(feature = "nativetls")]
161impl<Pr, B> OneshotService<Pr, NativeTlsConnector<HttpConnector>, B>
162where
163    B: HttpBody + Send,
164    B::Data: Send,
165{
166    /// Use [`client::nativetls_default()`] to build a client.
167    ///
168    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
169    ///
170    /// The `path` should implement [`PathRewriter`].
171    #[cfg_attr(docsrs, doc(cfg(feature = "nativetls")))]
172    /// # Errors
173    ///
174    /// When `authority` cannot be converted into an [`Authority`].
175    pub fn nativetls_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
176    where
177        Authority: TryFrom<A>,
178        <Authority as TryFrom<A>>::Error: Into<HttpError>,
179    {
180        let authority = authority.try_into().map_err(Into::into)?;
181        Ok(Self {
182            client: client::nativetls_default(),
183            scheme: Scheme::HTTPS,
184            authority,
185            path,
186        })
187    }
188}
189
190#[cfg(feature = "__rustls")]
191impl<Pr, B> OneshotService<Pr, RustlsConnector<HttpConnector>, B>
192where
193    B: HttpBody + Send,
194    B::Data: Send,
195{
196    /// Use [`client::rustls_default()`] to build a client.
197    ///
198    /// For the meaning of "authority", refer to the documentation of [`Uri`](http::uri::Uri).
199    ///
200    /// The `path` should implement [`PathRewriter`].
201    #[cfg_attr(docsrs, doc(cfg(feature = "rustls")))]
202    /// # Errors
203    ///
204    /// When `authority` cannot be converted into an [`Authority`].
205    pub fn https_default<A>(authority: A, path: Pr) -> Result<Self, HttpError>
206    where
207        Authority: TryFrom<A>,
208        <Authority as TryFrom<A>>::Error: Into<HttpError>,
209    {
210        let authority = authority.try_into().map_err(Into::into)?;
211        Ok(Self {
212            client: client::rustls_default(),
213            scheme: Scheme::HTTPS,
214            authority,
215            path,
216        })
217    }
218}
219
220impl<C, B, Pr> Service<Request<B>> for OneshotService<Pr, C, B>
221where
222    C: Connect + Clone + Send + Sync + 'static,
223    B: HttpBody + Send + 'static + Unpin,
224    B::Data: Send,
225    B::Error: Into<BoxErr>,
226    Pr: PathRewriter,
227{
228    type Response = Result<Response<Incoming>, Error>;
229    type Error = Infallible;
230    type Future = RevProxyFuture;
231
232    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
233        Poll::Ready(Ok(()))
234    }
235
236    fn call(&mut self, req: Request<B>) -> Self::Future {
237        RevProxyFuture::new(
238            &self.client,
239            req,
240            &self.scheme,
241            &self.authority,
242            &mut self.path,
243        )
244    }
245}
246
247#[cfg(test)]
248mod test {
249    use http::uri::{Parts, Uri};
250    use mockito::ServerGuard;
251
252    use super::*;
253    use crate::{test_helper, ReplaceAll};
254
255    async fn make_svc() -> (
256        ServerGuard,
257        OneshotService<ReplaceAll<'static>, HttpConnector, String>,
258    ) {
259        let server = mockito::Server::new_async().await;
260        let uri = Uri::try_from(&server.url());
261        assert!(uri.is_ok());
262        let uri = uri.unwrap();
263
264        let Parts {
265            scheme, authority, ..
266        } = uri.into_parts();
267
268        let svc = OneshotService::from(
269            client::http_default(),
270            scheme.unwrap(),
271            authority.unwrap(),
272            ReplaceAll("foo", "goo"),
273        );
274        assert!(svc.is_ok());
275        (server, svc.unwrap())
276    }
277
278    #[tokio::test]
279    async fn match_path() {
280        let (mut server, mut svc) = make_svc().await;
281        test_helper::match_path(&mut server, &mut svc).await;
282    }
283
284    #[tokio::test]
285    async fn match_query() {
286        let (mut server, mut svc) = make_svc().await;
287        test_helper::match_query(&mut server, &mut svc).await;
288    }
289
290    #[tokio::test]
291    async fn match_post() {
292        let (mut server, mut svc) = make_svc().await;
293        test_helper::match_post(&mut server, &mut svc).await;
294    }
295
296    #[tokio::test]
297    async fn match_header() {
298        let (mut server, mut svc) = make_svc().await;
299        test_helper::match_header(&mut server, &mut svc).await;
300    }
301}