salvo_proxy/
reqwest_client.rs

1use futures_util::TryStreamExt;
2use hyper::upgrade::OnUpgrade;
3use reqwest::Client as InnerClient;
4use salvo_core::Error;
5use salvo_core::http::{ResBody, StatusCode};
6use salvo_core::rt::tokio::TokioIo;
7use tokio::io::copy_bidirectional;
8
9use crate::{BoxedError, Client, HyperRequest, HyperResponse, Proxy, Upstreams};
10
11/// A [`Client`] implementation based on [`reqwest::Client`].
12///
13/// This client provides proxy capabilities using the Reqwest HTTP client.
14/// It supports all features of Reqwest including automatic redirect handling,
15/// connection pooling, and other HTTP client features.
16#[derive(Clone, Debug)]
17pub struct ReqwestClient {
18    inner: InnerClient,
19}
20
21impl<U> Proxy<U, ReqwestClient>
22where
23    U: Upstreams,
24    U::Error: Into<BoxedError>,
25{
26    /// Create a new `Proxy` using the default Reqwest client.
27    ///
28    /// This is a convenient way to create a proxy with standard configuration.
29    pub fn use_reqwest_client(upstreams: U) -> Self {
30        Self::new(upstreams, ReqwestClient::default())
31    }
32}
33
34impl Default for ReqwestClient {
35    fn default() -> Self {
36        #[cfg(feature = "ring")]
37        let _ = rustls::crypto::ring::default_provider().install_default();
38        Self::new(InnerClient::builder().build().expect("failed to build reqwest client"))
39    }
40}
41
42impl ReqwestClient {
43    /// Create a new `ReqwestClient` with the given [`reqwest::Client`].
44    #[must_use]
45    pub fn new(inner: InnerClient) -> Self {
46        Self { inner }
47    }
48}
49
50impl Client for ReqwestClient {
51    type Error = salvo_core::Error;
52
53    async fn execute(
54        &self,
55        proxied_request: HyperRequest,
56        request_upgraded: Option<OnUpgrade>,
57    ) -> Result<HyperResponse, Self::Error> {
58        let request_upgrade_type =
59            crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());
60
61        let proxied_request = proxied_request
62            .map(|s| reqwest::Body::wrap_stream(s.map_ok(|s| s.into_data().unwrap_or_default())));
63        let response = self
64            .inner
65            .execute(proxied_request.try_into().map_err(Error::other)?)
66            .await
67            .map_err(Error::other)?;
68
69        let res_headers = response.headers().clone();
70        let hyper_response = hyper::Response::builder()
71            .status(response.status())
72            .version(response.version());
73
74        let mut hyper_response = if response.status() == StatusCode::SWITCHING_PROTOCOLS {
75            let response_upgrade_type = crate::get_upgrade_type(response.headers());
76
77            if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
78                let mut response_upgraded = response.upgrade().await.map_err(|e| {
79                    Error::other(format!("response does not have an upgrade extension. {e}"))
80                })?;
81                if let Some(request_upgraded) = request_upgraded {
82                    tokio::spawn(async move {
83                        match request_upgraded.await {
84                            Ok(request_upgraded) => {
85                                let mut request_upgraded = TokioIo::new(request_upgraded);
86                                if let Err(e) = copy_bidirectional(
87                                    &mut response_upgraded,
88                                    &mut request_upgraded,
89                                )
90                                .await
91                                {
92                                    tracing::error!(error = ?e, "coping between upgraded connections failed");
93                                }
94                            }
95                            Err(e) => {
96                                tracing::error!(error = ?e, "upgrade request failed");
97                            }
98                        }
99                    });
100                } else {
101                    return Err(Error::other("request does not have an upgrade extension"));
102                }
103            } else {
104                return Err(Error::other("upgrade type mismatch"));
105            }
106            hyper_response.body(ResBody::None).map_err(Error::other)?
107        } else {
108            hyper_response
109                .body(ResBody::stream(response.bytes_stream()))
110                .map_err(Error::other)?
111        };
112        *hyper_response.headers_mut() = res_headers;
113        Ok(hyper_response)
114    }
115}
116
117// Unit tests for Proxy
118#[cfg(test)]
119mod tests {
120    use salvo_core::prelude::*;
121    use salvo_core::test::*;
122
123    use super::*;
124    use crate::{Proxy, Upstreams};
125
126    #[tokio::test]
127    async fn test_upstreams_elect() {
128        let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
129        let proxy = Proxy::new(upstreams.clone(), ReqwestClient::default());
130        let request = Request::new();
131        let depot = Depot::new();
132        let elected_upstream = proxy.upstreams().elect(&request, &depot).await.unwrap();
133        assert!(upstreams.contains(&elected_upstream));
134    }
135
136    #[tokio::test]
137    async fn test_reqwest_client() {
138        let router = Router::new().push(Router::with_path("rust/{**rest}").goal(Proxy::new(
139            vec!["https://salvo.rs"],
140            ReqwestClient::default(),
141        )));
142
143        let content = TestClient::get("http://127.0.0.1:5801/rust/guide/index.html")
144            .send(router)
145            .await
146            .take_string()
147            .await
148            .unwrap();
149        assert!(content.contains("Salvo"));
150    }
151
152    #[test]
153    fn test_others() {
154        let mut handler = Proxy::new(["https://www.bing.com"], ReqwestClient::default());
155        assert_eq!(handler.upstreams().len(), 1);
156        assert_eq!(handler.upstreams_mut().len(), 1);
157    }
158}