salvo_proxy/
lib.rs

1//! Provide HTTP proxy capabilities for the Salvo web framework.
2//!
3//! This crate allows you to easily forward requests to upstream servers,
4//! supporting both HTTP and HTTPS protocols. It's useful for creating API gateways,
5//! load balancers, and reverse proxies.
6//!
7//! # Example
8//!
9//! In this example, requests to different hosts are proxied to different upstream servers:
10//! - Requests to <http://127.0.0.1:8698/> are proxied to <https://www.rust-lang.org>
11//! - Requests to <http://localhost:8698/> are proxied to <https://crates.io>
12//!
13//! ```no_run
14//! use salvo_core::prelude::*;
15//! use salvo_proxy::Proxy;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new()
20//!         .push(
21//!             Router::new()
22//!                 .host("127.0.0.1")
23//!                 .path("{**rest}")
24//!                 .goal(Proxy::use_hyper_client("https://www.rust-lang.org")),
25//!         )
26//!         .push(
27//!             Router::new()
28//!                 .host("localhost")
29//!                 .path("{**rest}")
30//!                 .goal(Proxy::use_hyper_client("https://crates.io")),
31//!         );
32//!
33//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
34//!     Server::new(acceptor).serve(router).await;
35//! }
36//! ```
37#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44#[cfg(test)]
45use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
46
47use hyper::upgrade::OnUpgrade;
48#[cfg(not(test))]
49use local_ip_address::{local_ip, local_ipv6};
50use percent_encoding::{AsciiSet, CONTROLS, utf8_percent_encode};
51use salvo_core::conn::SocketAddr;
52use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
53use salvo_core::http::uri::Uri;
54use salvo_core::http::{ReqBody, ResBody, StatusCode};
55use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
56
57#[macro_use]
58mod cfg;
59
60cfg_feature! {
61    #![feature = "hyper-client"]
62    mod hyper_client;
63    pub use hyper_client::*;
64}
65cfg_feature! {
66    #![feature = "reqwest-client"]
67    mod reqwest_client;
68    pub use reqwest_client::*;
69}
70
71cfg_feature! {
72    #![feature = "unix-sock-client"]
73    #[cfg(unix)]
74    mod unix_sock_client;
75    #[cfg(unix)]
76    pub use unix_sock_client::*;
77}
78
79type HyperRequest = hyper::Request<ReqBody>;
80type HyperResponse = hyper::Response<ResBody>;
81
82const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
83
84const QUERY_ENCODE_SET: &AsciiSet = &CONTROLS
85    .add(b' ')
86    .add(b'"')
87    .add(b'#')
88    .add(b'<')
89    .add(b'>')
90    .add(b'`');
91const PATH_ENCODE_SET: &AsciiSet = &QUERY_ENCODE_SET
92    .add(b'?')
93    .add(b'^')
94    .add(b'`')
95    .add(b'{')
96    .add(b'}');
97
98/// Encode url path. This can be used when build your custom url path getter.
99#[inline]
100pub(crate) fn encode_url_path(path: &str) -> String {
101    path.split('/')
102        .map(|s| utf8_percent_encode(s, PATH_ENCODE_SET).to_string())
103        .collect::<Vec<_>>()
104        .join("/")
105}
106
107/// Client trait for implementing different HTTP clients for proxying.
108///
109/// Implement this trait to create custom proxy clients with different
110/// backends or configurations.
111pub trait Client: Send + Sync + 'static {
112    /// Error type returned by the client.
113    type Error: StdError + Send + Sync + 'static;
114
115    /// Execute a request through the proxy client.
116    fn execute(
117        &self,
118        req: HyperRequest,
119        upgraded: Option<OnUpgrade>,
120    ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
121}
122
123/// Upstreams trait for selecting target servers.
124///
125/// Implement this trait to customize how target servers are selected
126/// for proxying requests. This can be used to implement load balancing,
127/// failover, or other server selection strategies.
128pub trait Upstreams: Send + Sync + 'static {
129    /// Error type returned when selecting a server fails.
130    type Error: StdError + Send + Sync + 'static;
131
132    /// Elect a server to handle the current request.
133    fn elect(
134        &self,
135        req: &Request,
136        depot: &Depot,
137    ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
138}
139impl Upstreams for &'static str {
140    type Error = Infallible;
141
142    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
143        Ok(*self)
144    }
145}
146impl Upstreams for String {
147    type Error = Infallible;
148    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
149        Ok(self.as_str())
150    }
151}
152
153impl<const N: usize> Upstreams for [&'static str; N] {
154    type Error = Error;
155    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
156        if self.is_empty() {
157            return Err(Error::other("upstreams is empty"));
158        }
159        let index = fastrand::usize(..self.len());
160        Ok(self[index])
161    }
162}
163
164impl<T> Upstreams for Vec<T>
165where
166    T: AsRef<str> + Send + Sync + 'static,
167{
168    type Error = Error;
169    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
170        if self.is_empty() {
171            return Err(Error::other("upstreams is empty"));
172        }
173        let index = fastrand::usize(..self.len());
174        Ok(self[index].as_ref())
175    }
176}
177
178/// Url part getter. You can use this to get the proxied url path or query.
179pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
180
181/// Host header getter. You can use this to get the host header for the proxied request.
182pub type HostHeaderGetter =
183    Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
184
185/// Default url path getter.
186///
187/// This getter will get the last param as the rest url path from request.
188/// In most case you should use wildcard param, like `{**rest}`, `{*+rest}`.
189pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
190    req.params().tail().map(encode_url_path)
191}
192/// Default url query getter. This getter just return the query string from request uri.
193pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
194    req.uri().query().map(Into::into)
195}
196
197/// Default host header getter. This getter will get the host header from request uri
198pub fn default_host_header_getter(
199    forward_uri: &Uri,
200    _req: &Request,
201    _depot: &Depot,
202) -> Option<String> {
203    if let Some(host) = forward_uri.host() {
204        return Some(String::from(host));
205    }
206
207    None
208}
209
210/// RFC2616 complieant host header getter. This getter will get the host header from request uri,
211/// and add port if it's not default port. Falls back to default upon any forward URI parse error.
212pub fn rfc2616_host_header_getter(
213    forward_uri: &Uri,
214    req: &Request,
215    _depot: &Depot,
216) -> Option<String> {
217    let mut parts: Vec<String> = Vec::with_capacity(2);
218
219    if let Some(host) = forward_uri.host() {
220        parts.push(host.to_owned());
221
222        if let Some(scheme) = forward_uri.scheme_str()
223            && let Some(port) = forward_uri.port_u16()
224            && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
225        {
226            parts.push(port.to_string());
227        }
228    }
229
230    if parts.is_empty() {
231        default_host_header_getter(forward_uri, req, _depot)
232    } else {
233        Some(parts.join(":"))
234    }
235}
236
237/// Preserve original host header getter. Propagates the original request host header to the proxied
238/// request.
239pub fn preserve_original_host_header_getter(
240    forward_uri: &Uri,
241    req: &Request,
242    _depot: &Depot,
243) -> Option<String> {
244    if let Some(host_header) = req.headers().get(HOST)
245        && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
246    {
247        return Some(host);
248    }
249
250    default_host_header_getter(forward_uri, req, _depot)
251}
252
253/// Handler that can proxy request to other server.
254#[non_exhaustive]
255pub struct Proxy<U, C>
256where
257    U: Upstreams,
258    C: Client,
259{
260    /// Upstreams list.
261    pub upstreams: U,
262    /// [`Client`] for proxy.
263    pub client: C,
264    /// Url path getter.
265    pub url_path_getter: UrlPartGetter,
266    /// Url query getter.
267    pub url_query_getter: UrlPartGetter,
268    /// Host header getter
269    pub host_header_getter: HostHeaderGetter,
270    /// Flag to enable x-forwarded-for header.
271    pub client_ip_forwarding_enabled: bool,
272}
273
274impl<U, C> Debug for Proxy<U, C>
275where
276    U: Upstreams,
277    C: Client,
278{
279    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
280        f.debug_struct("Proxy").finish()
281    }
282}
283
284impl<U, C> Proxy<U, C>
285where
286    U: Upstreams,
287    U::Error: Into<BoxedError>,
288    C: Client,
289{
290    /// Create new `Proxy` with upstreams list.
291    #[must_use]
292    pub fn new(upstreams: U, client: C) -> Self {
293        Self {
294            upstreams,
295            client,
296            url_path_getter: Box::new(default_url_path_getter),
297            url_query_getter: Box::new(default_url_query_getter),
298            host_header_getter: Box::new(default_host_header_getter),
299            client_ip_forwarding_enabled: false,
300        }
301    }
302
303    /// Create new `Proxy` with upstreams list and enable x-forwarded-for header.
304    pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
305        Self {
306            upstreams,
307            client,
308            url_path_getter: Box::new(default_url_path_getter),
309            url_query_getter: Box::new(default_url_query_getter),
310            host_header_getter: Box::new(default_host_header_getter),
311            client_ip_forwarding_enabled: true,
312        }
313    }
314
315    /// Set url path getter.
316    #[inline]
317    #[must_use]
318    pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
319    where
320        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
321    {
322        self.url_path_getter = Box::new(url_path_getter);
323        self
324    }
325
326    /// Set url query getter.
327    #[inline]
328    #[must_use]
329    pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
330    where
331        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
332    {
333        self.url_query_getter = Box::new(url_query_getter);
334        self
335    }
336
337    /// Set host header query getter.
338    #[inline]
339    #[must_use]
340    pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
341    where
342        G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
343    {
344        self.host_header_getter = Box::new(host_header_getter);
345        self
346    }
347
348    /// Get upstreams list.
349    #[inline]
350    pub fn upstreams(&self) -> &U {
351        &self.upstreams
352    }
353    /// Get upstreams mutable list.
354    #[inline]
355    pub fn upstreams_mut(&mut self) -> &mut U {
356        &mut self.upstreams
357    }
358
359    /// Get client reference.
360    #[inline]
361    pub fn client(&self) -> &C {
362        &self.client
363    }
364    /// Get client mutable reference.
365    #[inline]
366    pub fn client_mut(&mut self) -> &mut C {
367        &mut self.client
368    }
369
370    /// Enable x-forwarded-for header prepending.
371    #[inline]
372    #[must_use]
373    pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
374        self.client_ip_forwarding_enabled = enable;
375        self
376    }
377
378    async fn build_proxied_request(
379        &self,
380        req: &mut Request,
381        depot: &Depot,
382    ) -> Result<HyperRequest, Error> {
383        let upstream = self
384            .upstreams
385            .elect(req, depot)
386            .await
387            .map_err(Error::other)?;
388
389        if upstream.is_empty() {
390            tracing::error!("upstreams is empty");
391            return Err(Error::other("upstreams is empty"));
392        }
393
394        let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
395        let query = (self.url_query_getter)(req, depot);
396        let rest = if let Some(query) = query {
397            if query.starts_with('?') {
398                format!(
399                    "{path}?{}",
400                    utf8_percent_encode(&query[1..], QUERY_ENCODE_SET)
401                )
402            } else {
403                format!("{path}?{}", utf8_percent_encode(&query, QUERY_ENCODE_SET))
404            }
405        } else {
406            path
407        };
408        let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
409            format!("{}{}", upstream.trim_end_matches('/'), rest)
410        } else if upstream.ends_with('/') || rest.starts_with('/') {
411            format!("{upstream}{rest}")
412        } else if rest.is_empty() {
413            upstream.to_owned()
414        } else {
415            format!("{upstream}/{rest}")
416        };
417        let forward_url: Uri = TryFrom::try_from(forward_url).map_err(Error::other)?;
418        let mut build = hyper::Request::builder()
419            .method(req.method())
420            .uri(&forward_url);
421        for (key, value) in req.headers() {
422            if key != HOST {
423                build = build.header(key, value);
424            }
425        }
426        if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
427            match HeaderValue::from_str(&host_value) {
428                Ok(host_value) => {
429                    build = build.header(HOST, host_value);
430                }
431                Err(e) => {
432                    tracing::error!(error = ?e, "invalid host header value");
433                }
434            }
435        }
436
437        if self.client_ip_forwarding_enabled {
438            let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
439            let current_xff = req.headers().get(&xff_header_name);
440
441            #[cfg(test)]
442            let system_ip_addr = match req.remote_addr() {
443                SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
444                    0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
445                ))),
446                _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
447            };
448
449            #[cfg(not(test))]
450            let system_ip_addr = match req.remote_addr() {
451                SocketAddr::IPv6(_) => local_ipv6().ok(),
452                _ => local_ip().ok(),
453            };
454
455            if let Some(system_ip_addr) = system_ip_addr {
456                let forwarded_addr = system_ip_addr.to_string();
457
458                let xff_value = match current_xff {
459                    Some(current_xff) => match current_xff.to_str() {
460                        Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
461                        _ => forwarded_addr.clone(),
462                    },
463                    None => forwarded_addr.clone(),
464                };
465
466                let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
467                    Ok(xff_header_halue) => Some(xff_header_halue),
468                    Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
469                        Ok(xff_header_halue) => Some(xff_header_halue),
470                        Err(e) => {
471                            tracing::error!(error = ?e, "invalid x-forwarded-for header value");
472                            None
473                        }
474                    },
475                };
476
477                if let Some(xff) = xff_header_halue
478                    && let Some(headers) = build.headers_mut()
479                {
480                    headers.insert(&xff_header_name, xff);
481                }
482            }
483        }
484
485        build.body(req.take_body()).map_err(Error::other)
486    }
487}
488
489#[async_trait]
490impl<U, C> Handler for Proxy<U, C>
491where
492    U: Upstreams,
493    U::Error: Into<BoxedError>,
494    C: Client,
495{
496    async fn handle(
497        &self,
498        req: &mut Request,
499        depot: &mut Depot,
500        res: &mut Response,
501        _ctrl: &mut FlowCtrl,
502    ) {
503        match self.build_proxied_request(req, depot).await {
504            Ok(proxied_request) => {
505                match self
506                    .client
507                    .execute(proxied_request, req.extensions_mut().remove())
508                    .await
509                {
510                    Ok(response) => {
511                        let (
512                            salvo_core::http::response::Parts {
513                                status,
514                                // version,
515                                headers,
516                                // extensions,
517                                ..
518                            },
519                            body,
520                        ) = response.into_parts();
521                        res.status_code(status);
522                        for name in headers.keys() {
523                            for value in headers.get_all(name) {
524                                res.headers.append(name, value.to_owned());
525                            }
526                        }
527                        res.body(body);
528                    }
529                    Err(e) => {
530                        tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
531                        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
532                    }
533                }
534            }
535            Err(e) => {
536                tracing::error!(error = ?e, "build proxied request failed");
537            }
538        }
539    }
540}
541#[inline]
542#[allow(dead_code)]
543fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
544    if headers
545        .get(&CONNECTION)
546        .map(|value| {
547            value
548                .to_str()
549                .unwrap_or_default()
550                .split(',')
551                .any(|e| e.trim() == UPGRADE)
552        })
553        .unwrap_or(false)
554        && let Some(upgrade_value) = headers.get(&UPGRADE)
555    {
556        tracing::debug!(
557            "found upgrade header with value: {:?}",
558            upgrade_value.to_str()
559        );
560        return upgrade_value.to_str().ok();
561    }
562
563    None
564}
565
566// Unit tests for Proxy
567#[cfg(test)]
568mod tests {
569    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
570    use std::str::FromStr;
571
572    use super::*;
573
574    #[test]
575    fn test_encode_url_path() {
576        let path = "/test/path";
577        let encoded_path = encode_url_path(path);
578        assert_eq!(encoded_path, "/test/path");
579    }
580
581    #[test]
582    fn test_get_upgrade_type() {
583        let mut headers = HeaderMap::new();
584        headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
585        headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
586        let upgrade_type = get_upgrade_type(&headers);
587        assert_eq!(upgrade_type, Some("websocket"));
588    }
589
590    #[test]
591    fn test_host_header_handling() {
592        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
593        let uri = Uri::from_str("http://host.tld/test").unwrap();
594        let mut req = Request::new();
595        let depot = Depot::new();
596
597        assert_eq!(
598            default_host_header_getter(&uri, &req, &depot),
599            Some("host.tld".to_string())
600        );
601
602        let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
603        assert_eq!(
604            rfc2616_host_header_getter(&uri_with_port, &req, &depot),
605            Some("host.tld:8080".to_string())
606        );
607
608        let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
609        assert_eq!(
610            rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
611            Some("host.tld".to_string())
612        );
613
614        let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
615        assert_eq!(
616            rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
617            Some("host.tld".to_string())
618        );
619
620        let uri_with_non_https_scheme_and_https_port =
621            Uri::from_str("http://host.tld:443/test").unwrap();
622        assert_eq!(
623            rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
624            Some("host.tld:443".to_string())
625        );
626
627        req.headers_mut()
628            .insert(HOST, HeaderValue::from_static("test.host.tld"));
629        assert_eq!(
630            preserve_original_host_header_getter(&uri, &req, &depot),
631            Some("test.host.tld".to_string())
632        );
633    }
634
635    #[tokio::test]
636    async fn test_client_ip_forwarding() {
637        let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
638
639        let mut request = Request::new();
640        let mut depot = Depot::new();
641
642        // Test functionality not broken
643        let proxy_without_forwarding =
644            Proxy::new(vec!["http://example.com"], HyperClient::default());
645
646        assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
647
648        let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
649
650        assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
651
652        let proxy =
653            Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
654        assert_eq!(proxy.client_ip_forwarding_enabled, true);
655
656        match proxy.build_proxied_request(&mut request, &mut depot).await {
657            Ok(req) => assert_eq!(
658                req.headers().get(&xff_header_name),
659                Some(&HeaderValue::from_static("101.102.103.104"))
660            ),
661            _ => assert!(false),
662        }
663
664        // Test choosing correct IP version depending on remote address
665        *request.remote_addr_mut() =
666            SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
667
668        match proxy.build_proxied_request(&mut request, &mut depot).await {
669            Ok(req) => assert_eq!(
670                req.headers().get(&xff_header_name),
671                Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
672            ),
673            _ => assert!(false),
674        }
675
676        *request.remote_addr_mut() = SocketAddr::Unknown;
677
678        match proxy.build_proxied_request(&mut request, &mut depot).await {
679            Ok(req) => assert_eq!(
680                req.headers().get(&xff_header_name),
681                Some(&HeaderValue::from_static("101.102.103.104"))
682            ),
683            _ => assert!(false),
684        }
685
686        // Test IP prepending when XFF header already exists in initial request.
687        request.headers_mut().insert(
688            &xff_header_name,
689            HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
690        );
691        *request.remote_addr_mut() =
692            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
693
694        match proxy.build_proxied_request(&mut request, &mut depot).await {
695            Ok(req) => assert_eq!(
696                req.headers().get(&xff_header_name),
697                Some(&HeaderValue::from_static(
698                    "101.102.103.104, 10.72.0.1, 127.0.0.1"
699                ))
700            ),
701            _ => assert!(false),
702        }
703    }
704}