salvo_proxy/
lib.rs

1//! Provide HTTP proxy capabilities for the Salvo web framework.
2//!
3//! This crate allows you to easily forward requests to upstream servers,
4//! supporting both HTTP and HTTPS protocols. It's useful for creating API gateways,
5//! load balancers, and reverse proxies.
6//!
7//! # Example
8//!
9//! In this example, requests to different hosts are proxied to different upstream servers:
10//! - Requests to http://127.0.0.1:5800/ are proxied to https://www.rust-lang.org
11//! - Requests to http://localhost:5800/ are proxied to https://crates.io
12//!
13//! ```no_run
14//! use salvo_core::prelude::*;
15//! use salvo_proxy::Proxy;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new()
20//!         .push(
21//!             Router::new()
22//!                 .host("127.0.0.1")
23//!                 .path("{**rest}")
24//!                 .goal(Proxy::use_hyper_client("https://www.rust-lang.org")),
25//!         )
26//!         .push(
27//!             Router::new()
28//!                 .host("localhost")
29//!                 .path("{**rest}")
30//!                 .goal(Proxy::use_hyper_client("https://crates.io")),
31//!         );
32//!
33//!     let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
34//!     Server::new(acceptor).serve(router).await;
35//! }
36//! ```
37#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43
44use hyper::upgrade::OnUpgrade;
45use percent_encoding::{CONTROLS, utf8_percent_encode};
46use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
47use salvo_core::http::uri::Uri;
48use salvo_core::http::{ReqBody, ResBody, StatusCode};
49use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
50
51#[macro_use]
52mod cfg;
53
54cfg_feature! {
55    #![feature = "hyper-client"]
56    mod hyper_client;
57    pub use hyper_client::*;
58}
59cfg_feature! {
60    #![feature = "reqwest-client"]
61    mod reqwest_client;
62    pub use reqwest_client::*;
63}
64
65cfg_feature! {
66    #![feature = "unix-sock-client"]
67    #[cfg(unix)]
68    mod unix_sock_client;
69    #[cfg(unix)]
70    pub use unix_sock_client::*;
71}
72
73type HyperRequest = hyper::Request<ReqBody>;
74type HyperResponse = hyper::Response<ResBody>;
75
76/// Encode url path. This can be used when build your custom url path getter.
77#[inline]
78pub(crate) fn encode_url_path(path: &str) -> String {
79    path.split('/')
80        .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
81        .collect::<Vec<_>>()
82        .join("/")
83}
84
85/// Client trait for implementing different HTTP clients for proxying.
86///
87/// Implement this trait to create custom proxy clients with different
88/// backends or configurations.
89pub trait Client: Send + Sync + 'static {
90    /// Error type returned by the client.
91    type Error: StdError + Send + Sync + 'static;
92
93    /// Execute a request through the proxy client.
94    fn execute(
95        &self,
96        req: HyperRequest,
97        upgraded: Option<OnUpgrade>,
98    ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
99}
100
101/// Upstreams trait for selecting target servers.
102///
103/// Implement this trait to customize how target servers are selected
104/// for proxying requests. This can be used to implement load balancing,
105/// failover, or other server selection strategies.
106pub trait Upstreams: Send + Sync + 'static {
107    /// Error type returned when selecting a server fails.
108    type Error: StdError + Send + Sync + 'static;
109
110    /// Elect a server to handle the current request.
111    fn elect(&self) -> impl Future<Output = Result<&str, Self::Error>> + Send;
112}
113impl Upstreams for &'static str {
114    type Error = Infallible;
115
116    async fn elect(&self) -> Result<&str, Self::Error> {
117        Ok(*self)
118    }
119}
120impl Upstreams for String {
121    type Error = Infallible;
122    async fn elect(&self) -> Result<&str, Self::Error> {
123        Ok(self.as_str())
124    }
125}
126
127impl<const N: usize> Upstreams for [&'static str; N] {
128    type Error = Error;
129    async fn elect(&self) -> Result<&str, Self::Error> {
130        if self.is_empty() {
131            return Err(Error::other("upstreams is empty"));
132        }
133        let index = fastrand::usize(..self.len());
134        Ok(self[index])
135    }
136}
137
138impl<T> Upstreams for Vec<T>
139where
140    T: AsRef<str> + Send + Sync + 'static,
141{
142    type Error = Error;
143    async fn elect(&self) -> Result<&str, Self::Error> {
144        if self.is_empty() {
145            return Err(Error::other("upstreams is empty"));
146        }
147        let index = fastrand::usize(..self.len());
148        Ok(self[index].as_ref())
149    }
150}
151
152/// Url part getter. You can use this to get the proxied url path or query.
153pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
154
155/// Default url path getter.
156///
157/// This getter will get the last param as the rest url path from request.
158/// In most case you should use wildcard param, like `{**rest}`, `{*+rest}`.
159pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
160    req.params().tail().map(encode_url_path)
161}
162/// Default url query getter. This getter just return the query string from request uri.
163pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
164    req.uri().query().map(Into::into)
165}
166
167/// Handler that can proxy request to other server.
168#[non_exhaustive]
169pub struct Proxy<U, C>
170where
171    U: Upstreams,
172    C: Client,
173{
174    /// Upstreams list.
175    pub upstreams: U,
176    /// [`Client`] for proxy.
177    pub client: C,
178    /// Url path getter.
179    pub url_path_getter: UrlPartGetter,
180    /// Url query getter.
181    pub url_query_getter: UrlPartGetter,
182}
183
184impl<U, C> Proxy<U, C>
185where
186    U: Upstreams,
187    U::Error: Into<BoxedError>,
188    C: Client,
189{
190    /// Create new `Proxy` with upstreams list.
191    pub fn new(upstreams: U, client: C) -> Self {
192        Proxy {
193            upstreams,
194            client,
195            url_path_getter: Box::new(default_url_path_getter),
196            url_query_getter: Box::new(default_url_query_getter),
197        }
198    }
199
200    /// Set url path getter.
201    #[inline]
202    pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
203    where
204        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
205    {
206        self.url_path_getter = Box::new(url_path_getter);
207        self
208    }
209
210    /// Set url query getter.
211    #[inline]
212    pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
213    where
214        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
215    {
216        self.url_query_getter = Box::new(url_query_getter);
217        self
218    }
219
220    /// Get upstreams list.
221    #[inline]
222    pub fn upstreams(&self) -> &U {
223        &self.upstreams
224    }
225    /// Get upstreams mutable list.
226    #[inline]
227    pub fn upstreams_mut(&mut self) -> &mut U {
228        &mut self.upstreams
229    }
230
231    /// Get client reference.
232    #[inline]
233    pub fn client(&self) -> &C {
234        &self.client
235    }
236    /// Get client mutable reference.
237    #[inline]
238    pub fn client_mut(&mut self) -> &mut C {
239        &mut self.client
240    }
241
242    async fn build_proxied_request(
243        &self,
244        req: &mut Request,
245        depot: &Depot,
246    ) -> Result<HyperRequest, Error> {
247        let upstream = self.upstreams.elect().await.map_err(Error::other)?;
248        if upstream.is_empty() {
249            tracing::error!("upstreams is empty");
250            return Err(Error::other("upstreams is empty"));
251        }
252
253        let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
254        let query = (self.url_query_getter)(req, depot);
255        let rest = if let Some(query) = query {
256            if query.starts_with('?') {
257                format!("{path}{query}")
258            } else {
259                format!("{path}?{query}")
260            }
261        } else {
262            path
263        };
264        let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
265            format!("{}{}", upstream.trim_end_matches('/'), rest)
266        } else if upstream.ends_with('/') || rest.starts_with('/') {
267            format!("{upstream}{rest}")
268        } else if rest.is_empty() {
269            upstream.to_string()
270        } else {
271            format!("{upstream}/{rest}")
272        };
273        let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
274        let mut build = hyper::Request::builder()
275            .method(req.method())
276            .uri(&forward_url);
277        for (key, value) in req.headers() {
278            if key != HOST {
279                build = build.header(key, value);
280            }
281        }
282        if let Some(host) = forward_url
283            .host()
284            .and_then(|host| HeaderValue::from_str(host).ok())
285        {
286            build = build.header(HeaderName::from_static("host"), host);
287        }
288        // let x_forwarded_for_header_name = "x-forwarded-for";
289        // // Add forwarding information in the headers
290        // match request.headers_mut().entry(x_forwarded_for_header_name) {
291        //     Ok(header_entry) => {
292        //         match header_entry {
293        //             hyper::header::Entry::Vacant(entry) => {
294        //                 let addr = format!("{}", client_ip);
295        //                 entry.insert(addr.parse().unwrap());
296        //             },
297        //             hyper::header::Entry::Occupied(mut entry) => {
298        //                 let addr = format!("{}, {}", entry.get().to_str().unwrap(), client_ip);
299        //                 entry.insert(addr.parse().unwrap());
300        //             }
301        //         }
302        //     }
303        //     // shouldn't happen...
304        //     Err(_) => panic!("Invalid header name: {}", x_forwarded_for_header_name),
305        // }
306        build.body(req.take_body()).map_err(Error::other)
307    }
308}
309
310#[async_trait]
311impl<U, C> Handler for Proxy<U, C>
312where
313    U: Upstreams,
314    U::Error: Into<BoxedError>,
315    C: Client,
316{
317    async fn handle(
318        &self,
319        req: &mut Request,
320        depot: &mut Depot,
321        res: &mut Response,
322        _ctrl: &mut FlowCtrl,
323    ) {
324        match self.build_proxied_request(req, depot).await {
325            Ok(proxied_request) => {
326                match self
327                    .client
328                    .execute(proxied_request, req.extensions_mut().remove())
329                    .await
330                {
331                    Ok(response) => {
332                        let (
333                            salvo_core::http::response::Parts {
334                                status,
335                                // version,
336                                headers,
337                                // extensions,
338                                ..
339                            },
340                            body,
341                        ) = response.into_parts();
342                        res.status_code(status);
343                        for name in headers.keys() {
344                            for value in headers.get_all(name) {
345                                res.headers.append(name, value.to_owned());
346                            }
347                        }
348                        res.body(body);
349                    }
350                    Err(e) => {
351                        tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
352                        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
353                    }
354                }
355            }
356            Err(e) => {
357                tracing::error!(error = ?e, "build proxied request failed");
358            }
359        }
360    }
361}
362#[inline]
363#[allow(dead_code)]
364fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
365    if headers
366        .get(&CONNECTION)
367        .map(|value| {
368            value
369                .to_str()
370                .unwrap_or_default()
371                .split(',')
372                .any(|e| e.trim() == UPGRADE)
373        })
374        .unwrap_or(false)
375    {
376        if let Some(upgrade_value) = headers.get(&UPGRADE) {
377            tracing::debug!(
378                "Found upgrade header with value: {:?}",
379                upgrade_value.to_str()
380            );
381            return upgrade_value.to_str().ok();
382        }
383    }
384
385    None
386}
387
388// Unit tests for Proxy
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn test_encode_url_path() {
395        let path = "/test/path";
396        let encoded_path = encode_url_path(path);
397        assert_eq!(encoded_path, "/test/path");
398    }
399
400    #[test]
401    fn test_get_upgrade_type() {
402        let mut headers = HeaderMap::new();
403        headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
404        headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
405        let upgrade_type = get_upgrade_type(&headers);
406        assert_eq!(upgrade_type, Some("websocket"));
407    }
408}