salvo_proxy/
lib.rs

1//! Provide HTTP proxy capabilities for the Salvo web framework.
2//!
3//! This crate allows you to easily forward requests to upstream servers,
4//! supporting both HTTP and HTTPS protocols. It's useful for creating API gateways,
5//! load balancers, and reverse proxies.
6//!
7//! # Example
8//!
9//! In this example, requests to different hosts are proxied to different upstream servers:
10//! - Requests to <http://127.0.0.1:8698/> are proxied to <https://www.rust-lang.org>
11//! - Requests to <http://localhost:8698/> are proxied to <https://crates.io>
12//!
13//! ```no_run
14//! use salvo_core::prelude::*;
15//! use salvo_proxy::Proxy;
16//!
17//! #[tokio::main]
18//! async fn main() {
19//!     let router = Router::new()
20//!         .push(
21//!             Router::new()
22//!                 .host("127.0.0.1")
23//!                 .path("{**rest}")
24//!                 .goal(Proxy::use_hyper_client("https://www.rust-lang.org")),
25//!         )
26//!         .push(
27//!             Router::new()
28//!                 .host("localhost")
29//!                 .path("{**rest}")
30//!                 .goal(Proxy::use_hyper_client("https://crates.io")),
31//!         );
32//!
33//!     let acceptor = TcpListener::new("0.0.0.0:8698").bind().await;
34//!     Server::new(acceptor).serve(router).await;
35//! }
36//! ```
37#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
38#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
39#![cfg_attr(docsrs, feature(doc_cfg))]
40
41use std::convert::Infallible;
42use std::error::Error as StdError;
43use std::fmt::{self, Debug, Formatter};
44
45use hyper::upgrade::OnUpgrade;
46use percent_encoding::{CONTROLS, utf8_percent_encode};
47use salvo_core::conn::SocketAddr;
48use salvo_core::http::header::{CONNECTION, HOST, HeaderMap, HeaderName, HeaderValue, UPGRADE};
49use salvo_core::http::uri::Uri;
50use salvo_core::http::{ReqBody, ResBody, StatusCode};
51use salvo_core::{BoxedError, Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
52
53#[cfg(test)]
54use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
55
56#[cfg(not(test))]
57use local_ip_address::{local_ip, local_ipv6};
58
59#[macro_use]
60mod cfg;
61
62cfg_feature! {
63    #![feature = "hyper-client"]
64    mod hyper_client;
65    pub use hyper_client::*;
66}
67cfg_feature! {
68    #![feature = "reqwest-client"]
69    mod reqwest_client;
70    pub use reqwest_client::*;
71}
72
73cfg_feature! {
74    #![feature = "unix-sock-client"]
75    #[cfg(unix)]
76    mod unix_sock_client;
77    #[cfg(unix)]
78    pub use unix_sock_client::*;
79}
80
81type HyperRequest = hyper::Request<ReqBody>;
82type HyperResponse = hyper::Response<ResBody>;
83
84const X_FORWARDER_FOR_HEADER_NAME: &str = "x-forwarded-for";
85
86/// Encode url path. This can be used when build your custom url path getter.
87#[inline]
88pub(crate) fn encode_url_path(path: &str) -> String {
89    path.split('/')
90        .map(|s| utf8_percent_encode(s, CONTROLS).to_string())
91        .collect::<Vec<_>>()
92        .join("/")
93}
94
95/// Client trait for implementing different HTTP clients for proxying.
96///
97/// Implement this trait to create custom proxy clients with different
98/// backends or configurations.
99pub trait Client: Send + Sync + 'static {
100    /// Error type returned by the client.
101    type Error: StdError + Send + Sync + 'static;
102
103    /// Execute a request through the proxy client.
104    fn execute(
105        &self,
106        req: HyperRequest,
107        upgraded: Option<OnUpgrade>,
108    ) -> impl Future<Output = Result<HyperResponse, Self::Error>> + Send;
109}
110
111/// Upstreams trait for selecting target servers.
112///
113/// Implement this trait to customize how target servers are selected
114/// for proxying requests. This can be used to implement load balancing,
115/// failover, or other server selection strategies.
116pub trait Upstreams: Send + Sync + 'static {
117    /// Error type returned when selecting a server fails.
118    type Error: StdError + Send + Sync + 'static;
119
120    /// Elect a server to handle the current request.
121    fn elect(
122        &self,
123        req: &Request,
124        depot: &Depot,
125    ) -> impl Future<Output = Result<&str, Self::Error>> + Send;
126}
127impl Upstreams for &'static str {
128    type Error = Infallible;
129
130    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
131        Ok(*self)
132    }
133}
134impl Upstreams for String {
135    type Error = Infallible;
136    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
137        Ok(self.as_str())
138    }
139}
140
141impl<const N: usize> Upstreams for [&'static str; N] {
142    type Error = Error;
143    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
144        if self.is_empty() {
145            return Err(Error::other("upstreams is empty"));
146        }
147        let index = fastrand::usize(..self.len());
148        Ok(self[index])
149    }
150}
151
152impl<T> Upstreams for Vec<T>
153where
154    T: AsRef<str> + Send + Sync + 'static,
155{
156    type Error = Error;
157    async fn elect(&self, _: &Request, _: &Depot) -> Result<&str, Self::Error> {
158        if self.is_empty() {
159            return Err(Error::other("upstreams is empty"));
160        }
161        let index = fastrand::usize(..self.len());
162        Ok(self[index].as_ref())
163    }
164}
165
166/// Url part getter. You can use this to get the proxied url path or query.
167pub type UrlPartGetter = Box<dyn Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static>;
168
169/// Host header getter. You can use this to get the host header for the proxied request.
170pub type HostHeaderGetter =
171    Box<dyn Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static>;
172
173/// Default url path getter.
174///
175/// This getter will get the last param as the rest url path from request.
176/// In most case you should use wildcard param, like `{**rest}`, `{*+rest}`.
177pub fn default_url_path_getter(req: &Request, _depot: &Depot) -> Option<String> {
178    req.params().tail().map(encode_url_path)
179}
180/// Default url query getter. This getter just return the query string from request uri.
181pub fn default_url_query_getter(req: &Request, _depot: &Depot) -> Option<String> {
182    req.uri().query().map(Into::into)
183}
184
185/// Default host header getter. This getter will get the host header from request uri
186pub fn default_host_header_getter(
187    forward_uri: &Uri,
188    _req: &Request,
189    _depot: &Depot,
190) -> Option<String> {
191    if let Some(host) = forward_uri.host() {
192        return Some(String::from(host));
193    }
194
195    None
196}
197
198/// RFC2616 complieant host header getter. This getter will get the host header from request uri, and add port if
199/// it's not default port. Falls back to default upon any forward URI parse error.
200pub fn rfc2616_host_header_getter(
201    forward_uri: &Uri,
202    req: &Request,
203    _depot: &Depot,
204) -> Option<String> {
205    let mut parts: Vec<String> = Vec::with_capacity(2);
206
207    if let Some(host) = forward_uri.host() {
208        parts.push(host.to_owned());
209
210        if let Some(scheme) = forward_uri.scheme_str()
211            && let Some(port) = forward_uri.port_u16()
212            && (scheme == "http" && port != 80 || scheme == "https" && port != 443)
213        {
214            parts.push(port.to_string());
215        }
216    }
217
218    if parts.is_empty() {
219        default_host_header_getter(forward_uri, req, _depot)
220    } else {
221        Some(parts.join(":"))
222    }
223}
224
225/// Preserve original host header getter. Propagates the original request host header to the proxied request.
226pub fn preserve_original_host_header_getter(
227    forward_uri: &Uri,
228    req: &Request,
229    _depot: &Depot,
230) -> Option<String> {
231    if let Some(host_header) = req.headers().get(HOST)
232        && let Ok(host) = String::from_utf8(host_header.as_bytes().to_vec())
233    {
234        return Some(host);
235    }
236
237    default_host_header_getter(forward_uri, req, _depot)
238}
239
240/// Handler that can proxy request to other server.
241#[non_exhaustive]
242pub struct Proxy<U, C>
243where
244    U: Upstreams,
245    C: Client,
246{
247    /// Upstreams list.
248    pub upstreams: U,
249    /// [`Client`] for proxy.
250    pub client: C,
251    /// Url path getter.
252    pub url_path_getter: UrlPartGetter,
253    /// Url query getter.
254    pub url_query_getter: UrlPartGetter,
255    /// Host header getter
256    pub host_header_getter: HostHeaderGetter,
257    /// Flag to enable x-forwarded-for header.
258    pub client_ip_forwarding_enabled: bool,
259}
260
261impl<U, C> Debug for Proxy<U, C>
262where
263    U: Upstreams,
264    C: Client,
265{
266    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
267        f.debug_struct("Proxy").finish()
268    }
269}
270
271impl<U, C> Proxy<U, C>
272where
273    U: Upstreams,
274    U::Error: Into<BoxedError>,
275    C: Client,
276{
277    /// Create new `Proxy` with upstreams list.
278    #[must_use]
279    pub fn new(upstreams: U, client: C) -> Self {
280        Self {
281            upstreams,
282            client,
283            url_path_getter: Box::new(default_url_path_getter),
284            url_query_getter: Box::new(default_url_query_getter),
285            host_header_getter: Box::new(default_host_header_getter),
286            client_ip_forwarding_enabled: false,
287        }
288    }
289
290    /// Create new `Proxy` with upstreams list and enable x-forwarded-for header.
291    pub fn with_client_ip_forwarding(upstreams: U, client: C) -> Self {
292        Self {
293            upstreams,
294            client,
295            url_path_getter: Box::new(default_url_path_getter),
296            url_query_getter: Box::new(default_url_query_getter),
297            host_header_getter: Box::new(default_host_header_getter),
298            client_ip_forwarding_enabled: true,
299        }
300    }
301
302    /// Set url path getter.
303    #[inline]
304    #[must_use]
305    pub fn url_path_getter<G>(mut self, url_path_getter: G) -> Self
306    where
307        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
308    {
309        self.url_path_getter = Box::new(url_path_getter);
310        self
311    }
312
313    /// Set url query getter.
314    #[inline]
315    #[must_use]
316    pub fn url_query_getter<G>(mut self, url_query_getter: G) -> Self
317    where
318        G: Fn(&Request, &Depot) -> Option<String> + Send + Sync + 'static,
319    {
320        self.url_query_getter = Box::new(url_query_getter);
321        self
322    }
323
324    /// Set host header query getter.
325    #[inline]
326    #[must_use]
327    pub fn host_header_getter<G>(mut self, host_header_getter: G) -> Self
328    where
329        G: Fn(&Uri, &Request, &Depot) -> Option<String> + Send + Sync + 'static,
330    {
331        self.host_header_getter = Box::new(host_header_getter);
332        self
333    }
334
335    /// Get upstreams list.
336    #[inline]
337    pub fn upstreams(&self) -> &U {
338        &self.upstreams
339    }
340    /// Get upstreams mutable list.
341    #[inline]
342    pub fn upstreams_mut(&mut self) -> &mut U {
343        &mut self.upstreams
344    }
345
346    /// Get client reference.
347    #[inline]
348    pub fn client(&self) -> &C {
349        &self.client
350    }
351    /// Get client mutable reference.
352    #[inline]
353    pub fn client_mut(&mut self) -> &mut C {
354        &mut self.client
355    }
356
357    /// Enable x-forwarded-for header prepending.
358    #[inline]
359    #[must_use]
360    pub fn client_ip_forwarding(mut self, enable: bool) -> Self {
361        self.client_ip_forwarding_enabled = enable;
362        self
363    }
364
365    async fn build_proxied_request(
366        &self,
367        req: &mut Request,
368        depot: &Depot,
369    ) -> Result<HyperRequest, Error> {
370        let upstream = self
371            .upstreams
372            .elect(req, depot)
373            .await
374            .map_err(Error::other)?;
375
376        if upstream.is_empty() {
377            tracing::error!("upstreams is empty");
378            return Err(Error::other("upstreams is empty"));
379        }
380
381        let path = encode_url_path(&(self.url_path_getter)(req, depot).unwrap_or_default());
382        let query = (self.url_query_getter)(req, depot);
383        let rest = if let Some(query) = query {
384            if query.starts_with('?') {
385                format!("{path}{query}")
386            } else {
387                format!("{path}?{query}")
388            }
389        } else {
390            path
391        };
392        let forward_url = if upstream.ends_with('/') && rest.starts_with('/') {
393            format!("{}{}", upstream.trim_end_matches('/'), rest)
394        } else if upstream.ends_with('/') || rest.starts_with('/') {
395            format!("{upstream}{rest}")
396        } else if rest.is_empty() {
397            upstream.to_owned()
398        } else {
399            format!("{upstream}/{rest}")
400        };
401        let forward_url = url::Url::parse(&forward_url).map_err(|e| {
402            Error::other(format!("url::Url::parse failed for '{forward_url}': {e}"))
403        })?;
404        let forward_url: Uri = forward_url
405            .as_str()
406            .parse()
407            .map_err(|e| Error::other(format!("Uri::parse failed for '{forward_url}': {e}")))?;
408        let mut build = hyper::Request::builder()
409            .method(req.method())
410            .uri(&forward_url);
411        for (key, value) in req.headers() {
412            if key != HOST {
413                build = build.header(key, value);
414            }
415        }
416        if let Some(host_value) = (self.host_header_getter)(&forward_url, req, depot) {
417            match HeaderValue::from_str(&host_value) {
418                Ok(host_value) => {
419                    build = build.header(HOST, host_value);
420                }
421                Err(e) => {
422                    tracing::error!(error = ?e, "invalid host header value");
423                }
424            }
425        }
426
427        if self.client_ip_forwarding_enabled {
428            let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
429            let current_xff = req.headers().get(&xff_header_name);
430
431            #[cfg(test)]
432            let system_ip_addr = match req.remote_addr() {
433                SocketAddr::IPv6(_) => Some(IpAddr::from(Ipv6Addr::new(
434                    0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8,
435                ))),
436                _ => Some(IpAddr::from(Ipv4Addr::new(101, 102, 103, 104))),
437            };
438
439            #[cfg(not(test))]
440            let system_ip_addr = match req.remote_addr() {
441                SocketAddr::IPv6(_) => local_ipv6().ok(),
442                _ => local_ip().ok(),
443            };
444
445            if let Some(system_ip_addr) = system_ip_addr {
446                let forwarded_addr = system_ip_addr.to_string();
447
448                let xff_value = match current_xff {
449                    Some(current_xff) => match current_xff.to_str() {
450                        Ok(current_xff) => format!("{forwarded_addr}, {current_xff}"),
451                        _ => forwarded_addr.clone(),
452                    },
453                    None => forwarded_addr.clone(),
454                };
455
456                let xff_header_halue = match HeaderValue::from_str(xff_value.as_str()) {
457                    Ok(xff_header_halue) => Some(xff_header_halue),
458                    Err(_) => match HeaderValue::from_str(forwarded_addr.as_str()) {
459                        Ok(xff_header_halue) => Some(xff_header_halue),
460                        Err(e) => {
461                            tracing::error!(error = ?e, "invalid x-forwarded-for header value");
462                            None
463                        }
464                    },
465                };
466
467                if let Some(xff) = xff_header_halue
468                    && let Some(headers) = build.headers_mut()
469                {
470                    headers.insert(&xff_header_name, xff);
471                }
472            }
473        }
474
475        build.body(req.take_body()).map_err(Error::other)
476    }
477}
478
479#[async_trait]
480impl<U, C> Handler for Proxy<U, C>
481where
482    U: Upstreams,
483    U::Error: Into<BoxedError>,
484    C: Client,
485{
486    async fn handle(
487        &self,
488        req: &mut Request,
489        depot: &mut Depot,
490        res: &mut Response,
491        _ctrl: &mut FlowCtrl,
492    ) {
493        match self.build_proxied_request(req, depot).await {
494            Ok(proxied_request) => {
495                match self
496                    .client
497                    .execute(proxied_request, req.extensions_mut().remove())
498                    .await
499                {
500                    Ok(response) => {
501                        let (
502                            salvo_core::http::response::Parts {
503                                status,
504                                // version,
505                                headers,
506                                // extensions,
507                                ..
508                            },
509                            body,
510                        ) = response.into_parts();
511                        res.status_code(status);
512                        for name in headers.keys() {
513                            for value in headers.get_all(name) {
514                                res.headers.append(name, value.to_owned());
515                            }
516                        }
517                        res.body(body);
518                    }
519                    Err(e) => {
520                        tracing::error!( error = ?e, uri = ?req.uri(), "get response data failed: {}", e);
521                        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
522                    }
523                }
524            }
525            Err(e) => {
526                tracing::error!(error = ?e, "build proxied request failed");
527            }
528        }
529    }
530}
531#[inline]
532#[allow(dead_code)]
533fn get_upgrade_type(headers: &HeaderMap) -> Option<&str> {
534    if headers
535        .get(&CONNECTION)
536        .map(|value| {
537            value
538                .to_str()
539                .unwrap_or_default()
540                .split(',')
541                .any(|e| e.trim() == UPGRADE)
542        })
543        .unwrap_or(false)
544        && let Some(upgrade_value) = headers.get(&UPGRADE)
545    {
546        tracing::debug!(
547            "found upgrade header with value: {:?}",
548            upgrade_value.to_str()
549        );
550        return upgrade_value.to_str().ok();
551    }
552
553    None
554}
555
556// Unit tests for Proxy
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6};
562    use std::str::FromStr;
563
564    #[test]
565    fn test_encode_url_path() {
566        let path = "/test/path";
567        let encoded_path = encode_url_path(path);
568        assert_eq!(encoded_path, "/test/path");
569    }
570
571    #[test]
572    fn test_get_upgrade_type() {
573        let mut headers = HeaderMap::new();
574        headers.insert(CONNECTION, HeaderValue::from_static("upgrade"));
575        headers.insert(UPGRADE, HeaderValue::from_static("websocket"));
576        let upgrade_type = get_upgrade_type(&headers);
577        assert_eq!(upgrade_type, Some("websocket"));
578    }
579
580    #[test]
581    fn test_host_header_handling() {
582        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
583        let uri = Uri::from_str("http://host.tld/test").unwrap();
584        let mut req = Request::new();
585        let depot = Depot::new();
586
587        assert_eq!(
588            default_host_header_getter(&uri, &req, &depot),
589            Some("host.tld".to_string())
590        );
591
592        let uri_with_port = Uri::from_str("http://host.tld:8080/test").unwrap();
593        assert_eq!(
594            rfc2616_host_header_getter(&uri_with_port, &req, &depot),
595            Some("host.tld:8080".to_string())
596        );
597
598        let uri_with_http_port = Uri::from_str("http://host.tld:80/test").unwrap();
599        assert_eq!(
600            rfc2616_host_header_getter(&uri_with_http_port, &req, &depot),
601            Some("host.tld".to_string())
602        );
603
604        let uri_with_https_port = Uri::from_str("https://host.tld:443/test").unwrap();
605        assert_eq!(
606            rfc2616_host_header_getter(&uri_with_https_port, &req, &depot),
607            Some("host.tld".to_string())
608        );
609
610        let uri_with_non_https_scheme_and_https_port =
611            Uri::from_str("http://host.tld:443/test").unwrap();
612        assert_eq!(
613            rfc2616_host_header_getter(&uri_with_non_https_scheme_and_https_port, &req, &depot),
614            Some("host.tld:443".to_string())
615        );
616
617        req.headers_mut()
618            .insert(HOST, HeaderValue::from_static("test.host.tld"));
619        assert_eq!(
620            preserve_original_host_header_getter(&uri, &req, &depot),
621            Some("test.host.tld".to_string())
622        );
623    }
624
625    #[tokio::test]
626    async fn test_client_ip_forwarding() {
627        let xff_header_name = HeaderName::from_static(X_FORWARDER_FOR_HEADER_NAME);
628
629        let mut request = Request::new();
630        let mut depot = Depot::new();
631
632        // Test functionality not broken
633        let proxy_without_forwarding =
634            Proxy::new(vec!["http://example.com"], HyperClient::default());
635
636        assert_eq!(proxy_without_forwarding.client_ip_forwarding_enabled, false);
637
638        let proxy_with_forwarding = proxy_without_forwarding.client_ip_forwarding(true);
639
640        assert_eq!(proxy_with_forwarding.client_ip_forwarding_enabled, true);
641
642        let proxy =
643            Proxy::with_client_ip_forwarding(vec!["http://example.com"], HyperClient::default());
644        assert_eq!(proxy.client_ip_forwarding_enabled, true);
645
646        match proxy.build_proxied_request(&mut request, &mut depot).await {
647            Ok(req) => assert_eq!(
648                req.headers().get(&xff_header_name),
649                Some(&HeaderValue::from_static("101.102.103.104"))
650            ),
651            _ => assert!(false),
652        }
653
654        // Test choosing correct IP version depending on remote address
655        *request.remote_addr_mut() =
656            SocketAddr::from(SocketAddrV6::new(Ipv6Addr::LOCALHOST, 12345, 0, 0));
657
658        match proxy.build_proxied_request(&mut request, &mut depot).await {
659            Ok(req) => assert_eq!(
660                req.headers().get(&xff_header_name),
661                Some(&HeaderValue::from_static("1:2:3:4:5:6:7:8"))
662            ),
663            _ => assert!(false),
664        }
665
666        *request.remote_addr_mut() = SocketAddr::Unknown;
667
668        match proxy.build_proxied_request(&mut request, &mut depot).await {
669            Ok(req) => assert_eq!(
670                req.headers().get(&xff_header_name),
671                Some(&HeaderValue::from_static("101.102.103.104"))
672            ),
673            _ => assert!(false),
674        }
675
676        // Test IP prepending when XFF header already exists in initial request.
677        request.headers_mut().insert(
678            &xff_header_name,
679            HeaderValue::from_static("10.72.0.1, 127.0.0.1"),
680        );
681        *request.remote_addr_mut() =
682            SocketAddr::from(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 12345));
683
684        match proxy.build_proxied_request(&mut request, &mut depot).await {
685            Ok(req) => assert_eq!(
686                req.headers().get(&xff_header_name),
687                Some(&HeaderValue::from_static(
688                    "101.102.103.104, 10.72.0.1, 127.0.0.1"
689                ))
690            ),
691            _ => assert!(false),
692        }
693    }
694}