plane_dynamic_proxy/
proxy.rs

1use crate::{
2    body::{simple_empty_body, to_simple_body, SimpleBody},
3    connector::TimeoutHttpConnector,
4    request::should_upgrade,
5    upgrade::{split_request, split_response, UpgradeHandler},
6};
7use http::StatusCode;
8use hyper::{Request, Response};
9use hyper_util::{client::legacy::Client, rt::TokioExecutor};
10use std::{convert::Infallible, time::Duration};
11
12const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
13
14/// A client for proxying HTTP requests to an upstream server.
15#[derive(Clone)]
16pub struct ProxyClient {
17    client: Client<TimeoutHttpConnector, SimpleBody>,
18    #[allow(unused)] // TODO: implement this.
19    timeout: Duration,
20}
21
22impl Default for ProxyClient {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl ProxyClient {
29    pub fn new() -> Self {
30        let client = Client::builder(TokioExecutor::new()).build(TimeoutHttpConnector::default());
31        Self {
32            client,
33            timeout: DEFAULT_TIMEOUT,
34        }
35    }
36
37    /// Sends an HTTP request to the upstream server and returns the response.
38    /// If the request establishes a websocket connection, an upgrade handler is returned.
39    /// In this case, you must call and await `.run()` on the upgrade handler (i.e. in a tokio task)
40    /// to ensure that messages are properly sent and received.
41    pub async fn request(
42        &self,
43        request: Request<SimpleBody>,
44    ) -> Result<(Response<SimpleBody>, Option<UpgradeHandler>), Infallible> {
45        let url = request.uri().to_string();
46
47        let res = self.handle_request(request).await;
48
49        let res = match res {
50            Ok(res) => res,
51            Err(ProxyError::Timeout) => {
52                tracing::warn!(url, "Upstream request failed");
53                return Ok((
54                    Response::builder()
55                        .status(StatusCode::GATEWAY_TIMEOUT)
56                        .body(simple_empty_body())
57                        .expect("Failed to build response"),
58                    None,
59                ));
60            }
61            Err(e) => {
62                tracing::warn!(url, ?e, "Upstream request failed");
63                return Ok((
64                    Response::builder()
65                        .status(StatusCode::BAD_GATEWAY)
66                        .body(simple_empty_body())
67                        .expect("Failed to build response"),
68                    None,
69                ));
70            }
71        };
72
73        let (res, upgrade_handler) = res;
74        let (parts, body) = res.into_parts();
75        let res = Response::from_parts(parts, to_simple_body(body));
76
77        Ok((res, upgrade_handler))
78    }
79
80    async fn handle_request(
81        &self,
82        request: Request<SimpleBody>,
83    ) -> Result<(Response<SimpleBody>, Option<UpgradeHandler>), ProxyError> {
84        if should_upgrade(&request) {
85            let (response, upgrade_handler) = self.handle_upgrade(request).await?;
86            Ok((response, Some(upgrade_handler)))
87        } else {
88            let result = self.upstream_request(request).await?;
89            Ok((result, None))
90        }
91    }
92
93    async fn handle_upgrade(
94        &self,
95        request: Request<SimpleBody>,
96    ) -> Result<(Response<SimpleBody>, UpgradeHandler), ProxyError> {
97        let (upstream_request, request_with_body) = split_request(request);
98        let res = self.upstream_request(upstream_request).await?;
99        let (upstream_response, response_with_body) = split_response(res);
100
101        let upgrade_handler = UpgradeHandler::new(request_with_body, response_with_body);
102
103        Ok((upstream_response, upgrade_handler))
104    }
105
106    async fn upstream_request(
107        &self,
108        request: Request<SimpleBody>,
109    ) -> Result<Response<SimpleBody>, ProxyError> {
110        let res = match self.client.request(request).await {
111            Ok(res) => res,
112            Err(e) => {
113                return Err(ProxyError::RequestFailed(e.into()));
114            }
115        };
116
117        let (parts, body) = res.into_parts();
118        let res = Response::from_parts(parts, to_simple_body(body));
119
120        Ok(res)
121    }
122}
123
124#[derive(thiserror::Error, Debug)]
125pub enum ProxyError {
126    #[error("Upstream request timed out.")]
127    Timeout,
128
129    #[error("Upstream request failed: {0}")]
130    RequestFailed(#[from] Box<dyn std::error::Error + Send + Sync>),
131
132    #[error("Failed to upgrade response: {0}")]
133    UpgradeError(#[from] hyper::Error),
134
135    #[error("IO error: {0}")]
136    IoError(#[from] tokio::io::Error),
137}