Skip to main content

cloudillo_proxy/
handler.rs

1//! HTTP forwarding and WebSocket tunneling for reverse proxy
2
3use axum::http::{header, HeaderMap, HeaderName, HeaderValue, Uri};
4use hyper::body::Incoming;
5use hyper_util::{client::legacy::Client, rt::TokioExecutor};
6use std::sync::Arc;
7use std::time::Duration;
8
9use crate::prelude::*;
10use crate::ProxySiteEntry;
11
12/// Headers that should not be forwarded between client and backend (hop-by-hop)
13const HOP_BY_HOP_HEADERS: &[&str] = &[
14	"connection",
15	"keep-alive",
16	"proxy-authenticate",
17	"proxy-authorization",
18	"te",
19	"trailers",
20	"transfer-encoding",
21];
22
23/// Check if a header is a hop-by-hop header that should be stripped
24fn is_hop_by_hop(name: &HeaderName) -> bool {
25	HOP_BY_HOP_HEADERS.iter().any(|h| name.as_str().eq_ignore_ascii_case(h))
26}
27
28/// Check if a request is a WebSocket upgrade request
29fn is_websocket_upgrade(headers: &HeaderMap) -> bool {
30	headers
31		.get(header::UPGRADE)
32		.and_then(|v| v.to_str().ok())
33		.map(|v| v.eq_ignore_ascii_case("websocket"))
34		.unwrap_or(false)
35}
36
37/// Build the backend URI from the proxy site entry and the original request URI
38fn build_backend_uri(entry: &ProxySiteEntry, original_uri: &Uri) -> ClResult<Uri> {
39	let mut backend = entry.backend_url.clone();
40	let combined_path = format!("{}{}", backend.path().trim_end_matches('/'), original_uri.path());
41	backend.set_path(&combined_path);
42	backend.set_query(original_uri.query());
43	debug!("Proxy backend URI: {} (combined_path={:?})", backend.as_str(), combined_path);
44	backend
45		.as_str()
46		.parse::<Uri>()
47		.map_err(|e| Error::Internal(format!("failed to build backend URI: {}", e)))
48}
49
50/// Copy non-hop-by-hop headers from source to destination
51fn copy_headers(src: &HeaderMap, dst: &mut HeaderMap, is_websocket: bool) {
52	for (name, value) in src.iter() {
53		// Skip hop-by-hop headers (but keep Upgrade for WebSocket)
54		if is_hop_by_hop(name) {
55			if is_websocket && name == header::UPGRADE {
56				dst.insert(name.clone(), value.clone());
57			}
58			continue;
59		}
60		dst.append(name.clone(), value.clone());
61	}
62}
63
64/// Handle a proxy request - main entry point for the proxy handler
65pub async fn handle_proxy_request(
66	entry: Arc<ProxySiteEntry>,
67	req: hyper::Request<Incoming>,
68	peer_addr: &str,
69) -> Result<hyper::Response<Incoming>, Error> {
70	let is_ws = is_websocket_upgrade(req.headers()) && entry.config.websocket.unwrap_or(true);
71
72	if is_ws {
73		return handle_websocket_proxy(entry, req, peer_addr).await;
74	}
75
76	let backend_uri = build_backend_uri(&entry, req.uri())?;
77
78	// Build the backend request
79	let mut backend_headers = HeaderMap::new();
80	copy_headers(req.headers(), &mut backend_headers, false);
81
82	// Host header handling
83	let preserve_host = entry.config.preserve_host.unwrap_or(true);
84	if preserve_host {
85		// Keep original Host header
86		if let Some(host) = req.headers().get(header::HOST) {
87			backend_headers.insert(header::HOST, host.clone());
88		}
89	} else if let Some(host) = entry.backend_url.host_str() {
90		// Rewrite to backend host
91		let host_val = if let Some(port) = entry.backend_url.port() {
92			format!("{}:{}", host, port)
93		} else {
94			host.to_string()
95		};
96		if let Ok(hv) = HeaderValue::from_str(&host_val) {
97			backend_headers.insert(header::HOST, hv);
98		}
99	}
100
101	// Add forwarding headers (always on for "basic" type)
102	let forward_headers = if entry.proxy_type.as_ref() == "basic" {
103		true
104	} else {
105		entry.config.forward_headers.unwrap_or(true)
106	};
107	if forward_headers {
108		if let Ok(hv) = HeaderValue::from_str(peer_addr) {
109			backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv.clone());
110			backend_headers.insert(HeaderName::from_static("x-real-ip"), hv);
111		}
112		backend_headers.insert(
113			HeaderName::from_static("x-forwarded-proto"),
114			HeaderValue::from_static("https"),
115		);
116		if let Ok(hv) = HeaderValue::from_str(&entry.domain) {
117			backend_headers.insert(HeaderName::from_static("x-forwarded-host"), hv);
118		}
119	}
120
121	// Add custom headers
122	if let Some(custom_headers) = &entry.config.custom_headers {
123		for (name, value) in custom_headers {
124			if let (Ok(hn), Ok(hv)) =
125				(HeaderName::from_bytes(name.as_bytes()), HeaderValue::from_str(value))
126			{
127				backend_headers.insert(hn, hv);
128			}
129		}
130	}
131
132	// Build the request
133	let method = req.method().clone();
134	let body = req.into_body();
135
136	let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
137
138	if let Some(headers) = backend_req.headers_mut() {
139		*headers = backend_headers;
140	}
141
142	let backend_req = backend_req
143		.body(body)
144		.map_err(|e| Error::Internal(format!("failed to build backend request: {}", e)))?;
145
146	// Set up timeouts
147	let connect_timeout =
148		Duration::from_secs(entry.config.connect_timeout_secs.unwrap_or(5) as u64);
149	let read_timeout = Duration::from_secs(entry.config.read_timeout_secs.unwrap_or(30) as u64);
150
151	// Send the request to the backend
152	let scheme = entry.backend_url.scheme();
153	match send_backend_request(scheme, connect_timeout, read_timeout, backend_req).await {
154		Ok(mut backend_resp) => {
155			// Strip hop-by-hop headers from response
156			let headers_to_remove: Vec<HeaderName> = backend_resp
157				.headers()
158				.keys()
159				.filter(|name| is_hop_by_hop(name))
160				.cloned()
161				.collect();
162			for name in headers_to_remove {
163				backend_resp.headers_mut().remove(&name);
164			}
165			Ok(backend_resp)
166		}
167		Err(e @ Error::Timeout) => {
168			warn!("Proxy backend timeout for {}", entry.domain);
169			Err(e)
170		}
171		Err(e) => {
172			warn!("Proxy backend error for {}: {}", entry.domain, e);
173			Err(e)
174		}
175	}
176}
177
178/// Handle a WebSocket proxy request via upgrade tunneling
179async fn handle_websocket_proxy(
180	entry: Arc<ProxySiteEntry>,
181	req: hyper::Request<Incoming>,
182	peer_addr: &str,
183) -> Result<hyper::Response<Incoming>, Error> {
184	// For WebSocket upgrade, we use hyper's low-level connection handling
185	// to establish a bidirectional tunnel
186	let backend_uri = build_backend_uri(&entry, req.uri())?;
187
188	let mut backend_headers = HeaderMap::new();
189	// Copy all headers including WebSocket-specific ones
190	for (name, value) in req.headers().iter() {
191		if is_hop_by_hop(name) && name != header::UPGRADE {
192			continue;
193		}
194		backend_headers.append(name.clone(), value.clone());
195	}
196
197	// Host header
198	let preserve_host = entry.config.preserve_host.unwrap_or(true);
199	if !preserve_host {
200		if let Some(host) = entry.backend_url.host_str() {
201			let host_val = if let Some(port) = entry.backend_url.port() {
202				format!("{}:{}", host, port)
203			} else {
204				host.to_string()
205			};
206			if let Ok(hv) = HeaderValue::from_str(&host_val) {
207				backend_headers.insert(header::HOST, hv);
208			}
209		}
210	}
211
212	// Add forwarding headers (always on for "basic" type)
213	let forward_headers = if entry.proxy_type.as_ref() == "basic" {
214		true
215	} else {
216		entry.config.forward_headers.unwrap_or(true)
217	};
218	if forward_headers {
219		if let Ok(hv) = HeaderValue::from_str(peer_addr) {
220			backend_headers.insert(HeaderName::from_static("x-forwarded-for"), hv);
221		}
222		backend_headers.insert(
223			HeaderName::from_static("x-forwarded-proto"),
224			HeaderValue::from_static("https"),
225		);
226	}
227
228	// Ensure Connection: Upgrade is present
229	backend_headers.insert(header::CONNECTION, HeaderValue::from_static("Upgrade"));
230
231	let method = req.method().clone();
232	let body = req.into_body();
233
234	let mut backend_req = hyper::Request::builder().method(method).uri(backend_uri);
235
236	if let Some(headers) = backend_req.headers_mut() {
237		*headers = backend_headers;
238	}
239
240	let backend_req = backend_req
241		.body(body)
242		.map_err(|e| Error::Internal(format!("failed to build ws backend request: {}", e)))?;
243
244	// Connect to backend
245	let connect_timeout =
246		Duration::from_secs(entry.config.connect_timeout_secs.unwrap_or(5) as u64);
247
248	let scheme = entry.backend_url.scheme();
249	match send_backend_request(scheme, connect_timeout, connect_timeout, backend_req).await {
250		Ok(backend_resp) => Ok(backend_resp),
251		Err(e @ Error::Timeout) => {
252			warn!("WebSocket proxy backend timeout for {}", entry.domain);
253			Err(e)
254		}
255		Err(e) => {
256			warn!("WebSocket proxy backend error for {}: {}", entry.domain, e);
257			Err(e)
258		}
259	}
260}
261
262/// Send a request to a backend, choosing HTTP or HTTPS connector based on scheme
263async fn send_backend_request(
264	scheme: &str,
265	connect_timeout: Duration,
266	timeout: Duration,
267	req: hyper::Request<Incoming>,
268) -> Result<hyper::Response<Incoming>, Error> {
269	let result = if scheme == "https" {
270		let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
271			.with_native_roots()
272			.map_err(|_| Error::ConfigError("no native root CA certificates found".into()))?
273			.https_only()
274			.enable_http1()
275			.build();
276		let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
277			.pool_idle_timeout(connect_timeout)
278			.build(https_connector);
279		tokio::time::timeout(timeout, client.request(req)).await
280	} else {
281		let http_connector = hyper_util::client::legacy::connect::HttpConnector::new();
282		let client: Client<_, Incoming> = Client::builder(TokioExecutor::new())
283			.pool_idle_timeout(connect_timeout)
284			.build(http_connector);
285		tokio::time::timeout(timeout, client.request(req)).await
286	};
287	match result {
288		Ok(Ok(resp)) => Ok(resp),
289		Ok(Err(_)) => Err(Error::NetworkError("bad gateway".into())),
290		Err(_) => Err(Error::Timeout),
291	}
292}
293
294#[cfg(test)]
295mod tests {
296	use super::*;
297
298	#[test]
299	fn test_is_hop_by_hop() {
300		assert!(is_hop_by_hop(&HeaderName::from_static("connection")));
301		assert!(is_hop_by_hop(&HeaderName::from_static("keep-alive")));
302		assert!(is_hop_by_hop(&HeaderName::from_static("transfer-encoding")));
303		assert!(!is_hop_by_hop(&HeaderName::from_static("content-type")));
304		assert!(!is_hop_by_hop(&HeaderName::from_static("host")));
305	}
306
307	#[test]
308	fn test_build_backend_uri() {
309		let entry = ProxySiteEntry {
310			site_id: 1,
311			domain: "test.example.com".into(),
312			proxy_type: "basic".into(),
313			backend_url: url::Url::parse("http://localhost:3000").unwrap(),
314			config: Default::default(),
315		};
316		let uri = "/api/test?foo=bar".parse::<Uri>().unwrap();
317		let result = build_backend_uri(&entry, &uri).unwrap();
318		assert_eq!(result.to_string(), "http://localhost:3000/api/test?foo=bar");
319	}
320
321	#[test]
322	fn test_build_backend_uri_root_path() {
323		let entry = ProxySiteEntry {
324			site_id: 1,
325			domain: "test.example.com".into(),
326			proxy_type: "basic".into(),
327			backend_url: url::Url::parse("http://localhost:3000").unwrap(),
328			config: Default::default(),
329		};
330		let uri = "/".parse::<Uri>().unwrap();
331		let result = build_backend_uri(&entry, &uri).unwrap();
332		assert_eq!(result.to_string(), "http://localhost:3000/");
333	}
334
335	#[test]
336	fn test_build_backend_uri_with_path_prefix() {
337		let entry = ProxySiteEntry {
338			site_id: 1,
339			domain: "test.example.com".into(),
340			proxy_type: "basic".into(),
341			backend_url: url::Url::parse("http://backend:3000/a/").unwrap(),
342			config: Default::default(),
343		};
344
345		// Root request should preserve the base path
346		let uri = "/".parse::<Uri>().unwrap();
347		let result = build_backend_uri(&entry, &uri).unwrap();
348		assert_eq!(result.to_string(), "http://backend:3000/a/");
349
350		// Subpath request should join with base path
351		let uri = "/foo".parse::<Uri>().unwrap();
352		let result = build_backend_uri(&entry, &uri).unwrap();
353		assert_eq!(result.to_string(), "http://backend:3000/a/foo");
354
355		// Subpath with query should work too
356		let uri = "/api/test?key=val".parse::<Uri>().unwrap();
357		let result = build_backend_uri(&entry, &uri).unwrap();
358		assert_eq!(result.to_string(), "http://backend:3000/a/api/test?key=val");
359	}
360}
361
362// vim: ts=4