salvo_proxy/
reqwest_client.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use futures_util::TryStreamExt;
use hyper::upgrade::OnUpgrade;
use reqwest::Client as InnerClient;
use salvo_core::http::{ResBody, StatusCode};
use salvo_core::rt::tokio::TokioIo;
use salvo_core::Error;
use tokio::io::copy_bidirectional;

use crate::{Client, HyperRequest, BoxedError, Proxy, Upstreams, HyperResponse};

/// A [`Client`] implementation based on [`reqwest::Client`].
#[derive(Default, Clone, Debug)]
pub struct ReqwestClient {
    inner: InnerClient,
}

impl<U> Proxy<U, ReqwestClient>
where
    U: Upstreams,
    U::Error: Into<BoxedError>,
{
    /// Create new `Proxy` which use default reqwest util client.
    pub fn use_reqwest_client(upstreams: U) -> Self {
        Proxy::new(upstreams, ReqwestClient::default())
    }
}

impl ReqwestClient {
    /// Create a new `ReqwestClient` with the given [`reqwest::Client`].
    pub fn new(inner: InnerClient) -> Self {
        Self { inner }
    }
}

impl Client for ReqwestClient {
    type Error = salvo_core::Error;

    async fn execute(
        &self,
        proxied_request: HyperRequest,
        request_upgraded: Option<OnUpgrade>,
    ) -> Result<HyperResponse, Self::Error> {
        let request_upgrade_type = crate::get_upgrade_type(proxied_request.headers()).map(|s| s.to_owned());

        let proxied_request =
            proxied_request.map(|s| reqwest::Body::wrap_stream(s.map_ok(|s| s.into_data().unwrap_or_default())));
        let response = self
            .inner
            .execute(proxied_request.try_into().map_err(Error::other)?)
            .await
            .map_err(Error::other)?;

        let res_headers = response.headers().clone();
        let hyper_response = hyper::Response::builder()
            .status(response.status())
            .version(response.version());

        let mut hyper_response = if response.status() == StatusCode::SWITCHING_PROTOCOLS {
            let response_upgrade_type = crate::get_upgrade_type(response.headers());

            if request_upgrade_type == response_upgrade_type.map(|s| s.to_lowercase()) {
                let mut response_upgraded = response
                    .upgrade()
                    .await
                    .map_err(|e| Error::other(format!("response does not have an upgrade extension. {}", e)))?;
                if let Some(request_upgraded) = request_upgraded {
                    tokio::spawn(async move {
                        match request_upgraded.await {
                            Ok(request_upgraded) => {
                                let mut request_upgraded = TokioIo::new(request_upgraded);
                                if let Err(e) = copy_bidirectional(&mut response_upgraded, &mut request_upgraded).await
                                {
                                    tracing::error!(error = ?e, "coping between upgraded connections failed");
                                }
                            }
                            Err(e) => {
                                tracing::error!(error = ?e, "upgrade request failed");
                            }
                        }
                    });
                } else {
                    return Err(Error::other("request does not have an upgrade extension"));
                }
            } else {
                return Err(Error::other("upgrade type mismatch"));
            }
            hyper_response.body(ResBody::None).map_err(Error::other)?
        } else {
            hyper_response
                .body(ResBody::stream(response.bytes_stream()))
                .map_err(Error::other)?
        };
        *hyper_response.headers_mut() = res_headers;
        Ok(hyper_response)
    }
}

// Unit tests for Proxy
#[cfg(test)]
mod tests {
    use salvo_core::prelude::*;
    use salvo_core::test::*;

    use super::*;
    use crate::{Upstreams, Proxy};

    #[tokio::test]
    async fn test_upstreams_elect() {
        let upstreams = vec!["https://www.example.com", "https://www.example2.com"];
        let proxy = Proxy::new(upstreams.clone(), ReqwestClient::default());
        let elected_upstream = proxy.upstreams().elect().await.unwrap();
        assert!(upstreams.contains(&elected_upstream));
    }

    #[tokio::test]
    async fn test_reqwest_client() {
        let router = Router::new().push(
            Router::with_path("rust/<**rest>").goal(Proxy::new(vec!["https://www.rust-lang.org"], ReqwestClient::default())),
        );

        let content = TestClient::get("http://127.0.0.1:5801/rust/tools/install")
            .send(router)
            .await
            .take_string()
            .await
            .unwrap();
        assert!(content.contains("Install Rust"));
    }

    #[test]
    fn test_others() {
        let mut handler = Proxy::new(["https://www.bing.com"], ReqwestClient::default());
        assert_eq!(handler.upstreams().len(), 1);
        assert_eq!(handler.upstreams_mut().len(), 1);
    }
}